diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py index 9cb22558cd032..fe66bc11f593f 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py @@ -148,8 +148,8 @@ def stats_iter( with self.assertRaisesRegex( PythonException, - "Return type of the user-defined function should be pyarrow.RecordBatch, but is " - + "tuple", + "Return type of the user-defined function should be iterator of " + "pyarrow.RecordBatch, but is iterator of tuple", ): df.groupby("id").applyInArrow(stats_iter, schema="id long, m double").collect() diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py b/python/pyspark/sql/tests/arrow/test_arrow_map.py index af26260849259..252a21d317a8a 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py @@ -114,6 +114,10 @@ def not_iter(_): def bad_iter_elem(_): return iter([1]) + def list_not_iter(_): + # Iterable but not an Iterator: violates the Iterator[pa.RecordBatch] contract. + return [pa.RecordBatch.from_pandas(pd.DataFrame({"a": [0]}))] + with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be iterator " @@ -128,6 +132,13 @@ def bad_iter_elem(_): ): (self.spark.range(10, numPartitions=3).mapInArrow(bad_iter_elem, "a int").count()) + with self.assertRaisesRegex( + PythonException, + "Return type of the user-defined function should be iterator " + "of pyarrow.RecordBatch, but is list", + ): + (self.spark.range(10, numPartitions=3).mapInArrow(list_not_iter, "a int").count()) + def test_empty_iterator(self): def empty_iter(_): return iter([]) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 95a7ccdc4f8dc..27a69a500bd74 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,7 +26,24 @@ import inspect import itertools import json -from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING, Union +import collections.abc +from typing import ( + Any, + Callable, + Iterable, + Iterator, + Optional, + Tuple, + Type, + TypeVar, + TYPE_CHECKING, + Union, + get_args, + get_origin, + overload, +) + +T = TypeVar("T") if TYPE_CHECKING: from pyspark.sql.pandas._typing import GroupedBatch @@ -234,46 +251,56 @@ def chain(f, g): return lambda *a: g(f(*a)) -def verify_result(expected_type: type) -> Callable[[Any], Iterator]: - """ - Create a result verifier that checks both iterability and element types. +@overload +def verify_return_type(result: Any, expected_type: Type[T]) -> T: ... - Returns a function that takes a UDF result, verifies it is iterable, - and lazily type-checks each element via map. - Parameters - ---------- - expected_type : type - The expected Python/PyArrow type for each element - (e.g. pa.RecordBatch, pa.Array). +@overload +def verify_return_type(result: Any, expected_type: Any) -> Any: ... + + +def verify_return_type(result: Any, expected_type: Any) -> Any: """ + Verify a UDF return value against an expected type. - package = getattr(inspect.getmodule(expected_type), "__package__", "") - label: str = f"{package}.{expected_type.__name__}" + Returns ``result`` unchanged if ``isinstance(result, expected_type)``. + For ``Iterator[T]``, returns a lazy iterator that checks each element + against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. + """ + if get_origin(expected_type) is collections.abc.Iterator: + (element_type,) = get_args(expected_type) + package = getattr(inspect.getmodule(element_type), "__package__", "") + label = f"iterator of {package}.{element_type.__name__}" - def check_element(element: Any) -> Any: - if not isinstance(element, expected_type): + if not isinstance(result, Iterator): raise PySparkTypeError( errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": f"iterator of {label}", - "actual": f"iterator of {type(element).__name__}", - }, + messageParameters={"expected": label, "actual": type(result).__name__}, ) - return element - def check(result: Any) -> Iterator: - if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": f"iterator of {label}", - "actual": type(result).__name__, - }, - ) + def check_element(element: T) -> T: + if not isinstance(element, element_type): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": label, + "actual": f"iterator of {type(element).__name__}", + }, + ) + return element + return map(check_element, result) - return check + if not isinstance(result, expected_type): + package = getattr(inspect.getmodule(expected_type), "__package__", "") + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": f"{package}.{expected_type.__name__}", + "actual": type(result).__name__, + }, + ) + return result def verify_result_row_count(result_length: int, expected: int) -> None: @@ -512,6 +539,8 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): + import pyarrow as pa + if runner_conf.assign_cols_by_name: expected_cols_and_types = { col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields @@ -529,7 +558,8 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table key = tuple(c[0] for c in key_table.columns) result = f(key, left_value_table, right_value_table) - verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) + verify_return_type(result, pa.Table) + verify_arrow_result(result, runner_conf.assign_cols_by_name, expected_cols_and_types) return result.to_batches() @@ -622,36 +652,6 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): ) -def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types): - import pyarrow as pa - - if not isinstance(table, pa.Table): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "pyarrow.Table", - "actual": type(table).__name__, - }, - ) - - verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types) - - -def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): - import pyarrow as pa - - if not isinstance(batch, pa.RecordBatch): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "pyarrow.RecordBatch", - "actual": type(batch).__name__, - }, - ) - - verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) - - def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): def wrapped(key_series, value_series): import pandas as pd @@ -2561,8 +2561,8 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record output_batches = udf_func(input_batches) # Post-processing - verified: Iterator[pa.RecordBatch] = verify_result(pa.RecordBatch)(output_batches) - yield from map(ArrowBatchTransformer.wrap_struct, verified) + verified_iter = verify_return_type(output_batches, Iterator[pa.RecordBatch]) + yield from map(ArrowBatchTransformer.wrap_struct, verified_iter) # profiling is not supported for UDF return func, None, ser, ser @@ -2626,7 +2626,7 @@ def extract_args(batch: pa.RecordBatch): args_iter = map(extract_args, data) # Call UDF and verify result type (iterator of pa.Array) - verified_iter = verify_result(pa.Array)(udf_func(args_iter)) + verified_iter = verify_return_type(udf_func(args_iter), Iterator[pa.Array]) # Process results: enforce schema and assemble into RecordBatch target_schema = pa.schema([pa.field("_0", arrow_return_type)]) @@ -2855,7 +2855,10 @@ def grouped_func( key = tuple(c[0] for c in keys.columns) result = grouped_udf(key, value_table) - verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) + verify_return_type(result, pa.Table) + verify_arrow_result( + result, runner_conf.assign_cols_by_name, expected_cols_and_types + ) # Reorder columns if needed and wrap into struct for batch in result.to_batches(): @@ -2925,8 +2928,8 @@ def grouped_func( result = grouped_udf(key, value_batches) # Verify, reorder, and wrap each output batch - for batch in result: - verify_arrow_batch( + for batch in verify_return_type(result, Iterator[pa.RecordBatch]): + verify_arrow_result( batch, runner_conf.assign_cols_by_name, expected_cols_and_types ) if runner_conf.assign_cols_by_name: