Skip to content

Commit 982a9ba

Browse files
committed
✨ feat: remove unused NamedExpr in CSE
1 parent f1a9b06 commit 982a9ba

File tree

9 files changed

+104
-26
lines changed

9 files changed

+104
-26
lines changed

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ version = "0.1.0"
44
description = ""
55
readme = "README.md"
66
requires-python = ">=3.9"
7-
dependencies = [
8-
"typing-extensions>=4.12.2",
9-
]
7+
dependencies = ["typing-extensions>=4.12.2"]
108
authors = [{ name = "Nyakku Shigure", email = "sigure.qaq@gmail.com" }]
119
keywords = []
1210
license = { text = "MIT" }
@@ -77,6 +75,7 @@ ignore = [
7775
[tool.ruff.lint.isort]
7876
required-imports = ["from __future__ import annotations"]
7977
known-first-party = ["expr_simplifier"]
78+
combine-as-imports = true
8079

8180
[tool.ruff.lint.per-file-ignores]
8281
"setup.py" = ["I"]

src/expr_simplifier/transforms/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,10 @@
22

33
from expr_simplifier.transforms.constant_folding import apply_constant_folding as apply_constant_folding
44
from expr_simplifier.transforms.cse import apply_cse as apply_cse
5-
from expr_simplifier.transforms.inline_named_expr import apply_inline_named_expr as apply_inline_named_expr
5+
from expr_simplifier.transforms.inline_named_expr import (
6+
apply_constant_propagation as apply_constant_propagation,
7+
apply_inline_all_named_expr as apply_inline_all_named_expr,
8+
)
9+
from expr_simplifier.transforms.remove_unused_named_expr import (
10+
apply_remove_unused_named_expr as apply_remove_unused_named_expr,
11+
)

src/expr_simplifier/transforms/constant_folding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing_extensions import TypeAlias
66

7-
from expr_simplifier.transforms.inline_named_expr import apply_inline_named_expr
7+
from expr_simplifier.transforms.inline_named_expr import apply_constant_propagation
88

99
SubExpressionTable: TypeAlias = dict[str, tuple[str, int]]
1010

@@ -35,7 +35,7 @@ def visit(self, node: ast.AST) -> ast.AST:
3535

3636
def apply_constant_folding(expr: ast.AST) -> ast.AST:
3737
# Constant propagation
38-
expr = apply_inline_named_expr(expr, constant_only=True)
38+
expr = apply_constant_propagation(expr)
3939
# Constant folding
4040
expr = ConstantFolding().visit(expr)
4141
return expr

src/expr_simplifier/transforms/cse.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from typing_extensions import TypeAlias
66

77
from expr_simplifier.symbol_table import SymbolTable
8-
from expr_simplifier.transforms.inline_named_expr import apply_inline_named_expr
8+
from expr_simplifier.transforms.inline_named_expr import apply_inline_all_named_expr
9+
from expr_simplifier.transforms.remove_unused_named_expr import apply_remove_unused_named_expr
910

1011
SubExpressionTable: TypeAlias = dict[str, tuple[str, int]]
1112

@@ -45,9 +46,9 @@ def visit(self, node: ast.AST) -> ast.expr:
4546
if count > 1:
4647
if symbol not in self.declared_symbols:
4748
self.declared_symbols.add(symbol)
48-
assign_node = ast.NamedExpr(target=ast.Name(id=symbol), value=transformed_node)
49+
assign_node = ast.NamedExpr(target=ast.Name(id=symbol, ctx=ast.Store()), value=transformed_node)
4950
return assign_node
50-
return ast.Name(id=symbol)
51+
return ast.Name(id=symbol, ctx=ast.Load())
5152
return transformed_node
5253

5354

@@ -57,8 +58,9 @@ def show_subexpressions(subexpressions: SubExpressionTable) -> None:
5758

5859

5960
def apply_cse(expr: ast.AST) -> ast.AST:
60-
expr = apply_inline_named_expr(expr)
61+
expr = apply_inline_all_named_expr(expr)
6162
cse_pre_analyzer = CSEPreAnalyzer()
6263
cse_pre_analyzer.visit(expr)
6364
cse = CommonSubexpressionElimination(cse_pre_analyzer.subexpressions)
64-
return cse.visit(expr)
65+
expr = cse.visit(expr)
66+
return apply_remove_unused_named_expr(expr)

src/expr_simplifier/transforms/inline_named_expr.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,43 @@
22

33
import ast
44
import copy
5+
from collections.abc import Callable
6+
7+
from typing_extensions import TypeAlias
8+
9+
NamedExpressions: TypeAlias = dict[str, ast.expr]
10+
ShouldReplaceFn: TypeAlias = Callable[[str, NamedExpressions], bool]
511

612

713
class InlineNamedExpr(ast.NodeTransformer):
8-
def __init__(self, constant_only: bool = False) -> None:
14+
def __init__(self, should_replace_fn: ShouldReplaceFn) -> None:
915
super().__init__()
16+
self.should_replace_fn = should_replace_fn
1017
self.named_expressions = dict[str, ast.expr]()
11-
self.constant_only = constant_only
12-
13-
def should_replace(self, symbol: str) -> bool:
14-
return not self.constant_only or isinstance(self.named_expressions[symbol], ast.Constant)
1518

1619
def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.expr:
1720
value = self.visit(node.value)
1821
name = node.target.id
1922
self.named_expressions[name] = value
20-
if not self.should_replace(name):
23+
if not self.should_replace_fn(name, self.named_expressions):
2124
return node
2225
return value
2326

2427
def visit_Name(self, node: ast.Name) -> ast.expr:
2528
if node.id in self.named_expressions:
26-
if not self.should_replace(node.id):
29+
if not self.should_replace_fn(node.id, self.named_expressions):
2730
return node
2831
return copy.deepcopy(self.named_expressions[node.id])
2932
return node
3033

3134

32-
def apply_inline_named_expr(expr: ast.AST, constant_only: bool = False) -> ast.AST:
33-
inline_named_expr = InlineNamedExpr(constant_only)
35+
def apply_inline_all_named_expr(expr: ast.AST) -> ast.AST:
36+
inline_named_expr = InlineNamedExpr(lambda symbol, named_expressions: True)
37+
return inline_named_expr.visit(expr)
38+
39+
40+
def apply_constant_propagation(expr: ast.AST) -> ast.AST:
41+
inline_named_expr = InlineNamedExpr(
42+
lambda symbol, named_expressions: isinstance(named_expressions[symbol], ast.Constant)
43+
)
3444
return inline_named_expr.visit(expr)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
5+
6+
class UsedSymbolsAnalyzer(ast.NodeVisitor):
7+
def __init__(self) -> None:
8+
super().__init__()
9+
self.used_symbols = set[str]()
10+
11+
def visit_Name(self, node: ast.Name) -> None:
12+
symbol = node.id
13+
if isinstance(node.ctx, ast.Load):
14+
self.used_symbols.add(symbol)
15+
16+
17+
class RemoveUnusedNamedExpr(ast.NodeTransformer):
18+
def __init__(self, used_symbols: set[str]) -> None:
19+
super().__init__()
20+
self.used_symbols = used_symbols
21+
22+
def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.expr:
23+
value = self.visit(node.value)
24+
name = node.target.id
25+
if name not in self.used_symbols:
26+
return value
27+
return node
28+
29+
30+
def apply_remove_unused_named_expr(expr: ast.AST) -> ast.AST:
31+
used_symbols_analyzer = UsedSymbolsAnalyzer()
32+
used_symbols_analyzer.visit(expr)
33+
used_symbols = used_symbols_analyzer.used_symbols
34+
35+
remove_unused_named_expr = RemoveUnusedNamedExpr(used_symbols)
36+
return remove_unused_named_expr.visit(expr)

tests/test_transforms/test_cse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
("1 + 1 + 2", "1 + 1 + 2"),
2020
(
2121
"(___x := a.b).c + ___x.d.e + (___y := (___z := a.b).c) + ___z.d.e + ___y.f.g",
22-
"(___t_1 := (___t_0 := a.b).c) + (___t_3 := (___t_2 := ___t_0.d).e) + ___t_1 + ___t_3 + ___t_1.f.g",
22+
"(___t_1 := (___t_0 := a.b).c) + (___t_3 := ___t_0.d.e) + ___t_1 + ___t_3 + ___t_1.f.g",
2323
),
2424
(
2525
"a.b.c and (fn1(fn2((___x := a.b.c.d.e.f.g)))) and (___x == False) and (___y := a.b.c)",
26-
"(___t_1 := (___t_0 := a.b).c) and fn1(fn2((___t_5 := (___t_4 := (___t_3 := (___t_2 := ___t_1.d).e).f).g))) and (___t_5 == False) and ___t_1",
26+
"(___t_1 := a.b.c) and fn1(fn2((___t_5 := ___t_1.d.e.f.g))) and (___t_5 == False) and ___t_1",
2727
),
2828
(
2929
"(___x := a.b.c) + ___x",
30-
"(___t_1 := (___t_0 := a.b).c) + ___t_1",
30+
"(___t_1 := a.b.c) + ___t_1",
3131
),
3232
],
3333
)

