From 400637e7056589a5eade5f7160494579e3a8dac9 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:18:09 +0000 Subject: [PATCH 01/18] refactor: unify verify_result and container-type checks into verify_return_type --- dev/sparktestsupport/modules.py | 1 + .../pyspark/tests/test_verify_return_type.py | 90 ++++++++++++++++++ python/pyspark/worker.py | 91 +++++++++++++------ 3 files changed, 152 insertions(+), 30 deletions(-) create mode 100644 python/pyspark/tests/test_verify_return_type.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index bd7d1f55aaee5..14ac7347e5a11 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -517,6 +517,7 @@ def __hash__(self): "pyspark.tests.test_statcounter", "pyspark.tests.test_taskcontext", "pyspark.tests.test_util", + "pyspark.tests.test_verify_return_type", "pyspark.tests.test_worker", "pyspark.tests.test_stage_sched", # unittests for upstream projects diff --git a/python/pyspark/tests/test_verify_return_type.py b/python/pyspark/tests/test_verify_return_type.py new file mode 100644 index 0000000000000..3f18999533da4 --- /dev/null +++ b/python/pyspark/tests/test_verify_return_type.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest +from typing import Iterator + +from pyspark.errors import PySparkTypeError +from pyspark.testing.utils import have_pyarrow, pyarrow_requirement_message +from pyspark.worker import verify_return_type + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class VerifyReturnTypeTests(unittest.TestCase): + def test_non_iterator_accepts_matching_type(self): + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1, 2])], names=["x"]) + self.assertIs(verify_return_type(batch, pa.RecordBatch), batch) + + def test_non_iterator_rejects_wrong_type(self): + import pyarrow as pa + + with self.assertRaises(PySparkTypeError) as ctx: + verify_return_type(123, pa.RecordBatch) + self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") + self.assertEqual( + ctx.exception.getMessageParameters(), + {"expected": "pyarrow.RecordBatch", "actual": "int"}, + ) + + def test_iterator_accepts_and_is_lazy(self): + import pyarrow as pa + + arrays = [pa.array([1]), pa.array([2])] + verified = verify_return_type(iter(arrays), Iterator[pa.Array]) + self.assertEqual(list(verified), arrays) + + def test_iterator_rejects_non_iterable(self): + import pyarrow as pa + + with self.assertRaises(PySparkTypeError) as ctx: + verify_return_type(5, Iterator[pa.RecordBatch]) + self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") + self.assertEqual( + ctx.exception.getMessageParameters(), + {"expected": "iterator of pyarrow.RecordBatch", "actual": "int"}, + ) + + def test_iterator_rejects_non_iterator_iterable(self): + import pyarrow as pa + + # A list is Iterable but not an Iterator: per the UDF contract we reject it. + with self.assertRaises(PySparkTypeError) as ctx: + verify_return_type([pa.array([1])], Iterator[pa.Array]) + self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") + self.assertEqual( + ctx.exception.getMessageParameters(), + {"expected": "iterator of pyarrow.Array", "actual": "list"}, + ) + + def test_iterator_rejects_wrong_element(self): + import pyarrow as pa + + verified = verify_return_type(iter([1]), Iterator[pa.Array]) + with self.assertRaises(PySparkTypeError) as ctx: + list(verified) + self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") + self.assertEqual( + ctx.exception.getMessageParameters(), + {"expected": "iterator of pyarrow.Array", "actual": "iterator of int"}, + ) + + +if __name__ == "__main__": + from pyspark.testing import main + + main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 95a7ccdc4f8dc..113f58c8e417c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -19,6 +19,7 @@ Worker that receives input from Piped RDD. """ +import collections.abc import os import sys import dataclasses @@ -26,7 +27,18 @@ import inspect import itertools import json -from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + Iterable, + Iterator, + Optional, + Tuple, + TYPE_CHECKING, + Union, + get_args, + get_origin, +) if TYPE_CHECKING: from pyspark.sql.pandas._typing import GroupedBatch @@ -234,46 +246,63 @@ def chain(f, g): return lambda *a: g(f(*a)) -def verify_result(expected_type: type) -> Callable[[Any], Iterator]: +def _type_label(t: type) -> str: + package = getattr(inspect.getmodule(t), "__package__", "") + return f"{package}.{t.__name__}" + + +def verify_return_type(result: Any, expected_type: Any) -> Any: """ - Create a result verifier that checks both iterability and element types. + Verify a UDF return value against an expected container type. + + If ``expected_type`` is a concrete type (e.g. ``pa.Table``), checks + ``isinstance(result, expected_type)`` and returns ``result`` unchanged. - Returns a function that takes a UDF result, verifies it is iterable, - and lazily type-checks each element via map. + If ``expected_type`` is ``Iterator[T]``, checks that ``result`` is iterable + and returns a lazy iterator that type-checks each element against ``T`` on + consumption. Parameters ---------- - expected_type : type - The expected Python/PyArrow type for each element - (e.g. pa.RecordBatch, pa.Array). + result : Any + The UDF return value. + expected_type : type or Iterator[type] + The expected Python/PyArrow container type (e.g. ``pa.Table``, + ``pa.RecordBatch``, ``pa.Array``), or ``Iterator[T]`` to require an + iterator of ``T``. """ + if get_origin(expected_type) is collections.abc.Iterator: + (element_type,) = get_args(expected_type) + label = f"iterator of {_type_label(element_type)}" - package = getattr(inspect.getmodule(expected_type), "__package__", "") - label: str = f"{package}.{expected_type.__name__}" - - def check_element(element: Any) -> Any: - if not isinstance(element, expected_type): + if not isinstance(result, Iterator): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": f"iterator of {label}", - "actual": f"iterator of {type(element).__name__}", - }, + messageParameters={"expected": label, "actual": type(result).__name__}, ) - return element - def check(result: Any) -> Iterator: - if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": f"iterator of {label}", - "actual": type(result).__name__, - }, - ) + def check_element(element: Any) -> Any: + if not isinstance(element, element_type): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": label, + "actual": f"iterator of {type(element).__name__}", + }, + ) + return element + return map(check_element, result) - return check + if not isinstance(result, expected_type): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": _type_label(expected_type), + "actual": type(result).__name__, + }, + ) + return result def verify_result_row_count(result_length: int, expected: int) -> None: @@ -2561,7 +2590,9 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record output_batches = udf_func(input_batches) # Post-processing - verified: Iterator[pa.RecordBatch] = verify_result(pa.RecordBatch)(output_batches) + verified: Iterator[pa.RecordBatch] = verify_return_type( + output_batches, Iterator[pa.RecordBatch] + ) yield from map(ArrowBatchTransformer.wrap_struct, verified) # profiling is not supported for UDF @@ -2626,7 +2657,7 @@ def extract_args(batch: pa.RecordBatch): args_iter = map(extract_args, data) # Call UDF and verify result type (iterator of pa.Array) - verified_iter = verify_result(pa.Array)(udf_func(args_iter)) + verified_iter = verify_return_type(udf_func(args_iter), Iterator[pa.Array]) # Process results: enforce schema and assemble into RecordBatch target_schema = pa.schema([pa.field("_0", arrow_return_type)]) From bf7662bf11e72ef8fc8c817fe34a83b9533b198f Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:24:41 +0000 Subject: [PATCH 02/18] test: move VerifyReturnTypeTests into pyspark.tests.test_worker --- dev/sparktestsupport/modules.py | 1 - .../pyspark/tests/test_verify_return_type.py | 90 ------------------- python/pyspark/tests/test_worker.py | 74 ++++++++++++++- 3 files changed, 73 insertions(+), 92 deletions(-) delete mode 100644 python/pyspark/tests/test_verify_return_type.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 14ac7347e5a11..bd7d1f55aaee5 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -517,7 +517,6 @@ def __hash__(self): "pyspark.tests.test_statcounter", "pyspark.tests.test_taskcontext", "pyspark.tests.test_util", - "pyspark.tests.test_verify_return_type", "pyspark.tests.test_worker", "pyspark.tests.test_stage_sched", # unittests for upstream projects diff --git a/python/pyspark/tests/test_verify_return_type.py b/python/pyspark/tests/test_verify_return_type.py deleted file mode 100644 index 3f18999533da4..0000000000000 --- a/python/pyspark/tests/test_verify_return_type.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import unittest -from typing import Iterator - -from pyspark.errors import PySparkTypeError -from pyspark.testing.utils import have_pyarrow, pyarrow_requirement_message -from pyspark.worker import verify_return_type - - -@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) -class VerifyReturnTypeTests(unittest.TestCase): - def test_non_iterator_accepts_matching_type(self): - import pyarrow as pa - - batch = pa.RecordBatch.from_arrays([pa.array([1, 2])], names=["x"]) - self.assertIs(verify_return_type(batch, pa.RecordBatch), batch) - - def test_non_iterator_rejects_wrong_type(self): - import pyarrow as pa - - with self.assertRaises(PySparkTypeError) as ctx: - verify_return_type(123, pa.RecordBatch) - self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") - self.assertEqual( - ctx.exception.getMessageParameters(), - {"expected": "pyarrow.RecordBatch", "actual": "int"}, - ) - - def test_iterator_accepts_and_is_lazy(self): - import pyarrow as pa - - arrays = [pa.array([1]), pa.array([2])] - verified = verify_return_type(iter(arrays), Iterator[pa.Array]) - self.assertEqual(list(verified), arrays) - - def test_iterator_rejects_non_iterable(self): - import pyarrow as pa - - with self.assertRaises(PySparkTypeError) as ctx: - verify_return_type(5, Iterator[pa.RecordBatch]) - self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") - self.assertEqual( - ctx.exception.getMessageParameters(), - {"expected": "iterator of pyarrow.RecordBatch", "actual": "int"}, - ) - - def test_iterator_rejects_non_iterator_iterable(self): - import pyarrow as pa - - # A list is Iterable but not an Iterator: per the UDF contract we reject it. - with self.assertRaises(PySparkTypeError) as ctx: - verify_return_type([pa.array([1])], Iterator[pa.Array]) - self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") - self.assertEqual( - ctx.exception.getMessageParameters(), - {"expected": "iterator of pyarrow.Array", "actual": "list"}, - ) - - def test_iterator_rejects_wrong_element(self): - import pyarrow as pa - - verified = verify_return_type(iter([1]), Iterator[pa.Array]) - with self.assertRaises(PySparkTypeError) as ctx: - list(verified) - self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") - self.assertEqual( - ctx.exception.getMessageParameters(), - {"expected": "iterator of pyarrow.Array", "actual": "iterator of int"}, - ) - - -if __name__ == "__main__": - from pyspark.testing import main - - main() diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 5d33cc9779ac5..7ff04150a91be 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -22,6 +22,7 @@ import threading import time import unittest +from typing import Iterator has_resource_module = True try: @@ -32,7 +33,16 @@ from py4j.protocol import Py4JJavaError from pyspark import SparkConf, SparkContext -from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, eventually +from pyspark.errors import PySparkTypeError +from pyspark.testing.utils import ( + ReusedPySparkTestCase, + PySparkTestCase, + QuietTest, + eventually, + have_pyarrow, + pyarrow_requirement_message, +) +from pyspark.worker import verify_return_type class WorkerTests(ReusedPySparkTestCase): @@ -286,6 +296,68 @@ def conf(cls): return conf +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class VerifyReturnTypeTests(unittest.TestCase): + def test_non_iterator_accepts_matching_type(self): + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1, 2])], names=["x"]) + self.assertIs(verify_return_type(batch, pa.RecordBatch), batch) + + def test_non_iterator_rejects_wrong_type(self): + import pyarrow as pa + + with self.assertRaises(PySparkTypeError) as ctx: + verify_return_type(123, pa.RecordBatch) + self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") + self.assertEqual( + ctx.exception.getMessageParameters(), + {"expected": "pyarrow.RecordBatch", "actual": "int"}, + ) + + def test_iterator_accepts_and_is_lazy(self): + import pyarrow as pa + + arrays = [pa.array([1]), pa.array([2])] + verified = verify_return_type(iter(arrays), Iterator[pa.Array]) + self.assertEqual(list(verified), arrays) + + def test_iterator_rejects_non_iterable(self): + import pyarrow as pa + + with self.assertRaises(PySparkTypeError) as ctx: + verify_return_type(5, Iterator[pa.RecordBatch]) + self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") + self.assertEqual( + ctx.exception.getMessageParameters(), + {"expected": "iterator of pyarrow.RecordBatch", "actual": "int"}, + ) + + def test_iterator_rejects_non_iterator_iterable(self): + import pyarrow as pa + + # A list is Iterable but not an Iterator: per the UDF contract we reject it. + with self.assertRaises(PySparkTypeError) as ctx: + verify_return_type([pa.array([1])], Iterator[pa.Array]) + self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") + self.assertEqual( + ctx.exception.getMessageParameters(), + {"expected": "iterator of pyarrow.Array", "actual": "list"}, + ) + + def test_iterator_rejects_wrong_element(self): + import pyarrow as pa + + verified = verify_return_type(iter([1]), Iterator[pa.Array]) + with self.assertRaises(PySparkTypeError) as ctx: + list(verified) + self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") + self.assertEqual( + ctx.exception.getMessageParameters(), + {"expected": "iterator of pyarrow.Array", "actual": "iterator of int"}, + ) + + if __name__ == "__main__": from pyspark.testing import main From c142f709a9b854754a396f562a46ebd826f26e34 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:30:21 +0000 Subject: [PATCH 03/18] refactor: use overloaded TypeVar for verify_return_type return type --- python/pyspark/worker.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 113f58c8e417c..15b48c207b14e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,12 +34,17 @@ Iterator, Optional, Tuple, + Type, + TypeVar, TYPE_CHECKING, Union, get_args, get_origin, + overload, ) +T = TypeVar("T") + if TYPE_CHECKING: from pyspark.sql.pandas._typing import GroupedBatch @@ -251,6 +256,14 @@ def _type_label(t: type) -> str: return f"{package}.{t.__name__}" +@overload +def verify_return_type(result: Any, expected_type: Type[T]) -> T: ... + + +@overload +def verify_return_type(result: Any, expected_type: Any) -> Any: ... + + def verify_return_type(result: Any, expected_type: Any) -> Any: """ Verify a UDF return value against an expected container type. From 46848e3c669670a4e14894401c929590207e7208 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:32:48 +0000 Subject: [PATCH 04/18] refactor: type check_element as (element: T) -> T --- python/pyspark/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 15b48c207b14e..0b7f086e7ab2e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -294,7 +294,7 @@ def verify_return_type(result: Any, expected_type: Any) -> Any: messageParameters={"expected": label, "actual": type(result).__name__}, ) - def check_element(element: Any) -> Any: + def check_element(element: T) -> T: if not isinstance(element, element_type): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", From 68ba54102dc7cb25a24895f3a02c076cd9acd06f Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:36:18 +0000 Subject: [PATCH 05/18] refactor: rename verified to verified_iter for consistency --- python/pyspark/worker.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0b7f086e7ab2e..29c839d9896b3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -2603,10 +2603,8 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record output_batches = udf_func(input_batches) # Post-processing - verified: Iterator[pa.RecordBatch] = verify_return_type( - output_batches, Iterator[pa.RecordBatch] - ) - yield from map(ArrowBatchTransformer.wrap_struct, verified) + verified_iter = verify_return_type(output_batches, Iterator[pa.RecordBatch]) + yield from map(ArrowBatchTransformer.wrap_struct, verified_iter) # profiling is not supported for UDF return func, None, ser, ser From 21eded84ff73ade47c4add3edbbe122f8503e75d Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:39:09 +0000 Subject: [PATCH 06/18] refactor: use verify_return_type in verify_arrow_table and verify_arrow_batch --- python/pyspark/worker.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 29c839d9896b3..adb25f8b25d94 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -667,30 +667,14 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types): import pyarrow as pa - if not isinstance(table, pa.Table): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "pyarrow.Table", - "actual": type(table).__name__, - }, - ) - + verify_return_type(table, pa.Table) verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types) def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): import pyarrow as pa - if not isinstance(batch, pa.RecordBatch): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "pyarrow.RecordBatch", - "actual": type(batch).__name__, - }, - ) - + verify_return_type(batch, pa.RecordBatch) verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) From 5d1c6509a7af14f16c82371da546194e8270146f Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:40:59 +0000 Subject: [PATCH 07/18] refactor: inline verify_arrow_table and verify_arrow_batch --- python/pyspark/worker.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index adb25f8b25d94..f40f36f218104 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -554,6 +554,8 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): + import pyarrow as pa + if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields @@ -571,7 +573,8 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table key = tuple(c[0] for c in key_table.columns) result = f(key, left_value_table, right_value_table) - verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) + verify_return_type(result, pa.Table) + verify_arrow_result(result, runner_conf.assign_cols_by_name, expected_cols_and_types) return result.to_batches() @@ -664,20 +667,6 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): ) -def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types): - import pyarrow as pa - - verify_return_type(table, pa.Table) - verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types) - - -def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): - import pyarrow as pa - - verify_return_type(batch, pa.RecordBatch) - verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) - - def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): def wrapped(key_series, value_series): import pandas as pd @@ -2881,7 +2870,10 @@ def grouped_func( key = tuple(c[0] for c in keys.columns) result = grouped_udf(key, value_table) - verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) + verify_return_type(result, pa.Table) + verify_arrow_result( + result, runner_conf.assign_cols_by_name, expected_cols_and_types + ) # Reorder columns if needed and wrap into struct for batch in result.to_batches(): @@ -2952,7 +2944,8 @@ def grouped_func( # Verify, reorder, and wrap each output batch for batch in result: - verify_arrow_batch( + verify_return_type(batch, pa.RecordBatch) + verify_arrow_result( batch, runner_conf.assign_cols_by_name, expected_cols_and_types ) if runner_conf.assign_cols_by_name: From c775a7cd60224a9a1472d9322f77c599083b7efc Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:43:50 +0000 Subject: [PATCH 08/18] refactor: inline _type_label helper --- python/pyspark/worker.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f40f36f218104..f0e261d649318 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -251,11 +251,6 @@ def chain(f, g): return lambda *a: g(f(*a)) -def _type_label(t: type) -> str: - package = getattr(inspect.getmodule(t), "__package__", "") - return f"{package}.{t.__name__}" - - @overload def verify_return_type(result: Any, expected_type: Type[T]) -> T: ... @@ -286,7 +281,8 @@ def verify_return_type(result: Any, expected_type: Any) -> Any: """ if get_origin(expected_type) is collections.abc.Iterator: (element_type,) = get_args(expected_type) - label = f"iterator of {_type_label(element_type)}" + package = getattr(inspect.getmodule(element_type), "__package__", "") + label = f"iterator of {package}.{element_type.__name__}" if not isinstance(result, Iterator): raise PySparkTypeError( @@ -308,10 +304,11 @@ def check_element(element: T) -> T: return map(check_element, result) if not isinstance(result, expected_type): + package = getattr(inspect.getmodule(expected_type), "__package__", "") raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", messageParameters={ - "expected": _type_label(expected_type), + "expected": f"{package}.{expected_type.__name__}", "actual": type(result).__name__, }, ) From ad98e2b79804ea75d86eac4fc913ee4d974f9378 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:44:26 +0000 Subject: [PATCH 09/18] docs: simplify verify_return_type docstring --- python/pyspark/worker.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f0e261d649318..434354f2d9002 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -261,23 +261,11 @@ def verify_return_type(result: Any, expected_type: Any) -> Any: ... def verify_return_type(result: Any, expected_type: Any) -> Any: """ - Verify a UDF return value against an expected container type. + Verify a UDF return value against an expected type. - If ``expected_type`` is a concrete type (e.g. ``pa.Table``), checks - ``isinstance(result, expected_type)`` and returns ``result`` unchanged. - - If ``expected_type`` is ``Iterator[T]``, checks that ``result`` is iterable - and returns a lazy iterator that type-checks each element against ``T`` on - consumption. - - Parameters - ---------- - result : Any - The UDF return value. - expected_type : type or Iterator[type] - The expected Python/PyArrow container type (e.g. ``pa.Table``, - ``pa.RecordBatch``, ``pa.Array``), or ``Iterator[T]`` to require an - iterator of ``T``. + Returns ``result`` unchanged if ``isinstance(result, expected_type)``. + For ``Iterator[T]``, returns a lazy iterator that checks each element + against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. """ if get_origin(expected_type) is collections.abc.Iterator: (element_type,) = get_args(expected_type) From b9775d73af453571be2b59f10c03712f0f4c4702 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 08:49:16 +0000 Subject: [PATCH 10/18] refactor: align Iterator detection with typehints.py (_name string check) --- python/pyspark/worker.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 434354f2d9002..b113c45ac7358 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -19,7 +19,6 @@ Worker that receives input from Piped RDD. """ -import collections.abc import os import sys import dataclasses @@ -38,8 +37,6 @@ TypeVar, TYPE_CHECKING, Union, - get_args, - get_origin, overload, ) @@ -267,8 +264,8 @@ def verify_return_type(result: Any, expected_type: Any) -> Any: For ``Iterator[T]``, returns a lazy iterator that checks each element against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. """ - if get_origin(expected_type) is collections.abc.Iterator: - (element_type,) = get_args(expected_type) + if getattr(expected_type, "_name", None) == "Iterator": + (element_type,) = expected_type.__args__ package = getattr(inspect.getmodule(element_type), "__package__", "") label = f"iterator of {package}.{element_type.__name__}" From 9533ad8d3f9bd3438dbc528d0559d26353e9fd15 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 09:50:10 +0000 Subject: [PATCH 11/18] fix: move verify_return_type out of worker.py to fix test import --- python/pyspark/sql/pandas/_verify.py | 74 ++++++++++++++++++++++++++++ python/pyspark/tests/test_worker.py | 2 +- python/pyspark/worker.py | 69 +------------------------- 3 files changed, 77 insertions(+), 68 deletions(-) create mode 100644 python/pyspark/sql/pandas/_verify.py diff --git a/python/pyspark/sql/pandas/_verify.py b/python/pyspark/sql/pandas/_verify.py new file mode 100644 index 0000000000000..256ef81625ea5 --- /dev/null +++ b/python/pyspark/sql/pandas/_verify.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +from typing import Any, Iterator, Type, TypeVar, overload + +from pyspark.errors import PySparkTypeError + +T = TypeVar("T") + + +@overload +def verify_return_type(result: Any, expected_type: Type[T]) -> T: ... + + +@overload +def verify_return_type(result: Any, expected_type: Any) -> Any: ... + + +def verify_return_type(result: Any, expected_type: Any) -> Any: + """ + Verify a UDF return value against an expected type. + + Returns ``result`` unchanged if ``isinstance(result, expected_type)``. + For ``Iterator[T]``, returns a lazy iterator that checks each element + against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. + """ + if getattr(expected_type, "_name", None) == "Iterator": + (element_type,) = expected_type.__args__ + package = getattr(inspect.getmodule(element_type), "__package__", "") + label = f"iterator of {package}.{element_type.__name__}" + + if not isinstance(result, Iterator): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={"expected": label, "actual": type(result).__name__}, + ) + + def check_element(element: T) -> T: + if not isinstance(element, element_type): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": label, + "actual": f"iterator of {type(element).__name__}", + }, + ) + return element + + return map(check_element, result) + + if not isinstance(result, expected_type): + package = getattr(inspect.getmodule(expected_type), "__package__", "") + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": f"{package}.{expected_type.__name__}", + "actual": type(result).__name__, + }, + ) + return result diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 7ff04150a91be..968d8158a2d09 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -34,6 +34,7 @@ from pyspark import SparkConf, SparkContext from pyspark.errors import PySparkTypeError +from pyspark.sql.pandas._verify import verify_return_type from pyspark.testing.utils import ( ReusedPySparkTestCase, PySparkTestCase, @@ -42,7 +43,6 @@ have_pyarrow, pyarrow_requirement_message, ) -from pyspark.worker import verify_return_type class WorkerTests(ReusedPySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b113c45ac7358..4fdc23715c28e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,21 +26,7 @@ import inspect import itertools import json -from typing import ( - Any, - Callable, - Iterable, - Iterator, - Optional, - Tuple, - Type, - TypeVar, - TYPE_CHECKING, - Union, - overload, -) - -T = TypeVar("T") +from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING, Union if TYPE_CHECKING: from pyspark.sql.pandas._typing import GroupedBatch @@ -89,6 +75,7 @@ ArrowStreamArrowUDTFSerializer, ) from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type +from pyspark.sql.pandas._verify import verify_return_type from pyspark.sql.types import ( ArrayType, BinaryType, @@ -248,58 +235,6 @@ def chain(f, g): return lambda *a: g(f(*a)) -@overload -def verify_return_type(result: Any, expected_type: Type[T]) -> T: ... - - -@overload -def verify_return_type(result: Any, expected_type: Any) -> Any: ... - - -def verify_return_type(result: Any, expected_type: Any) -> Any: - """ - Verify a UDF return value against an expected type. - - Returns ``result`` unchanged if ``isinstance(result, expected_type)``. - For ``Iterator[T]``, returns a lazy iterator that checks each element - against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. - """ - if getattr(expected_type, "_name", None) == "Iterator": - (element_type,) = expected_type.__args__ - package = getattr(inspect.getmodule(element_type), "__package__", "") - label = f"iterator of {package}.{element_type.__name__}" - - if not isinstance(result, Iterator): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={"expected": label, "actual": type(result).__name__}, - ) - - def check_element(element: T) -> T: - if not isinstance(element, element_type): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": label, - "actual": f"iterator of {type(element).__name__}", - }, - ) - return element - - return map(check_element, result) - - if not isinstance(result, expected_type): - package = getattr(inspect.getmodule(expected_type), "__package__", "") - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": f"{package}.{expected_type.__name__}", - "actual": type(result).__name__, - }, - ) - return result - - def verify_result_row_count(result_length: int, expected: int) -> None: """Raise if the result row count doesn't match the expected input row count.""" if result_length != expected: From 90bc06158ce6fde7a5b8764f831a32b342779d98 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 09:53:32 +0000 Subject: [PATCH 12/18] refactor: move verify_return_type back to worker.py, drop unit test --- python/pyspark/sql/pandas/_verify.py | 74 ---------------------------- python/pyspark/tests/test_worker.py | 74 +--------------------------- python/pyspark/worker.py | 69 +++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 149 deletions(-) delete mode 100644 python/pyspark/sql/pandas/_verify.py diff --git a/python/pyspark/sql/pandas/_verify.py b/python/pyspark/sql/pandas/_verify.py deleted file mode 100644 index 256ef81625ea5..0000000000000 --- a/python/pyspark/sql/pandas/_verify.py +++ /dev/null @@ -1,74 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import inspect -from typing import Any, Iterator, Type, TypeVar, overload - -from pyspark.errors import PySparkTypeError - -T = TypeVar("T") - - -@overload -def verify_return_type(result: Any, expected_type: Type[T]) -> T: ... - - -@overload -def verify_return_type(result: Any, expected_type: Any) -> Any: ... - - -def verify_return_type(result: Any, expected_type: Any) -> Any: - """ - Verify a UDF return value against an expected type. - - Returns ``result`` unchanged if ``isinstance(result, expected_type)``. - For ``Iterator[T]``, returns a lazy iterator that checks each element - against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. - """ - if getattr(expected_type, "_name", None) == "Iterator": - (element_type,) = expected_type.__args__ - package = getattr(inspect.getmodule(element_type), "__package__", "") - label = f"iterator of {package}.{element_type.__name__}" - - if not isinstance(result, Iterator): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={"expected": label, "actual": type(result).__name__}, - ) - - def check_element(element: T) -> T: - if not isinstance(element, element_type): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": label, - "actual": f"iterator of {type(element).__name__}", - }, - ) - return element - - return map(check_element, result) - - if not isinstance(result, expected_type): - package = getattr(inspect.getmodule(expected_type), "__package__", "") - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": f"{package}.{expected_type.__name__}", - "actual": type(result).__name__, - }, - ) - return result diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 968d8158a2d09..5d33cc9779ac5 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -22,7 +22,6 @@ import threading import time import unittest -from typing import Iterator has_resource_module = True try: @@ -33,16 +32,7 @@ from py4j.protocol import Py4JJavaError from pyspark import SparkConf, SparkContext -from pyspark.errors import PySparkTypeError -from pyspark.sql.pandas._verify import verify_return_type -from pyspark.testing.utils import ( - ReusedPySparkTestCase, - PySparkTestCase, - QuietTest, - eventually, - have_pyarrow, - pyarrow_requirement_message, -) +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, eventually class WorkerTests(ReusedPySparkTestCase): @@ -296,68 +286,6 @@ def conf(cls): return conf -@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) -class VerifyReturnTypeTests(unittest.TestCase): - def test_non_iterator_accepts_matching_type(self): - import pyarrow as pa - - batch = pa.RecordBatch.from_arrays([pa.array([1, 2])], names=["x"]) - self.assertIs(verify_return_type(batch, pa.RecordBatch), batch) - - def test_non_iterator_rejects_wrong_type(self): - import pyarrow as pa - - with self.assertRaises(PySparkTypeError) as ctx: - verify_return_type(123, pa.RecordBatch) - self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") - self.assertEqual( - ctx.exception.getMessageParameters(), - {"expected": "pyarrow.RecordBatch", "actual": "int"}, - ) - - def test_iterator_accepts_and_is_lazy(self): - import pyarrow as pa - - arrays = [pa.array([1]), pa.array([2])] - verified = verify_return_type(iter(arrays), Iterator[pa.Array]) - self.assertEqual(list(verified), arrays) - - def test_iterator_rejects_non_iterable(self): - import pyarrow as pa - - with self.assertRaises(PySparkTypeError) as ctx: - verify_return_type(5, Iterator[pa.RecordBatch]) - self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") - self.assertEqual( - ctx.exception.getMessageParameters(), - {"expected": "iterator of pyarrow.RecordBatch", "actual": "int"}, - ) - - def test_iterator_rejects_non_iterator_iterable(self): - import pyarrow as pa - - # A list is Iterable but not an Iterator: per the UDF contract we reject it. - with self.assertRaises(PySparkTypeError) as ctx: - verify_return_type([pa.array([1])], Iterator[pa.Array]) - self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") - self.assertEqual( - ctx.exception.getMessageParameters(), - {"expected": "iterator of pyarrow.Array", "actual": "list"}, - ) - - def test_iterator_rejects_wrong_element(self): - import pyarrow as pa - - verified = verify_return_type(iter([1]), Iterator[pa.Array]) - with self.assertRaises(PySparkTypeError) as ctx: - list(verified) - self.assertEqual(ctx.exception.getErrorClass(), "UDF_RETURN_TYPE") - self.assertEqual( - ctx.exception.getMessageParameters(), - {"expected": "iterator of pyarrow.Array", "actual": "iterator of int"}, - ) - - if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4fdc23715c28e..b113c45ac7358 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,7 +26,21 @@ import inspect import itertools import json -from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + Iterable, + Iterator, + Optional, + Tuple, + Type, + TypeVar, + TYPE_CHECKING, + Union, + overload, +) + +T = TypeVar("T") if TYPE_CHECKING: from pyspark.sql.pandas._typing import GroupedBatch @@ -75,7 +89,6 @@ ArrowStreamArrowUDTFSerializer, ) from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type -from pyspark.sql.pandas._verify import verify_return_type from pyspark.sql.types import ( ArrayType, BinaryType, @@ -235,6 +248,58 @@ def chain(f, g): return lambda *a: g(f(*a)) +@overload +def verify_return_type(result: Any, expected_type: Type[T]) -> T: ... + + +@overload +def verify_return_type(result: Any, expected_type: Any) -> Any: ... + + +def verify_return_type(result: Any, expected_type: Any) -> Any: + """ + Verify a UDF return value against an expected type. + + Returns ``result`` unchanged if ``isinstance(result, expected_type)``. + For ``Iterator[T]``, returns a lazy iterator that checks each element + against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. + """ + if getattr(expected_type, "_name", None) == "Iterator": + (element_type,) = expected_type.__args__ + package = getattr(inspect.getmodule(element_type), "__package__", "") + label = f"iterator of {package}.{element_type.__name__}" + + if not isinstance(result, Iterator): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={"expected": label, "actual": type(result).__name__}, + ) + + def check_element(element: T) -> T: + if not isinstance(element, element_type): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": label, + "actual": f"iterator of {type(element).__name__}", + }, + ) + return element + + return map(check_element, result) + + if not isinstance(result, expected_type): + package = getattr(inspect.getmodule(expected_type), "__package__", "") + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": f"{package}.{expected_type.__name__}", + "actual": type(result).__name__, + }, + ) + return result + + def verify_result_row_count(result_length: int, expected: int) -> None: """Raise if the result row count doesn't match the expected input row count.""" if result_length != expected: From 6f2967b12f2d748c27090845fa229ada31f6ef3e Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 27 Apr 2026 16:29:41 +0000 Subject: [PATCH 13/18] fix: address review comments on verify_return_type --- python/pyspark/sql/tests/arrow/test_arrow_map.py | 11 +++++++++++ python/pyspark/worker.py | 10 ++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py b/python/pyspark/sql/tests/arrow/test_arrow_map.py index af26260849259..252a21d317a8a 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py @@ -114,6 +114,10 @@ def not_iter(_): def bad_iter_elem(_): return iter([1]) + def list_not_iter(_): + # Iterable but not an Iterator: violates the Iterator[pa.RecordBatch] contract. + return [pa.RecordBatch.from_pandas(pd.DataFrame({"a": [0]}))] + with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be iterator " @@ -128,6 +132,13 @@ def bad_iter_elem(_): ): (self.spark.range(10, numPartitions=3).mapInArrow(bad_iter_elem, "a int").count()) + with self.assertRaisesRegex( + PythonException, + "Return type of the user-defined function should be iterator " + "of pyarrow.RecordBatch, but is list", + ): + (self.spark.range(10, numPartitions=3).mapInArrow(list_not_iter, "a int").count()) + def test_empty_iterator(self): def empty_iter(_): return iter([]) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b113c45ac7358..27a69a500bd74 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ import inspect import itertools import json +import collections.abc from typing import ( Any, Callable, @@ -37,6 +38,8 @@ TypeVar, TYPE_CHECKING, Union, + get_args, + get_origin, overload, ) @@ -264,8 +267,8 @@ def verify_return_type(result: Any, expected_type: Any) -> Any: For ``Iterator[T]``, returns a lazy iterator that checks each element against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. """ - if getattr(expected_type, "_name", None) == "Iterator": - (element_type,) = expected_type.__args__ + if get_origin(expected_type) is collections.abc.Iterator: + (element_type,) = get_args(expected_type) package = getattr(inspect.getmodule(element_type), "__package__", "") label = f"iterator of {package}.{element_type.__name__}" @@ -2925,8 +2928,7 @@ def grouped_func( result = grouped_udf(key, value_batches) # Verify, reorder, and wrap each output batch - for batch in result: - verify_return_type(batch, pa.RecordBatch) + for batch in verify_return_type(result, Iterator[pa.RecordBatch]): verify_arrow_result( batch, runner_conf.assign_cols_by_name, expected_cols_and_types ) From 52e8935ec938f93413b81744688be3918fea42b0 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 27 Apr 2026 19:37:05 +0000 Subject: [PATCH 14/18] test: update applyInArrow iter UDF error message expectation --- python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py index 9cb22558cd032..fe66bc11f593f 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py @@ -148,8 +148,8 @@ def stats_iter( with self.assertRaisesRegex( PythonException, - "Return type of the user-defined function should be pyarrow.RecordBatch, but is " - + "tuple", + "Return type of the user-defined function should be iterator of " + "pyarrow.RecordBatch, but is iterator of tuple", ): df.groupby("id").applyInArrow(stats_iter, schema="id long, m double").collect() From ac4cac33cbcab293d2f8a18960a3d17c5b7022d7 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 27 Apr 2026 20:24:12 +0000 Subject: [PATCH 15/18] chore: retrigger CI From 4577a7f725b01ccc1c1c92c690724de8a0208333 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 27 Apr 2026 21:16:28 +0000 Subject: [PATCH 16/18] chore: retrigger CI From 6701cb4e358f04c1bf06a9fd936dd2cc6d6a6879 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 27 Apr 2026 21:19:29 +0000 Subject: [PATCH 17/18] chore: retrigger CI From 5817417b7efb7d8fe68712480b3fcdaf5f344dcd Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 27 Apr 2026 22:43:15 +0000 Subject: [PATCH 18/18] chore: retrigger CI