Skip to content

Commit fbaba0b

Browse files
authored
refactor: ensure only valid IDs are propagated in identifier remapping (#2448)
1 parent 0434f26 commit fbaba0b

File tree

2 files changed

+54
-13
lines changed

2 files changed

+54
-13
lines changed

bigframes/core/rewrite/identifiers.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,6 @@ def remap_variables(
5757
new_root = root.transform_children(lambda node: remapped_children[node])
5858

5959
# Step 3: Transform the current node using the mappings from its children.
60-
# "reversed" is required for InNode so that in case of a duplicate column ID,
61-
# the left child's mapping is the one that's kept.
62-
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
63-
k: v for mapping in reversed(new_child_mappings) for k, v in mapping.items()
64-
}
6560
if isinstance(new_root, nodes.InNode):
6661
new_root = typing.cast(nodes.InNode, new_root)
6762
new_root = dataclasses.replace(
@@ -71,6 +66,9 @@ def remap_variables(
7166
),
7267
)
7368
else:
69+
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
70+
k: v for mapping in new_child_mappings for k, v in mapping.items()
71+
}
7472
new_root = new_root.remap_refs(downstream_mappings)
7573

7674
# Step 4: Create new IDs for columns defined by the current node.
@@ -82,12 +80,8 @@ def remap_variables(
8280
new_root._validate()
8381

8482
# Step 5: Determine which mappings to propagate up to the parent.
85-
if root.defines_namespace:
86-
# If a node defines a new namespace (e.g., a join), mappings from its
87-
# children are not visible to its parents.
88-
mappings_for_parent = node_defined_mappings
89-
else:
90-
# Otherwise, pass up the combined mappings from children and the current node.
91-
mappings_for_parent = downstream_mappings | node_defined_mappings
83+
propagated_mappings = {
84+
old_id: new_id for old_id, new_id in zip(root.ids, new_root.ids)
85+
}
9286

93-
return new_root, mappings_for_parent
87+
return new_root, propagated_mappings

tests/unit/core/rewrite/test_identifiers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515

1616
from bigframes.core import bq_data
1717
import bigframes.core as core
18+
import bigframes.core.agg_expressions as agg_ex
1819
import bigframes.core.expression as ex
1920
import bigframes.core.identifiers as identifiers
2021
import bigframes.core.nodes as nodes
2122
import bigframes.core.rewrite.identifiers as id_rewrite
23+
import bigframes.operations.aggregations as agg_ops
2224

2325

2426
def test_remap_variables_single_node(leaf):
@@ -52,6 +54,51 @@ def test_remap_variables_projection(leaf):
5254
assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)}
5355

5456

57+
def test_remap_variables_aggregate(leaf):
58+
# Aggregation: sum(col_a) AS sum_a
59+
# Group by nothing
60+
agg_op = agg_ex.UnaryAggregation(
61+
op=agg_ops.sum_op,
62+
arg=ex.DerefOp(leaf.fields[0].id),
63+
)
64+
node = nodes.AggregateNode(
65+
child=leaf,
66+
aggregations=((agg_op, identifiers.ColumnId("sum_a")),),
67+
by_column_ids=(),
68+
)
69+
70+
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
71+
_, mapping = id_rewrite.remap_variables(node, id_generator)
72+
73+
# leaf has 2 columns: col_a, col_b
74+
# AggregateNode defines 1 column: sum_a
75+
# Output of AggregateNode should only be sum_a
76+
assert len(mapping) == 1
77+
assert identifiers.ColumnId("sum_a") in mapping
78+
79+
80+
def test_remap_variables_aggregate_with_grouping(leaf):
81+
# Aggregation: sum(col_b) AS sum_b
82+
# Group by col_a
83+
agg_op = agg_ex.UnaryAggregation(
84+
op=agg_ops.sum_op,
85+
arg=ex.DerefOp(leaf.fields[1].id),
86+
)
87+
node = nodes.AggregateNode(
88+
child=leaf,
89+
aggregations=((agg_op, identifiers.ColumnId("sum_b")),),
90+
by_column_ids=(ex.DerefOp(leaf.fields[0].id),),
91+
)
92+
93+
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
94+
_, mapping = id_rewrite.remap_variables(node, id_generator)
95+
96+
# Output should have 2 columns: col_a (grouping) and sum_b (agg)
97+
assert len(mapping) == 2
98+
assert leaf.fields[0].id in mapping
99+
assert identifiers.ColumnId("sum_b") in mapping
100+
101+
55102
def test_remap_variables_nested_join_stability(leaf, fake_session, table):
56103
# Create two more distinct leaf nodes
57104
leaf2_uncached = core.ArrayValue.from_table(

0 commit comments

Comments
 (0)