diff --git a/docs/source/contributor-guide/spark_expressions_support.md b/docs/source/contributor-guide/spark_expressions_support.md index 65f941210a..d9897fe94c 100644 --- a/docs/source/contributor-guide/spark_expressions_support.md +++ b/docs/source/contributor-guide/spark_expressions_support.md @@ -40,6 +40,9 @@ - [ ] approx_top_k_combine - [ ] array_agg - [x] avg + - Spark 3.4.3 (2026-05-26) + - Spark 3.5.8 (2026-05-26): aggregate logic identical to 3.4.3 + - Spark 4.0.1 (2026-05-26): aggregate logic identical to 3.5.8; only `QueryContext` import path differs. `YearMonthIntervalType` and `DayTimeIntervalType` inputs (supported by Spark) fall back to Spark in Comet. - [x] bit_and - [x] bit_or - [x] bit_xor diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 2714a7e466..a9ae740900 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -155,8 +155,8 @@ object CometCount extends CometAggregateExpressionSerde[Count] { object CometAverage extends CometAggregateExpressionSerde[Average] { - override def getIncompatibleReasons(): Seq[String] = Seq( - "Falls back to Spark in ANSI mode. Supports all numeric inputs except decimal types.") + override def getUnsupportedReasons(): Seq[String] = Seq( + "YearMonthIntervalType and DayTimeIntervalType inputs are not supported") override def convert( aggExpr: AggregateExpression, diff --git a/spark/src/test/resources/sql-tests/expressions/aggregate/avg.sql b/spark/src/test/resources/sql-tests/expressions/aggregate/avg.sql index c718a80ded..9d41d86367 100644 --- a/spark/src/test/resources/sql-tests/expressions/aggregate/avg.sql +++ b/spark/src/test/resources/sql-tests/expressions/aggregate/avg.sql @@ -26,3 +26,107 @@ SELECT avg(i), avg(l), avg(f), avg(d) FROM test_avg query tolerance=1e-6 SELECT grp, avg(d) FROM test_avg GROUP BY grp ORDER BY grp + +-- single-row group (count == 1) +query tolerance=1e-6 +SELECT grp, avg(d) FROM test_avg WHERE i = 3 GROUP BY grp + +-- byte and short input types +statement +CREATE TABLE test_avg_small(b tinyint, s smallint, grp string) USING parquet + +statement +INSERT INTO test_avg_small VALUES (1, 100, 'a'), (2, 200, 'a'), (3, 300, 'b'), (NULL, NULL, 'b'), (-1, -100, 'a') + +query tolerance=1e-6 +SELECT avg(b), avg(s) FROM test_avg_small + +query tolerance=1e-6 +SELECT grp, avg(b), avg(s) FROM test_avg_small GROUP BY grp ORDER BY grp + +-- all-NULL input returns NULL +statement +CREATE TABLE test_avg_all_null(v double, grp string) USING parquet + +statement +INSERT INTO test_avg_all_null VALUES (NULL, 'a'), (NULL, 'a'), (NULL, 'b') + +query +SELECT avg(v) FROM test_avg_all_null + +query +SELECT grp, avg(v) FROM test_avg_all_null GROUP BY grp ORDER BY grp + +-- empty input (no rows) returns NULL +statement +CREATE TABLE test_avg_empty(v double) USING parquet + +query +SELECT avg(v) FROM test_avg_empty + +-- NaN and infinity input on doubles +statement +CREATE TABLE test_avg_special(v double, grp string) USING parquet + +statement +INSERT INTO test_avg_special VALUES + (double('NaN'), 'nan_only'), + (1.0, 'nan_only'), + (double('Infinity'), 'pos_inf_only'), + (1.0, 'pos_inf_only'), + (double('Infinity'), 'mixed_inf'), + (double('-Infinity'), 'mixed_inf'), + (double('-Infinity'), 'neg_inf_only'), + (-2.0, 'neg_inf_only') + +query tolerance=1e-6 +SELECT grp, avg(v) FROM test_avg_special GROUP BY grp ORDER BY grp + +-- boundary integer values +statement +CREATE TABLE test_avg_bounds(l long, grp string) USING parquet + +statement +INSERT INTO test_avg_bounds VALUES + (9223372036854775807, 'maxes'), + (9223372036854775807, 'maxes'), + (-9223372036854775808, 'mins'), + (-9223372036854775808, 'mins'), + (9223372036854775807, 'mixed'), + (-9223372036854775808, 'mixed') + +query tolerance=1e-6 +SELECT grp, avg(l) FROM test_avg_bounds GROUP BY grp ORDER BY grp + +-- negative-only inputs +statement +CREATE TABLE test_avg_negative(d double) USING parquet + +statement +INSERT INTO test_avg_negative VALUES (-1.5), (-2.5), (-3.5), (-0.0) + +query tolerance=1e-6 +SELECT avg(d) FROM test_avg_negative + +-- decimal column at higher precision +statement +CREATE TABLE test_avg_decimal(d decimal(20, 5), grp string) USING parquet + +statement +INSERT INTO test_avg_decimal VALUES + (10.50000, 'a'), + (20.25000, 'a'), + (NULL, 'a'), + (-5.00000, 'b'), + (0.00000, 'b'), + (5.00000, 'b') + +query +SELECT avg(d) FROM test_avg_decimal + +query +SELECT grp, avg(d) FROM test_avg_decimal GROUP BY grp ORDER BY grp + +-- count(*) and avg in the same query for cross-check +query tolerance=1e-6 +SELECT count(d), avg(d) FROM test_avg_decimal