Skip to content

Commit 08c0783

Browse files
Yicong-Huangzhengruifeng
authored andcommitted
[SPARK-54617][PYTHON][SQL] Enable Arrow Grouped Iter Aggregate UDF registration for SQL
### What changes were proposed in this pull request? This PR enables Arrow grouped iter aggregate UDFs to be registered and used in SQL queries. Previously, Arrow iter aggregate UDFs could only be used via DataFrame API, but not in SQL. The main change is adding `SQL_GROUPED_AGG_ARROW_ITER_UDF` to the allowed eval types in `UDFRegistration.register()` method, along with comprehensive test cases. ### Why are the changes needed? Arrow iter aggregate UDFs provide a memory-efficient way to perform grouped aggregations by processing data in batches iteratively. However, they could only be used via DataFrame API, not in SQL queries. This limitation prevented users from using these UDFs in SQL-based workflows. ### Does this PR introduce _any_ user-facing change? Yes. Users can now register Arrow grouped iter aggregate UDFs and use them in SQL queries. Example: ```python from typing import Iterator from pyspark.sql.functions import arrow_udf import pyarrow as pa arrow_udf("double") def arrow_mean_iter(it: Iterator[pa.Array]) -> float: sum_val = 0.0 cnt = 0 for v in it: sum_val += pa.compute.sum(v).as_py() cnt += len(v) return sum_val / cnt if cnt > 0 else 0.0 # Now this works: spark.udf.register("arrow_mean_iter", arrow_mean_iter) spark.sql("SELECT id, arrow_mean_iter(v) as mean FROM test_table GROUP BY id").show() ``` ### How was this patch tested? Added comprehensive test cases covering: - Single column Arrow iter aggregate UDF in SQL - Multiple columns Arrow iter aggregate UDF in SQL ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53357 from Yicong-Huang/SPARK-54617/feat/arrow-iter-agg-udf-sql. Authored-by: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 3dffd12 commit 08c0783

File tree

4 files changed

+80
-3
lines changed

4 files changed

+80
-3
lines changed

python/pyspark/sql/connect/udf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,14 +295,16 @@ def register(
295295
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
296296
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
297297
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
298+
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
298299
]:
299300
raise PySparkTypeError(
300301
errorClass="INVALID_UDF_EVAL_TYPE",
301302
messageParameters={
302303
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
303304
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_ARROW_UDF, "
304305
"SQL_SCALAR_PANDAS_ITER_UDF, SQL_SCALAR_ARROW_ITER_UDF, "
305-
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
306+
"SQL_GROUPED_AGG_PANDAS_UDF, SQL_GROUPED_AGG_ARROW_UDF or "
307+
"SQL_GROUPED_AGG_ARROW_ITER_UDF"
306308
},
307309
)
308310
self.sparkSession._client.register_udf(

python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,78 @@ def arrow_count_sum_partial(it: Iterator[pa.Array]) -> dict:
12121212
group2_result["result"]["sum"], 2.0, places=5, msg="Group 2 should sum to 2.0"
12131213
)
12141214

