Skip to content

Commit a189a45

Browse files
committed
[SPARK-54650][PYTHON] Move int to decimal conversion into _create_converter_from_pandas
### What changes were proposed in this pull request? Move int to decimal conversion into `_create_converter_from_pandas` ### Why are the changes needed? this conversion should be in `_create_converter_from_pandas` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #53405 from zhengruifeng/mv_int_dec. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 08c0783 commit a189a45

File tree

2 files changed

+30
-42
lines changed

2 files changed

+30
-42
lines changed

python/pyspark/sql/pandas/serializers.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
2020
"""
2121

22-
from decimal import Decimal
2322
from itertools import groupby
2423
from typing import TYPE_CHECKING, Iterator, Optional
2524

@@ -356,40 +355,6 @@ def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
356355
self._safecheck = safecheck
357356
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
358357

359-
@staticmethod
360-
def _apply_python_coercions(series, arrow_type):
361-
"""
362-
Apply additional coercions to the series in Python before converting to Arrow:
363-
- Convert integer series to decimal type.
364-
When we have a pandas series of integers that needs to be converted to
365-
pyarrow.decimal128 (with precision < 20), PyArrow fails with precision errors.
366-
Explicitly cast to Decimal first.
367-
368-
Parameters
369-
----------
370-
series : pandas.Series
371-
The series to potentially convert
372-
arrow_type : pyarrow.DataType
373-
The target arrow type
374-
375-
Returns
376-
-------
377-
pandas.Series
378-
The potentially converted pandas series
379-
"""
380-
import pyarrow.types as types
381-
import pandas as pd
382-
383-
# Convert integer series to Decimal objects
384-
if (
385-
types.is_decimal(arrow_type)
386-
and series.dtype.kind in ["i", "u"] # integer types (signed/unsigned)
387-
and not series.empty
388-
):
389-
series = series.apply(lambda x: Decimal(x) if pd.notna(x) else None)
390-
391-
return series
392-
393358
def arrow_to_pandas(
394359
self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None
395360
):
@@ -442,13 +407,13 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
442407
if arrow_type is not None:
443408
dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
444409
conv = _create_converter_from_pandas(
445-
dt, timezone=self._timezone, error_on_duplicated_field_names=False
410+
dt,
411+
timezone=self._timezone,
412+
error_on_duplicated_field_names=False,
413+
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
446414
)
447415
series = conv(series)
448416

449-
if self._int_to_decimal_coercion_enabled:
450-
series = self._apply_python_coercions(series, arrow_type)
451-
452417
if hasattr(series.array, "__arrow_array__"):
453418
mask = None
454419
else:
@@ -1046,12 +1011,10 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
10461011
timezone=self._timezone,
10471012
error_on_duplicated_field_names=False,
10481013
ignore_unexpected_complex_type_values=True,
1014+
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
10491015
)
10501016
series = conv(series)
10511017

1052-
if self._int_to_decimal_coercion_enabled:
1053-
series = self._apply_python_coercions(series, arrow_type)
1054-
10551018
if hasattr(series.array, "__arrow_array__"):
10561019
mask = None
10571020
else:

python/pyspark/sql/pandas/types.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import datetime
2323
import itertools
2424
import functools
25+
from decimal import Decimal
2526
from typing import Any, Callable, Iterable, List, Optional, Union, TYPE_CHECKING
2627

2728
from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError
@@ -1225,6 +1226,7 @@ def _create_converter_from_pandas(
12251226
timezone: Optional[str] = None,
12261227
error_on_duplicated_field_names: bool = True,
12271228
ignore_unexpected_complex_type_values: bool = False,
1229+
int_to_decimal_coercion_enabled: bool = False,
12281230
) -> Callable[["pd.Series"], "pd.Series"]:
12291231
"""
12301232
Create a converter of pandas Series to create Spark DataFrame with Arrow optimization.
@@ -1264,6 +1266,29 @@ def correct_timestamp(pser: pd.Series) -> pd.Series:
12641266

12651267
return correct_timestamp
12661268

1269+
elif isinstance(data_type, DecimalType):
1270+
if int_to_decimal_coercion_enabled:
1271+
# For decimal with low precision, e.g. pa.decimal128(1)
1272+
# pa.Array.from_pandas(pd.Series([1,2,3])).cast(pa.decimal128(1)) fails with
1273+
# ArrowInvalid: Precision is not great enough for the result.
1274+
# It should be at least 19.
1275+
# Here change it to
1276+
# pa.Array.from_pandas(pd.Series([1,2,3]).apply(
1277+
# lambda x: Decimal(x))).cast(pa.decimal128(1))
1278+
1279+
def convert_int_to_decimal(pser: pd.Series) -> pd.Series:
1280+
if pser.dtype.kind in ["i", "u"]:
1281+
return pser.apply( # type: ignore[return-value]
1282+
lambda x: Decimal(x) if pd.notna(x) else None
1283+
)
1284+
else:
1285+
return pser
1286+
1287+
return convert_int_to_decimal
1288+
1289+
else:
1290+
return lambda pser: pser
1291+
12671292
def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]:
12681293
if isinstance(dt, ArrayType):
12691294
_element_conv = _converter(dt.elementType)

0 commit comments

Comments
 (0)