[WIP][SPARK-54647][PYTHON] Support User-Defined Aggregate Functions (UDAF)#53400
[WIP][SPARK-54647][PYTHON] Support User-Defined Aggregate Functions (UDAF)#53400Yicong-Huang wants to merge 21 commits intoapache:masterfrom
Conversation
| def test_udaf_mixed_with_other_agg_not_supported(self): | ||
| """Test that mixing UDAF with other aggregate functions raises error.""" | ||
|
|
||
| class MySum(Aggregator): |
There was a problem hiding this comment.
Can we add some tests for more complicated data structures? like dictionary?
There was a problem hiding this comment.
added more data types!
| ] | ||
|
|
||
|
|
||
| class Aggregator: |
There was a problem hiding this comment.
do we necessarily need this class?
I see UDTF doesn't need a base class.
>>> class TestUDTF:
... def eval(self, *args: Any):
... yield "hello", "world"
There was a problem hiding this comment.
we could do duck typing if not go with the inheritance. I think it is debatable and we could offer both solutions (with or without a base class).
| Apply this UDAF to the given columns. | ||
|
|
||
| This creates a Column expression that can be used in DataFrame operations. | ||
| The actual aggregation is performed using mapInArrow and applyInArrow. |
There was a problem hiding this comment.
why not a dedicated pyhsical plan?
| ----- | ||
| This implementation uses mapInArrow and applyInArrow internally to perform | ||
| the aggregation. The approach follows: | ||
| 1. mapInArrow: Performs partial aggregation (reduce) on each partition |
There was a problem hiding this comment.
If we want to support partial aggregation with existing arrow UDFs, I think we should use a modified FlatMapGroupsInArrowExec with requiredChildDistribution = UnspecifiedDistribution.
| * MapInArrow, Aggregate, and FlatMapGroupsInArrow operators. | ||
| * | ||
| * This implements a three-phase aggregation pattern: | ||
| * 1. Partial aggregation (MapInArrow): Applies reduce() on each partition, outputs |
There was a problem hiding this comment.
MapInArrowExec dosen't requiredChildOrdering, where does it sort the data for partial aggregation?
There was a problem hiding this comment.
The Sort is now explicitly added in RewritePythonAggregatorUDAF before MapInArrow.
I think there should be a |
|
The whole approach is based on |
| group_buffers[grouping_key] = agg.zero() | ||
|
|
||
| if value is not None: | ||
| group_buffers[grouping_key] = agg.reduce(group_buffers[grouping_key], value) |
There was a problem hiding this comment.
group_buffers buffers all the aggregators within a partition, it will cause memory issue if the cardinality is large.
A reasonable physical plan should sort the partition by the key, and then output the partial aggregation result after finishing each group
There was a problem hiding this comment.
it mimic the HashAggregateExec, while SortAggregateExec is more stable
There was a problem hiding this comment.
thanks for the suggestion. I will take a look on different aggregateExec implementations.
|
@Yicong-Huang please let me help you as a reviewer for this, I implemented remote UDAFs several times for other systems preivously |
|
@dtenedor thanks! could you please have a pass on the current implementation? |
287e949 to
0abe9be
Compare
91b725d to
54e2d62
Compare
What changes were proposed in this pull request?
Add support for User-Defined Aggregate Functions (UDAF) in PySpark. Currently PySpark supports User-Defined Functions (UDF) and User-Defined Table Functions (UDTF), but lacks support for UDAF. Users need to write custom aggregation logic in Scala/Java or use less efficient workarounds.
This change adds UDAF support using a two-stage aggregation pattern with
mapInArrowandapplyInArrow. The basic idea is to implement aggregation (and partial aggregation) by:Where
func1callsAggregator.reduce()for partial aggregation within each partition, andfunc2callsAggregator.merge()to combine partial results, thenAggregator.finish()for final results.Aligned with Scala side, the implementation provides a Python
Aggregatorbase class that users can subclass:Users can create UDAF instances using the
udaf()function and use them withDataFrame.agg():Key changes:
pyspark.sql.udafmodule withAggregatorbase class,UserDefinedAggregateFunctionwrapper, andudaf()factory functionGroupedData.agg()by detecting UDAF columns via_udaf_funcattributeWhy are the changes needed?
Currently PySpark lacks support for User-Defined Aggregate Functions (UDAF), which limits users' ability to express complex aggregation logic directly in Python. Users must either write custom aggregation logic in Scala/Java or use less efficient workarounds. This change adds UDAF support to complement existing UDF and UDTF support in PySpark, aligning with the Scala/Java
Aggregatorinterface inorg.apache.spark.sql.expressions.Aggregator.Does this PR introduce any user-facing change?
Yes. This PR adds a new feature - User-Defined Aggregate Functions (UDAF) support in PySpark. Users can now define custom aggregation logic by subclassing the
Aggregatorclass and using theudaf()function to create UDAF instances that can be used withDataFrame.agg()andGroupedData.agg().Example:
How was this patch tested?
Added comprehensive unit tests in
python/pyspark/sql/tests/test_udaf.pycovering:groupBy().agg()df.agg()anddf.groupBy().agg()Was this patch authored or co-authored using generative AI tooling?
No.