diff --git a/native/spark-expr/src/json_funcs/to_json.rs b/native/spark-expr/src/json_funcs/to_json.rs index 3cc827f210..b8153e78e3 100644 --- a/native/spark-expr/src/json_funcs/to_json.rs +++ b/native/spark-expr/src/json_funcs/to_json.rs @@ -235,7 +235,7 @@ fn struct_to_json(array: &StructArray, timezone: &str) -> Result { json.push_str("\":"); // value let string_value = string_arrays[col_index].value(row_index); - if is_string[col_index] { + if is_string[col_index] || is_infinity(string_value) || is_nan(string_value) { json.push('"'); json.push_str(&escape_string(string_value)); json.push('"'); @@ -252,6 +252,14 @@ fn struct_to_json(array: &StructArray, timezone: &str) -> Result { Ok(Arc::new(builder.finish())) } +fn is_infinity(input: &str) -> bool { + input == "Infinity" || input == "-Infinity" +} + +fn is_nan(input: &str) -> bool { + input == "NaN" +} + #[cfg(test)] mod test { use crate::json_funcs::to_json::struct_to_json; diff --git a/spark/src/main/scala/org/apache/comet/serde/structs.scala b/spark/src/main/scala/org/apache/comet/serde/structs.scala index 449d0fc5b9..43b999a7d1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/structs.scala +++ b/spark/src/main/scala/org/apache/comet/serde/structs.scala @@ -105,53 +105,37 @@ object CometGetArrayStructFields extends CometExpressionSerde[GetArrayStructFiel object CometStructsToJson extends CometExpressionSerde[StructsToJson] { - override def getSupportLevel(expr: StructsToJson): SupportLevel = - Incompatible( - Some( - "Does not support Infinity/-Infinity for numeric types" + - " (https://github.com/apache/datafusion-comet/issues/3016)")) + override def getSupportLevel(expr: StructsToJson): SupportLevel = { + if (expr.options.nonEmpty) { + return Unsupported(Some("StructsToJson with options is not supported")) + } + val dataType = expr.child.dataType + if (!isSupportedType(dataType)) { + return Unsupported(Some(s"Struct type: $dataType contains unsupported types")) + } + Compatible() + } override def convert( expr: StructsToJson, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - if (expr.options.nonEmpty) { - withInfo(expr, "StructsToJson with options is not supported") - None - } else { - val isSupported = expr.child.dataType match { - case s: StructType => - s.fields.forall(f => isSupportedType(f.dataType)) - case _: MapType | _: ArrayType => - // Spark supports map and array in StructsToJson but this is not yet - // implemented in Comet - false - case _ => - false - } - - if (isSupported) { - exprToProtoInternal(expr.child, inputs, binding) match { - case Some(p) => - val toJson = ExprOuterClass.ToJson - .newBuilder() - .setChild(p) - .setTimezone(expr.timeZoneId.getOrElse("UTC")) - .setIgnoreNullFields(true) - .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setToJson(toJson) - .build()) - case _ => - withInfo(expr, expr.child) - None - } - } else { - withInfo(expr, "Unsupported data type", expr.child) + exprToProtoInternal(expr.child, inputs, binding) match { + case Some(p) => + val toJson = ExprOuterClass.ToJson + .newBuilder() + .setChild(p) + .setTimezone(expr.timeZoneId.getOrElse("UTC")) + .setIgnoreNullFields(true) + .build() + Some( + ExprOuterClass.Expr + .newBuilder() + .setToJson(toJson) + .build()) + case _ => + withInfo(expr, expr.child) None - } } } diff --git a/spark/src/test/resources/sql-tests/expressions/struct/structs_to_json.sql b/spark/src/test/resources/sql-tests/expressions/struct/structs_to_json.sql index 7f2310f147..776d11241e 100644 --- a/spark/src/test/resources/sql-tests/expressions/struct/structs_to_json.sql +++ b/spark/src/test/resources/sql-tests/expressions/struct/structs_to_json.sql @@ -29,3 +29,11 @@ SELECT to_json(named_struct('a', a, 'b', b)) FROM test_to_json -- literal arguments query spark_answer_only SELECT to_json(named_struct('a', 1, 'b', 'hello')) + +-- query expect_fallback(StructsToJson with options is not supported) +query ignore("Need support Spark 4.0.0") +SELECT to_json(named_struct('a', a, 'b', b), map('timestampFormat', 'dd/MM/yyyy')) + +-- query expect_fallback(Struct type: StructType(StructField(a,ArrayType(IntegerType,false),false)) contains unsupported types) +query ignore("Need support Spark 4.0.0") +SELECT to_json(named_struct(a, array(b))) diff --git a/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala index 64c330dbdd..50a3c0c79f 100644 --- a/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometJsonExpressionSuite.scala @@ -60,7 +60,7 @@ class CometJsonExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe filename, 100, SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false), - DataGenOptions(generateNaN = false, generateInfinity = false)) + DataGenOptions(generateNaN = true, generateInfinity = true)) } val table = spark.read.parquet(filename) val fieldsNames = table.schema.fields @@ -72,6 +72,20 @@ class CometJsonExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelpe } } + test("to_json - fallback reasons") { + assume(!isSpark40Plus) + withTable("t") { + sql("CREATE TABLE t(a INT, b STRING) USING parquet") + sql("INSERT INTO t VALUES (1, 'hello')") + checkSparkAnswerAndFallbackReason( + "SELECT to_json(named_struct('a', a, 'b', b), map('timestampFormat', 'dd/MM/yyyy')) FROM t", + "StructsToJson with options is not supported") + checkSparkAnswerAndFallbackReason( + "SELECT to_json(named_struct('b', array(b))) FROM t", + "Struct type: StructType(StructField(b,ArrayType(StringType,true),false)) contains unsupported types") + } + } + test("from_json - basic primitives") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable(