diff --git a/src/eigenscript/compiler/analysis/observer.py b/src/eigenscript/compiler/analysis/observer.py index dfbb8bb..223ec3c 100644 --- a/src/eigenscript/compiler/analysis/observer.py +++ b/src/eigenscript/compiler/analysis/observer.py @@ -10,7 +10,7 @@ This implements zero-cost abstraction: pay for geometric semantics only when used. """ -from typing import Set +from typing import Optional, Set from eigenscript.parser.ast_builder import ( ASTNode, Identifier, @@ -40,10 +40,19 @@ class ObserverAnalyzer: Unobserved variables can be compiled to raw doubles for maximum performance. """ + PREDICATES = { + "converged", + "diverging", + "oscillating", + "stable", + "improving", + } + def __init__(self): self.observed: Set[str] = set() self.user_functions: Set[str] = set() self.current_function: str = None + self.last_assigned: Optional[str] = None def analyze(self, ast_nodes: list[ASTNode]) -> Set[str]: """Analyze AST and return set of variable names that need EigenValue tracking. @@ -58,6 +67,7 @@ def analyze(self, ast_nodes: list[ASTNode]) -> Set[str]: self.observed = set() self.user_functions = set() self.current_function = None + self.last_assigned = None # First pass: collect all user-defined function names for node in ast_nodes: @@ -82,7 +92,9 @@ def _visit(self, node: ASTNode): elif isinstance(node, FunctionDef): # Function parameters are always observed (might be interrogated inside) prev_function = self.current_function + prev_last_assigned = self.last_assigned self.current_function = node.name + self.last_assigned = None # In EigenScript, functions implicitly have parameter 'n' self.observed.add("n") @@ -91,9 +103,11 @@ def _visit(self, node: ASTNode): self._visit(stmt) self.current_function = prev_function + self.last_assigned = prev_last_assigned elif isinstance(node, Assignment): self._visit(node.expression) + self.last_assigned = node.identifier elif isinstance(node, Interrogative): # Direct observation: "why is x" marks x as observed @@ -152,32 +166,19 @@ def _visit(self, node: ASTNode): self._visit(node.list_expr) self._visit(node.index_expr) - elif isinstance(node, Identifier): - # Check if this identifier is a predicate - if node.name in [ - "converged", - "diverging", - "oscillating", - "stable", - "improving", - ]: - # Predicates require the last variable to be observed - # This is a simplified heuristic - ideally we'd track scope - pass - def _check_for_predicates(self, node: ASTNode): """Check if condition uses predicates (converged, diverging, etc.).""" + if node is None: + return + if isinstance(node, Identifier): - if node.name in [ - "converged", - "diverging", - "oscillating", - "stable", - "improving", - ]: - # TODO: Mark the variable being tested as observed - # For now, this is handled by the codegen heuristic of "last variable" - pass + if node.name in self.PREDICATES and self.last_assigned: + self.observed.add(self.last_assigned) + elif isinstance(node, UnaryOp): + self._check_for_predicates(node.operand) + elif isinstance(node, BinaryOp): + self._check_for_predicates(node.left) + self._check_for_predicates(node.right) def _mark_expression_observed(self, node: ASTNode): """Recursively mark all identifiers in an expression as observed.""" diff --git a/tests/compiler/test_codegen.py b/tests/compiler/test_codegen.py index 61352fa..8d55ef2 100644 --- a/tests/compiler/test_codegen.py +++ b/tests/compiler/test_codegen.py @@ -10,6 +10,7 @@ try: from llvmlite import binding as llvm from eigenscript.compiler.codegen.llvm_backend import LLVMCodeGenerator + from eigenscript.compiler.analysis.observer import ObserverAnalyzer from eigenscript.lexer import Tokenizer from eigenscript.parser.ast_builder import Parser @@ -198,6 +199,26 @@ def test_module_verification(self): llvm_ir = self.compile_source(source) assert self.verify_llvm_ir(llvm_ir), f"Failed to verify: {source}" + def test_predicate_condition_marks_variable_observed(self): + """Predicates should mark the last assigned variable as observed.""" + source = """x is 0 +loop while not converged: + x is x + 1""" + + tokenizer = Tokenizer(source) + tokens = tokenizer.tokenize() + parser = Parser(tokens) + ast = parser.parse() + + analyzer = ObserverAnalyzer() + observed_vars = analyzer.analyze(ast.statements) + assert "x" in observed_vars + + codegen = LLVMCodeGenerator(observed_variables=observed_vars) + llvm_ir = codegen.compile(ast.statements) + + assert 'call i1 @"eigen_check_converged"' in llvm_ir + def test_runtime_functions_declared(self): """Test that runtime functions are properly declared.""" source = "x is 42"