1215+
def test_iterator_grouped_agg_sql_single_column(self):
1216+
"""
1217+
Test iterator API for grouped aggregation with single column in SQL.
1218+
"""
1219+
import pyarrow as pa
1220+
1221+
@arrow_udf("double")
1222+
def arrow_mean_iter(it: Iterator[pa.Array]) -> float:
1223+
sum_val = 0.0
1224+
cnt = 0
1225+
for v in it:
1226+
assert isinstance(v, pa.Array)
1227+
sum_val += pa.compute.sum(v).as_py()
1228+
cnt += len(v)
1229+
return sum_val / cnt if cnt > 0 else 0.0
1230+
1231+
df = self.spark.createDataFrame(
1232+
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
1233+
)
1234+
1235+
with self.tempView("test_table"), self.temp_func("arrow_mean_iter"):
1236+
df.createOrReplaceTempView("test_table")
1237+
self.spark.udf.register("arrow_mean_iter", arrow_mean_iter)
1238+
1239+
# Test SQL query with GROUP BY
1240+
result_sql = self.spark.sql(
1241+
"SELECT id, arrow_mean_iter(v) as mean FROM test_table GROUP BY id ORDER BY id"
1242+
)
1243+
expected = df.groupby("id").agg(sf.mean(df["v"]).alias("mean")).sort("id").collect()
1244+
1245+
self.assertEqual(expected, result_sql.collect())
1246+
1247+
def test_iterator_grouped_agg_sql_multiple_columns(self):
1248+
"""
1249+
Test iterator API for grouped aggregation with multiple columns in SQL.
1250+
"""
1251+
import pyarrow as pa
1252+
1253+
@arrow_udf("double")
1254+
def arrow_weighted_mean_iter(it: Iterator[Tuple[pa.Array, pa.Array]]) -> float:
1255+
weighted_sum = 0.0
1256+
weight = 0.0
1257+
for v, w in it:
1258+
assert isinstance(v, pa.Array)
1259+
assert isinstance(w, pa.Array)
1260+
weighted_sum += pa.compute.sum(pa.compute.multiply(v, w)).as_py()
1261+
weight += pa.compute.sum(w).as_py()
1262+
return weighted_sum / weight if weight > 0 else 0.0
1263+
1264+
df = self.spark.createDataFrame(
1265+
[(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
1266+
("id", "v", "w"),
1267+
)
1268+
1269+
with self.tempView("test_table"), self.temp_func("arrow_weighted_mean_iter"):
1270+
df.createOrReplaceTempView("test_table")
1271+
self.spark.udf.register("arrow_weighted_mean_iter", arrow_weighted_mean_iter)
1272+
1273+
# Test SQL query with GROUP BY and multiple columns
1274+
result_sql = self.spark.sql(
1275+
"SELECT id, arrow_weighted_mean_iter(v, w) as wm "
1276+
"FROM test_table GROUP BY id ORDER BY id"
1277+
)
1278+
1279+
# Expected weighted means:
1280+
# Group 1: (1.0*1.0 + 2.0*2.0) / (1.0 + 2.0) = 5.0 / 3.0
1281+
# Group 2: (3.0*1.0 + 5.0*2.0 + 10.0*3.0) / (1.0 + 2.0 + 3.0) = 43.0 / 6.0
1282+
expected = [Row(id=1, wm=5.0 / 3.0), Row(id=2, wm=43.0 / 6.0)]
1283+
1284+
actual_results = result_sql.collect()
1285+
self.assertEqual(actual_results, expected)
1286+
12151287

12161288
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
12171289
pass

python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ def test_register_grouped_map_udf(self):
212212
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
213213
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_ARROW_UDF, "
214214
"SQL_SCALAR_PANDAS_ITER_UDF, SQL_SCALAR_ARROW_ITER_UDF, "
215-
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
215+
"SQL_GROUPED_AGG_PANDAS_UDF, SQL_GROUPED_AGG_ARROW_UDF or "
216+
"SQL_GROUPED_AGG_ARROW_ITER_UDF"
216217
},
217218
)
218219

python/pyspark/sql/udf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,14 +681,16 @@ def register(
681681
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
682682
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
683683
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
684+
PythonEvalType.SQL_GROUPED_AGG_ARROW_ITER_UDF,
684685
]:
685686
raise PySparkTypeError(
686687
errorClass="INVALID_UDF_EVAL_TYPE",
687688
messageParameters={
688689
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
689690
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_ARROW_UDF, "
690691
"SQL_SCALAR_PANDAS_ITER_UDF, SQL_SCALAR_ARROW_ITER_UDF, "
691-
"SQL_GROUPED_AGG_PANDAS_UDF or SQL_GROUPED_AGG_ARROW_UDF"
692+
"SQL_GROUPED_AGG_PANDAS_UDF, SQL_GROUPED_AGG_ARROW_UDF or "
693+
"SQL_GROUPED_AGG_ARROW_ITER_UDF"
692694
},
693695
)
694696
source_udf = _create_udf(

0 commit comments

Comments
 (0)