From 6aafab47a24f426f0ef83740b6e9587035461897 Mon Sep 17 00:00:00 2001 From: NikolajDanger Date: Wed, 1 Apr 2026 15:56:27 +0200 Subject: [PATCH] :goat: Conditions have to be bools --- centvrion/ast_nodes.py | 6 ++++++ tests.py | 38 ++++++++++++-------------------------- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/centvrion/ast_nodes.py b/centvrion/ast_nodes.py index fda4716..a732b14 100644 --- a/centvrion/ast_nodes.py +++ b/centvrion/ast_nodes.py @@ -679,6 +679,8 @@ class SiStatement(Node): def _eval(self, vtable): vtable, cond = self.test.eval(vtable) + if not isinstance(cond, ValBool): + raise CentvrionError("SI condition must be a boolean") last_val = ValNul() if cond: for statement in self.statements: @@ -714,6 +716,8 @@ class DumStatement(Node): def _eval(self, vtable): last_val = ValNul() vtable, cond = self.test.eval(vtable) + if not isinstance(cond, ValBool): + raise CentvrionError("DVM condition must be a boolean") while not cond: for statement in self.statements: vtable, val = statement.eval(vtable) @@ -726,6 +730,8 @@ class DumStatement(Node): if vtable["#return"] is not None: break vtable, cond = self.test.eval(vtable) + if not isinstance(cond, ValBool): + raise CentvrionError("DVM condition must be a boolean") return vtable, last_val diff --git a/tests.py b/tests.py index c8f9c98..4e75d45 100644 --- a/tests.py +++ b/tests.py @@ -401,6 +401,11 @@ error_tests = [ ("PER i IN I FACE { DICE(i) }", CentvrionError), # PER over non-array ("LONGITVDO(I)", CentvrionError), # LONGITVDO on non-array ("DESIGNA x VT I\nINVOCA x ()", CentvrionError), # invoking a non-function + ("SI I TVNC { DESIGNA r VT I }", CentvrionError), # non-bool SI condition: int + ("DESIGNA z VT I - I\nSI z TVNC { DESIGNA r VT I }", CentvrionError), # non-bool SI condition: zero int + ("SI [I] TVNC { DESIGNA r VT I }", CentvrionError), # non-bool SI condition: non-empty list + ("SI [] TVNC { DESIGNA r VT I }", CentvrionError), # non-bool SI condition: empty list + ("DESIGNA x VT I\nDVM x FACE {\nDESIGNA x VT x + I\n}", CentvrionError), # non-bool DVM condition: int ] class TestErrors(unittest.TestCase): @@ -557,30 +562,10 @@ class TestDiceTypes(unittest.TestCase): run_test(self, source, nodes, value, output) -# --- SI/DVM: truthiness of non-bool conditions --- +# --- SI/DVM: boolean condition enforcement --- -truthiness_tests = [ - # nonzero int is truthy - ("SI I TVNC { DESIGNA r VT I } ALVID { DESIGNA r VT II }\nr", - Program([], [SiStatement(Numeral("I"), [Designa(ID("r"), Numeral("I"))], [Designa(ID("r"), Numeral("II"))]), ExpressionStatement(ID("r"))]), - ValInt(1)), - # zero int is falsy - ("DESIGNA z VT I - I\nSI z TVNC { DESIGNA r VT I } ALVID { DESIGNA r VT II }\nr", - Program([], [ - Designa(ID("z"), BinOp(Numeral("I"), Numeral("I"), "SYMBOL_MINUS")), - SiStatement(ID("z"), [Designa(ID("r"), Numeral("I"))], [Designa(ID("r"), Numeral("II"))]), - ExpressionStatement(ID("r")), - ]), - ValInt(2)), - # non-empty list is truthy - ("SI [I] TVNC { DESIGNA r VT I } ALVID { DESIGNA r VT II }\nr", - Program([], [SiStatement(DataArray([Numeral("I")]), [Designa(ID("r"), Numeral("I"))], [Designa(ID("r"), Numeral("II"))]), ExpressionStatement(ID("r"))]), - ValInt(1)), - # empty list is falsy - ("SI [] TVNC { DESIGNA r VT II } ALVID { DESIGNA r VT I }\nr", - Program([], [SiStatement(DataArray([]), [Designa(ID("r"), Numeral("II"))], [Designa(ID("r"), Numeral("I"))]), ExpressionStatement(ID("r"))]), - ValInt(1)), - # DVM exits when condition becomes truthy +dvm_bool_condition_tests = [ + # DVM exits when condition becomes true (boolean comparison) ( "DESIGNA x VT I\nDVM x PLVS III FACE {\nDESIGNA x VT x + I\n}\nx", Program([], [ @@ -592,12 +577,13 @@ truthiness_tests = [ ), ] -class TestTruthiness(unittest.TestCase): - @parameterized.expand(truthiness_tests) - def test_truthiness(self, source, nodes, value): +class TestDvmBoolCondition(unittest.TestCase): + @parameterized.expand(dvm_bool_condition_tests) + def test_dvm_bool_condition(self, source, nodes, value): run_test(self, source, nodes, value) + # --- Arithmetic: edge cases --- arithmetic_edge_tests = [