diff --git a/datafusion/spark/src/function/string/elt.rs b/datafusion/spark/src/function/string/elt.rs index 35a22fe5edb6..4af6d5128e97 100644 --- a/datafusion/spark/src/function/string/elt.rs +++ b/datafusion/spark/src/function/string/elt.rs @@ -23,11 +23,12 @@ use arrow::array::{ }; use arrow::compute::{can_cast_types, cast}; use arrow::datatypes::DataType::{Int64, Utf8}; -use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{DataType, Field, FieldRef, Int64Type}; use datafusion_common::cast::as_string_array; -use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{internal_err, plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -64,7 +65,12 @@ impl ScalarUDFImpl for SparkElt { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Utf8) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), Utf8, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -248,4 +254,57 @@ mod tests { assert_eq!(out.data_type(), &Utf8); Ok(()) } + + #[test] + fn test_elt_nullability() -> Result<()> { + use datafusion_expr::ReturnFieldArgs; + + let elt_func = SparkElt::new(); + + // Test with all non-nullable args - result should be non-nullable + let non_nullable_idx: FieldRef = Arc::new(Field::new("idx", Int64, false)); + let non_nullable_v1: FieldRef = Arc::new(Field::new("v1", Utf8, false)); + let non_nullable_v2: FieldRef = Arc::new(Field::new("v2", Utf8, false)); + + let result = elt_func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&non_nullable_idx), + Arc::clone(&non_nullable_v1), + Arc::clone(&non_nullable_v2), + ], + scalar_arguments: &[None, None, None], + })?; + assert!( + !result.is_nullable(), + "elt should NOT be nullable when all args are non-nullable" + ); + + // Test with nullable index - result should be nullable + let nullable_idx: FieldRef = Arc::new(Field::new("idx", Int64, true)); + let result = elt_func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + nullable_idx, + Arc::clone(&non_nullable_v1), + Arc::clone(&non_nullable_v2), + ], + scalar_arguments: &[None, None, None], + })?; + assert!( + result.is_nullable(), + "elt should be nullable when index is nullable" + ); + + // Test with nullable value - result should be nullable + let nullable_v1: FieldRef = Arc::new(Field::new("v1", Utf8, true)); + let result = elt_func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_nullable_idx, nullable_v1, non_nullable_v2], + scalar_arguments: &[None, None, None], + })?; + assert!( + result.is_nullable(), + "elt should be nullable when any value is nullable" + ); + + Ok(()) + } }