diff --git a/datafusion/spark/src/function/array/array_contains.rs b/datafusion/spark/src/function/array/array_contains.rs new file mode 100644 index 0000000000000..2bc5d64d8bff8 --- /dev/null +++ b/datafusion/spark/src/function/array/array_contains.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::datatypes::DataType; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_nested::array_has::array_has_udf; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `array_contains` function. +/// +/// Calls DataFusion's `array_has` and then applies Spark's null semantics: +/// - If the result from `array_has` is `true`, return `true`. +/// - If the result is `false` and the input array row contains any null elements, +/// return `null` (because the element might have been the null). +/// - If the result is `false` and the input array row has no null elements, +/// return `false`. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkArrayContains { + signature: Signature, +} + +impl Default for SparkArrayContains { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayContains { + pub fn new() -> Self { + Self { + signature: Signature::array_and_element(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayContains { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let haystack = args.args[0].clone(); + let array_has_result = array_has_udf().invoke_with_args(args)?; + + let result_array = array_has_result.to_array(1)?; + let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?; + Ok(ColumnarValue::Array(Arc::new(patched))) + } +} + +/// For each row where `array_has` returned `false`, set the output to null +/// if that row's input array contains any null elements. +fn apply_spark_null_semantics( + result: &BooleanArray, + haystack_arg: &ColumnarValue, +) -> Result { + // happy path + if result.false_count() == 0 || haystack_arg.data_type() == DataType::Null { + return Ok(result.clone()); + } + + let haystack = haystack_arg.to_array_of_size(result.len())?; + + let row_has_nulls = compute_row_has_nulls(&haystack)?; + + // A row keeps its validity when result is true OR the row has no nulls. + let keep_mask = result.values() | &!&row_has_nulls; + let new_validity = match result.nulls() { + Some(n) => n.inner() & &keep_mask, + None => keep_mask, + }; + + Ok(BooleanArray::new( + result.values().clone(), + Some(NullBuffer::new(new_validity)), + )) +} + +/// Returns a per-row bitmap where bit i is set if row i's list contains any null element. +fn compute_row_has_nulls(haystack: &dyn Array) -> Result { + match haystack.data_type() { + DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::()), + DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::()), + DataType::FixedSizeList(_, _) => { + let list = haystack.as_fixed_size_list(); + let buf = match list.values().nulls() { + Some(nulls) => { + let validity = nulls.inner(); + let vl = list.value_length() as usize; + let mut builder = BooleanBufferBuilder::new(list.len()); + for i in 0..list.len() { + builder.append(validity.slice(i * vl, vl).count_set_bits() < vl); + } + builder.finish() + } + None => BooleanBuffer::new_unset(list.len()), + }; + Ok(mask_with_list_nulls(buf, list.nulls())) + } + dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"), + } +} + +/// Computes per-row null presence for `List` and `LargeList` arrays. +fn generic_list_row_has_nulls( + list: &GenericListArray, +) -> Result { + let buf = match list.values().nulls() { + Some(nulls) => { + let validity = nulls.inner(); + let offsets = list.offsets(); + let mut builder = BooleanBufferBuilder::new(list.len()); + for i in 0..list.len() { + let s = offsets[i].as_usize(); + let len = offsets[i + 1].as_usize() - s; + builder.append(validity.slice(s, len).count_set_bits() < len); + } + builder.finish() + } + None => BooleanBuffer::new_unset(list.len()), + }; + Ok(mask_with_list_nulls(buf, list.nulls())) +} + +/// Rows where the list itself is null should not be marked as "has nulls". +fn mask_with_list_nulls( + buf: BooleanBuffer, + list_nulls: Option<&NullBuffer>, +) -> BooleanBuffer { + match list_nulls { + Some(n) => &buf & n.inner(), + None => buf, + } +} diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs index 0d4cd40d99329..6c16e05361641 100644 --- a/datafusion/spark/src/function/array/mod.rs +++ b/datafusion/spark/src/function/array/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod array_contains; pub mod repeat; pub mod shuffle; pub mod slice; @@ -24,6 +25,7 @@ use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(array_contains::SparkArrayContains, spark_array_contains); make_udf_function!(spark_array::SparkArray, array); make_udf_function!(shuffle::SparkShuffle, shuffle); make_udf_function!(repeat::SparkArrayRepeat, array_repeat); @@ -32,6 +34,11 @@ make_udf_function!(slice::SparkSlice, slice); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!(( + spark_array_contains, + "Returns true if the array contains the element (Spark semantics).", + array element + )); export_functions!((array, "Returns an array with the given elements.", args)); export_functions!(( shuffle, @@ -51,5 +58,11 @@ pub mod expr_fn { } pub fn functions() -> Vec> { - vec![array(), shuffle(), array_repeat(), slice()] + vec![ + spark_array_contains(), + array(), + shuffle(), + array_repeat(), + slice(), + ] } diff --git a/datafusion/sqllogictest/test_files/spark/array/array_contains.slt b/datafusion/sqllogictest/test_files/spark/array/array_contains.slt new file mode 100644 index 0000000000000..db9ac6b122e3f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/array_contains.slt @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for Spark-compatible array_contains function. +# Spark semantics: if element is found -> true; if not found and array has nulls -> null; if not found and no nulls -> false. + +### +### Scalar tests +### + +# Element found in array +query B +SELECT array_contains(array(1, 2, 3), 2); +---- +true + +# Element not found, no nulls in array +query B +SELECT array_contains(array(1, 2, 3), 4); +---- +false + +# Element not found, array has null elements -> null +query B +SELECT array_contains(array(1, NULL, 3), 2); +---- +NULL + +# Element found, array has null elements -> true (nulls don't matter) +query B +SELECT array_contains(array(1, NULL, 3), 1); +---- +true + +# Element found at the end, array has null elements -> true +query B +SELECT array_contains(array(1, NULL, 3), 3); +---- +true + +# Null array -> null +query B +SELECT array_contains(NULL, 1); +---- +NULL + +# Null element -> null +query B +SELECT array_contains(array(1, 2, 3), NULL); +---- +NULL + +# Empty array, element not found -> false +query B +SELECT array_contains(array(), 1); +---- +false + +# Array with only nulls, element not found -> null +query B +SELECT array_contains(array(NULL, NULL), 1); +---- +NULL + +# String array, element found +query B +SELECT array_contains(array('a', 'b', 'c'), 'b'); +---- +true + +# String array, element not found, no nulls +query B +SELECT array_contains(array('a', 'b', 'c'), 'd'); +---- +false + +# String array, element not found, has null +query B +SELECT array_contains(array('a', NULL, 'c'), 'd'); +---- +NULL + +### +### Columnar tests with a table +### + +statement ok +CREATE TABLE test_arrays AS VALUES + (1, make_array(1, 2, 3), 10), + (2, make_array(4, NULL, 6), 5), + (3, make_array(7, 8, 9), 10), + (4, NULL, 1), + (5, make_array(10, NULL, NULL), 10); + +# Column needle against column array +query IBB +SELECT column1, + array_contains(column2, column3), + array_contains(column2, 10) +FROM test_arrays +ORDER BY column1; +---- +1 false false +2 NULL NULL +3 false false +4 NULL NULL +5 true true + +statement ok +DROP TABLE test_arrays; + +### +### Nested array tests +### + +# Nested array element found +query B +SELECT array_contains(array(array(1, 2), array(3, 4)), array(3, 4)); +---- +true + +# Nested array element not found, no nulls +query B +SELECT array_contains(array(array(1, 2), array(3, 4)), array(5, 6)); +---- +false