|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import ast |
| 4 | + |
| 5 | +from typing_extensions import TypeAlias |
| 6 | + |
| 7 | +from expr_simplifier.symbol_table import SymbolTable |
| 8 | +from expr_simplifier.transforms.remove_named_expr import apply_remove_named_expr |
| 9 | + |
| 10 | +SubExpressionTable: TypeAlias = dict[str, tuple[str, int]] |
| 11 | + |
| 12 | + |
| 13 | +class CSEPreAnalyzer(ast.NodeVisitor): |
| 14 | + def __init__(self): |
| 15 | + self.subexpressions: SubExpressionTable = {} |
| 16 | + self.symbols = SymbolTable() |
| 17 | + super().__init__() |
| 18 | + |
| 19 | + def visit(self, node: ast.AST) -> None: |
| 20 | + super().visit(node) |
| 21 | + expr_string = ast.unparse(node) |
| 22 | + if isinstance(node, ast.expr): |
| 23 | + if isinstance(node, ast.Name): |
| 24 | + self.symbols.define_symbol(node.id) |
| 25 | + return |
| 26 | + if isinstance(node, ast.Constant): |
| 27 | + return |
| 28 | + if expr_string not in self.subexpressions: |
| 29 | + self.subexpressions[expr_string] = (self.symbols.request_new_symbol(), 0) |
| 30 | + symbol, count = self.subexpressions[expr_string] |
| 31 | + self.subexpressions[expr_string] = (symbol, count + 1) |
| 32 | + |
| 33 | + |
| 34 | +class CommonSubexpressionElimination(ast.NodeTransformer): |
| 35 | + def __init__(self, subexpressions: dict[str, tuple[str, int]]): |
| 36 | + self.subexpressions = subexpressions |
| 37 | + self.declared_symbols = set[str]() |
| 38 | + super().__init__() |
| 39 | + |
| 40 | + def visit(self, node: ast.AST) -> ast.expr: |
| 41 | + expr_string = ast.unparse(node) |
| 42 | + transformed_node = super().visit(node) |
| 43 | + if isinstance(node, ast.expr) and expr_string in self.subexpressions: |
| 44 | + symbol, count = self.subexpressions[expr_string] |
| 45 | + if count > 1: |
| 46 | + if symbol not in self.declared_symbols: |
| 47 | + self.declared_symbols.add(symbol) |
| 48 | + assign_node = ast.NamedExpr(target=ast.Name(id=symbol), value=transformed_node) |
| 49 | + return assign_node |
| 50 | + return ast.Name(id=symbol) |
| 51 | + return transformed_node |
| 52 | + |
| 53 | + |
| 54 | +def show_subexpressions(subexpressions: SubExpressionTable) -> None: |
| 55 | + for subexpression, (symbol, count) in subexpressions.items(): |
| 56 | + print(f"{symbol}: {subexpression} ({count})") |
| 57 | + |
| 58 | + |
| 59 | +def apply_cse(expr: ast.AST) -> ast.AST: |
| 60 | + expr = apply_remove_named_expr(expr) |
| 61 | + cse_pre_analyzer = CSEPreAnalyzer() |
| 62 | + cse_pre_analyzer.visit(expr) |
| 63 | + cse = CommonSubexpressionElimination(cse_pre_analyzer.subexpressions) |
| 64 | + return cse.visit(expr) |
0 commit comments