Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions src/eigenscript/compiler/analysis/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
This implements zero-cost abstraction: pay for geometric semantics only when used.
"""

from typing import Set
from typing import Optional, Set
from eigenscript.parser.ast_builder import (
ASTNode,
Identifier,
Expand Down Expand Up @@ -40,10 +40,19 @@ class ObserverAnalyzer:
Unobserved variables can be compiled to raw doubles for maximum performance.
"""

PREDICATES = {
"converged",
"diverging",
"oscillating",
"stable",
"improving",
}

def __init__(self):
self.observed: Set[str] = set()
self.user_functions: Set[str] = set()
self.current_function: str = None
self.last_assigned: Optional[str] = None

def analyze(self, ast_nodes: list[ASTNode]) -> Set[str]:
"""Analyze AST and return set of variable names that need EigenValue tracking.
Expand All @@ -58,6 +67,7 @@ def analyze(self, ast_nodes: list[ASTNode]) -> Set[str]:
self.observed = set()
self.user_functions = set()
self.current_function = None
self.last_assigned = None

# First pass: collect all user-defined function names
for node in ast_nodes:
Expand All @@ -82,7 +92,9 @@ def _visit(self, node: ASTNode):
elif isinstance(node, FunctionDef):
# Function parameters are always observed (might be interrogated inside)
prev_function = self.current_function
prev_last_assigned = self.last_assigned
self.current_function = node.name
self.last_assigned = None

# In EigenScript, functions implicitly have parameter 'n'
self.observed.add("n")
Expand All @@ -91,9 +103,11 @@ def _visit(self, node: ASTNode):
self._visit(stmt)

self.current_function = prev_function
self.last_assigned = prev_last_assigned

elif isinstance(node, Assignment):
self._visit(node.expression)
self.last_assigned = node.identifier

elif isinstance(node, Interrogative):
# Direct observation: "why is x" marks x as observed
Expand Down Expand Up @@ -152,32 +166,19 @@ def _visit(self, node: ASTNode):
self._visit(node.list_expr)
self._visit(node.index_expr)

elif isinstance(node, Identifier):
# Check if this identifier is a predicate
if node.name in [
"converged",
"diverging",
"oscillating",
"stable",
"improving",
]:
# Predicates require the last variable to be observed
# This is a simplified heuristic - ideally we'd track scope
pass

def _check_for_predicates(self, node: ASTNode):
"""Check if condition uses predicates (converged, diverging, etc.)."""
if node is None:
return

if isinstance(node, Identifier):
if node.name in [
"converged",
"diverging",
"oscillating",
"stable",
"improving",
]:
# TODO: Mark the variable being tested as observed
# For now, this is handled by the codegen heuristic of "last variable"
pass
if node.name in self.PREDICATES and self.last_assigned:
self.observed.add(self.last_assigned)
elif isinstance(node, UnaryOp):
self._check_for_predicates(node.operand)
elif isinstance(node, BinaryOp):
self._check_for_predicates(node.left)
self._check_for_predicates(node.right)

def _mark_expression_observed(self, node: ASTNode):
"""Recursively mark all identifiers in an expression as observed."""
Expand Down
21 changes: 21 additions & 0 deletions tests/compiler/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
try:
from llvmlite import binding as llvm
from eigenscript.compiler.codegen.llvm_backend import LLVMCodeGenerator
from eigenscript.compiler.analysis.observer import ObserverAnalyzer
from eigenscript.lexer import Tokenizer
from eigenscript.parser.ast_builder import Parser

Expand Down Expand Up @@ -198,6 +199,26 @@ def test_module_verification(self):
llvm_ir = self.compile_source(source)
assert self.verify_llvm_ir(llvm_ir), f"Failed to verify: {source}"

def test_predicate_condition_marks_variable_observed(self):
"""Predicates should mark the last assigned variable as observed."""
source = """x is 0
loop while not converged:
x is x + 1"""

tokenizer = Tokenizer(source)
tokens = tokenizer.tokenize()
parser = Parser(tokens)
ast = parser.parse()

analyzer = ObserverAnalyzer()
observed_vars = analyzer.analyze(ast.statements)
assert "x" in observed_vars

codegen = LLVMCodeGenerator(observed_variables=observed_vars)
llvm_ir = codegen.compile(ast.statements)

assert 'call i1 @"eigen_check_converged"' in llvm_ir

def test_runtime_functions_declared(self):
"""Test that runtime functions are properly declared."""
source = "x is 42"
Expand Down