-
Notifications
You must be signed in to change notification settings - Fork 323
feat: support Spark expression json_array_length #4365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
768b3e9
c68c342
d887555
231aa90
9500bbb
9577481
3791557
7c2f082
609a605
a151b2c
ad3e7f5
ea92e4b
8dfeca3
559741e
ebda14e
408152e
d7857b2
aef41be
5ac1c58
9ae8e23
5ca3888
160a817
88fc313
e14c180
610a885
f8acb2c
ec94897
43405e4
47b4915
26e2682
6cb5f07
ec194fb
256fccb
912c8f9
561a664
d926ef4
671412c
c9f52d1
67f72d9
314e594
ac8292f
c9c140e
decca58
0919b33
21a5771
7495e21
57076f4
0dfa19c
0a37a60
060bf07
abbba84
e65284f
678e417
88b5d71
690f79d
d4936a9
aaf5509
09dc0d4
2c5e2cf
a0fad37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use arrow::array::{Array, ArrayRef, Int32Builder, OffsetSizeTrait}; | ||
| use arrow::datatypes::DataType; | ||
| use datafusion::common::cast::as_generic_string_array; | ||
| use datafusion::common::{exec_err, Result, ScalarValue}; | ||
| use datafusion::logical_expr::{ | ||
| ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, | ||
| }; | ||
|
|
||
| use std::any::Any; | ||
|
|
||
| use serde::de::{IgnoredAny, SeqAccess, Visitor}; | ||
| use serde::Deserializer; | ||
| use std::fmt; | ||
| use std::sync::Arc; | ||
|
|
||
| #[derive(Debug, PartialEq, Eq, Hash)] | ||
| pub struct JsonArrayLength { | ||
| signature: Signature, | ||
| } | ||
|
|
||
| impl Default for JsonArrayLength { | ||
| fn default() -> Self { | ||
| Self::new() | ||
| } | ||
| } | ||
|
|
||
| impl JsonArrayLength { | ||
| pub fn new() -> Self { | ||
| Self { | ||
| signature: Signature::variadic( | ||
| vec![DataType::Utf8, DataType::LargeUtf8], | ||
| Volatility::Immutable, | ||
| ), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl ScalarUDFImpl for JsonArrayLength { | ||
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
| } | ||
|
|
||
| fn name(&self) -> &str { | ||
| "json_array_length" | ||
| } | ||
|
|
||
| fn signature(&self) -> &Signature { | ||
| &self.signature | ||
| } | ||
|
|
||
| fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
| Ok(DataType::Int32) | ||
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| spark_json_array_length(&args.args) | ||
| } | ||
| } | ||
|
|
||
| fn spark_json_array_length(args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
| if args.len() != 1 { | ||
| return exec_err!("json_array_length function takes exactly one argument"); | ||
| } | ||
| match &args[0] { | ||
| ColumnarValue::Array(array) => { | ||
| let result = spark_json_array_length_array(array)?; | ||
| Ok(ColumnarValue::Array(result)) | ||
| } | ||
| ColumnarValue::Scalar(scalar) => { | ||
| let result = spark_json_array_length_scalar(scalar)?; | ||
| Ok(ColumnarValue::Scalar(result)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn spark_json_array_length_array(array: &ArrayRef) -> Result<ArrayRef> { | ||
| match array.data_type() { | ||
| DataType::Utf8 => spark_json_array_length_array_inner::<i32>(array), | ||
| DataType::LargeUtf8 => spark_json_array_length_array_inner::<i64>(array), | ||
| other => { | ||
| exec_err!("Unsupported data type {other:?} for function `json_array_length`") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn spark_json_array_length_scalar(scalar: &ScalarValue) -> Result<ScalarValue> { | ||
| match scalar { | ||
| ScalarValue::Utf8(value) => spark_json_array_length_scalar_inner(value), | ||
| ScalarValue::LargeUtf8(value) => spark_json_array_length_scalar_inner(value), | ||
| other => { | ||
| exec_err!("Unsupported data type {other:?} for function `json_array_length`") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn spark_json_array_length_scalar_inner(json_str: &Option<String>) -> Result<ScalarValue> { | ||
| let array_length = json_str | ||
| .clone() | ||
| .and_then(|json_str| get_json_array_length(&json_str)); | ||
| Ok(ScalarValue::Int32(array_length)) | ||
| } | ||
|
|
||
| fn spark_json_array_length_array_inner<T: OffsetSizeTrait>(array: &ArrayRef) -> Result<ArrayRef> { | ||
| let str_array = as_generic_string_array::<T>(array)?; | ||
| let mut builder = Int32Builder::with_capacity(str_array.len()); | ||
| for row_idx in 0..str_array.len() { | ||
| if str_array.is_null(row_idx) { | ||
| builder.append_null(); | ||
| } else { | ||
| let json_str = str_array.value(row_idx); | ||
| if let Some(json_array_length) = get_json_array_length(json_str) { | ||
| builder.append_value(json_array_length); | ||
| } else { | ||
| builder.append_null() | ||
| } | ||
| } | ||
| } | ||
| Ok(Arc::new(builder.finish())) | ||
| } | ||
|
|
||
| struct ArrayItemCounter; | ||
|
|
||
| impl<'de> Visitor<'de> for ArrayItemCounter { | ||
| type Value = i32; | ||
|
|
||
| fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
| f.write_str("a JSON array") | ||
| } | ||
|
|
||
| fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> { | ||
| let mut len = 0i32; | ||
| while seq.next_element::<IgnoredAny>()?.is_some() { | ||
| len += 1; | ||
| } | ||
| Ok(len) | ||
| } | ||
| } | ||
|
|
||
| fn get_json_array_length(json: &str) -> Option<i32> { | ||
| let mut deserializer = serde_json::Deserializer::from_str(json); | ||
| deserializer.deserialize_seq(ArrayItemCounter).ok() | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| package org.apache.comet.serde | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.LengthOfJsonArray | ||
|
|
||
| object CometLengthOfJsonArray | ||
| extends CometScalarFunction[LengthOfJsonArray]("json_array_length") { | ||
|
|
||
| private val IncompatibleReason: String = | ||
| "Spark's lenient JSON parser allows single quotes, unescaped controls, " + | ||
| "and trailing content, " + | ||
| "while Comet's serde_json requires strict JSON." | ||
|
|
||
| override def getIncompatibleReasons(): Seq[String] = Seq(IncompatibleReason) | ||
|
|
||
| override def getSupportLevel(expr: LengthOfJsonArray): SupportLevel = Incompatible( | ||
| Some(IncompatibleReason)) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| -- Licensed to the Apache Software Foundation (ASF) under one | ||
| -- or more contributor license agreements. See the NOTICE file | ||
| -- distributed with this work for additional information | ||
| -- regarding copyright ownership. The ASF licenses this file | ||
| -- to you under the Apache License, Version 2.0 (the | ||
| -- "License"); you may not use this file except in compliance | ||
| -- with the License. You may obtain a copy of the License at | ||
| -- | ||
| -- http://www.apache.org/licenses/LICENSE-2.0 | ||
| -- | ||
| -- Unless required by applicable law or agreed to in writing, | ||
| -- software distributed under the License is distributed on an | ||
| -- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| -- KIND, either express or implied. See the License for the | ||
| -- specific language governing permissions and limitations | ||
| -- under the License. | ||
|
|
||
| statement | ||
| CREATE TABLE test_json_array_length(j string) USING parquet | ||
|
|
||
| statement | ||
| INSERT INTO test_json_array_length VALUES | ||
| ('[1,2,3,4]'), | ||
| ('[]'), | ||
| ('[1]'), | ||
| (NULL), | ||
| ('[1,2,3,{"f1":1,"f2":[5,6]},4]'), | ||
| ('[[1,2],[3,4],[5,6]]'), | ||
| ('[{"a":1},{"b":2},{"c":3}]'), | ||
| ('[1,2'), | ||
| ('[1,2,3,]'), | ||
| ('not a json'), | ||
| ('{"object": "not array"}'), | ||
| (''), | ||
| (' '), | ||
| ('[true, false, null]'), | ||
| ('["string1", "string2", "string3"]'), | ||
| ('[1, "mixed", true, null, {"key":"value"}]'), | ||
| ('[1,2,3,4,5,6,7,8,9,10]'), | ||
| ('["line1\nline2", "tab\tseparated", "quote\"here"]'), | ||
| ('{"outer": [1,2,3], "inner": [[1,2],[3,4]]}'), | ||
| ('{"arrays": {"first": [1,2], "second": [3,4,5]}}'), | ||
| ('[{"arr": [1,2,3]}, {"arr": [4,5]}]') | ||
|
|
||
| query spark_answer_only | ||
| SELECT json_array_length(j) FROM test_json_array_length | ||
|
|
||
| query spark_answer_only | ||
| SELECT json_array_length('[1,2,3,4]') | ||
|
|
||
| query spark_answer_only | ||
| SELECT json_array_length('not an array') | ||
|
|
||
| query spark_answer_only | ||
| SELECT json_array_length('{"key":"value"}') | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you also add examples for incompatible behavior, such as using single quotes around keys and values
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
|
|
||
| query spark_answer_only | ||
| SELECT json_array_length(NULL) | ||
|
|
||
| query spark_answer_only | ||
| SELECT json_array_length('[]') | ||
|
|
||
| query expect_fallback(Spark's lenient JSON parser allows single quotes, unescaped controls, and trailing content, while Comet's serde_json requires strict JSON.) | ||
| SELECT json_array_length("[{'key':'value'}]") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