From d1e39706fa898551ceb1d6ef15c7c7e77094f869 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 11 Feb 2026 18:41:59 +0000 Subject: [PATCH] refactor: Ensure only valid IDs are propagated in identifier remapping Corrected remap_variables to only propagate column IDs that are actually present in the current node's output fields. This prevents parent nodes from seeing internal or leaked column IDs from child nodes, which was specifically problematic for aggregate nodes. Added unit tests to verify correct propagation for AggregateNode with and without grouping. --- bigframes/core/rewrite/identifiers.py | 20 +++------ tests/unit/core/rewrite/test_identifiers.py | 47 +++++++++++++++++++++ 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/bigframes/core/rewrite/identifiers.py b/bigframes/core/rewrite/identifiers.py index da43fdf8b9..8efcbb4a0b 100644 --- a/bigframes/core/rewrite/identifiers.py +++ b/bigframes/core/rewrite/identifiers.py @@ -57,11 +57,6 @@ def remap_variables( new_root = root.transform_children(lambda node: remapped_children[node]) # Step 3: Transform the current node using the mappings from its children. - # "reversed" is required for InNode so that in case of a duplicate column ID, - # the left child's mapping is the one that's kept. - downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = { - k: v for mapping in reversed(new_child_mappings) for k, v in mapping.items() - } if isinstance(new_root, nodes.InNode): new_root = typing.cast(nodes.InNode, new_root) new_root = dataclasses.replace( @@ -71,6 +66,9 @@ def remap_variables( ), ) else: + downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = { + k: v for mapping in new_child_mappings for k, v in mapping.items() + } new_root = new_root.remap_refs(downstream_mappings) # Step 4: Create new IDs for columns defined by the current node. @@ -82,12 +80,8 @@ def remap_variables( new_root._validate() # Step 5: Determine which mappings to propagate up to the parent. - if root.defines_namespace: - # If a node defines a new namespace (e.g., a join), mappings from its - # children are not visible to its parents. - mappings_for_parent = node_defined_mappings - else: - # Otherwise, pass up the combined mappings from children and the current node. - mappings_for_parent = downstream_mappings | node_defined_mappings + propagated_mappings = { + old_id: new_id for old_id, new_id in zip(root.ids, new_root.ids) + } - return new_root, mappings_for_parent + return new_root, propagated_mappings diff --git a/tests/unit/core/rewrite/test_identifiers.py b/tests/unit/core/rewrite/test_identifiers.py index 09904ac4ba..5a909243a7 100644 --- a/tests/unit/core/rewrite/test_identifiers.py +++ b/tests/unit/core/rewrite/test_identifiers.py @@ -14,10 +14,12 @@ import typing import bigframes.core as core +import bigframes.core.agg_expressions as agg_ex import bigframes.core.expression as ex import bigframes.core.identifiers as identifiers import bigframes.core.nodes as nodes import bigframes.core.rewrite.identifiers as id_rewrite +import bigframes.operations.aggregations as agg_ops def test_remap_variables_single_node(leaf): @@ -51,6 +53,51 @@ def test_remap_variables_projection(leaf): assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)} +def test_remap_variables_aggregate(leaf): + # Aggregation: sum(col_a) AS sum_a + # Group by nothing + agg_op = agg_ex.UnaryAggregation( + op=agg_ops.sum_op, + arg=ex.DerefOp(leaf.fields[0].id), + ) + node = nodes.AggregateNode( + child=leaf, + aggregations=((agg_op, identifiers.ColumnId("sum_a")),), + by_column_ids=(), + ) + + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + _, mapping = id_rewrite.remap_variables(node, id_generator) + + # leaf has 2 columns: col_a, col_b + # AggregateNode defines 1 column: sum_a + # Output of AggregateNode should only be sum_a + assert len(mapping) == 1 + assert identifiers.ColumnId("sum_a") in mapping + + +def test_remap_variables_aggregate_with_grouping(leaf): + # Aggregation: sum(col_b) AS sum_b + # Group by col_a + agg_op = agg_ex.UnaryAggregation( + op=agg_ops.sum_op, + arg=ex.DerefOp(leaf.fields[1].id), + ) + node = nodes.AggregateNode( + child=leaf, + aggregations=((agg_op, identifiers.ColumnId("sum_b")),), + by_column_ids=(ex.DerefOp(leaf.fields[0].id),), + ) + + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + _, mapping = id_rewrite.remap_variables(node, id_generator) + + # Output should have 2 columns: col_a (grouping) and sum_b (agg) + assert len(mapping) == 2 + assert leaf.fields[0].id in mapping + assert identifiers.ColumnId("sum_b") in mapping + + def test_remap_variables_nested_join_stability(leaf, fake_session, table): # Create two more distinct leaf nodes leaf2_uncached = core.ArrayValue.from_table(