Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _This project uses semantic versioning_

## 13.1.0 (2026-03-25)

- Add Python-friendly `RunReport` wrapper that returns `CommandDecl` objects as rule keys instead of raw egglog s-expression strings, with pretty-printed Python syntax in `str()` output [#416](https://github.com/egraphs-good/egglog-python/pull/416)
- Improve high-level Python ergonomics and docs [#397](https://github.com/egraphs-good/egglog-python/pull/397)
- Add `EGraph.freeze()`, returning a `FrozenEGraph` snapshot that can be pretty-printed back into replayable high-level Python actions for debugging and inspection.
- Add a variadic `EGraph(*actions, seminaive=True, save_egglog_string=False)` constructor so actions can be registered at construction time, and export `ActionLike` from `egglog` for typing code that works with `EGraph.register(...)` and the constructor.
Expand Down
17 changes: 8 additions & 9 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .egraph_state import *
from .ipython_magic import IN_IPYTHON
from .pretty import pretty_decl
from .run_report import RunReport
from .runtime import *
from .thunk import *

Expand Down Expand Up @@ -953,36 +954,34 @@ def output(self) -> None:
raise NotImplementedError(msg)

@overload
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ...
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport: ...

@overload
def run(self, schedule: Schedule, /) -> bindings.RunReport: ...
def run(self, schedule: Schedule, /) -> RunReport: ...

@_TRACER.start_as_current_span("run")
def run(
self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
) -> bindings.RunReport:
def run(self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport:
"""
Run the egraph until the given limit or until the given facts are true.
"""
if isinstance(limit_or_schedule, int):
limit_or_schedule = run(ruleset, *until) * limit_or_schedule
return self._run_schedule(limit_or_schedule)

def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
def _run_schedule(self, schedule: Schedule) -> RunReport:
self._add_decls(schedule)
cmd = self._state.run_schedule_to_egg(schedule.schedule)
(command_output,) = self._run_program(cmd)
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report
return RunReport._from_bindings(command_output.report, self._state)

def stats(self) -> bindings.RunReport:
def stats(self) -> RunReport:
"""
Returns the overall run report for the egraph.
"""
(output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None))
assert isinstance(output, bindings.OverallStatistics)
return output.report
return RunReport._from_bindings(output.report, self._state)

def check_bool(self, *facts: FactLike) -> bool:
"""
Expand Down
37 changes: 31 additions & 6 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def _normalize_global_let_name(name: str) -> str:
return name if name.startswith("$") else f"${name}"


def _normalize_rule_key(key: str) -> str:
"""Normalize an egglog rule string for consistent matching."""
key = key.replace("'", '"')
return re.sub(r"\s+", " ", key).strip()


@dataclass
class EGraphState:
"""
Expand Down Expand Up @@ -76,6 +82,8 @@ class EGraphState:
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)

egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)

# Cache of egg expressions for converting to egg
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)

Expand Down Expand Up @@ -247,6 +255,13 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
case _:
assert_never(schedule)

def translate_rule_key(self, egglog_key: str) -> CommandDecl:
"""
Look up the original Python CommandDecl for an egglog rule key.
"""
normalized = _normalize_rule_key(egglog_key)
return self.egg_rule_to_command_decl[normalized]

