From d1464657c3e7e21ddb73ad908deb69bd24f0e979 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 10 Feb 2026 21:57:48 +0000 Subject: [PATCH] refactor: add brackets on sqlglot expression --- .../sqlglot/aggregations/unary_compiler.py | 6 +++- .../sqlglot/expressions/generic_ops.py | 5 ++- .../sqlglot/expressions/numeric_ops.py | 2 +- .../test_unary_compiler/test_cut/int_bins.sql | 4 ++- .../test_cut/int_bins_labels.sql | 4 ++- .../test_unary_compiler/test_qcut/out.sql | 32 +++++++++---------- .../test_generic_ops/test_notnull/out.sql | 4 +-- .../test_numeric_ops/test_isfinite/out.sql | 3 ++ .../test_numeric_ops/test_pow/out.sql | 21 +++++++++--- .../sqlglot/expressions/test_numeric_ops.py | 11 +++++++ .../out.sql | 12 +++---- .../out.sql | 4 +-- .../out.sql | 4 +-- .../out.sql | 4 +-- 14 files changed, 76 insertions(+), 40 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_isfinite/out.sql diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index 381b472bce..add3ccd923 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -16,6 +16,7 @@ import typing +import bigframes_vendored.sqlglot as sg import bigframes_vendored.sqlglot.expressions as sge import pandas as pd @@ -189,7 +190,10 @@ def _cut_ops_w_int_bins( condition: sge.Expression if this_bin == bins - 1: - condition = sge.Is(this=column.expr, expression=sge.Not(this=sge.Null())) + condition = sge.Is( + this=sge.paren(column.expr, copy=False), + expression=sg.not_(sge.Null(), copy=False), + ) else: if op.right: condition = sge.LTE( diff --git a/bigframes/core/compile/sqlglot/expressions/generic_ops.py b/bigframes/core/compile/sqlglot/expressions/generic_ops.py index 94ff12a7ef..14af91e591 100644 --- a/bigframes/core/compile/sqlglot/expressions/generic_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/generic_ops.py @@ -125,7 +125,10 @@ def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression: @register_unary_op(ops.notnull_op) def _(expr: TypedExpr) -> sge.Expression: - return sge.Not(this=sge.Is(this=sge.paren(expr.expr), expression=sge.Null())) + return sge.Is( + this=sge.paren(expr.expr, copy=False), + expression=sg.not_(sge.Null(), copy=False), + ) @register_ternary_op(ops.where_op) diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index f2ae6cd82e..2285a3a0bc 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -362,7 +362,7 @@ def _float_pow_op( sge.If( this=sge.and_( sge.LT(this=left_expr, expression=constants._ZERO), - sge.Not(this=exponent_is_whole), + sge.Not(this=sge.paren(exponent_is_whole)), ), true=constants._NAN, ), diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql index d7b0fde710..0a4aa961ab 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql @@ -30,7 +30,9 @@ SELECT 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) ) + 0 AS `right_inclusive` ) - WHEN `int64_col` IS NOT NULL + WHEN ( + `int64_col` + ) IS NOT NULL THEN STRUCT( ( MIN(`int64_col`) OVER () + ( diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql index 1a3aede050..b104228836 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql @@ -8,7 +8,9 @@ SELECT 2 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3) ) THEN 'b' - WHEN `int64_col` IS NOT NULL + WHEN ( + `int64_col` + ) IS NOT NULL THEN 'c' END AS `int_bins_labels` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql index e24f505030..35a95c5367 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_qcut/out.sql @@ -2,17 +2,17 @@ SELECT `rowindex`, `int64_col`, IF( - NOT ( + ( `int64_col` - ) IS NULL, + ) IS NOT NULL, IF( `int64_col` IS NULL, NULL, CAST(GREATEST( CEIL( - PERCENT_RANK() OVER (PARTITION BY NOT ( + PERCENT_RANK() OVER (PARTITION BY ( `int64_col` - ) IS NULL ORDER BY `int64_col` ASC) * 4 + ) IS NOT NULL ORDER BY `int64_col` ASC) * 4 ) - 1, 0 ) AS INT64) @@ -20,29 +20,29 @@ SELECT NULL ) AS `qcut_w_int`, IF( - NOT ( + ( `int64_col` - ) IS NULL, + ) IS NOT NULL, CASE - WHEN PERCENT_RANK() OVER (PARTITION BY NOT ( + WHEN PERCENT_RANK() OVER (PARTITION BY ( `int64_col` - ) IS NULL ORDER BY `int64_col` ASC) < 0 + ) IS NOT NULL ORDER BY `int64_col` ASC) < 0 THEN NULL - WHEN PERCENT_RANK() OVER (PARTITION BY NOT ( + WHEN PERCENT_RANK() OVER (PARTITION BY ( `int64_col` - ) IS NULL ORDER BY `int64_col` ASC) <= 0.25 + ) IS NOT NULL ORDER BY `int64_col` ASC) <= 0.25 THEN 0 - WHEN PERCENT_RANK() OVER (PARTITION BY NOT ( + WHEN PERCENT_RANK() OVER (PARTITION BY ( `int64_col` - ) IS NULL ORDER BY `int64_col` ASC) <= 0.5 + ) IS NOT NULL ORDER BY `int64_col` ASC) <= 0.5 THEN 1 - WHEN PERCENT_RANK() OVER (PARTITION BY NOT ( + WHEN PERCENT_RANK() OVER (PARTITION BY ( `int64_col` - ) IS NULL ORDER BY `int64_col` ASC) <= 0.75 + ) IS NOT NULL ORDER BY `int64_col` ASC) <= 0.75 THEN 2 - WHEN PERCENT_RANK() OVER (PARTITION BY NOT ( + WHEN PERCENT_RANK() OVER (PARTITION BY ( `int64_col` - ) IS NULL ORDER BY `int64_col` ASC) <= 1 + ) IS NOT NULL ORDER BY `int64_col` ASC) <= 1 THEN 3 ELSE NULL END, diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_notnull/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_notnull/out.sql index 1865e24c4c..c65fda76eb 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_notnull/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_notnull/out.sql @@ -1,5 +1,5 @@ SELECT - NOT ( + ( `float64_col` - ) IS NULL AS `float64_col` + ) IS NOT NULL AS `float64_col` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_isfinite/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_isfinite/out.sql new file mode 100644 index 0000000000..500d6a6769 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_isfinite/out.sql @@ -0,0 +1,3 @@ +SELECT + NOT IS_INF(`float64_col`) OR IS_NAN(`float64_col`) AS `float64_col` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql index 213d8a011b..8455e4a66f 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_pow/out.sql @@ -33,7 +33,9 @@ SELECT END ) WHEN `int64_col` < CAST(0 AS INT64) - AND NOT CAST(`float64_col` AS INT64) = `float64_col` + AND NOT ( + CAST(`float64_col` AS INT64) = `float64_col` + ) THEN CAST('NaN' AS FLOAT64) WHEN `int64_col` <> CAST(0 AS INT64) AND `float64_col` * LN(ABS(`int64_col`)) > 709.78 THEN CAST('Infinity' AS FLOAT64) * CASE @@ -75,7 +77,10 @@ SELECT ELSE `int64_col` END ) - WHEN `float64_col` < CAST(0 AS INT64) AND NOT CAST(`int64_col` AS INT64) = `int64_col` + WHEN `float64_col` < CAST(0 AS INT64) + AND NOT ( + CAST(`int64_col` AS INT64) = `int64_col` + ) THEN CAST('NaN' AS FLOAT64) WHEN `float64_col` <> CAST(0 AS INT64) AND `int64_col` * LN(ABS(`float64_col`)) > 709.78 @@ -119,7 +124,9 @@ SELECT END ) WHEN `float64_col` < CAST(0 AS INT64) - AND NOT CAST(`float64_col` AS INT64) = `float64_col` + AND NOT ( + CAST(`float64_col` AS INT64) = `float64_col` + ) THEN CAST('NaN' AS FLOAT64) WHEN `float64_col` <> CAST(0 AS INT64) AND `float64_col` * LN(ABS(`float64_col`)) > 709.78 @@ -167,7 +174,9 @@ SELECT ELSE 0 END ) - WHEN `float64_col` < CAST(0 AS INT64) AND NOT CAST(0 AS INT64) = 0 + WHEN `float64_col` < CAST(0 AS INT64) AND NOT ( + CAST(0 AS INT64) = 0 + ) THEN CAST('NaN' AS FLOAT64) WHEN `float64_col` <> CAST(0 AS INT64) AND 0 * LN(ABS(`float64_col`)) > 709.78 THEN CAST('Infinity' AS FLOAT64) * CASE @@ -214,7 +223,9 @@ SELECT ELSE 1 END ) - WHEN `float64_col` < CAST(0 AS INT64) AND NOT CAST(1 AS INT64) = 1 + WHEN `float64_col` < CAST(0 AS INT64) AND NOT ( + CAST(1 AS INT64) = 1 + ) THEN CAST('NaN' AS FLOAT64) WHEN `float64_col` <> CAST(0 AS INT64) AND 1 * LN(ABS(`float64_col`)) > 709.78 THEN CAST('Infinity' AS FLOAT64) * CASE diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index 1a08a80eb1..f0237159bc 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -17,6 +17,7 @@ from bigframes import operations as ops import bigframes.core.expression as ex +from bigframes.operations import numeric_ops import bigframes.pandas as bpd from bigframes.testing import utils @@ -156,6 +157,16 @@ def test_floor(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_isfinite(scalar_types_df: bpd.DataFrame, snapshot): + col_name = "float64_col" + bf_df = scalar_types_df[[col_name]] + sql = utils._apply_ops_to_sql( + bf_df, [numeric_ops.isfinite_op.as_expr(col_name)], [col_name] + ) + + snapshot.assert_match(sql, "out.sql") + + def test_ln(scalar_types_df: bpd.DataFrame, snapshot): col_name = "float64_col" bf_df = scalar_types_df[[col_name]] diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql index 155b7fae20..b91aafcbee 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql @@ -3,9 +3,9 @@ SELECT `rowindex`, CASE WHEN COALESCE( - SUM(CAST(NOT ( + SUM(CAST(( `bool_col` - ) IS NULL AS INT64)) OVER ( + ) IS NOT NULL AS INT64)) OVER ( PARTITION BY `bool_col` ORDER BY `bool_col` ASC NULLS LAST, `rowindex` ASC NULLS LAST ROWS BETWEEN 3 PRECEDING AND CURRENT ROW @@ -25,9 +25,9 @@ SELECT END AS `bool_col_1`, CASE WHEN COALESCE( - SUM(CAST(NOT ( + SUM(CAST(( `int64_col` - ) IS NULL AS INT64)) OVER ( + ) IS NOT NULL AS INT64)) OVER ( PARTITION BY `bool_col` ORDER BY `bool_col` ASC NULLS LAST, `rowindex` ASC NULLS LAST ROWS BETWEEN 3 PRECEDING AND CURRENT ROW @@ -47,9 +47,9 @@ SELECT END AS `int64_col` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` WHERE - NOT ( + ( `bool_col` - ) IS NULL + ) IS NOT NULL ORDER BY `bool_col` ASC NULLS LAST, `rowindex` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql index 2cee33d599..887e7e9212 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql @@ -7,9 +7,9 @@ SELECT `bfcol_0` AS `ts_col`, CASE WHEN COALESCE( - SUM(CAST(NOT ( + SUM(CAST(( `bfcol_1` - ) IS NULL AS INT64)) OVER ( + ) IS NOT NULL AS INT64)) OVER ( ORDER BY UNIX_MICROS(`bfcol_0`) ASC RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW ), diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql index 03babee380..8a8bf6445a 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql @@ -2,9 +2,9 @@ SELECT `rowindex`, CASE WHEN COALESCE( - SUM(CAST(NOT ( + SUM(CAST(( `int64_col` - ) IS NULL AS INT64)) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), + ) IS NOT NULL AS INT64)) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), 0 ) < 3 THEN NULL diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_wo_skips_nulls_op/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_wo_skips_nulls_op/out.sql index f9496e983e..cf14f1cd05 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_wo_skips_nulls_op/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_wo_skips_nulls_op/out.sql @@ -1,9 +1,9 @@ SELECT `rowindex`, CASE - WHEN COUNT(NOT ( + WHEN COUNT(( `int64_col` - ) IS NULL) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) < 5 + ) IS NOT NULL) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) < 5 THEN NULL WHEN TRUE THEN COUNT(`int64_col`) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 4 PRECEDING AND CURRENT ROW)