|
15 | 15 |
|
16 | 16 | from bigframes.core import bq_data |
17 | 17 | import bigframes.core as core |
| 18 | +import bigframes.core.agg_expressions as agg_ex |
18 | 19 | import bigframes.core.expression as ex |
19 | 20 | import bigframes.core.identifiers as identifiers |
20 | 21 | import bigframes.core.nodes as nodes |
21 | 22 | import bigframes.core.rewrite.identifiers as id_rewrite |
| 23 | +import bigframes.operations.aggregations as agg_ops |
22 | 24 |
|
23 | 25 |
|
24 | 26 | def test_remap_variables_single_node(leaf): |
@@ -52,6 +54,51 @@ def test_remap_variables_projection(leaf): |
52 | 54 | assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)} |
53 | 55 |
|
54 | 56 |
|
| 57 | +def test_remap_variables_aggregate(leaf): |
| 58 | + # Aggregation: sum(col_a) AS sum_a |
| 59 | + # Group by nothing |
| 60 | + agg_op = agg_ex.UnaryAggregation( |
| 61 | + op=agg_ops.sum_op, |
| 62 | + arg=ex.DerefOp(leaf.fields[0].id), |
| 63 | + ) |
| 64 | + node = nodes.AggregateNode( |
| 65 | + child=leaf, |
| 66 | + aggregations=((agg_op, identifiers.ColumnId("sum_a")),), |
| 67 | + by_column_ids=(), |
| 68 | + ) |
| 69 | + |
| 70 | + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) |
| 71 | + _, mapping = id_rewrite.remap_variables(node, id_generator) |
| 72 | + |
| 73 | + # leaf has 2 columns: col_a, col_b |
| 74 | + # AggregateNode defines 1 column: sum_a |
| 75 | + # Output of AggregateNode should only be sum_a |
| 76 | + assert len(mapping) == 1 |
| 77 | + assert identifiers.ColumnId("sum_a") in mapping |
| 78 | + |
| 79 | + |
| 80 | +def test_remap_variables_aggregate_with_grouping(leaf): |
| 81 | + # Aggregation: sum(col_b) AS sum_b |
| 82 | + # Group by col_a |
| 83 | + agg_op = agg_ex.UnaryAggregation( |
| 84 | + op=agg_ops.sum_op, |
| 85 | + arg=ex.DerefOp(leaf.fields[1].id), |
| 86 | + ) |
| 87 | + node = nodes.AggregateNode( |
| 88 | + child=leaf, |
| 89 | + aggregations=((agg_op, identifiers.ColumnId("sum_b")),), |
| 90 | + by_column_ids=(ex.DerefOp(leaf.fields[0].id),), |
| 91 | + ) |
| 92 | + |
| 93 | + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) |
| 94 | + _, mapping = id_rewrite.remap_variables(node, id_generator) |
| 95 | + |
| 96 | + # Output should have 2 columns: col_a (grouping) and sum_b (agg) |
| 97 | + assert len(mapping) == 2 |
| 98 | + assert leaf.fields[0].id in mapping |
| 99 | + assert identifiers.ColumnId("sum_b") in mapping |
| 100 | + |
| 101 | + |
55 | 102 | def test_remap_variables_nested_join_stability(leaf, fake_session, table): |
56 | 103 | # Create two more distinct leaf nodes |
57 | 104 | leaf2_uncached = core.ArrayValue.from_table( |
|
0 commit comments