Skip to content

Commit 48cddd7

Browse files
committed
✨ feat: add cse and constant_folding
1 parent b56386d commit 48cddd7

File tree

16 files changed

+393
-3
lines changed

16 files changed

+393
-3
lines changed

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@
55
"python.analysis.inlayHints.variableTypes": true,
66
"[python]": {
77
"editor.defaultFormatter": "charliermarsh.ruff"
8-
}
8+
},
9+
"ruff.interpreter": ["${workspaceFolder}/.venv/bin/python"]
910
}

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,18 @@
1212
<a href="https://github.com/astral-sh/ruff"><img alt="ruff" src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json&style=flat-square"></a>
1313
<a href="https://gitmoji.dev"><img alt="Gitmoji" src="https://img.shields.io/badge/gitmoji-%20😜%20😍-FFDD67?style=flat-square"></a>
1414
</p>
15+
16+
## Installation
17+
18+
TODO...
19+
20+
## Usage
21+
22+
```console
23+
$ expr_simplifier cse "a * 4 + (a * 4)"
24+
(___t_0 := (a * 4)) + ___t_0
25+
$ expr_simplifier constant_folding "(___x := 1 + 1) + ___x" --max-iter=2
26+
(___x := 2) + ___x
27+
$ expr_simplifier constant_folding "(___x := 1 + 1) + ___x" --max-iter=2
28+
4
29+
```

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ version = "0.1.0"
44
description = ""
55
readme = "README.md"
66
requires-python = ">=3.9"
7-
dependencies = []
7+
dependencies = [
8+
"typing-extensions>=4.12.2",
9+
]
810
authors = [{ name = "Nyakku Shigure", email = "sigure.qaq@gmail.com" }]
911
keywords = []
1012
license = { text = "MIT" }

src/expr_simplifier/__main__.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,48 @@
11
from __future__ import annotations
22

33
import argparse
4+
import ast
5+
from collections.abc import Callable
46

57
from expr_simplifier import __version__
8+
from expr_simplifier.transforms import apply_constant_folding, apply_cse
9+
from expr_simplifier.typing import Pass
10+
from expr_simplifier.utils import loop_until_stable
11+
12+
13+
def create_pass_command(name: str, passes: list[Pass]) -> Callable[[argparse.Namespace], None]:
14+
def pass_command(args: argparse.Namespace) -> None:
15+
expr = ast.parse(args.input, mode="eval")
16+
simplified_expr = loop_until_stable(expr, passes, args.max_iter)
17+
print(ast.unparse(simplified_expr))
18+
19+
pass_command.__name__ = name
20+
return pass_command
21+
22+
23+
def create_pass_parser(
24+
name: str,
25+
passes: list[Pass],
26+
description: str,
27+
subparser: argparse._SubParsersAction[argparse.ArgumentParser], # pyright: ignore [reportPrivateUsage]
28+
) -> None:
29+
parser = subparser.add_parser(name, help=description)
30+
parser.add_argument("input", help="The expression to simplify")
31+
parser.add_argument("--max-iter", type=int, default=100, help="The maximum number of iterations")
32+
parser.set_defaults(func=create_pass_command(name, passes))
633

734

835
def main() -> None:
936
parser = argparse.ArgumentParser(prog="moelib", description="A moe moe project")
1037
parser.add_argument("-v", "--version", action="version", version=__version__)
11-
args = parser.parse_args() # type: ignore
38+
sub_parsers = parser.add_subparsers(help="sub-command help", dest="sub_command")
39+
40+
create_pass_parser("cse", [apply_cse], "Common Subexpression Elimination", sub_parsers)
41+
create_pass_parser("constant_folding", [apply_constant_folding], "Constant Folding", sub_parsers)
42+
create_pass_parser("auto", [apply_constant_folding, apply_cse], "Auto Simplification", sub_parsers)
43+
44+
args = parser.parse_args()
45+
args.func(args)
1246

1347

