diff --git a/native/Cargo.lock b/native/Cargo.lock index 11e9b1ccff..9cd5f15878 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2140,6 +2140,7 @@ dependencies = [ "num", "rand 0.10.1", "regex", + "serde", "serde_json", "tokio", "twox-hash", diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 1b0359059c..800fe3ecb1 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -35,6 +35,7 @@ num = { workspace = true } regex = { workspace = true } # preserve_order: needed for get_json_object to match Spark's JSON key ordering serde_json = { version = "1.0", features = ["preserve_order"] } +serde = { version = "1.0", features = ["derive"] } datafusion-comet-common = { workspace = true } datafusion-comet-jni-bridge = { workspace = true } jni = "0.22.4" diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 7108105dcb..13b0589e71 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -16,6 +16,7 @@ // under the License. use crate::hash_funcs::*; +use crate::json_funcs::JsonArrayLength; use crate::map_funcs::spark_map_sort; use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; @@ -221,6 +222,7 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkMakeTime::default())), Arc::new(ScalarUDF::new_from_impl(SparkSecondsToTimestamp::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), + Arc::new(ScalarUDF::new_from_impl(JsonArrayLength::default())), ] } diff --git a/native/spark-expr/src/json_funcs/json_array_length.rs b/native/spark-expr/src/json_funcs/json_array_length.rs new file mode 100644 index 0000000000..eef879da3c --- /dev/null +++ b/native/spark-expr/src/json_funcs/json_array_length.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, Int32Builder, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion::common::cast::as_generic_string_array; +use datafusion::common::{exec_err, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +use std::any::Any; + +use serde::de::{IgnoredAny, SeqAccess, Visitor}; +use serde::Deserializer; +use std::fmt; +use std::sync::Arc; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct JsonArrayLength { + signature: Signature, +} + +impl Default for JsonArrayLength { + fn default() -> Self { + Self::new() + } +} + +impl JsonArrayLength { + pub fn new() -> Self { + Self { + signature: Signature::variadic( + vec![DataType::Utf8, DataType::LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for JsonArrayLength { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "json_array_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_json_array_length(&args.args) + } +} + +fn spark_json_array_length(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("json_array_length function takes exactly one argument"); + } + match &args[0] { + ColumnarValue::Array(array) => { + let result = spark_json_array_length_array(array)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar) => { + let result = spark_json_array_length_scalar(scalar)?; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +fn spark_json_array_length_array(array: &ArrayRef) -> Result { + match array.data_type() { + DataType::Utf8 => spark_json_array_length_array_inner::(array), + DataType::LargeUtf8 => spark_json_array_length_array_inner::(array), + other => { + exec_err!("Unsupported data type {other:?} for function `json_array_length`") + } + } +} + +fn spark_json_array_length_scalar(scalar: &ScalarValue) -> Result { + match scalar { + ScalarValue::Utf8(value) => spark_json_array_length_scalar_inner(value), + ScalarValue::LargeUtf8(value) => spark_json_array_length_scalar_inner(value), + other => { + exec_err!("Unsupported data type {other:?} for function `json_array_length`") + } + } +} + +fn spark_json_array_length_scalar_inner(json_str: &Option) -> Result { + let array_length = json_str + .clone() + .and_then(|json_str| get_json_array_length(&json_str)); + Ok(ScalarValue::Int32(array_length)) +} + +fn spark_json_array_length_array_inner(array: &ArrayRef) -> Result { + let str_array = as_generic_string_array::(array)?; + let mut builder = Int32Builder::with_capacity(str_array.len()); + for row_idx in 0..str_array.len() { + if str_array.is_null(row_idx) { + builder.append_null(); + } else { + let json_str = str_array.value(row_idx); + if let Some(json_array_length) = get_json_array_length(json_str) { + builder.append_value(json_array_length); + } else { + builder.append_null() + } + } + } + Ok(Arc::new(builder.finish())) +} + +struct ArrayItemCounter; + +impl<'de> Visitor<'de> for ArrayItemCounter { + type Value = i32; + + fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("a JSON array") + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut len = 0i32; + while seq.next_element::()?.is_some() { + len += 1; + } + Ok(len) + } +} + +fn get_json_array_length(json: &str) -> Option { + let mut deserializer = serde_json::Deserializer::from_str(json); + deserializer.deserialize_seq(ArrayItemCounter).ok() +} diff --git a/native/spark-expr/src/json_funcs/mod.rs b/native/spark-expr/src/json_funcs/mod.rs index 9f025070d7..59e1e5dd58 100644 --- a/native/spark-expr/src/json_funcs/mod.rs +++ b/native/spark-expr/src/json_funcs/mod.rs @@ -16,7 +16,9 @@ // under the License. mod from_json; +mod json_array_length; mod to_json; pub use from_json::FromJson; +pub use json_array_length::JsonArrayLength; pub use to_json::ToJson; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b818b61b1b..a842996961 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -266,6 +266,9 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { private val conversionExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Cast] -> CometCast) + private val jsonExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + classOf[LengthOfJsonArray] -> CometLengthOfJsonArray) + private[comet] val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( // TODO PromotePrecision classOf[Alias] -> CometAlias, @@ -291,7 +294,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { mathExpressions ++ hashExpressions ++ stringExpressions ++ conditionalExpressions ++ mapExpressions ++ predicateExpressions ++ structExpressions ++ bitwiseExpressions ++ miscExpressions ++ arrayExpressions ++ - temporalExpressions ++ conversionExpressions ++ urlExpressions + temporalExpressions ++ conversionExpressions ++ urlExpressions ++ jsonExpressions /** * Mapping of Spark aggregate expression class to Comet expression handler. diff --git a/spark/src/main/scala/org/apache/comet/serde/json.scala b/spark/src/main/scala/org/apache/comet/serde/json.scala new file mode 100644 index 0000000000..5f296599d6 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/json.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.LengthOfJsonArray + +object CometLengthOfJsonArray + extends CometScalarFunction[LengthOfJsonArray]("json_array_length") { + + private val IncompatibleReason: String = + "Spark's lenient JSON parser allows single quotes, unescaped controls, " + + "and trailing content, " + + "while Comet's serde_json requires strict JSON." + + override def getIncompatibleReasons(): Seq[String] = Seq(IncompatibleReason) + + override def getSupportLevel(expr: LengthOfJsonArray): SupportLevel = Incompatible( + Some(IncompatibleReason)) +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 868b09de9d..6a72fd04a7 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -160,6 +160,20 @@ trait CometExprShim extends CommonStringExprs { case _ => None } + case s: StaticInvoke => + (s.staticObject, s.functionName, s.arguments) match { + case (cls, "lengthOfJsonArray", Seq(child)) if cls == classOf[JsonExpressionUtils] => + val lengthOfJsonArray = LengthOfJsonArray(child) + val exprProto = exprToProtoInternal(lengthOfJsonArray, inputs, binding) + if (exprProto.isEmpty) { + lengthOfJsonArray + .getTagValue(CometExplainInfo.EXTENSION_INFO) + .foreach(reasons => s.setTagValue(CometExplainInfo.EXTENSION_INFO, reasons)) + } + exprProto + case _ => None + } + case ms: MapSort => val keyType = ms.dataType.asInstanceOf[MapType].keyType if (!supportedScalarSortElementType(keyType)) { diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala index 676cb468b4..e8cfda5fa2 100644 --- a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala @@ -21,7 +21,7 @@ package org.apache.comet.shims import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum -import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator +import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -191,6 +191,20 @@ trait CometExprShim extends CommonStringExprs { case _ => None } + case s: StaticInvoke => + (s.staticObject, s.functionName, s.arguments) match { + case (cls, "lengthOfJsonArray", Seq(child)) if cls == classOf[JsonExpressionUtils] => + val lengthOfJsonArray = LengthOfJsonArray(child) + val exprProto = exprToProtoInternal(lengthOfJsonArray, inputs, binding) + if (exprProto.isEmpty) { + lengthOfJsonArray + .getTagValue(CometExplainInfo.EXTENSION_INFO) + .foreach(reasons => s.setTagValue(CometExplainInfo.EXTENSION_INFO, reasons)) + } + exprProto + case _ => None + } + case ms: MapSort => val keyType = ms.dataType.asInstanceOf[MapType].keyType if (!supportedScalarSortElementType(keyType)) { diff --git a/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala index 676cb468b4..3deed1b0d7 100644 --- a/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala @@ -191,6 +191,20 @@ trait CometExprShim extends CommonStringExprs { case _ => None } + case s: StaticInvoke => + (s.staticObject, s.functionName, s.arguments) match { + case (cls, "lengthOfJsonArray", Seq(child)) if cls == classOf[JsonExpressionUtils] => + val lengthOfJsonArray = LengthOfJsonArray(child) + val exprProto = exprToProtoInternal(lengthOfJsonArray, inputs, binding) + if (exprProto.isEmpty) { + lengthOfJsonArray + .getTagValue(CometExplainInfo.EXTENSION_INFO) + .foreach(reasons => s.setTagValue(CometExplainInfo.EXTENSION_INFO, reasons)) + } + exprProto + case _ => None + } + case ms: MapSort => val keyType = ms.dataType.asInstanceOf[MapType].keyType if (!supportedScalarSortElementType(keyType)) { diff --git a/spark/src/test/resources/sql-tests/expressions/json/json_array_length.sql b/spark/src/test/resources/sql-tests/expressions/json/json_array_length.sql new file mode 100644 index 0000000000..9b7b332e50 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/json/json_array_length.sql @@ -0,0 +1,64 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +statement +CREATE TABLE test_json_array_length(j string) USING parquet + +statement +INSERT INTO test_json_array_length VALUES + ('[1,2,3,4]'), + ('[]'), + ('[1]'), + (NULL), + ('[1,2,3,{"f1":1,"f2":[5,6]},4]'), + ('[[1,2],[3,4],[5,6]]'), + ('[{"a":1},{"b":2},{"c":3}]'), + ('[1,2'), + ('[1,2,3,]'), + ('not a json'), + ('{"object": "not array"}'), + (''), + (' '), + ('[true, false, null]'), + ('["string1", "string2", "string3"]'), + ('[1, "mixed", true, null, {"key":"value"}]'), + ('[1,2,3,4,5,6,7,8,9,10]'), + ('["line1\nline2", "tab\tseparated", "quote\"here"]'), + ('{"outer": [1,2,3], "inner": [[1,2],[3,4]]}'), + ('{"arrays": {"first": [1,2], "second": [3,4,5]}}'), + ('[{"arr": [1,2,3]}, {"arr": [4,5]}]') + +query spark_answer_only +SELECT json_array_length(j) FROM test_json_array_length + +query spark_answer_only +SELECT json_array_length('[1,2,3,4]') + +query spark_answer_only +SELECT json_array_length('not an array') + +query spark_answer_only +SELECT json_array_length('{"key":"value"}') + +query spark_answer_only +SELECT json_array_length(NULL) + +query spark_answer_only +SELECT json_array_length('[]') + +query expect_fallback(Spark's lenient JSON parser allows single quotes, unescaped controls, and trailing content, while Comet's serde_json requires strict JSON.) +SELECT json_array_length("[{'key':'value'}]") diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometLengthOfJsonArrayBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometLengthOfJsonArrayBenchmark.scala new file mode 100644 index 0000000000..8c09ce01cf --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometLengthOfJsonArrayBenchmark.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.sql.catalyst.expressions.LengthOfJsonArray + +import org.apache.comet.CometConf + +/** + * Benchmark to measure performance of Comet json_array_length expression. To run this benchmark: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometLengthOfJsonArrayBenchmark + * }}} + * Results will be written to "spark/benchmarks/CometLengthOfJsonArray-**results.txt". + */ +object CometLengthOfJsonArrayBenchmark extends CometBenchmarkBase { + + override def runCometBenchmark(args: Array[String]): Unit = { + val numRows = 1024 * 1024 + runBenchmarkWithTable("json_array_length", numRows) { v => + withTempPath { dir => + withTempTable("parquetV1Table") { + import spark.implicits._ + prepareTable( + dir, + spark + .range(numRows) + .map { i => + val arrayLength = (i % 100).toInt + (0 until arrayLength) + .map(j => s""""item_${i}_$j"""") + .mkString("[", ",", "]") + } + .toDF("c1")) + + val extraConfigs = + Map(CometConf.getExprAllowIncompatConfigKey(classOf[LengthOfJsonArray]) -> "true") + + val benchmarks = List( + StringExprConfig( + "get json array length", + "select json_array_length(c1) from parquetV1Table", + extraConfigs)) + + benchmarks.foreach { config => + runBenchmark(config.name) { + runExpressionBenchmark(config.name, v, config.query, config.extraCometConfigs) + } + } + } + } + } + } +}