Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions datafusion/spark/src/function/array/array_contains.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{
Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait,
};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::DataType;
use datafusion_common::{Result, exec_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::array_has::array_has_udf;
use std::any::Any;
use std::sync::Arc;

/// Spark-compatible `array_contains` function.
///
/// Calls DataFusion's `array_has` and then applies Spark's null semantics:
/// - If the result from `array_has` is `true`, return `true`.
/// - If the result is `false` and the input array row contains any null elements,
/// return `null` (because the element might have been the null).
/// - If the result is `false` and the input array row has no null elements,
/// return `false`.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkArrayContains {
signature: Signature,
}

impl Default for SparkArrayContains {
fn default() -> Self {
Self::new()
}
}

impl SparkArrayContains {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for SparkArrayContains {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"array_contains"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let haystack = args.args[0].clone();
let array_has_result = array_has_udf().invoke_with_args(args)?;

let result_array = array_has_result.to_array(1)?;
let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?;
Ok(ColumnarValue::Array(Arc::new(patched)))
}
}

/// For each row where `array_has` returned `false`, set the output to null
/// if that row's input array contains any null elements.
fn apply_spark_null_semantics(
result: &BooleanArray,
haystack_arg: &ColumnarValue,
) -> Result<BooleanArray> {
// happy path
if result.false_count() == 0 || haystack_arg.data_type() == DataType::Null {
return Ok(result.clone());
}

let haystack = haystack_arg.to_array_of_size(result.len())?;

let row_has_nulls = compute_row_has_nulls(&haystack)?;

// A row keeps its validity when result is true OR the row has no nulls.
let keep_mask = result.values() | &!&row_has_nulls;
let new_validity = match result.nulls() {
Some(n) => n.inner() & &keep_mask,
None => keep_mask,
};

Ok(BooleanArray::new(
result.values().clone(),
Some(NullBuffer::new(new_validity)),
))
}

/// Returns a per-row bitmap where bit i is set if row i's list contains any null element.
fn compute_row_has_nulls(haystack: &dyn Array) -> Result<BooleanBuffer> {
match haystack.data_type() {
DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::<i32>()),
DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::<i64>()),
DataType::FixedSizeList(_, _) => {
let list = haystack.as_fixed_size_list();
let buf = match list.values().nulls() {
Some(nulls) => {
let validity = nulls.inner();
let vl = list.value_length() as usize;
let mut builder = BooleanBufferBuilder::new(list.len());
for i in 0..list.len() {
builder.append(validity.slice(i * vl, vl).count_set_bits() < vl);
}
builder.finish()
}
None => BooleanBuffer::new_unset(list.len()),
};
Ok(mask_with_list_nulls(buf, list.nulls()))
}
dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"),
}
}

/// Computes per-row null presence for `List` and `LargeList` arrays.
fn generic_list_row_has_nulls<O: OffsetSizeTrait>(
list: &GenericListArray<O>,
) -> Result<BooleanBuffer> {
let buf = match list.values().nulls() {
Some(nulls) => {
let validity = nulls.inner();
let offsets = list.offsets();
let mut builder = BooleanBufferBuilder::new(list.len());
for i in 0..list.len() {
let s = offsets[i].as_usize();
let len = offsets[i + 1].as_usize() - s;
builder.append(validity.slice(s, len).count_set_bits() < len);
}
builder.finish()
}
None => BooleanBuffer::new_unset(list.len()),
};
Ok(mask_with_list_nulls(buf, list.nulls()))
}

/// Rows where the list itself is null should not be marked as "has nulls".
fn mask_with_list_nulls(
buf: BooleanBuffer,
list_nulls: Option<&NullBuffer>,
) -> BooleanBuffer {
match list_nulls {
Some(n) => &buf & n.inner(),
None => buf,
}
}
15 changes: 14 additions & 1 deletion datafusion/spark/src/function/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

pub mod array_contains;
pub mod repeat;
pub mod shuffle;
pub mod slice;
Expand All @@ -24,6 +25,7 @@ use datafusion_expr::ScalarUDF;
use datafusion_functions::make_udf_function;
use std::sync::Arc;

