diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index a6cced4e12..94dab0dba3 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -64,9 +64,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on [#3346](https://github.com/apache/datafusion-comet/issues/3346) - **ArrayRemove**: Returns null when the element to remove is null, instead of removing null elements from the array. [#3173](https://github.com/apache/datafusion-comet/issues/3173) -- **ArraysOverlap**: Inconsistent behavior when arrays contain NULL values. - [#3645](https://github.com/apache/datafusion-comet/issues/3645), - [#2036](https://github.com/apache/datafusion-comet/issues/2036) - **ArrayUnion**: Sorts input arrays before performing the union, while Spark preserves the order of the first array and appends unique elements from the second. [#3644](https://github.com/apache/datafusion-comet/issues/3644) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 136e1e454f..6db05f4321 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -246,7 +246,7 @@ Comet supports using the following aggregate functions within window contexts wi | ArrayRemove | No | Returns null when element is null instead of removing null elements ([#3173](https://github.com/apache/datafusion-comet/issues/3173)) | | ArrayRepeat | No | | | ArrayUnion | No | Behaves differently than spark. Comet sorts the input arrays before performing the union, while Spark preserves the order of the first array and appends unique elements from the second. | -| ArraysOverlap | No | | +| ArraysOverlap | Yes | | | CreateArray | Yes | | | ElementAt | Yes | Input must be an array. Map inputs are not supported. | | Flatten | Yes | | diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 47a6e91421..0c121e02ca 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -246,20 +246,12 @@ object CometArrayMin extends CometExpressionSerde[ArrayMin] { } object CometArraysOverlap extends CometExpressionSerde[ArraysOverlap] { - - override def getSupportLevel(expr: ArraysOverlap): SupportLevel = - Incompatible( - Some( - "Inconsistent behavior with NULL values" + - " (https://github.com/apache/datafusion-comet/issues/3645)" + - " (https://github.com/apache/datafusion-comet/issues/2036)")) - override def convert( expr: ArraysOverlap, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - val leftArrayExprProto = exprToProto(expr.children.head, inputs, binding) - val rightArrayExprProto = exprToProto(expr.children(1), inputs, binding) + val leftArrayExprProto = exprToProto(expr.left, inputs, binding) + val rightArrayExprProto = exprToProto(expr.right, inputs, binding) val arraysOverlapScalarExpr = scalarFunctionExprToProtoWithReturnType( "array_has_any", @@ -267,7 +259,41 @@ object CometArraysOverlap extends CometExpressionSerde[ArraysOverlap] { false, leftArrayExprProto, rightArrayExprProto) - optExprWithInfo(arraysOverlapScalarExpr, expr, expr.children: _*) + + val leftIsNull = createUnaryExpr( + expr, + expr.left, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNull(unaryExpr)) + val rightIsNull = createUnaryExpr( + expr, + expr.right, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNull(unaryExpr)) + + val nullLiteralProto = exprToProto(Literal(null, BooleanType), inputs) + + if (arraysOverlapScalarExpr.isDefined && leftIsNull.isDefined && + rightIsNull.isDefined && nullLiteralProto.isDefined) { + val caseWhenExpr = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(leftIsNull.get) + .addThen(nullLiteralProto.get) + .addWhen(rightIsNull.get) + .addThen(nullLiteralProto.get) + .setElseExpr(arraysOverlapScalarExpr.get) + .build() + Some( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(caseWhenExpr) + .build()) + } else { + withInfo(expr, expr.children: _*) + None + } } } diff --git a/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql b/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql index 27d28a7402..88cf38d211 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql @@ -15,7 +15,6 @@ -- specific language governing permissions and limitations -- under the License. --- Config: spark.comet.expression.ArraysOverlap.allowIncompatible=true -- ConfigMatrix: parquet.enable.dictionary=false,true statement @@ -24,11 +23,11 @@ CREATE TABLE test_arrays_overlap(a array, b array) USING parquet statement INSERT INTO test_arrays_overlap VALUES (array(1, 2, 3), array(3, 4, 5)), (array(1, 2), array(3, 4)), (array(), array(1)), (NULL, array(1)), (array(1, NULL), array(NULL, 2)) -query ignore(https://github.com/apache/datafusion-comet/issues/3645) +query SELECT arrays_overlap(a, b) FROM test_arrays_overlap -- column + literal -query ignore(https://github.com/apache/datafusion-comet/issues/3645) +query SELECT arrays_overlap(a, array(3, 4, 5)) FROM test_arrays_overlap -- literal + column diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index fb5531a573..4e28ffda3c 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -23,7 +23,7 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArraysOverlap, ArrayUnion} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArrayUnion} import org.apache.spark.sql.catalyst.expressions.{ArrayContains, ArrayRemove} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ @@ -545,27 +545,47 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } test("arrays_overlap") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[ArraysOverlap]) -> "true") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - withTempView("t1") { - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null")) - checkSparkAnswerAndOperator(spark.sql( - "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1")); - } + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null")) + checkSparkAnswerAndOperator(spark.sql( + "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1")); } } } } + test("arrays_overlap - null handling behavior verification") { + withTable("t") { + sql("create table t using parquet as select CAST(NULL as array) a1 from range(1)") + val data = Seq( + "array(1, 2, 3)", + "array(3, 4, 5)", + "array(1, 2)", + "array(3, 4)", + "array(1, null, 3)", + "array(4, 5)", + "array(1, 4)", + "array(1, null)", + "array(2, null)", + "array()", + "array(null)", + "a1") + for (y <- data; x <- data) { + checkSparkAnswerAndOperator(sql(s"SELECT arrays_overlap($y, $x) from t")) + } + } + } + test("array_compact") { // TODO fix for Spark 4.0.0 assume(!isSpark40Plus)