tests/test_transforms/test_inline_named_expr.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44

55
import pytest
66

7-
from expr_simplifier.transforms import apply_inline_named_expr
7+
from expr_simplifier.transforms import (
8+
apply_constant_propagation,
9+
apply_inline_all_named_expr,
10+
)
811

912

1013
@pytest.mark.parametrize(
@@ -19,7 +22,7 @@
1922
)
2023
def test_inline_named_expr(expr: str, expected: str):
2124
tree = ast.parse(expr, mode="eval")
22-
transformed_tree = apply_inline_named_expr(tree)
25+
transformed_tree = apply_inline_all_named_expr(tree)
2326
transformed_expr = ast.unparse(transformed_tree)
2427
assert transformed_expr == expected
2528

@@ -37,6 +40,6 @@ def test_inline_named_expr(expr: str, expected: str):
3740
)
3841
def test_constant_propagation(expr: str, expected: str):
3942
tree = ast.parse(expr, mode="eval")
40-
transformed_tree = apply_inline_named_expr(tree, constant_only=True)
43+
transformed_tree = apply_constant_propagation(tree)
4144
transformed_expr = ast.unparse(transformed_tree)
4245
assert transformed_expr == expected
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
5+
import pytest
6+
7+
from expr_simplifier.transforms import apply_remove_unused_named_expr
8+
9+
10+
@pytest.mark.parametrize(
11+
["expr", "expected"],
12+
[
13+
("(___x := a.b) + ___x", "a.b + ___x"),
14+
("(___y := (___x := a.b)) + ___y", "(___y := a.b) + ___y"),
15+
("(___y := (___x := a.b))", "a.b"),
16+
],
17+
)
18+
def test_inline_named_expr(expr: str, expected: str):
19+
tree = ast.parse(expr, mode="eval")
20+
transformed_tree = apply_remove_unused_named_expr(tree)
21+
transformed_expr = ast.unparse(transformed_tree)
22+
assert transformed_expr == expected

0 commit comments

Comments
 (0)