From 2007b1afee2fef162267a92bad4bcc155ed9546f Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Wed, 6 May 2026 00:04:42 +0100 Subject: [PATCH 1/7] feat: add pretty run report --- python/egglog/egraph.py | 15 +++--- python/egglog/egraph_state.py | 26 ++++++++--- python/egglog/run_report.py | 86 +++++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 13 deletions(-) create mode 100644 python/egglog/run_report.py diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 8fc6643b..a60f13c3 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -42,6 +42,7 @@ from .egraph_state import * from .ipython_magic import IN_IPYTHON from .pretty import pretty_decl +from .run_report import PrettyRunReport from .runtime import * from .thunk import * @@ -953,15 +954,15 @@ def output(self) -> None: raise NotImplementedError(msg) @overload - def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ... + def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> PrettyRunReport: ... @overload - def run(self, schedule: Schedule, /) -> bindings.RunReport: ... + def run(self, schedule: Schedule, /) -> PrettyRunReport: ... @_TRACER.start_as_current_span("run") def run( self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None - ) -> bindings.RunReport: + ) -> PrettyRunReport: """ Run the egraph until the given limit or until the given facts are true. """ @@ -969,20 +970,20 @@ def run( limit_or_schedule = run(ruleset, *until) * limit_or_schedule return self._run_schedule(limit_or_schedule) - def _run_schedule(self, schedule: Schedule) -> bindings.RunReport: + def _run_schedule(self, schedule: Schedule) -> PrettyRunReport: self._add_decls(schedule) cmd = self._state.run_schedule_to_egg(schedule.schedule) (command_output,) = self._run_program(cmd) assert isinstance(command_output, bindings.RunScheduleOutput) - return command_output.report + return PrettyRunReport.from_bindings(command_output.report, self._state) - def stats(self) -> bindings.RunReport: + def stats(self) -> PrettyRunReport: """ Returns the overall run report for the egraph. """ (output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None)) assert isinstance(output, bindings.OverallStatistics) - return output.report + return PrettyRunReport.from_bindings(output.report, self._state) def check_bool(self, *facts: FactLike) -> bool: """ diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index 1d65aeff..650d0d01 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -76,6 +76,8 @@ class EGraphState: type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict) + egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict) + # Cache of egg expressions for converting to egg expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict) @@ -247,6 +249,14 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912 case _: assert_never(schedule) + def translate_rule_key(self, egglog_key: str) -> str: + """ + Translate an egglog rule name to its Python representation. + """ + if egglog_key in self.egg_rule_to_command_decl: + return pretty_decl(self.__egg_decls__, self.egg_rule_to_command_decl[egglog_key]) + return egglog_key + def ruleset_to_egg(self, ident: Ident) -> None: """ Registers a ruleset if it's not already registered. @@ -289,13 +299,15 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command self._expr_to_egg(rhs), [self.fact_to_egg(c) for c in conditions], ) - return ( - bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) - if isinstance(cmd, RewriteDecl) - else bindings.BiRewriteCommand(str(ruleset), rewrite) - ) + if isinstance(cmd, RewriteDecl): + egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) + else: + egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite) + + self.egg_rule_to_command_decl[str(egg_cmd)] = cmd + return egg_cmd case RuleDecl(head, body, name): - return bindings.RuleCommand( + egg_cmd = bindings.RuleCommand( bindings.Rule( span(), [self.action_to_egg(a) for a in head], @@ -304,6 +316,8 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command str(ruleset), ) ) + self.egg_rule_to_command_decl[str(egg_cmd)] = cmd + return egg_cmd # TODO: Replace with just constants value and looking at REF of function case DefaultRewriteDecl(ref, expr, subsume): sig = self.__egg_decls__.get_callable_decl(ref).signature diff --git a/python/egglog/run_report.py b/python/egglog/run_report.py new file mode 100644 index 00000000..233906c3 --- /dev/null +++ b/python/egglog/run_report.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta + +from . import bindings +from .egraph_state import EGraphState + + +@dataclass +class PrettyRuleReport: + plan: bindings.Plan | None + search_and_apply_time: timedelta + num_matches: int + + @classmethod + def from_bindings(cls, report: bindings.RuleReport) -> PrettyRuleReport: + return cls( + plan=report.plan, + search_and_apply_time=report.search_and_apply_time, + num_matches=report.num_matches, + ) + + +@dataclass +class PrettyRuleSetReport: + changed: bool + rule_reports: dict[str, list[PrettyRuleReport]] + search_and_apply_time: timedelta + merge_time: timedelta + + @classmethod + def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> PrettyRuleSetReport: + return cls( + changed=report.changed, + rule_reports={ + translate_key(k): [PrettyRuleReport.from_bindings(rr) for rr in v] + for k, v in report.rule_reports.items() + }, + search_and_apply_time=report.search_and_apply_time, + merge_time=report.merge_time, + ) + + +@dataclass +class PrettyIterationReport: + rule_set_report: PrettyRuleSetReport + rebuild_time: timedelta + + @classmethod + def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> PrettyIterationReport: + return cls( + rule_set_report=PrettyRuleSetReport.from_bindings(report.rule_set_report, translate_key), + rebuild_time=report.rebuild_time, + ) + + +@dataclass +class PrettyRunReport: + """Python-friendly wrapper around bindings.RunReport.""" + + iterations: list[PrettyIterationReport] + updated: bool + search_and_apply_time_per_rule: dict[str, timedelta] + num_matches_per_rule: dict[str, int] + search_and_apply_time_per_ruleset: dict[str, timedelta] + merge_time_per_ruleset: dict[str, timedelta] + rebuild_time_per_ruleset: dict[str, timedelta] + + @classmethod + def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> PrettyRunReport: + return cls( + iterations=[PrettyIterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations], + updated=report.updated, + search_and_apply_time_per_rule={ + state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items() + }, + num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()}, + search_and_apply_time_per_ruleset={ + state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_ruleset.items() + }, + merge_time_per_ruleset={state.translate_rule_key(k): v for k, v in report.merge_time_per_ruleset.items()}, + rebuild_time_per_ruleset={ + state.translate_rule_key(k): v for k, v in report.rebuild_time_per_ruleset.items() + }, + ) From 86133fa2a1d4e388100400f5ae0e91f2a4ff6668 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Wed, 6 May 2026 00:08:47 +0100 Subject: [PATCH 2/7] feat: add test for pretty run report --- python/tests/test_run_report.py | 200 ++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 python/tests/test_run_report.py diff --git a/python/tests/test_run_report.py b/python/tests/test_run_report.py new file mode 100644 index 00000000..38aa9223 --- /dev/null +++ b/python/tests/test_run_report.py @@ -0,0 +1,200 @@ +# mypy: disable-error-code="empty-body" +from __future__ import annotations + +from datetime import timedelta + +from egglog import * + + +class TestPrettyRunReport: + def _setup_simple_egraph(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(rewrite(x + y).to(y + x)) + egraph.register(Num(1) + Num(2)) + return egraph + + def test_run_returns_pretty_report(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + assert type(report).__name__ == "PrettyRunReport" + + def test_stats_returns_pretty_report(self): + egraph = self._setup_simple_egraph() + egraph.run(10) + report = egraph.stats() + assert type(report).__name__ == "PrettyRunReport" + + def test_rule_names_translated_in_top_level_dicts(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + for key in report.search_and_apply_time_per_rule: + assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" + assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}" + + for key in report.num_matches_per_rule: + assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" + assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}" + + def test_rule_names_translated_in_iterations(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + assert len(report.iterations) > 0 + for iteration in report.iterations: + for key in iteration.rule_set_report.rule_reports: + assert "__main__" not in key, f"Iteration rule key not translated: {key}" + assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" + + def test_updated_field(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + assert isinstance(report.updated, bool) + assert report.updated is True + + def test_num_matches(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + total_matches = sum(report.num_matches_per_rule.values()) + assert total_matches > 0 + + def test_timedelta_types(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + for v in report.search_and_apply_time_per_rule.values(): + assert isinstance(v, timedelta) + for v in report.search_and_apply_time_per_ruleset.values(): + assert isinstance(v, timedelta) + for v in report.merge_time_per_ruleset.values(): + assert isinstance(v, timedelta) + for v in report.rebuild_time_per_ruleset.values(): + assert isinstance(v, timedelta) + + def test_iteration_reports_are_pretty(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + + for it in report.iterations: + assert type(it).__name__ == "PrettyIterationReport" + assert type(it.rule_set_report).__name__ == "PrettyRuleSetReport" + for rule_reports in it.rule_set_report.rule_reports.values(): + for rr in rule_reports: + assert type(rr).__name__ == "PrettyRuleReport" + + def test_str_no_egglog_sexprs(self): + egraph = self._setup_simple_egraph() + report = egraph.run(10) + output = str(report) + + assert "(rewrite" not in output, f"str() still contains egglog s-expressions:\n{output}" + assert "__main__" not in output, f"str() still contains mangled names:\n{output}" + + def test_multiple_rules(self): + egraph = EGraph() + + class Math(Expr): + def __init__(self, value: i64Like) -> None: ... + def __add__(self, other: Math) -> Math: ... + def __mul__(self, other: Math) -> Math: ... + + a, b = vars_("a b", Math) + egraph.register( + rewrite(a + b).to(b + a), + rewrite(a * b).to(b * a), + ) + egraph.register(Math(1) + Math(2), Math(3) * Math(4)) + report = egraph.run(10) + + # should have two distinct translated rule keys + rule_keys = list(report.search_and_apply_time_per_rule.keys()) + assert len(rule_keys) == 2 + for key in rule_keys: + assert "__main__" not in key, f"Key not translated: {key}" + + def test_empty_run(self): + egraph = EGraph() + report = egraph.run(1) + assert type(report).__name__ == "PrettyRunReport" + assert isinstance(report.updated, bool) + + def test_named_rule(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(rule(x + y, name="comm").then(union(x + y).with_(y + x))) + egraph.register(Num(1) + Num(2)) + report = egraph.run(10) + + output = str(report) + assert "__main__" not in output, f"str() still contains mangled names:\n{output}" + + def test_unnamed_rule_decl(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(rule(x + y).then(union(x + y).with_(y + x))) + egraph.register(Num(1) + Num(2)) + report = egraph.run(10) + + output = str(report) + assert "__main__" not in output, f"Unnamed RuleDecl key not translated:\n{output}" + # Should contain Python rule() syntax somewhere in the keys + rule_keys = list(report.search_and_apply_time_per_rule.keys()) + assert len(rule_keys) > 0 + for key in rule_keys: + assert "__main__" not in key, f"RuleDecl key not translated: {key}" + + def test_birewrite_decl(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + def __mul__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(birewrite(x + y).to(y + x)) + egraph.register(Num(1) + Num(2)) + report = egraph.run(10) + + output = str(report) + assert "__main__" not in output, f"BiRewriteDecl key not translated:\n{output}" + rule_keys = list(report.search_and_apply_time_per_rule.keys()) + assert len(rule_keys) > 0 + for key in rule_keys: + assert "__main__" not in key, f"BiRewriteDecl key not translated: {key}" + assert "birewrite" in key, f"Expected birewrite() syntax, got: {key}" + + def test_rewrite_decl(self): + egraph = EGraph() + + class Num(Expr): + def __init__(self, n: i64Like) -> None: ... + def __add__(self, other: Num) -> Num: ... + + x, y = vars_("x y", Num) + egraph.register(rewrite(x + y).to(y + x)) + egraph.register(Num(1) + Num(2)) + report = egraph.run(10) + + rule_keys = list(report.search_and_apply_time_per_rule.keys()) + assert len(rule_keys) == 1 + key = rule_keys[0] + assert "rewrite" in key, f"Expected rewrite() syntax, got: {key}" + assert "__main__" not in key, f"RewriteDecl key not translated: {key}" From f554bb2081e2db4fc3b0265de24b73a879d563da Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Thu, 7 May 2026 01:07:48 +0100 Subject: [PATCH 3/7] chore: rename to runreport --- python/egglog/egraph.py | 16 ++++++++-------- python/egglog/run_report.py | 28 ++++++++++++++-------------- python/tests/test_run_report.py | 20 ++++++++++---------- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index a60f13c3..00210d41 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -42,7 +42,7 @@ from .egraph_state import * from .ipython_magic import IN_IPYTHON from .pretty import pretty_decl -from .run_report import PrettyRunReport +from .run_report import RunReport from .runtime import * from .thunk import * @@ -954,15 +954,15 @@ def output(self) -> None: raise NotImplementedError(msg) @overload - def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> PrettyRunReport: ... + def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport: ... @overload - def run(self, schedule: Schedule, /) -> PrettyRunReport: ... + def run(self, schedule: Schedule, /) -> RunReport: ... @_TRACER.start_as_current_span("run") def run( self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None - ) -> PrettyRunReport: + ) -> RunReport: """ Run the egraph until the given limit or until the given facts are true. """ @@ -970,20 +970,20 @@ def run( limit_or_schedule = run(ruleset, *until) * limit_or_schedule return self._run_schedule(limit_or_schedule) - def _run_schedule(self, schedule: Schedule) -> PrettyRunReport: + def _run_schedule(self, schedule: Schedule) -> RunReport: self._add_decls(schedule) cmd = self._state.run_schedule_to_egg(schedule.schedule) (command_output,) = self._run_program(cmd) assert isinstance(command_output, bindings.RunScheduleOutput) - return PrettyRunReport.from_bindings(command_output.report, self._state) + return RunReport.from_bindings(command_output.report, self._state) - def stats(self) -> PrettyRunReport: + def stats(self) -> RunReport: """ Returns the overall run report for the egraph. """ (output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None)) assert isinstance(output, bindings.OverallStatistics) - return PrettyRunReport.from_bindings(output.report, self._state) + return RunReport.from_bindings(output.report, self._state) def check_bool(self, *facts: FactLike) -> bool: """ diff --git a/python/egglog/run_report.py b/python/egglog/run_report.py index 233906c3..44f693e0 100644 --- a/python/egglog/run_report.py +++ b/python/egglog/run_report.py @@ -8,13 +8,13 @@ @dataclass -class PrettyRuleReport: +class RuleReport: plan: bindings.Plan | None search_and_apply_time: timedelta num_matches: int @classmethod - def from_bindings(cls, report: bindings.RuleReport) -> PrettyRuleReport: + def from_bindings(cls, report: bindings.RuleReport) -> RuleReport: return cls( plan=report.plan, search_and_apply_time=report.search_and_apply_time, @@ -23,18 +23,18 @@ def from_bindings(cls, report: bindings.RuleReport) -> PrettyRuleReport: @dataclass -class PrettyRuleSetReport: +class RuleSetReport: changed: bool - rule_reports: dict[str, list[PrettyRuleReport]] + rule_reports: dict[str, list[RuleReport]] search_and_apply_time: timedelta merge_time: timedelta @classmethod - def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> PrettyRuleSetReport: + def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> RuleSetReport: return cls( changed=report.changed, rule_reports={ - translate_key(k): [PrettyRuleReport.from_bindings(rr) for rr in v] + translate_key(k): [RuleReport.from_bindings(rr) for rr in v] for k, v in report.rule_reports.items() }, search_and_apply_time=report.search_and_apply_time, @@ -43,23 +43,23 @@ def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) @dataclass -class PrettyIterationReport: - rule_set_report: PrettyRuleSetReport +class IterationReport: + rule_set_report: RuleSetReport rebuild_time: timedelta @classmethod - def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> PrettyIterationReport: + def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> IterationReport: return cls( - rule_set_report=PrettyRuleSetReport.from_bindings(report.rule_set_report, translate_key), + rule_set_report=RuleSetReport.from_bindings(report.rule_set_report, translate_key), rebuild_time=report.rebuild_time, ) @dataclass -class PrettyRunReport: +class RunReport: """Python-friendly wrapper around bindings.RunReport.""" - iterations: list[PrettyIterationReport] + iterations: list[IterationReport] updated: bool search_and_apply_time_per_rule: dict[str, timedelta] num_matches_per_rule: dict[str, int] @@ -68,9 +68,9 @@ class PrettyRunReport: rebuild_time_per_ruleset: dict[str, timedelta] @classmethod - def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> PrettyRunReport: + def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport: return cls( - iterations=[PrettyIterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations], + iterations=[IterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations], updated=report.updated, search_and_apply_time_per_rule={ state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items() diff --git a/python/tests/test_run_report.py b/python/tests/test_run_report.py index 38aa9223..e804b73a 100644 --- a/python/tests/test_run_report.py +++ b/python/tests/test_run_report.py @@ -6,7 +6,7 @@ from egglog import * -class TestPrettyRunReport: +class TestRunReport: def _setup_simple_egraph(self): egraph = EGraph() @@ -19,16 +19,16 @@ def __add__(self, other: Num) -> Num: ... egraph.register(Num(1) + Num(2)) return egraph - def test_run_returns_pretty_report(self): + def test_run_returns_report(self): egraph = self._setup_simple_egraph() report = egraph.run(10) - assert type(report).__name__ == "PrettyRunReport" + assert type(report).__name__ == "RunReport" - def test_stats_returns_pretty_report(self): + def test_stats_returns_report(self): egraph = self._setup_simple_egraph() egraph.run(10) report = egraph.stats() - assert type(report).__name__ == "PrettyRunReport" + assert type(report).__name__ == "RunReport" def test_rule_names_translated_in_top_level_dicts(self): egraph = self._setup_simple_egraph() @@ -78,16 +78,16 @@ def test_timedelta_types(self): for v in report.rebuild_time_per_ruleset.values(): assert isinstance(v, timedelta) - def test_iteration_reports_are_pretty(self): + def test_iteration_reports(self): egraph = self._setup_simple_egraph() report = egraph.run(10) for it in report.iterations: - assert type(it).__name__ == "PrettyIterationReport" - assert type(it.rule_set_report).__name__ == "PrettyRuleSetReport" + assert type(it).__name__ == "IterationReport" + assert type(it.rule_set_report).__name__ == "RuleSetReport" for rule_reports in it.rule_set_report.rule_reports.values(): for rr in rule_reports: - assert type(rr).__name__ == "PrettyRuleReport" + assert type(rr).__name__ == "RuleReport" def test_str_no_egglog_sexprs(self): egraph = self._setup_simple_egraph() @@ -122,7 +122,7 @@ def __mul__(self, other: Math) -> Math: ... def test_empty_run(self): egraph = EGraph() report = egraph.run(1) - assert type(report).__name__ == "PrettyRunReport" + assert type(report).__name__ == "RunReport" assert isinstance(report.updated, bool) def test_named_rule(self): From c554d1add1b568bd415cfd2b5c5ef85db6a6ee36 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Thu, 7 May 2026 01:08:52 +0100 Subject: [PATCH 4/7] chore: pre-commit --- python/egglog/egraph.py | 4 +--- python/egglog/run_report.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 00210d41..bb7dfb73 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -960,9 +960,7 @@ def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> Ru def run(self, schedule: Schedule, /) -> RunReport: ... @_TRACER.start_as_current_span("run") - def run( - self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None - ) -> RunReport: + def run(self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport: """ Run the egraph until the given limit or until the given facts are true. """ diff --git a/python/egglog/run_report.py b/python/egglog/run_report.py index 44f693e0..703c7006 100644 --- a/python/egglog/run_report.py +++ b/python/egglog/run_report.py @@ -34,8 +34,7 @@ def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) return cls( changed=report.changed, rule_reports={ - translate_key(k): [RuleReport.from_bindings(rr) for rr in v] - for k, v in report.rule_reports.items() + translate_key(k): [RuleReport.from_bindings(rr) for rr in v] for k, v in report.rule_reports.items() }, search_and_apply_time=report.search_and_apply_time, merge_time=report.merge_time, From 810bd9580e4f04b1d755f96675bf1aa5393fabf6 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Thu, 7 May 2026 02:03:58 +0100 Subject: [PATCH 5/7] fix: store CommandDecl --- python/egglog/egraph_state.py | 25 +++++++++---- python/egglog/run_report.py | 65 +++++++++++++++++++++++++-------- python/tests/test_run_report.py | 35 ++++-------------- 3 files changed, 75 insertions(+), 50 deletions(-) diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index 650d0d01..62b74a28 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -48,6 +48,12 @@ def _normalize_global_let_name(name: str) -> str: return name if name.startswith("$") else f"${name}" +def _normalize_rule_key(key: str) -> str: + """Normalize an egglog rule string for consistent matching.""" + key = key.replace("'", '"') + return re.sub(r"\s+", " ", key).strip() + + @dataclass class EGraphState: """ @@ -249,13 +255,12 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912 case _: assert_never(schedule) - def translate_rule_key(self, egglog_key: str) -> str: + def translate_rule_key(self, egglog_key: str) -> CommandDecl: """ - Translate an egglog rule name to its Python representation. + Look up the original Python CommandDecl for an egglog rule key. """ - if egglog_key in self.egg_rule_to_command_decl: - return pretty_decl(self.__egg_decls__, self.egg_rule_to_command_decl[egglog_key]) - return egglog_key + normalized = _normalize_rule_key(egglog_key) + return self.egg_rule_to_command_decl[normalized] def ruleset_to_egg(self, ident: Ident) -> None: """ @@ -304,7 +309,11 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command else: egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite) - self.egg_rule_to_command_decl[str(egg_cmd)] = cmd + normalized = _normalize_rule_key(str(egg_cmd)) + self.egg_rule_to_command_decl[normalized] = cmd + if isinstance(cmd, BiRewriteDecl): + self.egg_rule_to_command_decl[normalized + "=>"] = cmd + self.egg_rule_to_command_decl[normalized + "<="] = cmd return egg_cmd case RuleDecl(head, body, name): egg_cmd = bindings.RuleCommand( @@ -316,7 +325,9 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command str(ruleset), ) ) - self.egg_rule_to_command_decl[str(egg_cmd)] = cmd + self.egg_rule_to_command_decl[_normalize_rule_key(str(egg_cmd))] = cmd + if name: + self.egg_rule_to_command_decl[name] = cmd return egg_cmd # TODO: Replace with just constants value and looking at REF of function case DefaultRewriteDecl(ref, expr, subsume): diff --git a/python/egglog/run_report.py b/python/egglog/run_report.py index 703c7006..493a3fd9 100644 --- a/python/egglog/run_report.py +++ b/python/egglog/run_report.py @@ -1,10 +1,16 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import timedelta from . import bindings +from .declarations import CommandDecl, Declarations from .egraph_state import EGraphState +from .pretty import pretty_decl + + +def _format_rule_key(decls: Declarations, key: CommandDecl) -> str: + return pretty_decl(decls, key) @dataclass @@ -25,12 +31,15 @@ def from_bindings(cls, report: bindings.RuleReport) -> RuleReport: @dataclass class RuleSetReport: changed: bool - rule_reports: dict[str, list[RuleReport]] + rule_reports: dict[CommandDecl, list[RuleReport]] search_and_apply_time: timedelta merge_time: timedelta + _decls: Declarations = field(repr=False, default=None) @classmethod - def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) -> RuleSetReport: + def from_bindings( + cls, report: bindings.RuleSetReport, translate_key: callable, decls: Declarations + ) -> RuleSetReport: return cls( changed=report.changed, rule_reports={ @@ -38,6 +47,16 @@ def from_bindings(cls, report: bindings.RuleSetReport, translate_key: callable) }, search_and_apply_time=report.search_and_apply_time, merge_time=report.merge_time, + _decls=decls, + ) + + def __repr__(self) -> str: + rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()} + return ( + f"RuleSetReport(changed={self.changed}, " + f"rule_reports={rule_reports_str}, " + f"search_and_apply_time={self.search_and_apply_time}, " + f"merge_time={self.merge_time})" ) @@ -47,9 +66,11 @@ class IterationReport: rebuild_time: timedelta @classmethod - def from_bindings(cls, report: bindings.IterationReport, translate_key: callable) -> IterationReport: + def from_bindings( + cls, report: bindings.IterationReport, translate_key: callable, decls: Declarations + ) -> IterationReport: return cls( - rule_set_report=RuleSetReport.from_bindings(report.rule_set_report, translate_key), + rule_set_report=RuleSetReport.from_bindings(report.rule_set_report, translate_key, decls), rebuild_time=report.rebuild_time, ) @@ -60,26 +81,40 @@ class RunReport: iterations: list[IterationReport] updated: bool - search_and_apply_time_per_rule: dict[str, timedelta] - num_matches_per_rule: dict[str, int] + search_and_apply_time_per_rule: dict[CommandDecl, timedelta] + num_matches_per_rule: dict[CommandDecl, int] search_and_apply_time_per_ruleset: dict[str, timedelta] merge_time_per_ruleset: dict[str, timedelta] rebuild_time_per_ruleset: dict[str, timedelta] + _decls: Declarations = field(repr=False, default=None) + + def __repr__(self) -> str: + time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()} + matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()} + return ( + f"RunReport(iterations={self.iterations}, " + f"updated={self.updated}, " + f"search_and_apply_time_per_rule={time_per_rule}, " + f"num_matches_per_rule={matches_per_rule}, " + f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, " + f"merge_time_per_ruleset={self.merge_time_per_ruleset}, " + f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})" + ) @classmethod def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport: return cls( - iterations=[IterationReport.from_bindings(it, state.translate_rule_key) for it in report.iterations], + iterations=[ + IterationReport.from_bindings(it, state.translate_rule_key, state.__egg_decls__) + for it in report.iterations + ], updated=report.updated, search_and_apply_time_per_rule={ state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items() }, num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()}, - search_and_apply_time_per_ruleset={ - state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_ruleset.items() - }, - merge_time_per_ruleset={state.translate_rule_key(k): v for k, v in report.merge_time_per_ruleset.items()}, - rebuild_time_per_ruleset={ - state.translate_rule_key(k): v for k, v in report.rebuild_time_per_ruleset.items() - }, + search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset, + merge_time_per_ruleset=report.merge_time_per_ruleset, + rebuild_time_per_ruleset=report.rebuild_time_per_ruleset, + _decls=state.__egg_decls__, ) diff --git a/python/tests/test_run_report.py b/python/tests/test_run_report.py index e804b73a..9f502558 100644 --- a/python/tests/test_run_report.py +++ b/python/tests/test_run_report.py @@ -4,6 +4,7 @@ from datetime import timedelta from egglog import * +from egglog.declarations import BiRewriteDecl, RewriteDecl, RuleDecl class TestRunReport: @@ -35,12 +36,10 @@ def test_rule_names_translated_in_top_level_dicts(self): report = egraph.run(10) for key in report.search_and_apply_time_per_rule: - assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" - assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}" + assert isinstance(key, RewriteDecl) for key in report.num_matches_per_rule: - assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" - assert "__main__" not in key, f"Key should not contain mangled egglog names: {key}" + assert isinstance(key, RewriteDecl) def test_rule_names_translated_in_iterations(self): egraph = self._setup_simple_egraph() @@ -49,8 +48,7 @@ def test_rule_names_translated_in_iterations(self): assert len(report.iterations) > 0 for iteration in report.iterations: for key in iteration.rule_set_report.rule_reports: - assert "__main__" not in key, f"Iteration rule key not translated: {key}" - assert "rewrite" in key, f"Expected Python rewrite syntax, got: {key}" + assert isinstance(key, RewriteDecl) def test_updated_field(self): egraph = self._setup_simple_egraph() @@ -117,7 +115,7 @@ def __mul__(self, other: Math) -> Math: ... rule_keys = list(report.search_and_apply_time_per_rule.keys()) assert len(rule_keys) == 2 for key in rule_keys: - assert "__main__" not in key, f"Key not translated: {key}" + assert isinstance(key, RewriteDecl) def test_empty_run(self): egraph = EGraph() @@ -158,7 +156,7 @@ def __add__(self, other: Num) -> Num: ... rule_keys = list(report.search_and_apply_time_per_rule.keys()) assert len(rule_keys) > 0 for key in rule_keys: - assert "__main__" not in key, f"RuleDecl key not translated: {key}" + assert isinstance(key, RuleDecl) def test_birewrite_decl(self): egraph = EGraph() @@ -178,23 +176,4 @@ def __mul__(self, other: Num) -> Num: ... rule_keys = list(report.search_and_apply_time_per_rule.keys()) assert len(rule_keys) > 0 for key in rule_keys: - assert "__main__" not in key, f"BiRewriteDecl key not translated: {key}" - assert "birewrite" in key, f"Expected birewrite() syntax, got: {key}" - - def test_rewrite_decl(self): - egraph = EGraph() - - class Num(Expr): - def __init__(self, n: i64Like) -> None: ... - def __add__(self, other: Num) -> Num: ... - - x, y = vars_("x y", Num) - egraph.register(rewrite(x + y).to(y + x)) - egraph.register(Num(1) + Num(2)) - report = egraph.run(10) - - rule_keys = list(report.search_and_apply_time_per_rule.keys()) - assert len(rule_keys) == 1 - key = rule_keys[0] - assert "rewrite" in key, f"Expected rewrite() syntax, got: {key}" - assert "__main__" not in key, f"RewriteDecl key not translated: {key}" + assert isinstance(key, BiRewriteDecl) From c09c7de32cca9f1e9d9b66a62558cba4366a5411 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Thu, 7 May 2026 02:10:52 +0100 Subject: [PATCH 6/7] fix: make methods private --- python/egglog/egraph.py | 4 ++-- python/egglog/run_report.py | 19 ++++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index bb7dfb73..aa87d41a 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -973,7 +973,7 @@ def _run_schedule(self, schedule: Schedule) -> RunReport: cmd = self._state.run_schedule_to_egg(schedule.schedule) (command_output,) = self._run_program(cmd) assert isinstance(command_output, bindings.RunScheduleOutput) - return RunReport.from_bindings(command_output.report, self._state) + return RunReport._from_bindings(command_output.report, self._state) def stats(self) -> RunReport: """ @@ -981,7 +981,7 @@ def stats(self) -> RunReport: """ (output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None)) assert isinstance(output, bindings.OverallStatistics) - return RunReport.from_bindings(output.report, self._state) + return RunReport._from_bindings(output.report, self._state) def check_bool(self, *facts: FactLike) -> bool: """ diff --git a/python/egglog/run_report.py b/python/egglog/run_report.py index 493a3fd9..1a5b0fc2 100644 --- a/python/egglog/run_report.py +++ b/python/egglog/run_report.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass, field from datetime import timedelta @@ -20,7 +21,7 @@ class RuleReport: num_matches: int @classmethod - def from_bindings(cls, report: bindings.RuleReport) -> RuleReport: + def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport: return cls( plan=report.plan, search_and_apply_time=report.search_and_apply_time, @@ -37,13 +38,13 @@ class RuleSetReport: _decls: Declarations = field(repr=False, default=None) @classmethod - def from_bindings( - cls, report: bindings.RuleSetReport, translate_key: callable, decls: Declarations + def _from_bindings( + cls, report: bindings.RuleSetReport, translate_key: Callable[[str], CommandDecl], decls: Declarations ) -> RuleSetReport: return cls( changed=report.changed, rule_reports={ - translate_key(k): [RuleReport.from_bindings(rr) for rr in v] for k, v in report.rule_reports.items() + translate_key(k): [RuleReport._from_bindings(rr) for rr in v] for k, v in report.rule_reports.items() }, search_and_apply_time=report.search_and_apply_time, merge_time=report.merge_time, @@ -66,11 +67,11 @@ class IterationReport: rebuild_time: timedelta @classmethod - def from_bindings( - cls, report: bindings.IterationReport, translate_key: callable, decls: Declarations + def _from_bindings( + cls, report: bindings.IterationReport, translate_key: Callable[[str], CommandDecl], decls: Declarations ) -> IterationReport: return cls( - rule_set_report=RuleSetReport.from_bindings(report.rule_set_report, translate_key, decls), + rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, translate_key, decls), rebuild_time=report.rebuild_time, ) @@ -102,10 +103,10 @@ def __repr__(self) -> str: ) @classmethod - def from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport: + def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport: return cls( iterations=[ - IterationReport.from_bindings(it, state.translate_rule_key, state.__egg_decls__) + IterationReport._from_bindings(it, state.translate_rule_key, state.__egg_decls__) for it in report.iterations ], updated=report.updated, From 3d7bec7c8cb08c744b90c6df910f8f727e50e7a0 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Thu, 7 May 2026 02:13:57 +0100 Subject: [PATCH 7/7] feat: update changelog --- docs/changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changelog.md b/docs/changelog.md index 461f8eac..275fcbbc 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ _This project uses semantic versioning_ ## 13.1.0 (2026-03-25) +- Add Python-friendly `RunReport` wrapper that returns `CommandDecl` objects as rule keys instead of raw egglog s-expression strings, with pretty-printed Python syntax in `str()` output [#416](https://github.com/egraphs-good/egglog-python/pull/416) - Improve high-level Python ergonomics and docs [#397](https://github.com/egraphs-good/egglog-python/pull/397) - Add `EGraph.freeze()`, returning a `FrozenEGraph` snapshot that can be pretty-printed back into replayable high-level Python actions for debugging and inspection. - Add a variadic `EGraph(*actions, seminaive=True, save_egglog_string=False)` constructor so actions can be registered at construction time, and export `ActionLike` from `egglog` for typing code that works with `EGraph.register(...)` and the constructor.