Skip to content

Commit 6a892ec

Browse files
Merge pull request #382 from egraphs-good/demo
Convert to multisets, facts as actions, and update multiset example
2 parents 6daa0ae + e6ad4fa commit 6a892ec

File tree

5 files changed

+92
-39
lines changed

5 files changed

+92
-39
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7+
- Support using facts as union actions, add conversions to multisets, and update multiset examlpe [#382](https://github.com/egraphs-good/egglog-python/pull/382)
8+
79
## 12.0.0 (2025-11-16)
810

911
- Add support for setting report level with `egraph.set_report_level` [#375](https://github.com/egraphs-good/egglog-python/pull/375)

python/egglog/builtins.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"Map",
4040
"MapLike",
4141
"MultiSet",
42+
"MultiSetLike",
4243
"Primitive",
4344
"PyObject",
4445
"Rational",
@@ -583,6 +584,17 @@ def __add__(self, other: MultiSet[T]) -> MultiSet[T]: ...
583584
def map(self, f: Callable[[T], T]) -> MultiSet[T]: ...
584585

585586

587+
converter(
588+
tuple,
589+
MultiSet,
590+
lambda t: MultiSet[get_type_args()[0]]( # type: ignore[misc,operator]
591+
*(convert(x, get_type_args()[0]) for x in t)
592+
),
593+
)
594+
595+
MultiSetLike: TypeAlias = MultiSet[T] | tuple[TO, ...]
596+
597+
586598
class Rational(BuiltinExpr, egg_sort="Rational"):
587599
@method(preserve=True)
588600
@deprecated("use .value")

python/egglog/egraph.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ def __replace_expr__(self, new_expr: Self) -> None:
367367
Replace the current expression with the new expression in place.
368368
"""
369369

370+
def __hash__(self) -> int: ... # type: ignore[empty-body]
371+
370372

371373
class BuiltinExpr(BaseExpr, metaclass=_ExprMetaclass):
372374
"""
@@ -1933,23 +1935,32 @@ def seq(*schedules: Schedule) -> Schedule:
19331935
return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
19341936

19351937

1936-
ActionLike: TypeAlias = Action | BaseExpr
1937-
1938-
19391938
def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
19401939
return tuple(map(_action_like, action_likes))
19411940

19421941

19431942
def _action_like(action_like: ActionLike) -> Action:
19441943
if isinstance(action_like, Action):
19451944
return action_like
1945+
if isinstance(action_like, Fact):
1946+
match action_like.fact:
1947+
case EqDecl(tp, left, right):
1948+
return Action(
1949+
action_like.__egg_decls__,
1950+
UnionDecl(tp, left, right),
1951+
)
1952+
case ExprFactDecl(expr):
1953+
return Action(
1954+
action_like.__egg_decls__,
1955+
ExprActionDecl(expr),
1956+
)
1957+
case _:
1958+
assert_never(action_like.fact)
19461959
return expr_action(action_like)
19471960

19481961

19491962
Command: TypeAlias = Action | RewriteOrRule
19501963

1951-
CommandLike: TypeAlias = ActionLike | RewriteOrRule
1952-
19531964

19541965
def _command_like(command_like: CommandLike) -> Command:
19551966
if isinstance(command_like, RewriteOrRule):
@@ -1976,6 +1987,8 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
19761987

19771988

19781989
FactLike = Fact | BaseExpr
1990+
ActionLike: TypeAlias = Action | BaseExpr | Fact
1991+
CommandLike: TypeAlias = ActionLike | RewriteOrRule
19791992

19801993

19811994
def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]:

python/egglog/examples/multiset.py

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,81 @@
66

77
from __future__ import annotations
88

9-
from collections import Counter
10-
119
from egglog import *
1210

1311

1412
class Math(Expr):
1513
def __init__(self, x: i64Like) -> None: ...
14+
def __add__(self, other: MathLike) -> Math: ...
15+
def __radd__(self, other: MathLike) -> Math: ...
16+
def __mul__(self, other: MathLike) -> Math: ...
17+
def __rmul__(self, other: MathLike) -> Math: ...
1618

1719

18-
@function
19-
def square(x: Math) -> Math: ...
20-
21-
22-
@ruleset
23-
def math_ruleset(i: i64):
24-
yield rewrite(square(Math(i))).to(Math(i * i))
25-
20+
MathLike = Math | i64Like
21+
converter(i64, Math, Math)
2622

27-
egraph = EGraph()
2823

29-
xs = MultiSet(Math(1), Math(2), Math(3))
30-
egraph.register(xs)
24+
@function
25+
def sum(xs: MultiSetLike[Math, MathLike]) -> Math: ...
3126

32-
egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
33-
egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
3427

35-
assert Counter(egraph.extract(xs).value) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
28+
@function
29+
def product(xs: MultiSetLike[Math, MathLike]) -> Math: ...
3630

3731

38-
inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
39-
egraph.register(inserted)
40-
egraph.check(xs.insert(Math(4)) == inserted)
41-
egraph.check(xs.contains(Math(1)))
42-
egraph.check(xs.not_contains(Math(4)))
43-
assert Math(1) in xs
44-
assert Math(4) not in xs
32+
@function
33+
def square(x: Math) -> Math: ...
4534

46-
egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
4735

48-
assert egraph.extract(xs.length()).value == 3
49-
assert len(xs) == 3
36+
x = constant("x", Math)
37+
expr1 = 2 * (x + 3)
38+
expr2 = 6 + 2 * x
5039

51-
egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
5240

53-
egraph.check(MultiSet(Math(1)).pick() == Math(1))
41+
@ruleset
42+
def math_ruleset(a: Math, b: Math, c: Math, i: i64, j: i64, xs: MultiSet[Math], ys: MultiSet[Math], zs: MultiSet[Math]):
43+
yield rewrite(a + b).to(sum(MultiSet(a, b)))
44+
yield rewrite(a * b).to(product(MultiSet(a, b)))
45+
# 0 or 1 elements sums/products also can be extracted back to numbers
46+
yield rule(a == sum(xs), xs.length() == i64(1)).then(a == xs.pick())
47+
yield rule(a == product(xs), xs.length() == i64(1)).then(a == xs.pick())
48+
yield rewrite(sum(MultiSet[Math]())).to(Math(0))
49+
yield rewrite(product(MultiSet[Math]())).to(Math(1))
50+
# distributive rule (a * (b + c) = a*b + a*c)
51+
yield rule(
52+
b == product(ys),
53+
a == sum(xs),
54+
ys.contains(a),
55+
ys.length() > 1,
56+
zs == ys.remove(a),
57+
).then(
58+
b == sum(xs.map(lambda x: product(zs.insert(x)))),
59+
)
60+
# constants
61+
yield rule(
62+
a == sum(xs),
63+
b == Math(i),
64+
xs.contains(b),
65+
ys == xs.remove(b),
66+
c == Math(j),
67+
ys.contains(c),
68+
).then(
69+
a == sum(ys.remove(c).insert(Math(i + j))),
70+
)
71+
yield rule(
72+
a == product(xs),
73+
b == Math(i),
74+
xs.contains(b),
75+
ys == xs.remove(b),
76+
c == Math(j),
77+
ys.contains(c),
78+
).then(
79+
a == product(ys.remove(c).insert(Math(i * j))),
80+
)
5481

55-
mapped = xs.map(square)
56-
egraph.register(mapped)
57-
egraph.run(math_ruleset)
58-
egraph.check(mapped == MultiSet(Math(1), Math(4), Math(9)))
5982

60-
egraph.check(xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3)))
83+
egraph = EGraph()
84+
egraph.register(expr1, expr2)
85+
egraph.run(math_ruleset.saturate())
86+
egraph.check(expr1 == expr2)

0 commit comments

Comments
 (0)