Skip to content

Commit edc198b

Browse files
committed
✨ feat: support logical_simplification
1 parent 75a430f commit edc198b

File tree

5 files changed

+230
-9
lines changed

5 files changed

+230
-9
lines changed

README.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,27 @@
1313
<a href="https://gitmoji.dev"><img alt="Gitmoji" src="https://img.shields.io/badge/gitmoji-%20😜%20😍-FFDD67?style=flat-square"></a>
1414
</p>
1515

16-
## Installation
16+
## Usage
1717

18-
TODO...
18+
### Quick start in CLI
1919

20-
## Usage
20+
We recommend using the tool `uv` to run the without manually installing the package:
2121

2222
```console
23-
$ expr_simplifier cse "a * 4 + (a * 4)"
23+
$ uvx expr_simplifier cse "a * 4 + (a * 4)"
2424
(___t_0 := (a * 4)) + ___t_0
25-
$ expr_simplifier constant_folding "(___x := 1 + 1) + ___x" --max-iter=1
25+
$ uvx expr_simplifier constant_folding "(___x := 1 + 1) + ___x" --max-iter=1
2626
(___x := 2) + ___x
27-
$ expr_simplifier constant_folding "(___x := 1 + 1) + ___x" --max-iter=2
27+
$ uvx expr_simplifier constant_folding "(___x := 1 + 1) + ___x" --max-iter=2
2828
4
29+
# uvx expr_simplifier logical_simplification "a and b and a"
30+
a and b
2931
```
3032

33+
### As a library
34+
35+
TODO...
36+
3137
## TODOs
3238

33-
- [ ] Fold same logic operations (`a and b and a` -> `a and b`)
3439
- [ ] Add runtime checks in uts

src/expr_simplifier/__main__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Callable
66

77
from expr_simplifier import __version__
8-
from expr_simplifier.transforms import apply_constant_folding, apply_cse
8+
from expr_simplifier.transforms import apply_constant_folding, apply_cse, apply_logical_simplification
99
from expr_simplifier.typing import Pass
1010
from expr_simplifier.utils import loop_until_stable
1111

@@ -39,7 +39,10 @@ def main() -> None:
3939

4040
create_pass_parser("cse", [apply_cse], "Common Subexpression Elimination", sub_parsers)
4141
create_pass_parser("constant_folding", [apply_constant_folding], "Constant Folding", sub_parsers)
42-
create_pass_parser("auto", [apply_constant_folding, apply_cse], "Auto Simplification", sub_parsers)
42+
create_pass_parser("logical_simplification", [apply_logical_simplification], "Logical Simplification", sub_parsers)
43+
create_pass_parser(
44+
"auto", [apply_constant_folding, apply_logical_simplification, apply_cse], "Auto Simplification", sub_parsers
45+
)
4346

4447
args = parser.parse_args()
4548
args.func(args)

