🐐 Tests
This commit is contained in:
135
tests/_helpers.py
Normal file
135
tests/_helpers.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from parameterized import parameterized
|
||||
|
||||
from fractions import Fraction
|
||||
|
||||
from centvrion.ast_nodes import (
|
||||
ArrayIndex, ArraySlice, Bool, BinOp, BuiltIn, DataArray, DataDict, DataRangeArray,
|
||||
Defini, Continva, Designa, DesignaDestructure, DesignaIndex, DumStatement,
|
||||
Erumpe, ExpressionStatement, Fvnctio, ID, InterpolatedString, Invoca,
|
||||
ModuleCall, Nullus, Numeral, PerStatement, Program, Redi, SiStatement,
|
||||
String, TemptaStatement, UnaryMinus, UnaryNot, Fractio, frac_to_fraction,
|
||||
fraction_to_frac, num_to_int, int_to_num, make_string,
|
||||
_cent_rng,
|
||||
)
|
||||
from centvrion.compiler.emitter import compile_program
|
||||
from centvrion.errors import CentvrionError
|
||||
from centvrion.lexer import Lexer
|
||||
from centvrion.parser import Parser
|
||||
from centvrion.values import ValInt, ValStr, ValBool, ValList, ValDict, ValNul, ValFunc, ValFrac
|
||||
|
||||
_RUNTIME_C = os.path.join(
|
||||
os.path.dirname(__file__), "..",
|
||||
"centvrion", "compiler", "runtime", "cent_runtime.c"
|
||||
)
|
||||
|
||||
def run_test(self, source, target_nodes, target_value, target_output="", input_lines=[]):
|
||||
_cent_rng.seed(1)
|
||||
|
||||
lexer = Lexer().get_lexer()
|
||||
tokens = lexer.lex(source + "\n")
|
||||
program = Parser().parse(tokens)
|
||||
|
||||
##########################
|
||||
####### Parser Test ######
|
||||
##########################
|
||||
if target_nodes is not None:
|
||||
self.assertEqual(
|
||||
program,
|
||||
target_nodes,
|
||||
f"Parser test:\n{program}\n{target_nodes}"
|
||||
)
|
||||
|
||||
##########################
|
||||
#### Interpreter Test ####
|
||||
##########################
|
||||
captured = StringIO()
|
||||
try:
|
||||
if input_lines:
|
||||
inputs = iter(input_lines)
|
||||
with patch("builtins.input", lambda: next(inputs)), patch("sys.stdout", captured):
|
||||
result = program.eval()
|
||||
else:
|
||||
with patch("sys.stdout", captured):
|
||||
result = program.eval()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.assertEqual(result, target_value, "Return value test")
|
||||
self.assertEqual(captured.getvalue(), target_output, "Output test")
|
||||
|
||||
##########################
|
||||
###### Printer Test ######
|
||||
##########################
|
||||
try:
|
||||
new_text = program.print()
|
||||
new_tokens = Lexer().get_lexer().lex(new_text + "\n")
|
||||
new_nodes = Parser().parse(new_tokens)
|
||||
except Exception as e:
|
||||
raise Exception(f"###Printer test###\n{new_text}") from e
|
||||
self.assertEqual(
|
||||
program,
|
||||
new_nodes,
|
||||
f"Printer test\n{source}\n{new_text}"
|
||||
)
|
||||
|
||||
##########################
|
||||
###### Compiler Test #####
|
||||
##########################
|
||||
c_source = compile_program(program)
|
||||
# Force deterministic RNG seed=1 for test reproducibility
|
||||
c_source = c_source.replace("cent_init();", "cent_init(); cent_semen((CentValue){.type=CENT_INT, .ival=1});", 1)
|
||||
with tempfile.NamedTemporaryFile(suffix=".c", delete=False, mode="w") as tmp_c:
|
||||
tmp_c.write(c_source)
|
||||
tmp_c_path = tmp_c.name
|
||||
with tempfile.NamedTemporaryFile(suffix="", delete=False) as tmp_bin:
|
||||
tmp_bin_path = tmp_bin.name
|
||||
try:
|
||||
subprocess.run(
|
||||
["gcc", "-O2", tmp_c_path, _RUNTIME_C, "-o", tmp_bin_path, "-lcurl", "-lmicrohttpd"],
|
||||
check=True, capture_output=True,
|
||||
)
|
||||
stdin_data = "".join(f"{l}\n" for l in input_lines)
|
||||
proc = subprocess.run(
|
||||
[tmp_bin_path],
|
||||
input=stdin_data, capture_output=True, text=True,
|
||||
)
|
||||
self.assertEqual(proc.returncode, 0, f"Compiler binary exited non-zero:\n{proc.stderr}")
|
||||
self.assertEqual(proc.stdout, target_output, "Compiler output test")
|
||||
finally:
|
||||
os.unlink(tmp_c_path)
|
||||
os.unlink(tmp_bin_path)
|
||||
|
||||
assert target_nodes is not None, "All tests must have target nodes"
|
||||
|
||||
|
||||
def run_compiler_error_test(self, source):
|
||||
lexer = Lexer().get_lexer()
|
||||
tokens = lexer.lex(source + "\n")
|
||||
program = Parser().parse(tokens)
|
||||
try:
|
||||
c_source = compile_program(program)
|
||||
except CentvrionError:
|
||||
return # compile-time detection is valid
|
||||
with tempfile.NamedTemporaryFile(suffix=".c", delete=False, mode="w") as tmp_c:
|
||||
tmp_c.write(c_source)
|
||||
tmp_c_path = tmp_c.name
|
||||
with tempfile.NamedTemporaryFile(suffix="", delete=False) as tmp_bin:
|
||||
tmp_bin_path = tmp_bin.name
|
||||
try:
|
||||
subprocess.run(
|
||||
["gcc", "-O2", tmp_c_path, _RUNTIME_C, "-o", tmp_bin_path, "-lcurl", "-lmicrohttpd"],
|
||||
check=True, capture_output=True,
|
||||
)
|
||||
proc = subprocess.run([tmp_bin_path], capture_output=True, text=True)
|
||||
self.assertNotEqual(proc.returncode, 0, "Expected non-zero exit for error program")
|
||||
self.assertTrue(proc.stderr.strip(), "Expected error message on stderr")
|
||||
finally:
|
||||
os.unlink(tmp_c_path)
|
||||
os.unlink(tmp_bin_path)
|
||||
Reference in New Issue
Block a user