diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index a6cced4e12..906fac0179 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -80,11 +80,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on timezone is UTC. [#2649](https://github.com/apache/datafusion-comet/issues/2649) -### Math Expressions - -- **Tan**: `tan(-0.0)` produces `0.0` instead of `-0.0`. - [#1897](https://github.com/apache/datafusion-comet/issues/1897) - ### Aggregate Expressions - **Corr**: Returns null instead of NaN in some edge cases. diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 136e1e454f..9c28cb4ddb 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -121,48 +121,48 @@ Expressions that are not Spark-compatible will fall back to Spark by default and ## Math Expressions -| Expression | SQL | Spark-Compatible? | Compatibility Notes | -| -------------- | --------- | ----------------- | ----------------------------------------------------------------------------------------------------- | -| Abs | `abs` | Yes | | -| Acos | `acos` | Yes | | -| Add | `+` | Yes | | -| Asin | `asin` | Yes | | -| Atan | `atan` | Yes | | -| Atan2 | `atan2` | Yes | | -| BRound | `bround` | Yes | | -| Ceil | `ceil` | Yes | | -| Cos | `cos` | Yes | | -| Cosh | `cosh` | Yes | | -| Cot | `cot` | Yes | | -| Divide | `/` | Yes | | -| Exp | `exp` | Yes | | -| Expm1 | `expm1` | Yes | | -| Floor | `floor` | Yes | | -| Hex | `hex` | Yes | | -| IntegralDivide | `div` | Yes | | -| IsNaN | `isnan` | Yes | | -| Log | `log` | Yes | | -| Log2 | `log2` | Yes | | -| Log10 | `log10` | Yes | | -| Multiply | `*` | Yes | | -| Pow | `power` | Yes | | -| Rand | `rand` | Yes | | -| Randn | `randn` | Yes | | -| Remainder | `%` | Yes | | -| Round | `round` | Yes | | -| Signum | `signum` | Yes | | -| Sin | `sin` | Yes | | -| Sinh | `sinh` | Yes | | -| Sqrt | `sqrt` | Yes | | -| Subtract | `-` | Yes | | -| Tan | `tan` | No | tan(-0.0) produces incorrect result ([#1897](https://github.com/apache/datafusion-comet/issues/1897)) | -| Tanh | `tanh` | Yes | | -| TryAdd | `try_add` | Yes | Only integer inputs are supported | -| TryDivide | `try_div` | Yes | Only integer inputs are supported | -| TryMultiply | `try_mul` | Yes | Only integer inputs are supported | -| TrySubtract | `try_sub` | Yes | Only integer inputs are supported | -| UnaryMinus | `-` | Yes | | -| Unhex | `unhex` | Yes | | +| Expression | SQL | Spark-Compatible? | Compatibility Notes | +| -------------- | --------- | ----------------- | --------------------------------- | +| Abs | `abs` | Yes | | +| Acos | `acos` | Yes | | +| Add | `+` | Yes | | +| Asin | `asin` | Yes | | +| Atan | `atan` | Yes | | +| Atan2 | `atan2` | Yes | | +| BRound | `bround` | Yes | | +| Ceil | `ceil` | Yes | | +| Cos | `cos` | Yes | | +| Cosh | `cosh` | Yes | | +| Cot | `cot` | Yes | | +| Divide | `/` | Yes | | +| Exp | `exp` | Yes | | +| Expm1 | `expm1` | Yes | | +| Floor | `floor` | Yes | | +| Hex | `hex` | Yes | | +| IntegralDivide | `div` | Yes | | +| IsNaN | `isnan` | Yes | | +| Log | `log` | Yes | | +| Log2 | `log2` | Yes | | +| Log10 | `log10` | Yes | | +| Multiply | `*` | Yes | | +| Pow | `power` | Yes | | +| Rand | `rand` | Yes | | +| Randn | `randn` | Yes | | +| Remainder | `%` | Yes | | +| Round | `round` | Yes | | +| Signum | `signum` | Yes | | +| Sin | `sin` | Yes | | +| Sinh | `sinh` | Yes | | +| Sqrt | `sqrt` | Yes | | +| Subtract | `-` | Yes | | +| Tan | `tan` | Yes | | +| Tanh | `tanh` | Yes | | +| TryAdd | `try_add` | Yes | Only integer inputs are supported | +| TryDivide | `try_div` | Yes | Only integer inputs are supported | +| TryMultiply | `try_mul` | Yes | Only integer inputs are supported | +| TrySubtract | `try_sub` | Yes | Only integer inputs are supported | +| UnaryMinus | `-` | Yes | | +| Unhex | `unhex` | Yes | | ## Hashing Functions diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 59fb0f9819..09d35d533d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -117,7 +117,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Sinh] -> CometScalarFunction("sinh"), classOf[Sqrt] -> CometScalarFunction("sqrt"), classOf[Subtract] -> CometSubtract, - classOf[Tan] -> CometTan, + classOf[Tan] -> CometScalarFunction("tan"), classOf[Tanh] -> CometScalarFunction("tanh"), classOf[Cot] -> CometScalarFunction("cot"), classOf[UnaryMinus] -> CometUnaryMinus, diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index 03fe4aaa82..a01d4cdf9d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Tan, Unhex} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Unhex} import org.apache.spark.sql.types.{DecimalType, DoubleType, NumericType} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -30,8 +30,11 @@ object CometAtan2 extends CometExpressionSerde[Atan2] { expr: Atan2, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - val leftExpr = exprToProtoInternal(expr.left, inputs, binding) - val rightExpr = exprToProtoInternal(expr.right, inputs, binding) + // Spark adds +0.0 to inputs in order to convert -0.0 to +0.0 + val left = Add(expr.left, Literal.default(expr.left.dataType)) + val right = Add(expr.right, Literal.default(expr.right.dataType)) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, expr.left, expr.right) } @@ -189,24 +192,6 @@ object CometAbs extends CometExpressionSerde[Abs] with MathExprBase { } } -object CometTan extends CometExpressionSerde[Tan] { - - override def getSupportLevel(expr: Tan): SupportLevel = - Incompatible( - Some( - "tan(-0.0) produces incorrect result" + - " (https://github.com/apache/datafusion-comet/issues/1897)")) - - override def convert( - expr: Tan, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProto("tan", childExpr: _*) - optExprWithInfo(optExpr, expr, expr.children: _*) - } -} - sealed trait MathExprBase { protected def nullIfNegative(expression: Expression): Expression = { val zero = Literal.default(expression.dataType) diff --git a/spark/src/test/resources/sql-tests/expressions/math/atan2.sql b/spark/src/test/resources/sql-tests/expressions/math/atan2.sql index 7a912930b8..47c7fe8ba5 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/atan2.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/atan2.sql @@ -21,7 +21,15 @@ statement CREATE TABLE test_atan2(y double, x double) USING parquet statement -INSERT INTO test_atan2 VALUES (0.0, 1.0), (1.0, 0.0), (1.0, 1.0), (-1.0, -1.0), (0.0, 0.0), (NULL, 1.0), (1.0, NULL), (cast('NaN' as double), 1.0), (cast('Infinity' as double), 1.0) +INSERT INTO test_atan2 VALUES + (0.0, 0.0), (0.0, -0.0), (0.0, 1.0), (0.0, -1.0), (0.0, NULL), (0.0, cast('NaN' as double)), (0.0, cast('Infinity' as double)), (0.0, cast('-Infinity' as double)), + (-0.0, 0.0), (-0.0, -0.0), (-0.0, 1.0), (-0.0, -1.0), (-0.0, NULL), (-0.0, cast('NaN' as double)), (-0.0, cast('Infinity' as double)), (-0.0, cast('-Infinity' as double)), + (1.0, 0.0), (1.0, -0.0), (1.0, 1.0), (1.0, -1.0), (1.0, NULL), (1.0, cast('NaN' as double)), (1.0, cast('Infinity' as double)), (1.0, cast('-Infinity' as double)), + (-1.0, 0.0), (-1.0, -0.0), (-1.0, 1.0), (-1.0, -1.0), (-1.0, NULL), (-1.0, cast('NaN' as double)), (-1.0, cast('Infinity' as double)), (-1.0, cast('-Infinity' as double)), + (NULL, 0.0), (NULL, -0.0), (NULL, 1.0), (NULL, -1.0), (NULL, NULL), (NULL, cast('NaN' as double)), (NULL, cast('Infinity' as double)), (NULL, cast('-Infinity' as double)), + (cast('NaN' as double), 0.0), (cast('NaN' as double), -0.0), (cast('NaN' as double), 1.0), (cast('NaN' as double), -1.0), (cast('NaN' as double), NULL), (cast('NaN' as double), cast('NaN' as double)), (cast('NaN' as double), cast('Infinity' as double)), (cast('NaN' as double), cast('-Infinity' as double)), + (cast('Infinity' as double), 0.0), (cast('Infinity' as double), -0.0), (cast('Infinity' as double), 1.0), (cast('Infinity' as double), -1.0), (cast('Infinity' as double), NULL), (cast('Infinity' as double), cast('NaN' as double)), (cast('Infinity' as double), cast('Infinity' as double)), (cast('Infinity' as double), cast('-Infinity' as double)), + (cast('-Infinity' as double), 0.0), (cast('-Infinity' as double), -0.0), (cast('-Infinity' as double), 1.0), (cast('-Infinity' as double), -1.0), (cast('-Infinity' as double), NULL), (cast('-Infinity' as double), cast('NaN' as double)), (cast('-Infinity' as double), cast('Infinity' as double)), (cast('-Infinity' as double), cast('-Infinity' as double)) query tolerance=1e-6 SELECT atan2(y, x) FROM test_atan2 @@ -34,6 +42,9 @@ SELECT atan2(y, 1.0) FROM test_atan2 query tolerance=1e-6 SELECT atan2(1.0, x) FROM test_atan2 --- literal + literal +-- literal permutations query tolerance=1e-6 -SELECT atan2(1.0, 1.0), atan2(0.0, 0.0), atan2(-1.0, -1.0), atan2(NULL, 1.0) +SELECT atan2(0.0, 0.0), atan2(0.0, -0.0), atan2(0.0, 1.0), atan2(0.0, -1.0), + atan2(-0.0, 0.0), atan2(-0.0, -0.0), atan2(-0.0, 1.0), atan2(-0.0, -1.0), + atan2(1.0, 0.0), atan2(1.0, -0.0), atan2(1.0, 1.0), atan2(1.0, -1.0), + atan2(-1.0, 0.0), atan2(-1.0, -0.0), atan2(-1.0, 1.0), atan2(-1.0, -1.0) diff --git a/spark/src/test/resources/sql-tests/expressions/math/tan.sql b/spark/src/test/resources/sql-tests/expressions/math/tan.sql index 9496844804..dd03ffb620 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/tan.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/tan.sql @@ -22,11 +22,11 @@ statement CREATE TABLE test_tan(d double) USING parquet statement -INSERT INTO test_tan VALUES (0.0), (0.7853981633974483), (-0.7853981633974483), (1.0), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)) +INSERT INTO test_tan VALUES (0.0), (-0.0), (0.7853981633974483), (-0.7853981633974483), (1.0), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)) query tolerance=1e-6 SELECT tan(d) FROM test_tan -- literal arguments query tolerance=1e-6 -SELECT tan(0.0), tan(0.7853981633974483), tan(NULL) +SELECT tan(0.0), tan(-0.0), tan(0.7853981633974483), tan(NULL) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 1e66c9f599..78a19983d9 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, StructsToJson, Tan, TruncDate, TruncTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, StructsToJson, TruncDate, TruncTimestamp} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps import org.apache.spark.sql.comet.CometProjectExec import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} @@ -1333,8 +1333,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { private val doubleValues: Seq[Double] = Seq( -1.0, - // TODO we should eventually enable negative zero but there are known issues still - // -0.0, + -0.0, 0.0, +1.0, Double.MinValue, @@ -1345,42 +1344,41 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { Double.NegativeInfinity) test("various math scalar functions") { - val data = doubleValues.map(n => (n, n)) - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Tan]) -> "true") { - withParquetTable(data, "tbl") { - // expressions with single arg - for (expr <- Seq( - "acos", - "asin", - "atan", - "cos", - "cosh", - "exp", - "ln", - "log10", - "log2", - "sin", - "sinh", - "sqrt", - "tan", - "tanh", - "cot")) { - val (_, cometPlan) = - checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1), $expr(_2) FROM tbl")) - val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec => - op - } - assert(cometProjectExecs.length == 1, expr) - } - // expressions with two args - for (expr <- Seq("atan2", "pow")) { - val (_, cometPlan) = - checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1, _2) FROM tbl")) - val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec => - op - } - assert(cometProjectExecs.length == 1, expr) - } + withParquetTable(doubleValues.map(n => (n, n)), "tbl") { + // expressions with single arg + for (expr <- Seq( + "acos", + "asin", + "atan", + "cos", + "cosh", + "exp", + "ln", + "log10", + "log2", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + "cot")) { + val (_, cometPlan) = + checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1), $expr(_2) FROM tbl")) + val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec => + op + } + assert(cometProjectExecs.length == 1, expr) + } + } + withParquetTable(doubleValues.flatMap(m => doubleValues.map(n => (m, n))), "tbl") { + // expressions with two args + for (expr <- Seq("atan2", "pow")) { + val (_, cometPlan) = + checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1, _2) FROM tbl")) + val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec => + op + } + assert(cometProjectExecs.length == 1, expr) } } }