diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d599373f0a313..a753c91162bea 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -40,7 +40,9 @@ use datafusion_expr::{ }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_functions_nested::expr_fn::{array_has, array_max, array_min}; +use datafusion_functions_nested::expr_fn::{ + array_has, array_max, array_min, array_position, cardinality, +}; mod binary_op; mod function; @@ -635,7 +637,11 @@ impl SqlToRel<'_, S> { schema, planner_context, ), - _ => not_impl_err!("ALL only supports subquery comparison currently"), + _ => { + let left_expr = self.sql_to_expr(*left, schema, planner_context)?; + let right_expr = self.sql_to_expr(*right, schema, planner_context)?; + plan_all_op(&left_expr, &right_expr, &compare_op) + } }, #[expect(deprecated)] SQLExpr::Wildcard(_token) => Ok(Expr::Wildcard { @@ -1297,6 +1303,64 @@ fn plan_any_op( } } +/// Plans `needle ALL(haystack)` with proper SQL NULL semantics. +/// +/// CASE/WHEN structure: +/// WHEN arr IS NULL → NULL +/// WHEN empty → TRUE +/// WHEN lhs IS NULL → NULL +/// WHEN decisive_condition → FALSE +/// WHEN has_nulls → NULL +/// ELSE → TRUE +fn plan_all_op( + needle: &Expr, + haystack: &Expr, + compare_op: &BinaryOperator, +) -> Result { + let null_arr_check = haystack.clone().is_null(); + let empty_check = cardinality(haystack.clone()).eq(lit(0u64)); + let null_lhs_check = needle.clone().is_null(); + // DataFusion's array_position uses is_null() checks internally (not equality), + // so it can locate NULL elements even though NULL = NULL is NULL in standard SQL. + let has_nulls = + array_position(haystack.clone(), lit(ScalarValue::Null), lit(1i64)).is_not_null(); + + let decisive_condition = match compare_op { + BinaryOperator::NotEq => array_has(haystack.clone(), needle.clone()), + BinaryOperator::Eq => { + let all_equal = array_min(haystack.clone()) + .eq(needle.clone()) + .and(array_max(haystack.clone()).eq(needle.clone())); + Expr::Not(Box::new(all_equal)) + } + BinaryOperator::Gt => { + Expr::Not(Box::new(needle.clone().gt(array_max(haystack.clone())))) + } + BinaryOperator::Lt => { + Expr::Not(Box::new(needle.clone().lt(array_min(haystack.clone())))) + } + BinaryOperator::GtEq => { + Expr::Not(Box::new(needle.clone().gt_eq(array_max(haystack.clone())))) + } + BinaryOperator::LtEq => { + Expr::Not(Box::new(needle.clone().lt_eq(array_min(haystack.clone())))) + } + _ => { + return plan_err!( + "Unsupported AllOp: '{compare_op}', only '=', '<>', '>', '<', '>=', '<=' are supported" + ); + } + }; + + let null_bool = lit(ScalarValue::Boolean(None)); + when(null_arr_check, null_bool.clone()) + .when(empty_check, lit(true)) + .when(null_lhs_check, null_bool.clone()) + .when(decisive_condition, lit(false)) + .when(has_nulls, null_bool) + .otherwise(lit(true)) +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/sqllogictest/test_files/array/array_all.slt b/datafusion/sqllogictest/test_files/array/array_all.slt new file mode 100644 index 0000000000000..70ba15edbf47b --- /dev/null +++ b/datafusion/sqllogictest/test_files/array/array_all.slt @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## all operator + +# = ALL: true when all elements equal val +query B +select 5 = ALL(make_array(5, 5, 5)); +---- +true + +query B +select 5 = ALL(make_array(5, 5, 3)); +---- +false + +# <> ALL: true when val differs from every element +query B +select 5 <> ALL(make_array(1, 2, 3)); +---- +true + +query B +select 5 <> ALL(make_array(1, 2, 5)); +---- +false + +# > ALL: true when val greater than all elements +query B +select 10 > ALL(make_array(1, 2, 3)); +---- +true + +query B +select 3 > ALL(make_array(1, 2, 3)); +---- +false + +# < ALL: true when val less than all elements +query B +select 0 < ALL(make_array(1, 2, 3)); +---- +true + +query B +select 2 < ALL(make_array(1, 2, 3)); +---- +false + +# >= ALL: true when val >= all elements +query B +select 5 >= ALL(make_array(1, 2, 5)); +---- +true + +query B +select 4 >= ALL(make_array(1, 2, 5)); +---- +false + +# <= ALL: true when val <= all elements +query B +select 1 <= ALL(make_array(1, 2, 5)); +---- +true + +query B +select 2 <= ALL(make_array(1, 2, 5)); +---- +false + +# Empty arrays: all operators return TRUE (vacuous truth) +query B +select 5 = ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 <> ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 > ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 < ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 >= ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +query B +select 5 <= ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +# NULL LHS with empty array returns TRUE (vacuous truth) +query B +select NULL = ALL(arrow_cast(make_array(), 'List(Int64)')); +---- +true + +# NULL LHS with non-empty array returns NULL +query B +select NULL = ALL(make_array(1, 2, 3)); +---- +NULL + +query B +select NULL > ALL(make_array(1, 2, 3)); +---- +NULL + +query B +select NULL <> ALL(make_array(1, 2, 3)); +---- +NULL + +# All-NULL arrays: returns NULL +query B +select 5 = ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 <> ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 > ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 < ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 >= ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +query B +select 5 <= ALL(make_array(NULL::INT, NULL::INT)); +---- +NULL + +# Mixed NULL + non-NULL (non-NULL elements satisfy, but NULLs present → NULL) +query B +select 5 > ALL(make_array(3, NULL)); +---- +NULL + +query B +select 5 >= ALL(make_array(5, NULL)); +---- +NULL + +query B +select 1 < ALL(make_array(3, NULL)); +---- +NULL + +query B +select 1 <= ALL(make_array(1, NULL)); +---- +NULL + +# Mixed NULL + non-NULL (not satisfying condition → FALSE wins over NULL) +query B +select 5 > ALL(make_array(6, NULL)); +---- +false + +query B +select 5 < ALL(make_array(3, NULL)); +---- +false + +query B +select 5 = ALL(make_array(5, 3, NULL)); +---- +false + +# NULL array input returns NULL +query B +select 5 = ALL(NULL::INT[]); +---- +NULL + +query B +select 5 > ALL(NULL::INT[]); +---- +NULL + +query B +select 5 < ALL(NULL::INT[]); +---- +NULL