Skip to content

Commit 015dd76

Browse files
authored
Pass Field information back and forth when using scalar UDFs (#1299)
* Pass Field information back and forth when using scalar UDFs * Add ArrowArrayExportable class and use it to create pyarrow arrays for python UDFs * Minor user documentation update * Update naming from type to field where appropriate * Add unit test to check field inputs * Update docstring * Add text to user documentation on passing field information for scalar UDFs * Minor change requested in code review * Make type hints match outer
1 parent 3227276 commit 015dd76

File tree

8 files changed

+348
-113
lines changed

8 files changed

+348
-113
lines changed

docs/source/user-guide/common-operations/udf-and-udfa.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,17 @@ converting to Python objects to do the evaluation.
9090
9191
df.select(col("a"), is_null_arr(col("a")).alias("is_null")).show()
9292
93+
In this example we passed the PyArrow ``DataType`` when we defined the function
94+
by calling ``udf()``. If you need additional control, such as specifying
95+
metadata or nullability of the input or output, you can instead specify a
96+
PyArrow ``Field``.
97+
98+
If you need to write a custom function but do not want to incur the performance
99+
cost of converting to Python objects and back, a more advanced approach is to
100+
write Rust based UDFs and to expose them to Python. There is an example in the
101+
`DataFusion blog <https://datafusion.apache.org/blog/2024/11/19/datafusion-python-udf-comparisons/>`_
102+
describing how to do this.
103+
93104
Aggregate Functions
94105
-------------------
95106

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ dev = [
141141
"maturin>=1.8.1",
142142
"numpy>1.25.0;python_version<'3.14'",
143143
"numpy>=2.3.2;python_version>='3.14'",
144+
"pyarrow>=19.0.0",
144145
"pre-commit>=4.3.0",
145146
"pyyaml>=6.0.3",
146147
"pytest>=7.4.4",

python/datafusion/user_defined.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from _typeshed import CapsuleType as _PyCapsule
3535

3636
_R = TypeVar("_R", bound=pa.DataType)
37-
from collections.abc import Callable
37+
from collections.abc import Callable, Sequence
3838

3939

4040
class Volatility(Enum):
@@ -81,6 +81,27 @@ def __str__(self) -> str:
8181
return self.name.lower()
8282

8383

84+
def data_type_or_field_to_field(value: pa.DataType | pa.Field, name: str) -> pa.Field:
85+
"""Helper function to return a Field from either a Field or DataType."""
86+
if isinstance(value, pa.Field):
87+
return value
88+
return pa.field(name, type=value)
89+
90+
91+
def data_types_or_fields_to_field_list(
92+
inputs: Sequence[pa.Field | pa.DataType] | pa.Field | pa.DataType,
93+
) -> list[pa.Field]:
94+
"""Helper function to return a list of Fields."""
95+
if isinstance(inputs, pa.DataType):
96+
return [pa.field("value", type=inputs)]
97+
if isinstance(inputs, pa.Field):
98+
return [inputs]
99+
100+
return [
101+
data_type_or_field_to_field(v, f"value_{idx}") for (idx, v) in enumerate(inputs)
102+
]
103+
104+
84105
class ScalarUDFExportable(Protocol):
85106
"""Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
86107

@@ -103,8 +124,8 @@ def __init__(
103124
self,
104125
name: str,
105126
func: Callable[..., _R],
106-
input_types: pa.DataType | list[pa.DataType],
107-
return_type: _R,
127+
input_fields: list[pa.Field],
128+
return_field: _R,
108129
volatility: Volatility | str,
109130
) -> None:
110131
"""Instantiate a scalar user-defined function (UDF).
@@ -114,10 +135,10 @@ def __init__(
114135
if hasattr(func, "__datafusion_scalar_udf__"):
115136
self._udf = df_internal.ScalarUDF.from_pycapsule(func)
116137
return
117-
if isinstance(input_types, pa.DataType):
118-
input_types = [input_types]
138+
if isinstance(input_fields, pa.DataType):
139+
input_fields = [input_fields]
119140
self._udf = df_internal.ScalarUDF(
120-
name, func, input_types, return_type, str(volatility)
141+
name, func, input_fields, return_field, str(volatility)
121142
)
122143

123144
def __repr__(self) -> str:
@@ -136,8 +157,8 @@ def __call__(self, *args: Expr) -> Expr:
136157
@overload
137158
@staticmethod
138159
def udf(
139-
input_types: list[pa.DataType],
140-
return_type: _R,
160+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
161+
return_field: pa.DataType | pa.Field,
141162
volatility: Volatility | str,
142163
name: str | None = None,
143164
) -> Callable[..., ScalarUDF]: ...
@@ -146,8 +167,8 @@ def udf(
146167
@staticmethod
147168
def udf(
148169
func: Callable[..., _R],
149-
input_types: list[pa.DataType],
150-
return_type: _R,
170+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
171+
return_field: pa.DataType | pa.Field,
151172
volatility: Volatility | str,
152173
name: str | None = None,
153174
) -> ScalarUDF: ...
@@ -163,20 +184,24 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
163184
This class can be used both as either a function or a decorator.
164185
165186
Usage:
166-
- As a function: ``udf(func, input_types, return_type, volatility, name)``.
167-
- As a decorator: ``@udf(input_types, return_type, volatility, name)``.
187+
- As a function: ``udf(func, input_fields, return_field, volatility, name)``.
188+
- As a decorator: ``@udf(input_fields, return_field, volatility, name)``.
168189
When used a decorator, do **not** pass ``func`` explicitly.
169190
191+
In lieu of passing a PyArrow Field, you can pass a DataType for simplicity.
192+
When you do so, it will be assumed that the nullability of the inputs and
193+
output are True and that they have no metadata.
194+
170195
Args:
171196
func (Callable, optional): Only needed when calling as a function.
172197
Skip this argument when using `udf` as a decorator. If you have a Rust
173198
backed ScalarUDF within a PyCapsule, you can pass this parameter
174199
and ignore the rest. They will be determined directly from the
175200
underlying function. See the online documentation for more information.
176-
input_types (list[pa.DataType]): The data types of the arguments
177-
to ``func``. This list must be of the same length as the number of
178-
arguments.
179-
return_type (_R): The data type of the return value from the function.
201+
input_fields (list[pa.Field | pa.DataType]): The data types or Fields
202+
of the arguments to ``func``. This list must be of the same length
203+
as the number of arguments.
204+
return_field (_R): The field of the return value from the function.
180205
volatility (Volatility | str): See `Volatility` for allowed values.
181206
name (Optional[str]): A descriptive name for the function.
182207
@@ -196,12 +221,12 @@ def double_func(x):
196221
@udf([pa.int32()], pa.int32(), "volatile", "double_it")
197222
def double_udf(x):
198223
return x * 2
199-
"""
224+
""" # noqa: W505 E501
200225

201226
def _function(
202227
func: Callable[..., _R],
203-
input_types: list[pa.DataType],
204-
return_type: _R,
228+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
229+
return_field: pa.DataType | pa.Field,
205230
volatility: Volatility | str,
206231
name: str | None = None,
207232
) -> ScalarUDF:
@@ -213,23 +238,25 @@ def _function(
213238
name = func.__qualname__.lower()
214239
else:
215240
name = func.__class__.__name__.lower()
241+
input_fields = data_types_or_fields_to_field_list(input_fields)
242+
return_field = data_type_or_field_to_field(return_field, "value")
216243
return ScalarUDF(
217244
name=name,
218245
func=func,
219-
input_types=input_types,
220-
return_type=return_type,
246+
input_fields=input_fields,
247+
return_field=return_field,
221248
volatility=volatility,
222249
)
223250

224251
def _decorator(
225-
input_types: list[pa.DataType],
226-
return_type: _R,
252+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
253+
return_field: _R,
227254
volatility: Volatility | str,
228255
name: str | None = None,
229256
) -> Callable:
230257
def decorator(func: Callable) -> Callable:
231258
udf_caller = ScalarUDF.udf(
232-
func, input_types, return_type, volatility, name
259+
func, input_fields, return_field, volatility, name
233260
)
234261

235262
@functools.wraps(func)
@@ -260,8 +287,8 @@ def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
260287
return ScalarUDF(
261288
name=name,
262289
func=func,
263-
input_types=None,
264-
return_type=None,
290+
input_fields=None,
291+
return_field=None,
265292
volatility=None,
266293
)
267294

python/tests/test_udf.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
import pyarrow as pa
1919
import pytest
20-
from datafusion import column, udf
20+
from datafusion import SessionContext, column, udf
21+
from datafusion import functions as f
2122

2223

2324
@pytest.fixture
@@ -124,3 +125,86 @@ def udf_with_param(values: pa.Array) -> pa.Array:
124125
result = df2.collect()[0].column(0)
125126

126127
assert result == pa.array([False, True, True])
128+
129+
130+
def test_udf_with_metadata(ctx) -> None:
131+
from uuid import UUID
132+
133+
@udf([pa.string()], pa.uuid(), "stable")
134+
def uuid_from_string(uuid_string):
135+
return pa.array((UUID(s).bytes for s in uuid_string.to_pylist()), pa.uuid())
136+
137+
@udf([pa.uuid()], pa.int64(), "stable")
138+
def uuid_version(uuid):
139+
return pa.array(s.version for s in uuid.to_pylist())
140+
141+
batch = pa.record_batch({"idx": pa.array(range(5))})
142+
results = (
143+
ctx.create_dataframe([[batch]])
144+
.with_column("uuid_string", f.uuid())
145+
.with_column("uuid", uuid_from_string(column("uuid_string")))
146+
.select(uuid_version(column("uuid").alias("uuid_version")))
147+
.collect()
148+
)
149+
150+
assert results[0][0].to_pylist() == [4, 4, 4, 4, 4]
151+
152+
153+
def test_udf_with_nullability(ctx: SessionContext) -> None:
154+
import pyarrow.compute as pc
155+
156+
field_nullable_i64 = pa.field("with_nulls", type=pa.int64(), nullable=True)
157+
field_non_nullable_i64 = pa.field("no_nulls", type=pa.int64(), nullable=False)
158+
159+
@udf([field_nullable_i64], field_nullable_i64, "stable")
160+
def nullable_abs(input_col):
161+
return pc.abs(input_col)
162+
163+
@udf([field_non_nullable_i64], field_non_nullable_i64, "stable")
164+
def non_nullable_abs(input_col):
165+
return pc.abs(input_col)
166+
167+
batch = pa.record_batch(
168+
{
169+
"with_nulls": pa.array([-2, None, 0, 1, 2]),
170+
"no_nulls": pa.array([-2, -1, 0, 1, 2]),
171+
},
172+
schema=pa.schema(
173+
[
174+
field_nullable_i64,
175+
field_non_nullable_i64,
176+
]
177+
),
178+
)
179+
ctx.register_record_batches("t", [[batch]])
180+
df = ctx.table("t")
181+
182+
# Input matches expected, nullable
183+
df_result = df.select(nullable_abs(column("with_nulls")))
184+
returned_field = df_result.schema().field(0)
185+
assert returned_field.nullable
186+
results = df_result.collect()
187+
assert results[0][0].to_pylist() == [2, None, 0, 1, 2]
188+
189+
# Input coercible to expected, nullable
190+
df_result = df.select(nullable_abs(column("no_nulls")))
191+
returned_field = df_result.schema().field(0)
192+
assert returned_field.nullable
193+
results = df_result.collect()
194+
assert results[0][0].to_pylist() == [2, 1, 0, 1, 2]
195+
196+
# Input matches expected, no nulls
197+
df_result = df.select(non_nullable_abs(column("no_nulls")))
198+
returned_field = df_result.schema().field(0)
199+
assert not returned_field.nullable
200+
results = df_result.collect()
201+
assert results[0][0].to_pylist() == [2, 1, 0, 1, 2]
202+
203+
# Invalid - requires non-nullable input but that is not possible
204+
df_result = df.select(non_nullable_abs(column("with_nulls")))
205+
returned_field = df_result.schema().field(0)
206+
assert not returned_field.nullable
207+
208+
with pytest.raises(Exception) as e_info:
209+
_results = df_result.collect()
210+
assert "InvalidArgumentError" in str(e_info)

src/array.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Array, ArrayRef};
21+
use arrow::datatypes::{Field, FieldRef};
22+
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
23+
use arrow::pyarrow::ToPyArrow;
24+
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
25+
use pyo3::types::PyCapsule;
26+
use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python};
27+
28+
use crate::errors::PyDataFusionResult;
29+
use crate::utils::validate_pycapsule;
30+
31+
/// A Python object which implements the Arrow PyCapsule for importing
32+
/// into other libraries.
33+
#[pyclass(name = "ArrowArrayExportable", module = "datafusion", frozen)]
34+
#[derive(Clone)]
35+
pub struct PyArrowArrayExportable {
36+
array: ArrayRef,
37+
field: FieldRef,
38+
}
39+
40+
#[pymethods]
41+
impl PyArrowArrayExportable {
42+
#[pyo3(signature = (requested_schema=None))]
43+
fn __arrow_c_array__<'py>(
44+
&'py self,
45+
py: Python<'py>,
46+
requested_schema: Option<Bound<'py, PyCapsule>>,
47+
) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
48+
let field = if let Some(schema_capsule) = requested_schema {
49+
validate_pycapsule(&schema_capsule, "arrow_schema")?;
50+
51+
let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
52+
let desired_field = Field::try_from(schema_ptr)?;
53+
54+
Arc::new(desired_field)
55+
} else {
56+
Arc::clone(&self.field)
57+
};
58+
59+
let ffi_schema = FFI_ArrowSchema::try_from(&field)?;
60+
let schema_capsule = PyCapsule::new(py, ffi_schema, Some(cr"arrow_schema".into()))?;
61+
62+
let ffi_array = FFI_ArrowArray::new(&self.array.to_data());
63+
let array_capsule = PyCapsule::new(py, ffi_array, Some(cr"arrow_array".into()))?;
64+
65+
Ok((schema_capsule, array_capsule))
66+
}
67+
}
68+
69+
impl ToPyArrow for PyArrowArrayExportable {
70+
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
71+
let module = py.import("pyarrow")?;
72+
let method = module.getattr("array")?;
73+
let array = method.call((self.clone(),), None)?;
74+
Ok(array)
75+
}
76+
}
77+
78+
impl PyArrowArrayExportable {
79+
pub fn new(array: ArrayRef, field: FieldRef) -> Self {
80+
Self { array, field }
81+
}
82+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub mod store;
5252
pub mod table;
5353
pub mod unparser;
5454

55+
mod array;
5556
#[cfg(feature = "substrait")]
5657
pub mod substrait;
5758
#[allow(clippy::borrow_deref_ref)]

0 commit comments

Comments
 (0)