make_udf_function!(array_contains::SparkArrayContains, spark_array_contains);
make_udf_function!(spark_array::SparkArray, array);
make_udf_function!(shuffle::SparkShuffle, shuffle);
make_udf_function!(repeat::SparkArrayRepeat, array_repeat);
Expand All @@ -32,6 +34,11 @@ make_udf_function!(slice::SparkSlice, slice);
pub mod expr_fn {
use datafusion_functions::export_functions;

export_functions!((
spark_array_contains,
"Returns true if the array contains the element (Spark semantics).",
array element
));
export_functions!((array, "Returns an array with the given elements.", args));
export_functions!((
shuffle,
Expand All @@ -51,5 +58,11 @@ pub mod expr_fn {
}

pub fn functions() -> Vec<Arc<ScalarUDF>> {
vec![array(), shuffle(), array_repeat(), slice()]
vec![
spark_array_contains(),
array(),
shuffle(),
array_repeat(),
slice(),
]
}
140 changes: 140 additions & 0 deletions datafusion/sqllogictest/test_files/spark/array/array_contains.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# Tests for Spark-compatible array_contains function.
# Spark semantics: if element is found -> true; if not found and array has nulls -> null; if not found and no nulls -> false.

###
### Scalar tests
###

# Element found in array
query B
SELECT array_contains(array(1, 2, 3), 2);
----
true

# Element not found, no nulls in array
query B
SELECT array_contains(array(1, 2, 3), 4);
----
false

# Element not found, array has null elements -> null
query B
SELECT array_contains(array(1, NULL, 3), 2);
----
NULL

# Element found, array has null elements -> true (nulls don't matter)
query B
SELECT array_contains(array(1, NULL, 3), 1);
----
true

# Element found at the end, array has null elements -> true
query B
SELECT array_contains(array(1, NULL, 3), 3);
----
true

# Null array -> null
query B
SELECT array_contains(NULL, 1);
----
NULL

# Null element -> null
query B
SELECT array_contains(array(1, 2, 3), NULL);
----
NULL

# Empty array, element not found -> false
query B
SELECT array_contains(array(), 1);
----
false

# Array with only nulls, element not found -> null
query B
SELECT array_contains(array(NULL, NULL), 1);
----
NULL

# String array, element found
query B
SELECT array_contains(array('a', 'b', 'c'), 'b');
----
true

# String array, element not found, no nulls
query B
SELECT array_contains(array('a', 'b', 'c'), 'd');
----
false

# String array, element not found, has null
query B
SELECT array_contains(array('a', NULL, 'c'), 'd');
----
NULL

###
### Columnar tests with a table
###

statement ok
CREATE TABLE test_arrays AS VALUES
(1, make_array(1, 2, 3), 10),
(2, make_array(4, NULL, 6), 5),
(3, make_array(7, 8, 9), 10),
(4, NULL, 1),
(5, make_array(10, NULL, NULL), 10);

# Column needle against column array
query IBB
SELECT column1,
array_contains(column2, column3),
array_contains(column2, 10)
FROM test_arrays
ORDER BY column1;
----
1 false false
2 NULL NULL
3 false false
4 NULL NULL
5 true true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that what Spark does? It short circuits and returns true even if there are NULLs in the array? What if the match is after the NULL? Would it return NULL? I thought if any element of the array was NULL the outcome of the expression is NULL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala> spark.sql("select array_contains(array(1, null, 2), 2)").show(false)
+------------------------------------+
|array_contains(array(1, NULL, 2), 2)|
+------------------------------------+
|true                                |
+------------------------------------+


scala> spark.sql("select array_contains(array(1, 2, null), 2)").show(false)
+------------------------------------+
|array_contains(array(1, 2, NULL), 2)|
+------------------------------------+
|true                                |
+------------------------------------+


scala> spark.sql("select array_contains(array(1, null), 2)").show(false)
+---------------------------------+
|array_contains(array(1, NULL), 2)|
+---------------------------------+
|null                             |
+---------------------------------+

I think explanation in https://issues.apache.org/jira/browse/SPARK-55749 is very accurately explaining this behavior

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on this data Spark run

=== Test Data ===
+----+----------------+----+                                                    
|col1|col2            |col3|
+----+----------------+----+
|1   |[1, 2, 3]       |10  |
|2   |[4, null, 6]    |5   |
|3   |[7, 8, 9]       |10  |
|4   |null            |1   |
|5   |[10, null, null]|10  |
+----+----------------+----+

=== array_contains(col2, col3) and array_contains(col2, 10) ===
+----+--------------------------+------------------------+
|col1|array_contains(col2, col3)|array_contains(col2, 10)|
+----+--------------------------+------------------------+
|1   |false                     |false                   |
|2   |null                      |null                    |
|3   |false                     |false                   |
|4   |null                      |null                    |
|5   |true                      |true                    |
+----+--------------------------+------------------------+


statement ok
DROP TABLE test_arrays;

###
### Nested array tests
###

# Nested array element found
query B
SELECT array_contains(array(array(1, 2), array(3, 4)), array(3, 4));
----
true

# Nested array element not found, no nulls
query B
SELECT array_contains(array(array(1, 2), array(3, 4)), array(5, 6));
----
false