diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index a1fa124548d0..f65349a83799 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -23,6 +23,7 @@ use arrow::compute::kernels::comparison::ends_with as arrow_ends_with; use arrow::datatypes::DataType; use datafusion_common::types::logical_string; +use datafusion_common::utils::take_function_args; use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::binary::{binary_to_string_coercion, string_coercion}; use datafusion_expr::{ @@ -95,12 +96,7 @@ impl ScalarUDFImpl for EndsWithFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let [str_arg, suffix_arg] = args.args.as_slice() else { - return exec_err!( - "ends_with was called with {} arguments, expected 2", - args.args.len() - ); - }; + let [str_arg, suffix_arg] = take_function_args(self.name(), &args.args)?; // Determine the common type for coercion let coercion_type = string_coercion( diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 18a61869a8dc..40f525408f60 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; use datafusion_common::{Result, exec_err}; use datafusion_expr::function::Hint; @@ -115,11 +115,7 @@ impl ScalarUDFImpl for LtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[0] == DataType::Utf8View { - Ok(DataType::Utf8View) - } else { - utf8_to_str_type(&arg_types[0], "ltrim") - } + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index f0bafc980e32..77a08bf533c2 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; use datafusion_common::{Result, exec_err}; use datafusion_expr::function::Hint; @@ -115,11 +115,7 @@ impl ScalarUDFImpl for RtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[0] == DataType::Utf8View { - Ok(DataType::Utf8View) - } else { - utf8_to_str_type(&arg_types[0], "rtrim") - } + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 8ac505bf360f..d29d33a154d7 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -24,8 +24,11 @@ use arrow::array::{AsArray, GenericStringBuilder}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_common::cast::as_int64_array; +use datafusion_common::types::{NativeType, logical_int64, logical_string}; use datafusion_common::{DataFusionError, Result, exec_err}; -use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility, +}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; use std::any::Any; @@ -60,19 +63,16 @@ impl Default for SplitPartFunc { impl SplitPartFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( + signature: Signature::coercible( vec![ - TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), - TypeSignature::Exact(vec![Utf8View, Utf8, Int64]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8, Int64]), - TypeSignature::Exact(vec![Utf8, Utf8View, Int64]), - TypeSignature::Exact(vec![Utf8, Utf8, Int64]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View, Int64]), - TypeSignature::Exact(vec![LargeUtf8, Utf8, Int64]), - TypeSignature::Exact(vec![Utf8, LargeUtf8, Int64]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), ], Volatility::Immutable, ), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 259612c42997..c38a5bffcb2b 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Scalar}; use arrow::compute::kernels::comparison::starts_with as arrow_starts_with; use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::type_coercion::binary::{ binary_to_string_coercion, string_coercion, @@ -92,12 +93,7 @@ impl ScalarUDFImpl for StartsWithFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let [str_arg, prefix_arg] = args.args.as_slice() else { - return exec_err!( - "starts_with was called with {} arguments, expected 2", - args.args.len() - ); - }; + let [str_arg, prefix_arg] = take_function_args(self.name(), &args.args)?; // Determine the common type for coercion let coercion_type = string_coercion( diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index dd4f4174266f..891cbe254957 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -21,20 +21,16 @@ use std::sync::Arc; use crate::utils::make_scalar_function; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::buffer::{Buffer, OffsetBuffer}; -use arrow::datatypes::DataType::{ - Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Utf8, -}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; -use datafusion_common::Result; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{exec_err, plan_err}; - -use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; -use datafusion_expr_common::signature::TypeSignature::Exact; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; /// Hex lookup table for fast conversion @@ -201,17 +197,8 @@ impl Default for ToHexFunc { impl ToHexFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - Exact(vec![Int8]), - Exact(vec![Int16]), - Exact(vec![Int32]), - Exact(vec![Int64]), - Exact(vec![UInt8]), - Exact(vec![UInt16]), - Exact(vec![UInt32]), - Exact(vec![UInt64]), - ], + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Integer)], Volatility::Immutable, ), } @@ -231,25 +218,37 @@ impl ScalarUDFImpl for ToHexFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Utf8, - _ => { - return plan_err!("The to_hex function can only accept integers."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { match args.args[0].data_type() { - Int64 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt64 => make_scalar_function(to_hex::, vec![])(&args.args), - Int32 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt32 => make_scalar_function(to_hex::, vec![])(&args.args), - Int16 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt16 => make_scalar_function(to_hex::, vec![])(&args.args), - Int8 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt8 => make_scalar_function(to_hex::, vec![])(&args.args), + DataType::Null => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + DataType::Int64 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::UInt64 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::Int32 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::UInt32 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::Int16 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::UInt16 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::Int8 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::UInt8 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 3171eb98fa2b..3a99412f5ed2 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -56,7 +56,7 @@ impl Default for UuidFunc { impl UuidFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } } diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index f715f8f46a48..ef91eade01e5 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -73,7 +73,7 @@ select decode('', null) from test; query error DataFusion error: This feature is not implemented: Encoding must be a scalar; array specified encoding is not yet supported select decode('', hex_field) from test; -query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'to_hex' function +query error DataFusion error: Error during planning: Internal error: Expect TypeSignatureClass::Integer but received NativeType::String, DataType: Utf8View select to_hex(hex_field) from test; query error DataFusion error: Execution error: Failed to decode value using base64