|
2 | 2 |
|
3 | 3 | import ast |
4 | 4 | import copy |
| 5 | +from collections.abc import Callable |
| 6 | + |
| 7 | +from typing_extensions import TypeAlias |
| 8 | + |
| 9 | +NamedExpressions: TypeAlias = dict[str, ast.expr] |
| 10 | +ShouldReplaceFn: TypeAlias = Callable[[str, NamedExpressions], bool] |
5 | 11 |
|
6 | 12 |
|
7 | 13 | class InlineNamedExpr(ast.NodeTransformer): |
8 | | - def __init__(self, constant_only: bool = False) -> None: |
| 14 | + def __init__(self, should_replace_fn: ShouldReplaceFn) -> None: |
9 | 15 | super().__init__() |
| 16 | + self.should_replace_fn = should_replace_fn |
10 | 17 | self.named_expressions = dict[str, ast.expr]() |
11 | | - self.constant_only = constant_only |
12 | | - |
13 | | - def should_replace(self, symbol: str) -> bool: |
14 | | - return not self.constant_only or isinstance(self.named_expressions[symbol], ast.Constant) |
15 | 18 |
|
16 | 19 | def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.expr: |
17 | 20 | value = self.visit(node.value) |
18 | 21 | name = node.target.id |
19 | 22 | self.named_expressions[name] = value |
20 | | - if not self.should_replace(name): |
| 23 | + if not self.should_replace_fn(name, self.named_expressions): |
21 | 24 | return node |
22 | 25 | return value |
23 | 26 |
|
24 | 27 | def visit_Name(self, node: ast.Name) -> ast.expr: |
25 | 28 | if node.id in self.named_expressions: |
26 | | - if not self.should_replace(node.id): |
| 29 | + if not self.should_replace_fn(node.id, self.named_expressions): |
27 | 30 | return node |
28 | 31 | return copy.deepcopy(self.named_expressions[node.id]) |
29 | 32 | return node |
30 | 33 |
|
31 | 34 |
|
32 | | -def apply_inline_named_expr(expr: ast.AST, constant_only: bool = False) -> ast.AST: |
33 | | - inline_named_expr = InlineNamedExpr(constant_only) |
| 35 | +def apply_inline_all_named_expr(expr: ast.AST) -> ast.AST: |
| 36 | + inline_named_expr = InlineNamedExpr(lambda symbol, named_expressions: True) |
| 37 | + return inline_named_expr.visit(expr) |
| 38 | + |
| 39 | + |
| 40 | +def apply_constant_propagation(expr: ast.AST) -> ast.AST: |
| 41 | + inline_named_expr = InlineNamedExpr( |
| 42 | + lambda symbol, named_expressions: isinstance(named_expressions[symbol], ast.Constant) |
| 43 | + ) |
34 | 44 | return inline_named_expr.visit(expr) |
0 commit comments