From 39dd966ab7233cd8690d5ba29964766b0a773605 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 24 Apr 2026 06:30:50 +0000 Subject: [PATCH] refactor: migrate verify_arrow_result into enforce_schema for grouped/cogrouped map Arrow UDF paths --- python/pyspark/sql/conversion.py | 115 +++++++++---- .../sql/tests/arrow/test_arrow_udtf.py | 9 +- python/pyspark/sql/tests/test_conversion.py | 66 +++++++- python/pyspark/worker.py | 156 ++++++++---------- 4 files changed, 214 insertions(+), 132 deletions(-) diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 8fc4fa5cc0cc7..bca3843abd7b6 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, overload import pyspark -from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.errors import PySparkRuntimeError, PySparkValueError from pyspark.sql.pandas.types import ( _dedup_names, _deduplicate_field_names, @@ -110,19 +110,20 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch": @classmethod def enforce_schema( cls, - batch: "pa.RecordBatch", + batch: Union["pa.RecordBatch", "pa.Table"], arrow_schema: "pa.Schema", *, arrow_cast: bool = True, safecheck: bool = True, - ) -> "pa.RecordBatch": + reorder_by_name: bool = True, + ) -> Union["pa.RecordBatch", "pa.Table"]: """ - Enforce target schema on a RecordBatch by reordering columns and coercing types. + Enforce a target schema on an Arrow RecordBatch or Table. Parameters ---------- - batch : pa.RecordBatch - Input RecordBatch to transform. + batch : pa.RecordBatch or pa.Table + Input to transform. Output is of the same container type. arrow_schema : pa.Schema Target Arrow schema. Callers should pre-compute this once via to_arrow_schema() to avoid repeated conversion. @@ -131,11 +132,26 @@ def enforce_schema( If False, raise an error on type mismatch instead of casting. safecheck : bool, default True If True, use safe casting (fails on overflow/truncation). + reorder_by_name : bool, default True + If True, match columns by name and reorder to the target order; any + missing or extra names raise ``RESULT_COLUMN_NAMES_MISMATCH``. Output + columns are renamed to target names. + If False, match columns by position (ignore names) and preserve the + original column names in the output. Returns ------- - pa.RecordBatch - RecordBatch with columns reordered and types coerced to match target schema. + pa.RecordBatch or pa.Table + Same container type as ``batch``, with columns matched (and possibly + reordered/cast) per the target schema. + + Raises + ------ + PySparkRuntimeError + ``RESULT_COLUMN_NAMES_MISMATCH`` when ``reorder_by_name=True`` and the + batch has missing or extra column names. + ``RESULT_COLUMN_TYPES_MISMATCH`` when any column's type does not match + the target (and either ``arrow_cast=False`` or the cast itself fails). """ import pyarrow as pa @@ -146,37 +162,68 @@ def enforce_schema( if batch.schema.equals(arrow_schema, check_metadata=False): return batch - # Check if columns are in the same order (by name) as the target schema. - # If so, use index-based access (faster than name lookup). - batch_names = [batch.schema.field(i).name for i in range(batch.num_columns)] target_names = [field.name for field in arrow_schema] - use_index = batch_names == target_names - coerced_arrays = [] - for i, field in enumerate(arrow_schema): - try: - arr = batch.column(i) if use_index else batch.column(field.name) - except KeyError: - raise PySparkTypeError( - f"Result column '{field.name}' does not exist in the output. " - f"Expected schema: {arrow_schema}, got: {batch.schema}." + # Step 1: pick source columns from batch to align with target schema + if reorder_by_name: + batch_names = [batch.schema.field(i).name for i in range(batch.num_columns)] + missing = sorted(set(target_names) - set(batch_names)) + extra = sorted(set(batch_names) - set(target_names)) + if missing or extra: + raise PySparkRuntimeError( + errorClass="RESULT_COLUMN_NAMES_MISMATCH", + messageParameters={ + "missing": f" Missing: {', '.join(missing)}." if missing else "", + "extra": f" Unexpected: {', '.join(extra)}." if extra else "", + }, ) - if arr.type != field.type: - if not arrow_cast: - raise PySparkTypeError( - f"Result type of column '{field.name}' does not match " - f"the expected type. Expected: {field.type}, got: {arr.type}." - ) + source_columns = [batch.column(name) for name in target_names] + output_names = target_names + else: + # Positional: require exact column-count match, then take columns by + # index, preserving the batch's original column names. + if batch.num_columns != len(arrow_schema): + raise PySparkRuntimeError( + errorClass="RESULT_COLUMN_SCHEMA_MISMATCH", + messageParameters={ + "expected": str(len(arrow_schema)), + "actual": str(batch.num_columns), + }, + ) + source_columns = [batch.column(i) for i in range(len(arrow_schema))] + output_names = [batch.schema.field(i).name for i in range(len(arrow_schema))] + + # Step 2: check types / cast, collect all mismatches + type_mismatches = [] + coerced_arrays = [] + for field, arr in zip(arrow_schema, source_columns): + if arr.type == field.type: + coerced_arrays.append(arr) + elif not arrow_cast: + type_mismatches.append((field.name, field.type, arr.type)) + coerced_arrays.append(arr) + else: try: - arr = arr.cast(target_type=field.type, safe=safecheck) - except (pa.ArrowInvalid, pa.ArrowTypeError) as e: - raise PySparkTypeError( - f"Result type of column '{field.name}' does not match " - f"the expected type. Expected: {field.type}, got: {arr.type}." - ) from e - coerced_arrays.append(arr) + coerced_arrays.append(arr.cast(target_type=field.type, safe=safecheck)) + except (pa.ArrowInvalid, pa.ArrowTypeError): + type_mismatches.append((field.name, field.type, arr.type)) + coerced_arrays.append(arr) + + if type_mismatches: + raise PySparkRuntimeError( + errorClass="RESULT_COLUMN_TYPES_MISMATCH", + messageParameters={ + "mismatch": ", ".join( + f"column '{name}' (expected {expected}, actual {actual})" + for name, expected, actual in type_mismatches + ) + }, + ) - return pa.RecordBatch.from_arrays(coerced_arrays, names=target_names) + # Preserve input container type (Table vs RecordBatch) + if isinstance(batch, pa.Table): + return pa.Table.from_arrays(coerced_arrays, names=output_names) + return pa.RecordBatch.from_arrays(coerced_arrays, names=output_names) @classmethod def to_pandas( diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py index f41b7613ec42d..b82523005ac72 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py @@ -211,9 +211,8 @@ def eval(self) -> Iterator["pa.Table"]: with self.assertRaisesRegex( PythonException, - r"(?s)Result column 'x' does not exist in the output\. " - r"Expected schema: x: int32\ny: string, " - r"got: wrong_col: int32\nanother_wrong_col: double\.", + r"(?s)\[RESULT_COLUMN_NAMES_MISMATCH\].*" + r"Missing: x, y\..*Unexpected: another_wrong_col, wrong_col\.", ): result_df = MismatchedSchemaUDTF() result_df.collect() @@ -375,8 +374,8 @@ def eval(self) -> Iterator["pa.Table"]: # Should fail with Arrow cast exception since string cannot be cast to int with self.assertRaisesRegex( PythonException, - "Result type of column 'id' does not match " - "the expected type. Expected: int32, got: string.", + r"(?s)\[RESULT_COLUMN_TYPES_MISMATCH\].*" + r"column 'id' \(expected int32, actual string\)", ): result_df = StringToIntUDTF() result_df.collect() diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index 304d8be740d41..dd5c7f44d2818 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -18,7 +18,7 @@ import unittest from zoneinfo import ZoneInfo -from pyspark.errors import PySparkTypeError, PySparkValueError +from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError from pyspark.sql.conversion import ( ArrowArrayToPandasConversion, ArrowTableToRowsConversion, @@ -185,8 +185,9 @@ def test_enforce_schema_arrow_cast_false(self): batch = pa.RecordBatch.from_arrays([pa.array([1], type=pa.int32())], names=["x"]) target = pa.schema([("x", pa.int64())]) - with self.assertRaises(PySparkTypeError): + with self.assertRaises(PySparkRuntimeError) as cm: ArrowBatchTransformer.enforce_schema(batch, target, arrow_cast=False) + self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_TYPES_MISMATCH") def test_enforce_schema_safecheck(self): """safecheck=True rejects overflow; safecheck=False allows it.""" @@ -194,18 +195,73 @@ def test_enforce_schema_safecheck(self): batch = pa.RecordBatch.from_arrays([pa.array([999], type=pa.int64())], names=["x"]) target = pa.schema([("x", pa.int8())]) - with self.assertRaises(PySparkTypeError): + with self.assertRaises(PySparkRuntimeError) as cm: ArrowBatchTransformer.enforce_schema(batch, target, safecheck=True) + self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_TYPES_MISMATCH") result = ArrowBatchTransformer.enforce_schema(batch, target, safecheck=False) self.assertEqual(result.schema, target) def test_enforce_schema_missing_column(self): - """Missing column raises PySparkTypeError.""" + """Missing column raises RESULT_COLUMN_NAMES_MISMATCH.""" import pyarrow as pa batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) - with self.assertRaises(PySparkTypeError): + with self.assertRaises(PySparkRuntimeError) as cm: ArrowBatchTransformer.enforce_schema(batch, pa.schema([("missing", pa.int64())])) + self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_NAMES_MISMATCH") + + def test_enforce_schema_extra_column(self): + """Extra column raises RESULT_COLUMN_NAMES_MISMATCH with the extra name listed.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1]), pa.array([2])], names=["a", "b"]) + with self.assertRaises(PySparkRuntimeError) as cm: + ArrowBatchTransformer.enforce_schema(batch, pa.schema([("a", pa.int64())])) + self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_NAMES_MISMATCH") + self.assertIn("b", str(cm.exception)) + + def test_enforce_schema_reorder_by_name(self): + """reorder_by_name=True reorders input columns to match target schema order.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array(["x"]), pa.array([1])], names=["b", "a"]) + target = pa.schema([("a", pa.int64()), ("b", pa.string())]) + result = ArrowBatchTransformer.enforce_schema(batch, target) + self.assertEqual(result.schema.names, ["a", "b"]) + self.assertEqual(result.column(0).to_pylist(), [1]) + self.assertEqual(result.column(1).to_pylist(), ["x"]) + + def test_enforce_schema_positional(self): + """reorder_by_name=False matches columns by index, preserving input names.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1]), pa.array(["x"])], names=["foo", "bar"]) + target = pa.schema([("a", pa.int64()), ("b", pa.string())]) + result = ArrowBatchTransformer.enforce_schema(batch, target, reorder_by_name=False) + # Input column names are preserved + self.assertEqual(result.schema.names, ["foo", "bar"]) + self.assertEqual(result.column(0).to_pylist(), [1]) + self.assertEqual(result.column(1).to_pylist(), ["x"]) + + def test_enforce_schema_positional_count_mismatch(self): + """reorder_by_name=False with wrong column count raises RESULT_COLUMN_SCHEMA_MISMATCH.""" + import pyarrow as pa + + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + target = pa.schema([("x", pa.int64()), ("y", pa.int64())]) + with self.assertRaises(PySparkRuntimeError) as cm: + ArrowBatchTransformer.enforce_schema(batch, target, reorder_by_name=False) + self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_SCHEMA_MISMATCH") + + def test_enforce_schema_table_input(self): + """enforce_schema accepts pa.Table and returns pa.Table.""" + import pyarrow as pa + + table = pa.table({"x": pa.array([1], type=pa.int32())}) + target = pa.schema([("x", pa.int64())]) + result = ArrowBatchTransformer.enforce_schema(table, target) + self.assertIsInstance(result, pa.Table) + self.assertEqual(result.schema, target) @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 95a7ccdc4f8dc..e0277c61c9a6b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -512,14 +512,11 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): - if runner_conf.assign_cols_by_name: - expected_cols_and_types = { - col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields - } - else: - expected_cols_and_types = [ - (col.name, to_arrow_type(col.dataType, timezone="UTC")) for col in return_type.fields - ] + import pyarrow as pa + + arrow_return_schema = pa.schema( + [(col.name, to_arrow_type(col.dataType, timezone="UTC")) for col in return_type.fields] + ) def wrapped(left_key_table, left_value_table, right_key_table, right_value_table): if len(argspec.args) == 2: @@ -529,7 +526,21 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table key = tuple(c[0] for c in key_table.columns) result = f(key, left_value_table, right_value_table) - verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) + if not isinstance(result, pa.Table): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": "pyarrow.Table", + "actual": type(result).__name__, + }, + ) + + result = ArrowBatchTransformer.enforce_schema( + result, + arrow_return_schema, + arrow_cast=False, + reorder_by_name=runner_conf.assign_cols_by_name, + ) return result.to_batches() @@ -539,28 +550,6 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table ) -def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): - def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): - import pandas as pd - - left_df = pd.concat(left_value_series, axis=1) - right_df = pd.concat(right_value_series, axis=1) - - if len(argspec.args) == 2: - result = f(left_df, right_df) - elif len(argspec.args) == 3: - key_series = left_key_series if not left_df.empty else right_key_series - key = tuple(s[0] for s in key_series) - result = f(key, left_df, right_df) - verify_pandas_result( - result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False - ) - - return result - - return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), return_type)] - - def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): # the types of the fields have to be identical to return type # an empty table can have no columns; if there are columns, they have to match @@ -622,34 +611,26 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types): ) -def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types): - import pyarrow as pa - - if not isinstance(table, pa.Table): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "pyarrow.Table", - "actual": type(table).__name__, - }, - ) - - verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types) - +def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): + def wrapped(left_key_series, left_value_series, right_key_series, right_value_series): + import pandas as pd -def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types): - import pyarrow as pa + left_df = pd.concat(left_value_series, axis=1) + right_df = pd.concat(right_value_series, axis=1) - if not isinstance(batch, pa.RecordBatch): - raise PySparkTypeError( - errorClass="UDF_RETURN_TYPE", - messageParameters={ - "expected": "pyarrow.RecordBatch", - "actual": type(batch).__name__, - }, + if len(argspec.args) == 2: + result = f(left_df, right_df) + elif len(argspec.args) == 3: + key_series = left_key_series if not left_df.empty else right_key_series + key = tuple(s[0] for s in key_series) + result = f(key, left_df, right_df) + verify_pandas_result( + result, return_type, runner_conf.assign_cols_by_name, truncate_return_schema=False ) - verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types) + return result + + return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), return_type)] def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): @@ -2812,15 +2793,7 @@ def grouped_func( arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) - if runner_conf.assign_cols_by_name: - expected_cols_and_types = { - col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields - } - else: - expected_cols_and_types = [ - (col.name, to_arrow_type(col.dataType, timezone="UTC")) - for col in return_type.fields - ] + arrow_return_schema = pa.schema(list(arrow_return_type)) key_offsets = parsed_offsets[0][0] value_offsets = parsed_offsets[0][1] @@ -2855,15 +2828,24 @@ def grouped_func( key = tuple(c[0] for c in keys.columns) result = grouped_udf(key, value_table) - verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types) + if not isinstance(result, pa.Table): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": "pyarrow.Table", + "actual": type(result).__name__, + }, + ) + + # Verify types (and reorder by name when configured). + result = ArrowBatchTransformer.enforce_schema( + result, + arrow_return_schema, + arrow_cast=False, + reorder_by_name=runner_conf.assign_cols_by_name, + ) - # Reorder columns if needed and wrap into struct for batch in result.to_batches(): - if runner_conf.assign_cols_by_name: - batch = pa.RecordBatch.from_arrays( - [batch.column(field.name) for field in arrow_return_type], - names=[field.name for field in arrow_return_type], - ) yield ArrowBatchTransformer.wrap_struct(batch) # profiling is not supported for UDF @@ -2882,15 +2864,7 @@ def grouped_func( arrow_return_type = to_arrow_type( return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types ) - if runner_conf.assign_cols_by_name: - expected_cols_and_types = { - col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields - } - else: - expected_cols_and_types = [ - (col.name, to_arrow_type(col.dataType, timezone="UTC")) - for col in return_type.fields - ] + arrow_return_schema = pa.schema(list(arrow_return_type)) key_offsets = parsed_offsets[0][0] value_offsets = parsed_offsets[0][1] @@ -2924,16 +2898,22 @@ def grouped_func( key = tuple(c[0] for c in keys.columns) result = grouped_udf(key, value_batches) - # Verify, reorder, and wrap each output batch + # Verify (and reorder by name when configured) each output batch for batch in result: - verify_arrow_batch( - batch, runner_conf.assign_cols_by_name, expected_cols_and_types - ) - if runner_conf.assign_cols_by_name: - batch = pa.RecordBatch.from_arrays( - [batch.column(field.name) for field in arrow_return_type], - names=[field.name for field in arrow_return_type], + if not isinstance(batch, pa.RecordBatch): + raise PySparkTypeError( + errorClass="UDF_RETURN_TYPE", + messageParameters={ + "expected": "pyarrow.RecordBatch", + "actual": type(batch).__name__, + }, ) + batch = ArrowBatchTransformer.enforce_schema( + batch, + arrow_return_schema, + arrow_cast=False, + reorder_by_name=runner_conf.assign_cols_by_name, + ) yield ArrowBatchTransformer.wrap_struct(batch) # Drain remaining input batches to maintain stream position