diff --git a/bigframes/core/rewrite/identifiers.py b/bigframes/core/rewrite/identifiers.py index da43fdf8b9..8efcbb4a0b 100644 --- a/bigframes/core/rewrite/identifiers.py +++ b/bigframes/core/rewrite/identifiers.py @@ -57,11 +57,6 @@ def remap_variables( new_root = root.transform_children(lambda node: remapped_children[node]) # Step 3: Transform the current node using the mappings from its children. - # "reversed" is required for InNode so that in case of a duplicate column ID, - # the left child's mapping is the one that's kept. - downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = { - k: v for mapping in reversed(new_child_mappings) for k, v in mapping.items() - } if isinstance(new_root, nodes.InNode): new_root = typing.cast(nodes.InNode, new_root) new_root = dataclasses.replace( @@ -71,6 +66,9 @@ def remap_variables( ), ) else: + downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = { + k: v for mapping in new_child_mappings for k, v in mapping.items() + } new_root = new_root.remap_refs(downstream_mappings) # Step 4: Create new IDs for columns defined by the current node. @@ -82,12 +80,8 @@ def remap_variables( new_root._validate() # Step 5: Determine which mappings to propagate up to the parent. - if root.defines_namespace: - # If a node defines a new namespace (e.g., a join), mappings from its - # children are not visible to its parents. - mappings_for_parent = node_defined_mappings - else: - # Otherwise, pass up the combined mappings from children and the current node. - mappings_for_parent = downstream_mappings | node_defined_mappings + propagated_mappings = { + old_id: new_id for old_id, new_id in zip(root.ids, new_root.ids) + } - return new_root, mappings_for_parent + return new_root, propagated_mappings diff --git a/tests/unit/core/rewrite/test_identifiers.py b/tests/unit/core/rewrite/test_identifiers.py index 9def077a89..54bcd85e3e 100644 --- a/tests/unit/core/rewrite/test_identifiers.py +++ b/tests/unit/core/rewrite/test_identifiers.py @@ -15,10 +15,12 @@ from bigframes.core import bq_data import bigframes.core as core +import bigframes.core.agg_expressions as agg_ex import bigframes.core.expression as ex import bigframes.core.identifiers as identifiers import bigframes.core.nodes as nodes import bigframes.core.rewrite.identifiers as id_rewrite +import bigframes.operations.aggregations as agg_ops def test_remap_variables_single_node(leaf): @@ -52,6 +54,51 @@ def test_remap_variables_projection(leaf): assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)} +def test_remap_variables_aggregate(leaf): + # Aggregation: sum(col_a) AS sum_a + # Group by nothing + agg_op = agg_ex.UnaryAggregation( + op=agg_ops.sum_op, + arg=ex.DerefOp(leaf.fields[0].id), + ) + node = nodes.AggregateNode( + child=leaf, + aggregations=((agg_op, identifiers.ColumnId("sum_a")),), + by_column_ids=(), + ) + + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + _, mapping = id_rewrite.remap_variables(node, id_generator) + + # leaf has 2 columns: col_a, col_b + # AggregateNode defines 1 column: sum_a + # Output of AggregateNode should only be sum_a + assert len(mapping) == 1 + assert identifiers.ColumnId("sum_a") in mapping + + +def test_remap_variables_aggregate_with_grouping(leaf): + # Aggregation: sum(col_b) AS sum_b + # Group by col_a + agg_op = agg_ex.UnaryAggregation( + op=agg_ops.sum_op, + arg=ex.DerefOp(leaf.fields[1].id), + ) + node = nodes.AggregateNode( + child=leaf, + aggregations=((agg_op, identifiers.ColumnId("sum_b")),), + by_column_ids=(ex.DerefOp(leaf.fields[0].id),), + ) + + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + _, mapping = id_rewrite.remap_variables(node, id_generator) + + # Output should have 2 columns: col_a (grouping) and sum_b (agg) + assert len(mapping) == 2 + assert leaf.fields[0].id in mapping + assert identifiers.ColumnId("sum_b") in mapping + + def test_remap_variables_nested_join_stability(leaf, fake_session, table): # Create two more distinct leaf nodes leaf2_uncached = core.ArrayValue.from_table(