-
Notifications
You must be signed in to change notification settings - Fork 2k
feat: Support Spark array_contains builtin function
#20685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+322
−1
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use arrow::array::{ | ||
| Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, | ||
| }; | ||
| use arrow::buffer::{BooleanBuffer, NullBuffer}; | ||
| use arrow::datatypes::DataType; | ||
| use datafusion_common::{Result, exec_err}; | ||
| use datafusion_expr::{ | ||
| ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, | ||
| }; | ||
| use datafusion_functions_nested::array_has::array_has_udf; | ||
| use std::any::Any; | ||
| use std::sync::Arc; | ||
|
|
||
| /// Spark-compatible `array_contains` function. | ||
| /// | ||
| /// Calls DataFusion's `array_has` and then applies Spark's null semantics: | ||
| /// - If the result from `array_has` is `true`, return `true`. | ||
| /// - If the result is `false` and the input array row contains any null elements, | ||
| /// return `null` (because the element might have been the null). | ||
| /// - If the result is `false` and the input array row has no null elements, | ||
| /// return `false`. | ||
| #[derive(Debug, PartialEq, Eq, Hash)] | ||
| pub struct SparkArrayContains { | ||
| signature: Signature, | ||
| } | ||
|
|
||
| impl Default for SparkArrayContains { | ||
| fn default() -> Self { | ||
| Self::new() | ||
| } | ||
| } | ||
|
|
||
| impl SparkArrayContains { | ||
| pub fn new() -> Self { | ||
| Self { | ||
| signature: Signature::array_and_element(Volatility::Immutable), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| impl ScalarUDFImpl for SparkArrayContains { | ||
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
| } | ||
|
|
||
| fn name(&self) -> &str { | ||
| "array_contains" | ||
| } | ||
|
|
||
| fn signature(&self) -> &Signature { | ||
| &self.signature | ||
| } | ||
|
|
||
| fn return_type(&self, _: &[DataType]) -> Result<DataType> { | ||
| Ok(DataType::Boolean) | ||
| } | ||
|
|
||
| fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
| let haystack = args.args[0].clone(); | ||
| let array_has_result = array_has_udf().invoke_with_args(args)?; | ||
|
|
||
| let result_array = array_has_result.to_array(1)?; | ||
| let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?; | ||
| Ok(ColumnarValue::Array(Arc::new(patched))) | ||
| } | ||
| } | ||
|
|
||
| /// For each row where `array_has` returned `false`, set the output to null | ||
| /// if that row's input array contains any null elements. | ||
| fn apply_spark_null_semantics( | ||
| result: &BooleanArray, | ||
| haystack_arg: &ColumnarValue, | ||
| ) -> Result<BooleanArray> { | ||
| // happy path | ||
| if result.false_count() == 0 || haystack_arg.data_type() == DataType::Null { | ||
| return Ok(result.clone()); | ||
| } | ||
|
|
||
| let haystack = haystack_arg.to_array_of_size(result.len())?; | ||
|
|
||
| let row_has_nulls = compute_row_has_nulls(&haystack)?; | ||
|
|
||
| // A row keeps its validity when result is true OR the row has no nulls. | ||
| let keep_mask = result.values() | &!&row_has_nulls; | ||
| let new_validity = match result.nulls() { | ||
| Some(n) => n.inner() & &keep_mask, | ||
| None => keep_mask, | ||
| }; | ||
|
|
||
| Ok(BooleanArray::new( | ||
| result.values().clone(), | ||
| Some(NullBuffer::new(new_validity)), | ||
| )) | ||
| } | ||
|
|
||
| /// Returns a per-row bitmap where bit i is set if row i's list contains any null element. | ||
| fn compute_row_has_nulls(haystack: &dyn Array) -> Result<BooleanBuffer> { | ||
| match haystack.data_type() { | ||
| DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::<i32>()), | ||
| DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::<i64>()), | ||
| DataType::FixedSizeList(_, _) => { | ||
| let list = haystack.as_fixed_size_list(); | ||
| let buf = match list.values().nulls() { | ||
| Some(nulls) => { | ||
| let validity = nulls.inner(); | ||
| let vl = list.value_length() as usize; | ||
| let mut builder = BooleanBufferBuilder::new(list.len()); | ||
| for i in 0..list.len() { | ||
| builder.append(validity.slice(i * vl, vl).count_set_bits() < vl); | ||
| } | ||
| builder.finish() | ||
| } | ||
| None => BooleanBuffer::new_unset(list.len()), | ||
| }; | ||
| Ok(mask_with_list_nulls(buf, list.nulls())) | ||
| } | ||
| dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"), | ||
| } | ||
| } | ||
|
|
||
| /// Computes per-row null presence for `List` and `LargeList` arrays. | ||
| fn generic_list_row_has_nulls<O: OffsetSizeTrait>( | ||
| list: &GenericListArray<O>, | ||
| ) -> Result<BooleanBuffer> { | ||
| let buf = match list.values().nulls() { | ||
| Some(nulls) => { | ||
| let validity = nulls.inner(); | ||
| let offsets = list.offsets(); | ||
| let mut builder = BooleanBufferBuilder::new(list.len()); | ||
| for i in 0..list.len() { | ||
| let s = offsets[i].as_usize(); | ||
| let len = offsets[i + 1].as_usize() - s; | ||
| builder.append(validity.slice(s, len).count_set_bits() < len); | ||
| } | ||
| builder.finish() | ||
| } | ||
| None => BooleanBuffer::new_unset(list.len()), | ||
| }; | ||
| Ok(mask_with_list_nulls(buf, list.nulls())) | ||
| } | ||
|
|
||
| /// Rows where the list itself is null should not be marked as "has nulls". | ||
| fn mask_with_list_nulls( | ||
| buf: BooleanBuffer, | ||
| list_nulls: Option<&NullBuffer>, | ||
| ) -> BooleanBuffer { | ||
| match list_nulls { | ||
| Some(n) => &buf & n.inner(), | ||
| None => buf, | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
datafusion/sqllogictest/test_files/spark/array/array_contains.slt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| # Tests for Spark-compatible array_contains function. | ||
| # Spark semantics: if element is found -> true; if not found and array has nulls -> null; if not found and no nulls -> false. | ||
|
|
||
| ### | ||
| ### Scalar tests | ||
| ### | ||
|
|
||
| # Element found in array | ||
| query B | ||
| SELECT array_contains(array(1, 2, 3), 2); | ||
| ---- | ||
| true | ||
|
|
||
| # Element not found, no nulls in array | ||
| query B | ||
| SELECT array_contains(array(1, 2, 3), 4); | ||
| ---- | ||
| false | ||
|
|
||
| # Element not found, array has null elements -> null | ||
| query B | ||
| SELECT array_contains(array(1, NULL, 3), 2); | ||
| ---- | ||
| NULL | ||
|
|
||
| # Element found, array has null elements -> true (nulls don't matter) | ||
| query B | ||
| SELECT array_contains(array(1, NULL, 3), 1); | ||
| ---- | ||
| true | ||
|
|
||
| # Element found at the end, array has null elements -> true | ||
| query B | ||
| SELECT array_contains(array(1, NULL, 3), 3); | ||
| ---- | ||
| true | ||
|
|
||
| # Null array -> null | ||
| query B | ||
| SELECT array_contains(NULL, 1); | ||
| ---- | ||
| NULL | ||
|
|
||
| # Null element -> null | ||
| query B | ||
| SELECT array_contains(array(1, 2, 3), NULL); | ||
| ---- | ||
| NULL | ||
|
|
||
| # Empty array, element not found -> false | ||
| query B | ||
| SELECT array_contains(array(), 1); | ||
| ---- | ||
| false | ||
|
|
||
| # Array with only nulls, element not found -> null | ||
| query B | ||
| SELECT array_contains(array(NULL, NULL), 1); | ||
| ---- | ||
| NULL | ||
|
|
||
| # String array, element found | ||
| query B | ||
| SELECT array_contains(array('a', 'b', 'c'), 'b'); | ||
| ---- | ||
| true | ||
|
|
||
| # String array, element not found, no nulls | ||
| query B | ||
| SELECT array_contains(array('a', 'b', 'c'), 'd'); | ||
| ---- | ||
| false | ||
|
|
||
| # String array, element not found, has null | ||
| query B | ||
| SELECT array_contains(array('a', NULL, 'c'), 'd'); | ||
| ---- | ||
| NULL | ||
|
|
||
| ### | ||
| ### Columnar tests with a table | ||
| ### | ||
|
|
||
| statement ok | ||
| CREATE TABLE test_arrays AS VALUES | ||
| (1, make_array(1, 2, 3), 10), | ||
| (2, make_array(4, NULL, 6), 5), | ||
| (3, make_array(7, 8, 9), 10), | ||
| (4, NULL, 1), | ||
| (5, make_array(10, NULL, NULL), 10); | ||
|
|
||
| # Column needle against column array | ||
| query IBB | ||
| SELECT column1, | ||
| array_contains(column2, column3), | ||
| array_contains(column2, 10) | ||
| FROM test_arrays | ||
| ORDER BY column1; | ||
| ---- | ||
| 1 false false | ||
| 2 NULL NULL | ||
| 3 false false | ||
| 4 NULL NULL | ||
| 5 true true | ||
|
|
||
| statement ok | ||
| DROP TABLE test_arrays; | ||
|
|
||
| ### | ||
| ### Nested array tests | ||
| ### | ||
|
|
||
| # Nested array element found | ||
| query B | ||
| SELECT array_contains(array(array(1, 2), array(3, 4)), array(3, 4)); | ||
| ---- | ||
| true | ||
|
|
||
| # Nested array element not found, no nulls | ||
| query B | ||
| SELECT array_contains(array(array(1, 2), array(3, 4)), array(5, 6)); | ||
| ---- | ||
| false | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that what Spark does? It short circuits and returns true even if there are NULLs in the array? What if the match is after the NULL? Would it return NULL? I thought if any element of the array was NULL the outcome of the expression is NULL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think explanation in https://issues.apache.org/jira/browse/SPARK-55749 is very accurately explaining this behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on this data Spark run