diff --git a/src/taskgraph/graph.py b/src/taskgraph/graph.py index 57327508..a521cdfc 100644 --- a/src/taskgraph/graph.py +++ b/src/taskgraph/graph.py @@ -61,21 +61,29 @@ def transitive_closure(self, nodes, reverse=False): f"Unknown nodes in transitive closure: {nodes - self.nodes}" ) - # generate a new graph by expanding along edges until reaching a fixed - # point - new_nodes, new_edges = nodes, set() - nodes, edges = set(), set() - while (new_nodes, new_edges) != (nodes, edges): - nodes, edges = new_nodes, new_edges - add_edges = { - (left, right, name) - for (left, right, name) in self.edges - if (right if reverse else left) in nodes - } - add_nodes = {(left if reverse else right) for (left, right, _) in add_edges} - new_nodes = nodes | add_nodes - new_edges = edges | add_edges - return Graph(new_nodes, new_edges) + # Build an adjacency list once, then BFS — O(Vertices + Edges) + adjacency = collections.defaultdict(set) + for left, right, _name in self.edges: + if reverse: + adjacency[right].add(left) + else: + adjacency[left].add(right) + + result_nodes = set(nodes) + queue = collections.deque(nodes) + while queue: + node = queue.popleft() + for neighbor in adjacency.get(node, ()): + if neighbor not in result_nodes: + result_nodes.add(neighbor) + queue.append(neighbor) + + result_edges = frozenset( + (left, right, name) + for left, right, name in self.edges + if left in result_nodes and right in result_nodes + ) + return Graph(frozenset(result_nodes), result_edges) def _visit(self, reverse): forward_links, reverse_links = self.links_and_reverse_links_dict() diff --git a/test/test_graph.py b/test/test_graph.py index f1d683cc..2d14acd6 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -134,6 +134,46 @@ def test_transitive_closure_loopy(self): "transitive closure of a loop is the whole loop" self.assertEqual(self.loopy.transitive_closure({"A"}), self.loopy) + def test_transitive_closure_reverse_tree(self): + "reverse transitive closure from leaf nodes finds ancestors" + self.assertEqual( + self.tree.transitive_closure({"d"}, reverse=True), + Graph( + {"a", "b", "d"}, + {("a", "b", "L"), ("b", "d", "K")}, + ), + ) + + def test_transitive_closure_reverse_root(self): + "reverse transitive closure from root has no ancestors" + self.assertEqual( + self.tree.transitive_closure({"a"}, reverse=True), + Graph({"a"}, set()), + ) + + def test_transitive_closure_unknown_nodes(self): + "transitive closure raises on nodes not in the graph" + with pytest.raises(Exception, match="Unknown nodes"): + self.tree.transitive_closure({"z"}) + + def test_transitive_closure_diamond(self): + "transitive closure on a diamond graph reaches shared descendants via convergent paths" + self.assertEqual( + self.diamonds.transitive_closure({"A"}), + Graph( + {"A", "D", "F", "G", "I", "J"}, + { + ("A", "F", "L"), + ("A", "D", "L"), + ("D", "F", "L"), + ("D", "G", "L"), + ("F", "I", "L"), + ("G", "I", "L"), + ("G", "J", "L"), + }, + ), + ) + def test_visit_postorder_empty(self): "postorder visit of an empty graph is empty" self.assertEqual(list(Graph(set(), set()).visit_postorder()), [])