src/expr_simplifier/transforms/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
apply_constant_propagation as apply_constant_propagation,
77
apply_inline_all_named_expr as apply_inline_all_named_expr,
88
)
9+
from expr_simplifier.transforms.logical_simplification import (
10+
apply_logical_short_circuiting as apply_logical_short_circuiting,
11+
apply_logical_simplification as apply_logical_simplification,
12+
apply_remove_same_subexpression_in_logical_op as apply_remove_same_subexpression_in_logical_op,
13+
)
914
from expr_simplifier.transforms.remove_unused_named_expr import (
1015
apply_remove_unused_named_expr as apply_remove_unused_named_expr,
1116
)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
5+
6+
def create_logical_values_node(op: ast.boolop, values: list[ast.expr], default_value: bool) -> ast.expr:
7+
if not values:
8+
return ast.Constant(value=default_value)
9+
if len(values) == 1:
10+
return values[0]
11+
return ast.BoolOp(op=op, values=values)
12+
13+
14+
class LogicalShortCircuiting(ast.NodeTransformer):
15+
def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr:
16+
if isinstance(node.op, ast.And):
17+
return self._visit_and(node)
18+
elif isinstance(node.op, ast.Or):
19+
return self._visit_or(node)
20+
return node
21+
22+
def _visit_and(self, node: ast.BoolOp) -> ast.expr:
23+
and_values: list[ast.expr] = []
24+
for value in node.values:
25+
new_value = self.visit(value)
26+
if isinstance(new_value, ast.Constant):
27+
if new_value.value is False:
28+
return new_value
29+
if new_value.value is True:
30+
continue
31+
and_values.append(new_value)
32+
return create_logical_values_node(ast.And(), and_values, True)
33+
34+
def _visit_or(self, node: ast.BoolOp) -> ast.expr:
35+
or_values: list[ast.expr] = []
36+
for value in node.values:
37+
new_value = self.visit(value)
38+
if isinstance(new_value, ast.Constant):
39+
if new_value.value is True:
40+
return new_value
41+
if new_value.value is False:
42+
continue
43+
or_values.append(new_value)
44+
return create_logical_values_node(ast.Or(), or_values, False)
45+
46+
47+
class LogicCluster:
48+
def __init__(self, ids: set[int]) -> None:
49+
self.ids = ids
50+
self.subexpressions = set[str]()
51+
52+
def match(self, other: LogicCluster) -> bool:
53+
return bool(self.ids & other.ids)
54+
55+
def add_expression(self, expr: str) -> None:
56+
self.subexpressions.add(expr)
57+
58+
def update(self, other: LogicCluster) -> None:
59+
self.ids.update(other.ids)
60+
self.subexpressions.update(other.subexpressions)
61+
62+
def __repr__(self) -> str:
63+
return f"LogicCluster(ids={self.ids}, subexpressions={self.subexpressions})"
64+
65+
66+
class RemoveSameSubExpressionInLogicalOp(ast.NodeTransformer):
67+
def __init__(self):
68+
self.and_clusters: list[LogicCluster] = []
69+
self.or_clusters: list[LogicCluster] = []
70+
71+
@staticmethod
72+
def update_cluster(regisitry: list[LogicCluster], new_cluster: LogicCluster) -> LogicCluster:
73+
for cluster in regisitry:
74+
if cluster.match(new_cluster):
75+
cluster.update(new_cluster)
76+
return cluster
77+
regisitry.append(new_cluster)
78+
return new_cluster
79+
80+
def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr:
81+
if isinstance(node.op, ast.And):
82+
return self._visit_and(node)
83+
elif isinstance(node.op, ast.Or):
84+
return self._visit_or(node)
85+
return node
86+
87+
def _visit_and(self, node: ast.BoolOp) -> ast.expr:
88+
node_id = id(node)
89+
cluster_ids = {id(value) for value in node.values}
90+
cluster = LogicCluster({*cluster_ids, node_id})
91+
cluster = self.update_cluster(self.and_clusters, cluster)
92+
transformed_values: list[ast.expr] = []
93+
for value in node.values:
94+
known_exprs = cluster.subexpressions.copy()
95+
transformed_value = self.visit(value)
96+
stringified_value = ast.unparse(transformed_value)
97+
if stringified_value not in known_exprs:
98+
cluster.add_expression(stringified_value)
99+
transformed_values.append(transformed_value)
100+
return create_logical_values_node(ast.And(), transformed_values, True)
101+
102+
def _visit_or(self, node: ast.BoolOp) -> ast.expr:
103+
node_id = id(node)
104+
cluster_ids = {id(value) for value in node.values}
105+
cluster = LogicCluster({*cluster_ids, node_id})
106+
cluster = self.update_cluster(self.or_clusters, cluster)
107+
transformed_values: list[ast.expr] = []
108+
for value in node.values:
109+
known_exprs = cluster.subexpressions.copy()
110+
transformed_value = self.visit(value)
111+
stringified_value = ast.unparse(transformed_value)
112+
if stringified_value not in known_exprs:
113+
cluster.add_expression(stringified_value)
114+
transformed_values.append(transformed_value)
115+
return create_logical_values_node(ast.Or(), transformed_values, False)
116+
117+
118+
def apply_logical_short_circuiting(expr: ast.AST) -> ast.AST:
119+
logical_short_circuiting = LogicalShortCircuiting()
120+
return logical_short_circuiting.visit(expr)
121+
122+
123+
def apply_remove_same_subexpression_in_logical_op(expr: ast.AST) -> ast.AST:
124+
remove_same_subexpression_in_logical_op = RemoveSameSubExpressionInLogicalOp()
125+
res = remove_same_subexpression_in_logical_op.visit(expr)
126+
return res
127+
128+
129+
def apply_logical_simplification(expr: ast.AST) -> ast.AST:
130+
expr = apply_remove_same_subexpression_in_logical_op(expr)
131+
expr = apply_logical_short_circuiting(expr)
132+
return expr
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
5+
import pytest
6+
7+
from expr_simplifier.transforms import (
8+
apply_logical_short_circuiting,
9+
apply_logical_simplification,
10+
apply_remove_same_subexpression_in_logical_op,
11+
)
12+
13+
14+
@pytest.mark.parametrize(
15+
["expr", "expected"],
16+
[
17+
("a and False and b", "False"),
18+
("(a and False) and b", "False"),
19+
("a and (False and b)", "False"),
20+
("a or True or b", "True"),
21+
("(a or True) or b", "True"),
22+
("a or (True or b)", "True"),
23+
("a and True and b", "a and b"),
24+
("a or False or b", "a or b"),
25+
("a and True or b", "a or b"),
26+
("not (a and False)", "not False"),
27+
("True and True", "True"),
28+
("True or True", "True"),
29+
("False and False", "False"),
30+
("False or False", "False"),
31+
],
32+
)
33+
def test_logical_short_circuiting(expr: str, expected: str):
34+
tree = ast.parse(expr, mode="eval")
35+
transformed_tree = apply_logical_short_circuiting(tree)
36+
transformed_expr = ast.unparse(transformed_tree)
37+
assert transformed_expr == expected
38+
39+
40+
@pytest.mark.parametrize(
41+
["expr", "expected"],
42+
[
43+
("a and b and a", "a and b"),
44+
("(a and b) and a", "a and b"),
45+
("a and (b and a)", "a and b"),
46+
("a and b and a and b", "a and b"),
47+
("(a and b) and (a and b)", "(a and b) and True"),
48+
("a or b or a", "a or b"),
49+
("(a or b) or a", "a or b"),
50+
("a or (b or a)", "a or b"),
51+
("a or b or a or b", "a or b"),
52+
("(a or b) or (a or b)", "(a or b) or False"),
53+
("a and b and c and a and b and c", "a and b and c"),
54+
("(a and b) and (c and a) and b and c", "(a and b) and c"),
55+
("a and (b and c) and a and b and c", "a and (b and c)"),
56+
],
57+
)
58+
def test_remove_same_subexpression_in_logical_op(expr: str, expected: str):
59+
tree = ast.parse(expr, mode="eval")
60+
transformed_tree = apply_remove_same_subexpression_in_logical_op(tree)
61+
transformed_expr = ast.unparse(transformed_tree)
62+
assert transformed_expr == expected
63+
64+
65+
@pytest.mark.parametrize(
66+
["expr", "expected"],
67+
[
68+
("(a and b) and (a and b)", "a and b"),
69+
("(a or b) or (a or b)", "a or b"),
70+
],
71+
)
72+
def test_logical_simplification(expr: str, expected: str):
73+
tree = ast.parse(expr, mode="eval")
74+
transformed_tree = apply_logical_simplification(tree)
75+
transformed_expr = ast.unparse(transformed_tree)
76+
assert transformed_expr == expected

0 commit comments

Comments
 (0)