def ruleset_to_egg(self, ident: Ident) -> None:
"""
Registers a ruleset if it's not already registered.
Expand Down Expand Up @@ -289,13 +304,19 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
self._expr_to_egg(rhs),
[self.fact_to_egg(c) for c in conditions],
)
return (
bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
if isinstance(cmd, RewriteDecl)
else bindings.BiRewriteCommand(str(ruleset), rewrite)
)
if isinstance(cmd, RewriteDecl):
egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
else:
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)

normalized = _normalize_rule_key(str(egg_cmd))
self.egg_rule_to_command_decl[normalized] = cmd
if isinstance(cmd, BiRewriteDecl):
self.egg_rule_to_command_decl[normalized + "=>"] = cmd
self.egg_rule_to_command_decl[normalized + "<="] = cmd
Comment on lines +315 to +316
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you leave a comment here that the birewrite is de-sugared into these two names in egglog which is why we have both of them?

return egg_cmd
case RuleDecl(head, body, name):
return bindings.RuleCommand(
egg_cmd = bindings.RuleCommand(
bindings.Rule(
span(),
[self.action_to_egg(a) for a in head],
Expand All @@ -304,6 +325,10 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
str(ruleset),
)
)
self.egg_rule_to_command_decl[_normalize_rule_key(str(egg_cmd))] = cmd
if name:
self.egg_rule_to_command_decl[name] = cmd
Comment on lines +328 to +330
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a name is provided, we don't need to save the normalized version right?

return egg_cmd
# TODO: Replace with just constants value and looking at REF of function
case DefaultRewriteDecl(ref, expr, subsume):
sig = self.__egg_decls__.get_callable_decl(ref).signature
Expand Down
121 changes: 121 additions & 0 deletions python/egglog/run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import timedelta

from . import bindings
from .declarations import CommandDecl, Declarations
from .egraph_state import EGraphState
from .pretty import pretty_decl


def _format_rule_key(decls: Declarations, key: CommandDecl) -> str:
return pretty_decl(decls, key)
Comment on lines +13 to +14
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove this extra function



@dataclass
class RuleReport:
plan: bindings.Plan | None
search_and_apply_time: timedelta
num_matches: int

@classmethod
def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport:
return cls(
plan=report.plan,
search_and_apply_time=report.search_and_apply_time,
num_matches=report.num_matches,
)


@dataclass
class RuleSetReport:
changed: bool
rule_reports: dict[CommandDecl, list[RuleReport]]
search_and_apply_time: timedelta
merge_time: timedelta
_decls: Declarations = field(repr=False, default=None)

@classmethod
def _from_bindings(
cls, report: bindings.RuleSetReport, translate_key: Callable[[str], CommandDecl], decls: Declarations
) -> RuleSetReport:
return cls(
changed=report.changed,
rule_reports={
translate_key(k): [RuleReport._from_bindings(rr) for rr in v] for k, v in report.rule_reports.items()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A birewrite will actually show up twice right? So if we want to keep both of them as one Python rule, then I think then we would need to add the timings for both of them?

},
search_and_apply_time=report.search_and_apply_time,
merge_time=report.merge_time,
_decls=decls,
)

def __repr__(self) -> str:
rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()}
return (
f"RuleSetReport(changed={self.changed}, "
f"rule_reports={rule_reports_str}, "
f"search_and_apply_time={self.search_and_apply_time}, "
f"merge_time={self.merge_time})"
)


@dataclass
class IterationReport:
rule_set_report: RuleSetReport
rebuild_time: timedelta

@classmethod
def _from_bindings(
cls, report: bindings.IterationReport, translate_key: Callable[[str], CommandDecl], decls: Declarations
) -> IterationReport:
return cls(
rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, translate_key, decls),
rebuild_time=report.rebuild_time,
)


@dataclass
class RunReport:
"""Python-friendly wrapper around bindings.RunReport."""

iterations: list[IterationReport]
updated: bool
search_and_apply_time_per_rule: dict[CommandDecl, timedelta]
num_matches_per_rule: dict[CommandDecl, int]
search_and_apply_time_per_ruleset: dict[str, timedelta]
merge_time_per_ruleset: dict[str, timedelta]
rebuild_time_per_ruleset: dict[str, timedelta]
Comment on lines +87 to +89
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to do this, but we could also change these strings into actual ruleset objects? I am not sure if that would be too hard or awkward, and isn't essential for this PR.

_decls: Declarations = field(repr=False, default=None)

def __repr__(self) -> str:
time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()}
matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()}
return (
f"RunReport(iterations={self.iterations}, "
f"updated={self.updated}, "
f"search_and_apply_time_per_rule={time_per_rule}, "
f"num_matches_per_rule={matches_per_rule}, "
f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, "
f"merge_time_per_ruleset={self.merge_time_per_ruleset}, "
f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})"
)

@classmethod
def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport:
return cls(
iterations=[
IterationReport._from_bindings(it, state.translate_rule_key, state.__egg_decls__)
for it in report.iterations
],
updated=report.updated,
search_and_apply_time_per_rule={
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items()
},
num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()},
search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset,
merge_time_per_ruleset=report.merge_time_per_ruleset,
rebuild_time_per_ruleset=report.rebuild_time_per_ruleset,
_decls=state.__egg_decls__,
)
Loading
Loading