1448
if __name__ == "__main__":
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
4+
class UniqueNameGenerator:
5+
def __init__(self, prefix: str):
6+
self.prefix = prefix
7+
self._counter = 0
8+
9+
def generate_name(self) -> str:
10+
name = f"{self.prefix}{self._counter}"
11+
self._counter += 1
12+
return name
13+
14+
15+
class SymbolTable:
16+
def __init__(self):
17+
self._symbols = set[str]()
18+
self._name_generator = UniqueNameGenerator("___t_")
19+
20+
def define_symbol(self, symbol: str):
21+
self._symbols.add(symbol)
22+
23+
def request_new_symbol(self) -> str:
24+
while True:
25+
new_symbol = self._name_generator.generate_name()
26+
if self.is_symbol_defined(new_symbol):
27+
continue
28+
self.define_symbol(new_symbol)
29+
return new_symbol
30+
31+
def is_symbol_defined(self, symbol: str) -> bool:
32+
return symbol in self._symbols
33+
34+
def __contains__(self, symbol: str) -> bool:
35+
return self.is_symbol_defined(symbol)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import annotations
2+
3+
from expr_simplifier.transforms.constant_folding import apply_constant_folding as apply_constant_folding
4+
from expr_simplifier.transforms.cse import apply_cse as apply_cse
5+
from expr_simplifier.transforms.remove_named_expr import apply_remove_named_expr as apply_remove_named_expr
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
5+
from typing_extensions import TypeAlias
6+
7+
from expr_simplifier.transforms.remove_named_expr import apply_remove_named_expr
8+
9+
SubExpressionTable: TypeAlias = dict[str, tuple[str, int]]
10+
11+
12+
def fold_to_constant(node: ast.AST) -> ast.Constant:
13+
return ast.Constant(value=eval(ast.unparse(node)))
14+
15+
16+
class ConstantFolding(ast.NodeTransformer):
17+
def visit(self, node: ast.AST) -> ast.AST:
18+
transformed_node = super().visit(node)
19+
if isinstance(node, ast.BinOp) and isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant):
20+
return fold_to_constant(node)
21+
if isinstance(node, ast.UnaryOp) and isinstance(node.operand, ast.Constant):
22+
return fold_to_constant(node)
23+
if isinstance(node, ast.BoolOp) and all(isinstance(value, ast.Constant) for value in node.values):
24+
return fold_to_constant(node)
25+
if isinstance(node, ast.Compare) and all(isinstance(comp, ast.Constant) for comp in node.comparators):
26+
return fold_to_constant(node)
27+
if isinstance(node, ast.JoinedStr) and all(
28+
isinstance(value, ast.Constant)
29+
or (isinstance(value, ast.FormattedValue) and isinstance(value.value, ast.Constant))
30+
for value in node.values
31+
):
32+
return fold_to_constant(node)
33+
return transformed_node
34+
35+
36+
def apply_constant_folding(expr: ast.AST) -> ast.AST:
37+
# Constant propagation
38+
expr = apply_remove_named_expr(expr, constant_only=True)
39+
# Constant folding
40+
expr = ConstantFolding().visit(expr)
41+
return expr
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
5+
from typing_extensions import TypeAlias
6+
7+
from expr_simplifier.symbol_table import SymbolTable
8+
from expr_simplifier.transforms.remove_named_expr import apply_remove_named_expr
9+
10+
SubExpressionTable: TypeAlias = dict[str, tuple[str, int]]
11+
12+
13+
class CSEPreAnalyzer(ast.NodeVisitor):
14+
def __init__(self):
15+
self.subexpressions: SubExpressionTable = {}
16+
self.symbols = SymbolTable()
17+
super().__init__()
18+
19+
def visit(self, node: ast.AST) -> None:
20+
super().visit(node)
21+
expr_string = ast.unparse(node)
22+
if isinstance(node, ast.expr):
23+
if isinstance(node, ast.Name):
24+
self.symbols.define_symbol(node.id)
25+
return
26+
if isinstance(node, ast.Constant):
27+
return
28+
if expr_string not in self.subexpressions:
29+
self.subexpressions[expr_string] = (self.symbols.request_new_symbol(), 0)
30+
symbol, count = self.subexpressions[expr_string]
31+
self.subexpressions[expr_string] = (symbol, count + 1)
32+
33+
34+
class CommonSubexpressionElimination(ast.NodeTransformer):
35+
def __init__(self, subexpressions: dict[str, tuple[str, int]]):
36+
self.subexpressions = subexpressions
37+
self.declared_symbols = set[str]()
38+
super().__init__()
39+
40+
def visit(self, node: ast.AST) -> ast.expr:
41+
expr_string = ast.unparse(node)
42+
transformed_node = super().visit(node)
43+
if isinstance(node, ast.expr) and expr_string in self.subexpressions:
44+
symbol, count = self.subexpressions[expr_string]
45+
if count > 1:
46+
if symbol not in self.declared_symbols:
47+
self.declared_symbols.add(symbol)
48+
assign_node = ast.NamedExpr(target=ast.Name(id=symbol), value=transformed_node)
49+
return assign_node
50+
return ast.Name(id=symbol)
51+
return transformed_node
52+
53+
54+
def show_subexpressions(subexpressions: SubExpressionTable) -> None:
55+
for subexpression, (symbol, count) in subexpressions.items():
56+
print(f"{symbol}: {subexpression} ({count})")
57+
58+
59+
def apply_cse(expr: ast.AST) -> ast.AST:
60+
expr = apply_remove_named_expr(expr)
61+
cse_pre_analyzer = CSEPreAnalyzer()
62+
cse_pre_analyzer.visit(expr)
63+
cse = CommonSubexpressionElimination(cse_pre_analyzer.subexpressions)
64+
return cse.visit(expr)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import copy
5+
6+
7+
class RemoveNamedExpr(ast.NodeTransformer):
8+
def __init__(self, constant_only: bool = False) -> None:
9+
super().__init__()
10+
self.named_expressions = dict[str, ast.expr]()
11+
self.constant_only = constant_only
12+
13+
def should_replace(self, symbol: str) -> bool:
14+
return not self.constant_only or isinstance(self.named_expressions[symbol], ast.Constant)
15+
16+
def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.expr:
17+
value = self.visit(node.value)
18+
name = node.target.id
19+
self.named_expressions[name] = value
20+
if not self.should_replace(name):
21+
return node
22+
return value
23+
24+
def visit_Name(self, node: ast.Name) -> ast.expr:
25+
if node.id in self.named_expressions:
26+
if not self.should_replace(node.id):
27+
return node
28+
return copy.deepcopy(self.named_expressions[node.id])
29+
return node
30+
31+
32+
def apply_remove_named_expr(expr: ast.AST, constant_only: bool = False) -> ast.AST:
33+
remove_named_expr = RemoveNamedExpr(constant_only)
34+
return remove_named_expr.visit(expr)

src/expr_simplifier/typing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from collections.abc import Callable
5+
6+
from typing_extensions import TypeAlias
7+
8+
Pass: TypeAlias = Callable[[ast.AST], ast.AST]

0 commit comments

Comments
 (0)