-
Notifications
You must be signed in to change notification settings - Fork 22
feat: add pretty run report #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2007b1a
86133fa
f554bb2
c554d1a
810bd95
c09c7de
3d7bec7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,12 @@ def _normalize_global_let_name(name: str) -> str: | |
| return name if name.startswith("$") else f"${name}" | ||
|
|
||
|
|
||
| def _normalize_rule_key(key: str) -> str: | ||
| """Normalize an egglog rule string for consistent matching.""" | ||
| key = key.replace("'", '"') | ||
| return re.sub(r"\s+", " ", key).strip() | ||
|
|
||
|
|
||
| @dataclass | ||
| class EGraphState: | ||
| """ | ||
|
|
@@ -76,6 +82,8 @@ class EGraphState: | |
| type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) | ||
| egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict) | ||
|
|
||
| egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict) | ||
|
|
||
| # Cache of egg expressions for converting to egg | ||
| expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict) | ||
|
|
||
|
|
@@ -247,6 +255,13 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912 | |
| case _: | ||
| assert_never(schedule) | ||
|
|
||
| def translate_rule_key(self, egglog_key: str) -> CommandDecl: | ||
| """ | ||
| Look up the original Python CommandDecl for an egglog rule key. | ||
| """ | ||
| normalized = _normalize_rule_key(egglog_key) | ||
| return self.egg_rule_to_command_decl[normalized] | ||
|
|
||
| def ruleset_to_egg(self, ident: Ident) -> None: | ||
| """ | ||
| Registers a ruleset if it's not already registered. | ||
|
|
@@ -289,13 +304,19 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command | |
| self._expr_to_egg(rhs), | ||
| [self.fact_to_egg(c) for c in conditions], | ||
| ) | ||
| return ( | ||
| bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) | ||
| if isinstance(cmd, RewriteDecl) | ||
| else bindings.BiRewriteCommand(str(ruleset), rewrite) | ||
| ) | ||
| if isinstance(cmd, RewriteDecl): | ||
| egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) | ||
| else: | ||
| egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite) | ||
|
|
||
| normalized = _normalize_rule_key(str(egg_cmd)) | ||
| self.egg_rule_to_command_decl[normalized] = cmd | ||
| if isinstance(cmd, BiRewriteDecl): | ||
| self.egg_rule_to_command_decl[normalized + "=>"] = cmd | ||
| self.egg_rule_to_command_decl[normalized + "<="] = cmd | ||
| return egg_cmd | ||
| case RuleDecl(head, body, name): | ||
| return bindings.RuleCommand( | ||
| egg_cmd = bindings.RuleCommand( | ||
| bindings.Rule( | ||
| span(), | ||
| [self.action_to_egg(a) for a in head], | ||
|
|
@@ -304,6 +325,10 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command | |
| str(ruleset), | ||
| ) | ||
| ) | ||
| self.egg_rule_to_command_decl[_normalize_rule_key(str(egg_cmd))] = cmd | ||
| if name: | ||
| self.egg_rule_to_command_decl[name] = cmd | ||
|
Comment on lines
+328
to
+330
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a name is provided, we don't need to save the normalized version right? |
||
| return egg_cmd | ||
| # TODO: Replace with just constants value and looking at REF of function | ||
| case DefaultRewriteDecl(ref, expr, subsume): | ||
| sig = self.__egg_decls__.get_callable_decl(ref).signature | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
| from dataclasses import dataclass, field | ||
| from datetime import timedelta | ||
|
|
||
| from . import bindings | ||
| from .declarations import CommandDecl, Declarations | ||
| from .egraph_state import EGraphState | ||
| from .pretty import pretty_decl | ||
|
|
||
|
|
||
| def _format_rule_key(decls: Declarations, key: CommandDecl) -> str: | ||
| return pretty_decl(decls, key) | ||
|
Comment on lines
+13
to
+14
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove this extra function |
||
|
|
||
|
|
||
| @dataclass | ||
| class RuleReport: | ||
| plan: bindings.Plan | None | ||
| search_and_apply_time: timedelta | ||
| num_matches: int | ||
|
|
||
| @classmethod | ||
| def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport: | ||
| return cls( | ||
| plan=report.plan, | ||
| search_and_apply_time=report.search_and_apply_time, | ||
| num_matches=report.num_matches, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class RuleSetReport: | ||
| changed: bool | ||
| rule_reports: dict[CommandDecl, list[RuleReport]] | ||
| search_and_apply_time: timedelta | ||
| merge_time: timedelta | ||
| _decls: Declarations = field(repr=False, default=None) | ||
|
|
||
| @classmethod | ||
| def _from_bindings( | ||
| cls, report: bindings.RuleSetReport, translate_key: Callable[[str], CommandDecl], decls: Declarations | ||
| ) -> RuleSetReport: | ||
| return cls( | ||
| changed=report.changed, | ||
| rule_reports={ | ||
| translate_key(k): [RuleReport._from_bindings(rr) for rr in v] for k, v in report.rule_reports.items() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A birewrite will actually show up twice right? So if we want to keep both of them as one Python rule, then I think then we would need to add the timings for both of them? |
||
| }, | ||
| search_and_apply_time=report.search_and_apply_time, | ||
| merge_time=report.merge_time, | ||
| _decls=decls, | ||
| ) | ||
|
|
||
| def __repr__(self) -> str: | ||
| rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()} | ||
| return ( | ||
| f"RuleSetReport(changed={self.changed}, " | ||
| f"rule_reports={rule_reports_str}, " | ||
| f"search_and_apply_time={self.search_and_apply_time}, " | ||
| f"merge_time={self.merge_time})" | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class IterationReport: | ||
| rule_set_report: RuleSetReport | ||
| rebuild_time: timedelta | ||
|
|
||
| @classmethod | ||
| def _from_bindings( | ||
| cls, report: bindings.IterationReport, translate_key: Callable[[str], CommandDecl], decls: Declarations | ||
| ) -> IterationReport: | ||
| return cls( | ||
| rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, translate_key, decls), | ||
| rebuild_time=report.rebuild_time, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class RunReport: | ||
| """Python-friendly wrapper around bindings.RunReport.""" | ||
|
|
||
| iterations: list[IterationReport] | ||
| updated: bool | ||
| search_and_apply_time_per_rule: dict[CommandDecl, timedelta] | ||
| num_matches_per_rule: dict[CommandDecl, int] | ||
| search_and_apply_time_per_ruleset: dict[str, timedelta] | ||
| merge_time_per_ruleset: dict[str, timedelta] | ||
| rebuild_time_per_ruleset: dict[str, timedelta] | ||
|
Comment on lines
+87
to
+89
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have to do this, but we could also change these strings into actual ruleset objects? I am not sure if that would be too hard or awkward, and isn't essential for this PR. |
||
| _decls: Declarations = field(repr=False, default=None) | ||
|
|
||
| def __repr__(self) -> str: | ||
| time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()} | ||
| matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()} | ||
| return ( | ||
| f"RunReport(iterations={self.iterations}, " | ||
| f"updated={self.updated}, " | ||
| f"search_and_apply_time_per_rule={time_per_rule}, " | ||
| f"num_matches_per_rule={matches_per_rule}, " | ||
| f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, " | ||
| f"merge_time_per_ruleset={self.merge_time_per_ruleset}, " | ||
| f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})" | ||
| ) | ||
|
|
||
| @classmethod | ||
| def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport: | ||
| return cls( | ||
| iterations=[ | ||
| IterationReport._from_bindings(it, state.translate_rule_key, state.__egg_decls__) | ||
| for it in report.iterations | ||
| ], | ||
| updated=report.updated, | ||
| search_and_apply_time_per_rule={ | ||
| state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items() | ||
| }, | ||
| num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()}, | ||
| search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset, | ||
| merge_time_per_ruleset=report.merge_time_per_ruleset, | ||
| rebuild_time_per_ruleset=report.rebuild_time_per_ruleset, | ||
| _decls=state.__egg_decls__, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you leave a comment here that the birewrite is de-sugared into these two names in egglog which is why we have both of them?