Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions bigframes/core/rewrite/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ def remap_variables(
new_root = root.transform_children(lambda node: remapped_children[node])

# Step 3: Transform the current node using the mappings from its children.
# "reversed" is required for InNode so that in case of a duplicate column ID,
# the left child's mapping is the one that's kept.
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
k: v for mapping in reversed(new_child_mappings) for k, v in mapping.items()
}
if isinstance(new_root, nodes.InNode):
new_root = typing.cast(nodes.InNode, new_root)
new_root = dataclasses.replace(
Expand All @@ -71,6 +66,9 @@ def remap_variables(
),
)
else:
downstream_mappings: dict[identifiers.ColumnId, identifiers.ColumnId] = {
k: v for mapping in new_child_mappings for k, v in mapping.items()
}
new_root = new_root.remap_refs(downstream_mappings)

# Step 4: Create new IDs for columns defined by the current node.
Expand All @@ -82,12 +80,8 @@ def remap_variables(
new_root._validate()

# Step 5: Determine which mappings to propagate up to the parent.
if root.defines_namespace:
# If a node defines a new namespace (e.g., a join), mappings from its
# children are not visible to its parents.
mappings_for_parent = node_defined_mappings
else:
# Otherwise, pass up the combined mappings from children and the current node.
mappings_for_parent = downstream_mappings | node_defined_mappings
propagated_mappings = {
old_id: new_id for old_id, new_id in zip(root.ids, new_root.ids)
}

return new_root, mappings_for_parent
return new_root, propagated_mappings
47 changes: 47 additions & 0 deletions tests/unit/core/rewrite/test_identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

from bigframes.core import bq_data
import bigframes.core as core
import bigframes.core.agg_expressions as agg_ex
import bigframes.core.expression as ex
import bigframes.core.identifiers as identifiers
import bigframes.core.nodes as nodes
import bigframes.core.rewrite.identifiers as id_rewrite
import bigframes.operations.aggregations as agg_ops


def test_remap_variables_single_node(leaf):
Expand Down Expand Up @@ -52,6 +54,51 @@ def test_remap_variables_projection(leaf):
assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)}


def test_remap_variables_aggregate(leaf):
# Aggregation: sum(col_a) AS sum_a
# Group by nothing
agg_op = agg_ex.UnaryAggregation(
op=agg_ops.sum_op,
arg=ex.DerefOp(leaf.fields[0].id),
)
node = nodes.AggregateNode(
child=leaf,
aggregations=((agg_op, identifiers.ColumnId("sum_a")),),
by_column_ids=(),
)

id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
_, mapping = id_rewrite.remap_variables(node, id_generator)

# leaf has 2 columns: col_a, col_b
# AggregateNode defines 1 column: sum_a
# Output of AggregateNode should only be sum_a
assert len(mapping) == 1
assert identifiers.ColumnId("sum_a") in mapping


def test_remap_variables_aggregate_with_grouping(leaf):
# Aggregation: sum(col_b) AS sum_b
# Group by col_a
agg_op = agg_ex.UnaryAggregation(
op=agg_ops.sum_op,
arg=ex.DerefOp(leaf.fields[1].id),
)
node = nodes.AggregateNode(
child=leaf,
aggregations=((agg_op, identifiers.ColumnId("sum_b")),),
by_column_ids=(ex.DerefOp(leaf.fields[0].id),),
)

id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
_, mapping = id_rewrite.remap_variables(node, id_generator)

# Output should have 2 columns: col_a (grouping) and sum_b (agg)
assert len(mapping) == 2
assert leaf.fields[0].id in mapping
assert identifiers.ColumnId("sum_b") in mapping


def test_remap_variables_nested_join_stability(leaf, fake_session, table):
# Create two more distinct leaf nodes
leaf2_uncached = core.ArrayValue.from_table(
Expand Down
Loading