-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[SPARK-56612][PYTHON] Unify verify_result and container-type checks into verify_return_type helper #55532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[SPARK-56612][PYTHON] Unify verify_result and container-type checks into verify_return_type helper #55532
Changes from all commits
400637e
bf7662b
c142f70
46848e3
68ba541
21eded8
5d1c650
c775a7c
ad98e2b
b9775d7
9533ad8
90bc061
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,21 @@ | |
| import inspect | ||
| import itertools | ||
| import json | ||
| from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING, Union | ||
| from typing import ( | ||
| Any, | ||
| Callable, | ||
| Iterable, | ||
| Iterator, | ||
| Optional, | ||
| Tuple, | ||
| Type, | ||
| TypeVar, | ||
| TYPE_CHECKING, | ||
| Union, | ||
| overload, | ||
| ) | ||
|
|
||
| T = TypeVar("T") | ||
|
|
||
| if TYPE_CHECKING: | ||
| from pyspark.sql.pandas._typing import GroupedBatch | ||
|
|
@@ -234,46 +248,56 @@ def chain(f, g): | |
| return lambda *a: g(f(*a)) | ||
|
|
||
|
|
||
| def verify_result(expected_type: type) -> Callable[[Any], Iterator]: | ||
| """ | ||
| Create a result verifier that checks both iterability and element types. | ||
| @overload | ||
| def verify_return_type(result: Any, expected_type: Type[T]) -> T: ... | ||
|
|
||
| Returns a function that takes a UDF result, verifies it is iterable, | ||
| and lazily type-checks each element via map. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| expected_type : type | ||
| The expected Python/PyArrow type for each element | ||
| (e.g. pa.RecordBatch, pa.Array). | ||
| @overload | ||
| def verify_return_type(result: Any, expected_type: Any) -> Any: ... | ||
|
|
||
|
|
||
| def verify_return_type(result: Any, expected_type: Any) -> Any: | ||
| """ | ||
| Verify a UDF return value against an expected type. | ||
|
|
||
| package = getattr(inspect.getmodule(expected_type), "__package__", "") | ||
| label: str = f"{package}.{expected_type.__name__}" | ||
| Returns ``result`` unchanged if ``isinstance(result, expected_type)``. | ||
| For ``Iterator[T]``, returns a lazy iterator that checks each element | ||
| against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch. | ||
| """ | ||
| if getattr(expected_type, "_name", None) == "Iterator": | ||
| (element_type,) = expected_type.__args__ | ||
| package = getattr(inspect.getmodule(element_type), "__package__", "") | ||
| label = f"iterator of {package}.{element_type.__name__}" | ||
|
|
||
| def check_element(element: Any) -> Any: | ||
| if not isinstance(element, expected_type): | ||
| if not isinstance(result, Iterator): | ||
| raise PySparkTypeError( | ||
| errorClass="UDF_RETURN_TYPE", | ||
| messageParameters={ | ||
| "expected": f"iterator of {label}", | ||
| "actual": f"iterator of {type(element).__name__}", | ||
| }, | ||
| messageParameters={"expected": label, "actual": type(result).__name__}, | ||
| ) | ||
| return element | ||
|
|
||
| def check(result: Any) -> Iterator: | ||
| if not isinstance(result, Iterator) and not hasattr(result, "__iter__"): | ||
| raise PySparkTypeError( | ||
| errorClass="UDF_RETURN_TYPE", | ||
| messageParameters={ | ||
| "expected": f"iterator of {label}", | ||
| "actual": type(result).__name__, | ||
| }, | ||
| ) | ||
| def check_element(element: T) -> T: | ||
| if not isinstance(element, element_type): | ||
| raise PySparkTypeError( | ||
| errorClass="UDF_RETURN_TYPE", | ||
| messageParameters={ | ||
| "expected": label, | ||
| "actual": f"iterator of {type(element).__name__}", | ||
| }, | ||
| ) | ||
| return element | ||
|
|
||
| return map(check_element, result) | ||
|
|
||
| return check | ||
| if not isinstance(result, expected_type): | ||
| package = getattr(inspect.getmodule(expected_type), "__package__", "") | ||
| raise PySparkTypeError( | ||
| errorClass="UDF_RETURN_TYPE", | ||
| messageParameters={ | ||
| "expected": f"{package}.{expected_type.__name__}", | ||
| "actual": type(result).__name__, | ||
| }, | ||
| ) | ||
| return result | ||
|
|
||
|
|
||
| def verify_result_row_count(result_length: int, expected: int) -> None: | ||
|
|
@@ -512,6 +536,8 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu | |
|
|
||
|
|
||
| def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): | ||
| import pyarrow as pa | ||
|
|
||
| if runner_conf.assign_cols_by_name: | ||
| expected_cols_and_types = { | ||
| col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields | ||
|
|
@@ -529,7 +555,8 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table | |
| key = tuple(c[0] for c in key_table.columns) | ||
| result = f(key, left_value_table, right_value_table) | ||
|
|
||
| verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) | ||
| verify_return_type(result, pa.Table) | ||
| verify_arrow_result(result, runner_conf.assign_cols_by_name, expected_cols_and_types) | ||
|
|
||
| return result.to_batches() | ||
|
|
||
|
|
@@ -622,36 +649,6 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): | |
| ) | ||
|
|
||
|
|
||
| def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types): | ||
| import pyarrow as pa | ||
|
|
||
| if not isinstance(table, pa.Table): | ||
| raise PySparkTypeError( | ||
| errorClass="UDF_RETURN_TYPE", | ||
| messageParameters={ | ||
| "expected": "pyarrow.Table", | ||
| "actual": type(table).__name__, | ||
| }, | ||
| ) | ||
|
|
||
| verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types) | ||
|
|
||
|
|
||
| def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): | ||
| import pyarrow as pa | ||
|
|
||
| if not isinstance(batch, pa.RecordBatch): | ||
| raise PySparkTypeError( | ||
| errorClass="UDF_RETURN_TYPE", | ||
| messageParameters={ | ||
| "expected": "pyarrow.RecordBatch", | ||
| "actual": type(batch).__name__, | ||
| }, | ||
| ) | ||
|
|
||
| verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) | ||
|
|
||
|
|
||
| def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): | ||
| def wrapped(key_series, value_series): | ||
| import pandas as pd | ||
|
|
@@ -2561,8 +2558,8 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record | |
| output_batches = udf_func(input_batches) | ||
|
|
||
| # Post-processing | ||
| verified: Iterator[pa.RecordBatch] = verify_result(pa.RecordBatch)(output_batches) | ||
| yield from map(ArrowBatchTransformer.wrap_struct, verified) | ||
| verified_iter = verify_return_type(output_batches, Iterator[pa.RecordBatch]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The behavior change called out in the PR description — "a UDF returning a non- |
||
| yield from map(ArrowBatchTransformer.wrap_struct, verified_iter) | ||
|
|
||
| # profiling is not supported for UDF | ||
| return func, None, ser, ser | ||
|
|
@@ -2626,7 +2623,7 @@ def extract_args(batch: pa.RecordBatch): | |
| args_iter = map(extract_args, data) | ||
|
|
||
| # Call UDF and verify result type (iterator of pa.Array) | ||
| verified_iter = verify_result(pa.Array)(udf_func(args_iter)) | ||
| verified_iter = verify_return_type(udf_func(args_iter), Iterator[pa.Array]) | ||
|
|
||
| # Process results: enforce schema and assemble into RecordBatch | ||
| target_schema = pa.schema([pa.field("_0", arrow_return_type)]) | ||
|
|
@@ -2855,7 +2852,10 @@ def grouped_func( | |
| key = tuple(c[0] for c in keys.columns) | ||
| result = grouped_udf(key, value_table) | ||
|
|
||
| verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) | ||
| verify_return_type(result, pa.Table) | ||
| verify_arrow_result( | ||
| result, runner_conf.assign_cols_by_name, expected_cols_and_types | ||
| ) | ||
|
|
||
| # Reorder columns if needed and wrap into struct | ||
| for batch in result.to_batches(): | ||
|
|
@@ -2926,7 +2926,8 @@ def grouped_func( | |
|
|
||
| # Verify, reorder, and wrap each output batch | ||
| for batch in result: | ||
| verify_arrow_batch( | ||
| verify_return_type(batch, pa.RecordBatch) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR description frames this as "aligning runtime with the documented |
||
| verify_arrow_result( | ||
| batch, runner_conf.assign_cols_by_name, expected_cols_and_types | ||
| ) | ||
| if runner_conf.assign_cols_by_name: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor / forward-looking:
getattr(expected_type, "_name", None) == "Iterator"matchestyping.Iterator[T]but notcollections.abc.Iterator[T](PEP 585 form is atypes.GenericAliasand has no_name). A future caller writingfrom collections.abc import Iteratorwould silently fall through to the concrete-type branch, thenisinstance(result, Iterator[T])would raiseTypeError: isinstance() argument 2 cannot be a parameterized generic— confusing relative to the actual mistake.typing.get_origin(expected_type) is collections.abc.Iteratorhandles both forms, or a one-line docstring caveat that the helper expects thetyping.Xform would suffice. All current callers usefrom typing import Iterator, so this is forward-looking only.