diff --git a/README.md b/README.md index 01b3a94..e80507f 100644 --- a/README.md +++ b/README.md @@ -244,6 +244,28 @@ Calling a function is done with the `INVOCA` keyword. > CXXI ``` +## First-class functions +Functions are first-class values in CENTVRION. They can be assigned to variables, passed as arguments, returned from functions, and stored in arrays or dicts. + +Anonymous functions are created with the `FVNCTIO` keyword: + +![FVNCTIO](snippets/fvnctio.png) + +``` +> XIV +``` + +`INVOCA` accepts any expression as the callee, not just a name: + +![INVOCA expressions](snippets/invoca_expr.png) +``` +> VI +> VI +> XVI +``` + +Note: CENTVRION does **not** have closures. When a function is called, it receives a copy of the *caller's* scope, not the scope where it was defined. Variables from a function's definition site are only available if they also exist in the caller's scope at call time. + ## Built-ins ### DICE `DICE(value, ...)` diff --git a/centvrion/ast_nodes.py b/centvrion/ast_nodes.py index d87ddcb..4c6df95 100644 --- a/centvrion/ast_nodes.py +++ b/centvrion/ast_nodes.py @@ -144,6 +144,8 @@ def make_string(val, magnvm=False, svbnvlla=False) -> str: for k, v in val.value().items() ) return "{" + inner + "}" + elif isinstance(val, ValFunc): + return "FVNCTIO" else: raise CentvrionError(f"Cannot display {val!r}") @@ -577,6 +579,31 @@ class Defini(Node): return vtable, ValNul() +class Fvnctio(Node): + def __init__(self, parameters: list[ID], statements: list[Node]) -> None: + self.parameters = parameters + self.statements = statements + + def __eq__(self, other): + return (type(self) == type(other) + and self.parameters == other.parameters + and self.statements == other.statements) + + def __repr__(self) -> str: + parameter_string = f"parameters([{rep_join(self.parameters)}])" + statements_string = f"statements([{rep_join(self.statements)}])" + fvn_string = rep_join([parameter_string, statements_string]) + return f"Fvnctio({fvn_string})" + + def print(self): + params = ", ".join(p.print() for p in self.parameters) + body = "\n".join(s.print() for s in self.statements) + return f"FVNCTIO ({params}) VT {{\n{body}\n}}" + + def _eval(self, vtable): + return vtable, ValFunc(self.parameters, self.statements) + + class Redi(Node): def __init__(self, values: list[Node]) -> None: self.values = values @@ -954,32 +981,36 @@ class PerStatement(Node): class Invoca(Node): - def __init__(self, name, parameters) -> None: - self.name = name + def __init__(self, callee, parameters) -> None: + self.callee = callee self.parameters = parameters def __eq__(self, other): - return type(self) == type(other) and self.name == other.name and self.parameters == other.parameters + return (type(self) == type(other) + and self.callee == other.callee + and self.parameters == other.parameters) def __repr__(self) -> str: parameters_string = f"parameters([{rep_join(self.parameters)}])" - invoca_string = rep_join([self.name, parameters_string]) + invoca_string = rep_join([self.callee, parameters_string]) return f"Invoca({invoca_string})" def print(self): args = ", ".join(p.print() for p in self.parameters) - return f"INVOCA {self.name.print()} ({args})" + return f"INVOCA {self.callee.print()} ({args})" def _eval(self, vtable): params = [p.eval(vtable)[1] for p in self.parameters] - if self.name.name not in vtable: - raise CentvrionError(f"Undefined function: {self.name.name}") - func = vtable[self.name.name] + vtable, func = self.callee.eval(vtable) if not isinstance(func, ValFunc): - raise CentvrionError(f"{self.name.name} is not a function") + callee_desc = (self.callee.name + if isinstance(self.callee, ID) else "expression") + raise CentvrionError(f"{callee_desc} is not a function") if len(params) != len(func.params): + callee_desc = (self.callee.name + if isinstance(self.callee, ID) else "FVNCTIO") raise CentvrionError( - f"{self.name.name} expects {len(func.params)} argument(s), got {len(params)}" + f"{callee_desc} expects {len(func.params)} argument(s), got {len(params)}" ) func_vtable = vtable.copy() for i, param in enumerate(func.params): diff --git a/centvrion/compiler/context.py b/centvrion/compiler/context.py index fb26b27..e900e54 100644 --- a/centvrion/compiler/context.py +++ b/centvrion/compiler/context.py @@ -7,6 +7,10 @@ class EmitContext: self.functions = {} # source-level name / alias → c_func_name; populated by emitter pre-pass self.func_resolve = {} + # id(Fvnctio_node) → c_func_name; populated by lambda lifting pass + self.lambda_names = {} + # [(c_name, Fvnctio_node), ...]; populated by lambda lifting pass + self.lambdas = [] def fresh_tmp(self): name = f"_t{self._tmp_counter}" diff --git a/centvrion/compiler/emit_expr.py b/centvrion/compiler/emit_expr.py index 24d0219..1d635d9 100644 --- a/centvrion/compiler/emit_expr.py +++ b/centvrion/compiler/emit_expr.py @@ -3,7 +3,7 @@ from centvrion.ast_nodes import ( String, InterpolatedString, Numeral, Fractio, Bool, Nullus, ID, BinOp, UnaryMinus, UnaryNot, ArrayIndex, DataArray, DataRangeArray, DataDict, - BuiltIn, Invoca, + BuiltIn, Invoca, Fvnctio, num_to_int, frac_to_fraction, ) @@ -162,6 +162,9 @@ def emit_expr(node, ctx): if isinstance(node, Invoca): return _emit_invoca(node, ctx) + if isinstance(node, Fvnctio): + return _emit_fvnctio(node, ctx) + raise NotImplementedError(type(node).__name__) @@ -261,7 +264,8 @@ def _emit_builtin(node, ctx): def _emit_invoca(node, ctx): """ Emits a user-defined function call. - Requires ctx.functions[name] = [param_names] populated by the emitter pre-pass. + Supports both static resolution (ID callee with known function) and + dynamic dispatch (arbitrary expression callee via CENT_FUNC values). """ lines = [] param_vars = [] @@ -270,21 +274,59 @@ def _emit_invoca(node, ctx): lines.extend(p_lines) param_vars.append(p_var) - func_name = node.name.name - c_func_name = ctx.func_resolve.get(func_name) - if c_func_name is None: - raise CentvrionError(f"Undefined function: {func_name}") + # Try static resolution for simple ID callees + if isinstance(node.callee, ID): + c_func_name = ctx.func_resolve.get(node.callee.name) + if c_func_name is not None: + call_scope_var = ctx.fresh_tmp() + "_sc" + lines.append(f"CentScope {call_scope_var} = cent_scope_copy(&_scope);") + param_names = ctx.functions[c_func_name] + if len(param_vars) != len(param_names): + raise CentvrionError( + f"Function '{node.callee.name}' expects {len(param_names)} argument(s), " + f"got {len(param_vars)}" + ) + for i, pname in enumerate(param_names): + lines.append(f'cent_scope_set(&{call_scope_var}, "{pname}", {param_vars[i]});') + tmp = ctx.fresh_tmp() + lines.append(f"CentValue {tmp} = {c_func_name}({call_scope_var});") + return lines, tmp + + # Dynamic dispatch: evaluate callee, call via function pointer + callee_lines, callee_var = emit_expr(node.callee, ctx) + lines.extend(callee_lines) + lines.append(f'if ({callee_var}.type != CENT_FUNC) cent_type_error("cannot call non-function");') call_scope_var = ctx.fresh_tmp() + "_sc" lines.append(f"CentScope {call_scope_var} = cent_scope_copy(&_scope);") - - param_names = ctx.functions[c_func_name] - if len(param_vars) != len(param_names): - raise CentvrionError( - f"Function '{func_name}' expects {len(param_names)} argument(s), got {len(param_vars)}" + nargs = len(param_vars) + lines.append( + f"if ({callee_var}.fnval.param_count != {nargs}) " + f'cent_runtime_error("wrong number of arguments");' + ) + for i, pv in enumerate(param_vars): + lines.append( + f'cent_scope_set(&{call_scope_var}, ' + f'{callee_var}.fnval.param_names[{i}], {pv});' ) - for i, pname in enumerate(param_names): - lines.append(f'cent_scope_set(&{call_scope_var}, "{pname}", {param_vars[i]});') - tmp = ctx.fresh_tmp() - lines.append(f"CentValue {tmp} = {c_func_name}({call_scope_var});") + lines.append(f"CentValue {tmp} = {callee_var}.fnval.fn({call_scope_var});") + return lines, tmp + + +def _emit_fvnctio(node, ctx): + """Emit a FVNCTIO lambda expression as a CENT_FUNC value.""" + c_name = ctx.lambda_names[id(node)] + param_names = ctx.functions[c_name] + tmp = ctx.fresh_tmp() + lines = [] + # Build static param name array + params_arr = ctx.fresh_tmp() + "_pn" + lines.append( + f"static const char *{params_arr}[] = {{" + + ", ".join(f'"{p}"' for p in param_names) + + "};" + ) + lines.append( + f"CentValue {tmp} = cent_func_val({c_name}, {params_arr}, {len(param_names)});" + ) return lines, tmp diff --git a/centvrion/compiler/emit_stmt.py b/centvrion/compiler/emit_stmt.py index 877f65e..9f71a45 100644 --- a/centvrion/compiler/emit_stmt.py +++ b/centvrion/compiler/emit_stmt.py @@ -11,9 +11,6 @@ def emit_stmt(node, ctx): Returns lines — list of C statements. """ if isinstance(node, Designa): - # Function alias: resolved at compile time, no runtime code needed - if isinstance(node.value, ID) and node.value.name in ctx.func_resolve: - return [] val_lines, val_var = emit_expr(node.value, ctx) return val_lines + [f'cent_scope_set(&_scope, "{node.id.name}", {val_var});'] @@ -87,7 +84,20 @@ def emit_stmt(node, ctx): return lines if isinstance(node, Defini): - # Function definitions are hoisted by emitter.py; no-op here. + # Top-level definitions are handled by emitter.py (hoisted + scope-set). + # Nested definitions (inside another function) need runtime scope-set. + if ctx.current_function is not None: + name = node.name.name + c_name = ctx.func_resolve[name] + param_names = ctx.functions[c_name] + pn_var = ctx.fresh_tmp() + "_pn" + return [ + f"static const char *{pn_var}[] = {{" + + ", ".join(f'"{p}"' for p in param_names) + + "};", + f'cent_scope_set(&_scope, "{name}", ' + f"cent_func_val({c_name}, {pn_var}, {len(param_names)}));", + ] return [] if isinstance(node, Redi): diff --git a/centvrion/compiler/emitter.py b/centvrion/compiler/emitter.py index 72df9b9..5ce1bcd 100644 --- a/centvrion/compiler/emitter.py +++ b/centvrion/compiler/emitter.py @@ -1,11 +1,32 @@ import os -from centvrion.ast_nodes import Defini, Designa, ID +from centvrion.ast_nodes import Defini, Designa, Fvnctio, ID, Node from centvrion.compiler.context import EmitContext from centvrion.compiler.emit_stmt import emit_stmt, _emit_body _RUNTIME_DIR = os.path.join(os.path.dirname(__file__), "runtime") +def _collect_lambdas(node, ctx, counter): + """Walk AST recursively, find all Fvnctio nodes, assign C names.""" + if isinstance(node, Fvnctio): + c_name = f"_cent_lambda_{counter[0]}" + counter[0] += 1 + ctx.lambda_names[id(node)] = c_name + ctx.functions[c_name] = [p.name for p in node.parameters] + ctx.lambdas.append((c_name, node)) + for attr in vars(node).values(): + if isinstance(attr, Node): + _collect_lambdas(attr, ctx, counter) + elif isinstance(attr, list): + for item in attr: + if isinstance(item, Node): + _collect_lambdas(item, ctx, counter) + elif isinstance(item, tuple): + for elem in item: + if isinstance(elem, Node): + _collect_lambdas(elem, ctx, counter) + + def compile_program(program): """Return a complete C source string for the given Program AST node.""" ctx = EmitContext() @@ -26,10 +47,11 @@ def compile_program(program): ctx.functions[c_name] = [p.name for p in stmt.parameters] ctx.func_resolve[name] = c_name func_definitions.append((c_name, stmt)) - elif isinstance(stmt, Designa) and isinstance(stmt.value, ID): - rhs = stmt.value.name - if rhs in ctx.func_resolve: - ctx.func_resolve[stmt.id.name] = ctx.func_resolve[rhs] + + # Lambda lifting: find all Fvnctio nodes in the entire AST + counter = [0] + for stmt in program.statements: + _collect_lambdas(stmt, ctx, counter) lines = [] @@ -39,13 +61,13 @@ def compile_program(program): "", ] - # Forward declarations + # Forward declarations (named functions + lambdas) for c_name in ctx.functions: lines.append(f"CentValue {c_name}(CentScope _scope);") if ctx.functions: lines.append("") - # Hoisted function definitions + # Hoisted named function definitions for c_name, stmt in func_definitions: ctx.current_function = c_name lines.append(f"CentValue {c_name}(CentScope _scope) {{") @@ -55,6 +77,16 @@ def compile_program(program): lines += ["_func_return:", " return _return_val;", "}", ""] ctx.current_function = None + # Hoisted lambda definitions + for c_name, fvnctio_node in ctx.lambdas: + ctx.current_function = c_name + lines.append(f"CentValue {c_name}(CentScope _scope) {{") + lines.append(" CentValue _return_val = cent_null();") + for l in _emit_body(fvnctio_node.statements, ctx): + lines.append(f" {l}") + lines += ["_func_return:", " return _return_val;", "}", ""] + ctx.current_function = None + # main() lines.append("int main(void) {") lines.append(" cent_init();") @@ -62,8 +94,25 @@ def compile_program(program): lines.append(" cent_magnvm = 1;") lines.append(" CentScope _scope = {0};") lines.append(" CentValue _return_val = cent_null();") + + # Build a map from id(Defini_node) → c_name for scope registration + defini_c_names = {id(stmt): c_name for c_name, stmt in func_definitions} + for stmt in program.statements: if isinstance(stmt, Defini): + name = stmt.name.name + c_name = defini_c_names[id(stmt)] + param_names = ctx.functions[c_name] + pn_var = f"_pn_{c_name}" + lines.append( + f" static const char *{pn_var}[] = {{" + + ", ".join(f'"{p}"' for p in param_names) + + "};" + ) + lines.append( + f' cent_scope_set(&_scope, "{name}", ' + f"cent_func_val({c_name}, {pn_var}, {len(param_names)}));" + ) continue for l in emit_stmt(stmt, ctx): lines.append(f" {l}") diff --git a/centvrion/compiler/runtime/cent_runtime.c b/centvrion/compiler/runtime/cent_runtime.c index ebbb93b..c7099a4 100644 --- a/centvrion/compiler/runtime/cent_runtime.c +++ b/centvrion/compiler/runtime/cent_runtime.c @@ -290,6 +290,10 @@ static int write_val(CentValue v, char *buf, int bufsz) { return total; } + case CENT_FUNC: + if (buf && bufsz > 7) { memcpy(buf, "FVNCTIO", 7); buf[7] = '\0'; } + return 7; + case CENT_DICT: { /* "{key VT val, key VT val}" */ int total = 2; /* '{' + '}' */ @@ -454,6 +458,7 @@ CentValue cent_eq(CentValue a, CentValue b) { switch (a.type) { case CENT_STR: return cent_bool(strcmp(a.sval, b.sval) == 0); case CENT_BOOL: return cent_bool(a.bval == b.bval); + case CENT_FUNC: return cent_bool(a.fnval.fn == b.fnval.fn); case CENT_NULL: return cent_bool(1); default: cent_type_error("'EST' not supported for this type"); diff --git a/centvrion/compiler/runtime/cent_runtime.h b/centvrion/compiler/runtime/cent_runtime.h index b20eee2..673a9c8 100644 --- a/centvrion/compiler/runtime/cent_runtime.h +++ b/centvrion/compiler/runtime/cent_runtime.h @@ -15,12 +15,23 @@ typedef enum { CENT_LIST, CENT_FRAC, CENT_DICT, + CENT_FUNC, CENT_NULL } CentType; typedef struct CentValue CentValue; typedef struct CentList CentList; typedef struct CentDict CentDict; +struct CentScope; /* forward declaration */ + +/* First-class function value */ +typedef CentValue (*CentFuncPtr)(struct CentScope); + +typedef struct { + CentFuncPtr fn; + const char **param_names; + int param_count; +} CentFuncInfo; /* Duodecimal fraction: num/den stored as exact integers */ typedef struct { @@ -48,14 +59,15 @@ struct CentValue { char *sval; /* CENT_STR */ int bval; /* CENT_BOOL */ CentList lval; /* CENT_LIST */ - CentFrac fval; /* CENT_FRAC */ - CentDict dval; /* CENT_DICT */ + CentFrac fval; /* CENT_FRAC */ + CentDict dval; /* CENT_DICT */ + CentFuncInfo fnval; /* CENT_FUNC */ }; }; /* Scope: flat name→value array. Stack-allocated by the caller; cent_scope_set uses cent_arena when it needs to grow. */ -typedef struct { +typedef struct CentScope { const char **names; CentValue *vals; int len; @@ -111,6 +123,14 @@ static inline CentValue cent_list(CentValue *items, int len, int cap) { r.lval.cap = cap; return r; } +static inline CentValue cent_func_val(CentFuncPtr fn, const char **param_names, int param_count) { + CentValue r; + r.type = CENT_FUNC; + r.fnval.fn = fn; + r.fnval.param_names = param_names; + r.fnval.param_count = param_count; + return r; +} static inline CentValue cent_dict_val(CentValue *keys, CentValue *vals, int len, int cap) { CentValue r; r.type = CENT_DICT; diff --git a/centvrion/lexer.py b/centvrion/lexer.py index bc85dcc..6a46c80 100644 --- a/centvrion/lexer.py +++ b/centvrion/lexer.py @@ -18,6 +18,7 @@ keyword_tokens = [("KEYWORD_"+i, i) for i in [ "ET", "FACE", "FALSITAS", + "FVNCTIO", "INVOCA", "IN", "MINVE", diff --git a/centvrion/parser.py b/centvrion/parser.py index ea6e04f..3c207c5 100644 --- a/centvrion/parser.py +++ b/centvrion/parser.py @@ -287,10 +287,14 @@ class Parser(): def unary_not(tokens): return ast_nodes.UnaryNot(tokens[1]) - @self.pg.production('expression : KEYWORD_INVOCA id expressions') + @self.pg.production('expression : KEYWORD_INVOCA expression expressions') def invoca(tokens): return ast_nodes.Invoca(tokens[1], tokens[2]) + @self.pg.production('expression : KEYWORD_FVNCTIO ids KEYWORD_VT SYMBOL_LCURL statements SYMBOL_RCURL') + def fvnctio(tokens): + return ast_nodes.Fvnctio(tokens[1], tokens[4]) + @self.pg.production('expression : SYMBOL_LPARENS expression SYMBOL_RPARENS') def parens(tokens): return tokens[1] diff --git a/language/main.tex b/language/main.tex index 99d4502..415083f 100644 --- a/language/main.tex +++ b/language/main.tex @@ -55,7 +55,8 @@ \languageline{expression}{\texttt{(} \textit{expression} \texttt{)}} \\ \languageline{expression}{\textbf{id}} \\ \languageline{expression}{\textbf{builtin} \texttt{(} \textit{optional-expressions} \texttt{)}} \\ - \languageline{expression}{\texttt{INVOCA} \textbf{id} \texttt{(} \textit{optional-expressions} \texttt{)}} \\ + \languageline{expression}{\texttt{INVOCA} \textit{expression} \texttt{(} \textit{optional-expressions} \texttt{)}} \\ + \languageline{expression}{\texttt{FVNCTIO} \texttt{(} \textit{optional-ids} \texttt{)} \texttt{VT} \textit{scope}} \\ \languageline{expression}{\textit{literal}} \\ \languageline{expression}{\textit{expression} \texttt{[} \textit{expression} \texttt{]}} \\ \languageline{expression}{\textit{expression} \textbf{binop} \textit{expression}} \\ diff --git a/snippets/fvnctio.cent b/snippets/fvnctio.cent new file mode 100644 index 0000000..917099e --- /dev/null +++ b/snippets/fvnctio.cent @@ -0,0 +1,9 @@ +DEFINI apply (f, x) VT { + REDI (INVOCA f (x)) +} + +DESIGNA dbl VT FVNCTIO (n) VT { + REDI (n * II) +} + +DICE(INVOCA apply (dbl, VII)) diff --git a/snippets/fvnctio.png b/snippets/fvnctio.png new file mode 100644 index 0000000..f3f2774 Binary files /dev/null and b/snippets/fvnctio.png differ diff --git a/snippets/invoca_expr.cent b/snippets/invoca_expr.cent new file mode 100644 index 0000000..fa0952c --- /dev/null +++ b/snippets/invoca_expr.cent @@ -0,0 +1,11 @@ +// Immediately invoked +DICE(INVOCA FVNCTIO (x) VT { REDI (x + I) } (V)) + +// From an array +DESIGNA fns VT [FVNCTIO (x) VT { REDI (x + I) }] +DICE(INVOCA fns[I] (V)) + +// Passing a named function as an argument +DEFINI apply (f, x) VT { REDI (INVOCA f (x)) } +DEFINI sqr (x) VT { REDI (x * x) } +DICE(INVOCA apply (sqr, IV)) diff --git a/snippets/invoca_expr.png b/snippets/invoca_expr.png new file mode 100644 index 0000000..f10f3ee Binary files /dev/null and b/snippets/invoca_expr.png differ diff --git a/snippets/syntaxes/centvrion.sublime-syntax b/snippets/syntaxes/centvrion.sublime-syntax index 0390bcf..c794c8b 100644 --- a/snippets/syntaxes/centvrion.sublime-syntax +++ b/snippets/syntaxes/centvrion.sublime-syntax @@ -78,7 +78,7 @@ contexts: scope: support.class.module.centvrion keywords: - - match: '\b(AETERNVM|ALVID|AVGE|AVT|CONTINVA|DEFINI|DESIGNA|DISPAR|DONICVM|DVM|ERVMPE|EST|ET|FACE|INVOCA|IN|MINVE|MINVS|NON|PER|PLVS|REDI|RELIQVVM|SI|TABVLA|TVNC|VSQVE|VT|CVM)\b' + - match: '\b(AETERNVM|ALVID|AVGE|AVT|CONTINVA|DEFINI|DESIGNA|DISPAR|DONICVM|DVM|ERVMPE|EST|ET|FACE|FVNCTIO|INVOCA|IN|MINVE|MINVS|NON|PER|PLVS|REDI|RELIQVVM|SI|TABVLA|TVNC|VSQVE|VT|CVM)\b' scope: keyword.control.centvrion operators: diff --git a/tests.py b/tests.py index dfd2a5d..a064a17 100644 --- a/tests.py +++ b/tests.py @@ -12,9 +12,9 @@ from fractions import Fraction from centvrion.ast_nodes import ( ArrayIndex, Bool, BinOp, BuiltIn, DataArray, DataDict, DataRangeArray, Defini, Continva, Designa, DesignaDestructure, DesignaIndex, DumStatement, - Erumpe, ExpressionStatement, ID, InterpolatedString, Invoca, ModuleCall, - Nullus, Numeral, PerStatement, Program, Redi, SiStatement, String, - UnaryMinus, UnaryNot, Fractio, frac_to_fraction, fraction_to_frac, + Erumpe, ExpressionStatement, Fvnctio, ID, InterpolatedString, Invoca, + ModuleCall, Nullus, Numeral, PerStatement, Program, Redi, SiStatement, + String, UnaryMinus, UnaryNot, Fractio, frac_to_fraction, fraction_to_frac, num_to_int, int_to_num, make_string, ) from centvrion.compiler.emitter import compile_program @@ -2047,5 +2047,157 @@ class TestDictDisplay(unittest.TestCase): run_test(self, source, nodes, value, output) +# --- First-class functions / FVNCTIO --- + +fvnctio_tests = [ + # Lambda assigned to variable, then called + ( + "DESIGNA f VT FVNCTIO (x) VT { REDI (x + I) }\nINVOCA f (V)", + Program([], [ + Designa(ID("f"), Fvnctio([ID("x")], [Redi([BinOp(ID("x"), Numeral("I"), "SYMBOL_PLUS")])])), + ExpressionStatement(Invoca(ID("f"), [Numeral("V")])), + ]), + ValInt(6), + ), + # IIFE: immediately invoked lambda + ( + "INVOCA FVNCTIO (x) VT { REDI (x * II) } (III)", + Program([], [ + ExpressionStatement(Invoca( + Fvnctio([ID("x")], [Redi([BinOp(ID("x"), Numeral("II"), "SYMBOL_TIMES")])]), + [Numeral("III")], + )), + ]), + ValInt(6), + ), + # Zero-arg lambda + ( + "INVOCA FVNCTIO () VT { REDI (XLII) } ()", + Program([], [ + ExpressionStatement(Invoca( + Fvnctio([], [Redi([Numeral("XLII")])]), + [], + )), + ]), + ValInt(42), + ), + # Function passed as argument + ( + "DEFINI apply (f, x) VT { REDI (INVOCA f (x)) }\n" + "DESIGNA dbl VT FVNCTIO (n) VT { REDI (n * II) }\n" + "INVOCA apply (dbl, V)", + Program([], [ + Defini(ID("apply"), [ID("f"), ID("x")], [ + Redi([Invoca(ID("f"), [ID("x")])]) + ]), + Designa(ID("dbl"), Fvnctio([ID("n")], [ + Redi([BinOp(ID("n"), Numeral("II"), "SYMBOL_TIMES")]) + ])), + ExpressionStatement(Invoca(ID("apply"), [ID("dbl"), Numeral("V")])), + ]), + ValInt(10), + ), + # Lambda uses caller-scope variable (copy-caller semantics) + ( + "DESIGNA n VT III\n" + "DESIGNA f VT FVNCTIO (x) VT { REDI (x + n) }\n" + "INVOCA f (V)", + Program([], [ + Designa(ID("n"), Numeral("III")), + Designa(ID("f"), Fvnctio([ID("x")], [ + Redi([BinOp(ID("x"), ID("n"), "SYMBOL_PLUS")]) + ])), + ExpressionStatement(Invoca(ID("f"), [Numeral("V")])), + ]), + ValInt(8), + ), + # Named function passed as value + ( + "DEFINI sqr (x) VT { REDI (x * x) }\n" + "DESIGNA f VT sqr\n" + "INVOCA f (IV)", + Program([], [ + Defini(ID("sqr"), [ID("x")], [Redi([BinOp(ID("x"), ID("x"), "SYMBOL_TIMES")])]), + Designa(ID("f"), ID("sqr")), + ExpressionStatement(Invoca(ID("f"), [Numeral("IV")])), + ]), + ValInt(16), + ), + # Nested lambdas + ( + "INVOCA FVNCTIO (x) VT { REDI (INVOCA FVNCTIO (y) VT { REDI (y + I) } (x)) } (V)", + Program([], [ + ExpressionStatement(Invoca( + Fvnctio([ID("x")], [ + Redi([Invoca( + Fvnctio([ID("y")], [Redi([BinOp(ID("y"), Numeral("I"), "SYMBOL_PLUS")])]), + [ID("x")], + )]) + ]), + [Numeral("V")], + )), + ]), + ValInt(6), + ), + # DICE on a function value + ( + "DESIGNA f VT FVNCTIO (x) VT { REDI (x) }\nDICE(f)", + Program([], [ + Designa(ID("f"), Fvnctio([ID("x")], [Redi([ID("x")])])), + ExpressionStatement(BuiltIn("DICE", [ID("f")])), + ]), + ValStr("FVNCTIO"), + "FVNCTIO\n", + ), + # Lambda stored in array, called via index + ( + "DESIGNA fns VT [FVNCTIO (x) VT { REDI (x + I) }, FVNCTIO (x) VT { REDI (x * II) }]\n" + "INVOCA fns[I] (V)", + Program([], [ + Designa(ID("fns"), DataArray([ + Fvnctio([ID("x")], [Redi([BinOp(ID("x"), Numeral("I"), "SYMBOL_PLUS")])]), + Fvnctio([ID("x")], [Redi([BinOp(ID("x"), Numeral("II"), "SYMBOL_TIMES")])]), + ])), + ExpressionStatement(Invoca( + ArrayIndex(ID("fns"), Numeral("I")), + [Numeral("V")], + )), + ]), + ValInt(6), + ), + # Lambda stored in dict, called via key + ( + 'DESIGNA d VT TABVLA {"add" VT FVNCTIO (x) VT { REDI (x + I) }}\n' + 'INVOCA d["add"] (V)', + Program([], [ + Designa(ID("d"), DataDict([ + (String("add"), Fvnctio([ID("x")], [Redi([BinOp(ID("x"), Numeral("I"), "SYMBOL_PLUS")])])), + ])), + ExpressionStatement(Invoca( + ArrayIndex(ID("d"), String("add")), + [Numeral("V")], + )), + ]), + ValInt(6), + ), + # Multi-param lambda + ( + "DESIGNA add VT FVNCTIO (a, b) VT { REDI (a + b) }\nINVOCA add (III, IV)", + Program([], [ + Designa(ID("add"), Fvnctio([ID("a"), ID("b")], [ + Redi([BinOp(ID("a"), ID("b"), "SYMBOL_PLUS")]) + ])), + ExpressionStatement(Invoca(ID("add"), [Numeral("III"), Numeral("IV")])), + ]), + ValInt(7), + ), +] + +class TestFvnctio(unittest.TestCase): + @parameterized.expand(fvnctio_tests) + def test_fvnctio(self, source, nodes, value, output=""): + run_test(self, source, nodes, value, output) + + if __name__ == "__main__": unittest.main() diff --git a/vscode-extension/syntaxes/cent.tmLanguage.json b/vscode-extension/syntaxes/cent.tmLanguage.json index 7f279ae..0b3a706 100644 --- a/vscode-extension/syntaxes/cent.tmLanguage.json +++ b/vscode-extension/syntaxes/cent.tmLanguage.json @@ -45,7 +45,7 @@ "patterns": [ { "name": "keyword.control.cent", - "match": "\\b(AETERNVM|ALVID|AVT|CONTINVA|CVM|DEFINI|DESIGNA|DONICVM|DVM|ERVMPE|ET|FACE|IN|INVOCA|NON|PER|REDI|SI|TVNC|VSQVE|VT)\\b" + "match": "\\b(AETERNVM|ALVID|AVT|CONTINVA|CVM|DEFINI|DESIGNA|DONICVM|DVM|ERVMPE|ET|FACE|FVNCTIO|IN|INVOCA|NON|PER|REDI|SI|TVNC|VSQVE|VT)\\b" }, { "name": "keyword.operator.comparison.cent",