From 746d1e08bbd8a797428401341b320d9fda03af4a Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 9 Dec 2025 18:15:59 -0800 Subject: [PATCH 1/2] Update egglog to fix #387 --- Cargo.lock | 18 +++++++++--------- Cargo.toml | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 90841891..5263bef7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -317,7 +317,7 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "egglog" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "csv", "dyn-clone", @@ -344,7 +344,7 @@ dependencies = [ [[package]] name = "egglog-add-primitive" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "quote", "syn 2.0.107", @@ -353,7 +353,7 @@ dependencies = [ [[package]] name = "egglog-ast" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "ordered-float", ] @@ -361,7 +361,7 @@ dependencies = [ [[package]] name = "egglog-bridge" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "anyhow", "dyn-clone", @@ -385,7 +385,7 @@ dependencies = [ [[package]] name = "egglog-concurrency" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "arc-swap", "rayon", @@ -394,7 +394,7 @@ dependencies = [ [[package]] name = "egglog-core-relations" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "anyhow", "bumpalo", @@ -437,7 +437,7 @@ dependencies = [ [[package]] name = "egglog-numeric-id" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "rayon", ] @@ -445,7 +445,7 @@ dependencies = [ [[package]] name = "egglog-reports" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "clap", "hashbrown 0.16.0", @@ -459,7 +459,7 @@ dependencies = [ [[package]] name = "egglog-union-find" version = "1.0.0" -source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=fix-fn-bug#fdc03716d3acc47d603a9797bb50330025db0269" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=dag-extract#cbf82c9755f0422d53b995b3e3223d991ab18905" dependencies = [ "crossbeam", "egglog-concurrency", diff --git a/Cargo.toml b/Cargo.toml index 4d468878..185cd331 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,12 +13,12 @@ crate-type = ["cdylib"] pyo3 = { version = "0.27", features = ["extension-module", "num-bigint", "num-rational"] } num-bigint = "*" num-rational = "*" -egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug", default-features = false } -egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract", default-features = false } +egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } +egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "main", default-features = false } -egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } +egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] } serde_json = "1" pyo3-log = "*" @@ -31,11 +31,11 @@ base64 = "0.22.1" # Use patched version of egglog in experimental [patch.'https://github.com/egraphs-good/egglog'] -egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } -egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-fn-bug" } +egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } +egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } +egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } +egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } +egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "dag-extract" } # enable debug symbols for easier profiling [profile.release] From aab8e15e799015fcd665865bc64d68cbc044d8d7 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 9 Dec 2025 18:18:31 -0800 Subject: [PATCH 2/2] Add test case --- python/tests/test_extract_bug.py | 131 +++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 python/tests/test_extract_bug.py diff --git a/python/tests/test_extract_bug.py b/python/tests/test_extract_bug.py new file mode 100644 index 00000000..7dc3718b --- /dev/null +++ b/python/tests/test_extract_bug.py @@ -0,0 +1,131 @@ +""" +Tests extraction with a DAG-based cost model. +from https://github.com/egraphs-good/egglog-python/issues/387#issuecomment-3628927075 +""" + +from dataclasses import dataclass, field + +from egglog import * +from egglog import bindings + +# A cost model, approximately equivalent to, greedy_dag_cost_model, +# which operates purely on the `bindings` level, for the sake of +# minimization. + +ENode = tuple[str, tuple[bindings.Value, ...]] + + +@dataclass +class DAGCostValue: + """Cost value for DAG-based extraction.""" + + cost: int + _values: dict[ENode, int] + + def __eq__(self, rhs: object) -> bool: + if not isinstance(rhs, DAGCostValue): + return False + return self.cost == rhs.cost + + def __lt__(self, other: "DAGCostValue") -> bool: + return self.cost < other.cost + + def __le__(self, other: "DAGCostValue") -> bool: + return self.cost <= other.cost + + def __gt__(self, other: "DAGCostValue") -> bool: + return self.cost > other.cost + + def __ge__(self, other: "DAGCostValue") -> bool: + return self.cost >= other.cost + + def __hash__(self) -> int: + return hash(self.cost) + + def __str__(self) -> str: + return f"DAGCostValue(cost={self.cost})" + + def __repr__(self) -> str: + return f"DAGCostValue(cost={self.cost}, nchildren={len(self._values)})" + + +@dataclass +class DAGCost: + """ + DAG-based cost model for e-graph extraction. + + This cost model counts each unique e-node once, implementing + a greedy DAG extraction strategy. + """ + + graph: bindings.EGraph + cache: dict[ENode, DAGCostValue] = field(default_factory=dict) + + def merge_costs(self, costs: list[DAGCostValue], node: ENode, self_cost: int = 0) -> DAGCostValue: + # if node in self.cache: + # return self.cache[node] + values: dict[ENode, int] = {} + for child in costs: + values.update(child._values) + cost = DAGCostValue(cost=sum(values.values(), start=self_cost), _values=values) + cost._values[node] = self_cost + # self.cache[node] = cost + # print(f"merge {costs=} out={cost}") + return cost + + def cost_fold(self, fn: str, enode: ENode, children_costs: list[DAGCostValue]) -> DAGCostValue: + return self.merge_costs(children_costs, enode, 1) + # print(f"fold {fn=} {out=}") + + def enode_cost(self, name: str, args: list[bindings.Value]) -> ENode: + return (name, tuple(args)) + + def container_cost(self, tp: str, value: bindings.Value, element_costs: list[DAGCostValue]) -> DAGCostValue: + return self.merge_costs(element_costs, (tp, (value,)), 1) + + def base_value_cost(self, tp: str, value: bindings.Value) -> DAGCostValue: + return self.merge_costs([], (tp, (value,)), 1) + + @property + def egg_cost_model(self) -> bindings.CostModel: + return bindings.CostModel( + fold=self.cost_fold, + enode_cost=self.enode_cost, + container_cost=self.container_cost, + base_value_cost=self.base_value_cost, + ) + + +def test_dag_cost_model(): + graph = EGraph() + + commands = graph._egraph.parse_program(""" + (sort S) + + (constructor Si (i64) S) + (constructor Swide (S S S S S S S S) S ) + (constructor Ssa (S) S) + (constructor Ssb (S) S) + (constructor Ssc (S) S) + (constructor Sp (S S) S) + + + (let w + (Swide (Si 0) (Si 1) (Si 2) (Si 3) (Si 4) (Si 5) (Si 6) (Si 7))) + + (let l (Ssa (Ssb (Ssc (Si 0))))) + (let x (Ssa w)) + (let v (Sp w x)) + + (union x l) + """) + graph._egraph.run_program(*commands) + + cost_model = DAGCost(graph._egraph) + extractor = bindings.Extractor(["S"], graph._egraph, cost_model.egg_cost_model) + termdag = bindings.TermDag() + value = graph._egraph.lookup_function("v", []) + assert value is not None + cost, _term = extractor.extract_best(graph._egraph, termdag, value, "S") + + assert cost.cost in {19, 21}