From 6dc07d4523a127e5e4157d04abbf3bead23eb294 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Fri, 19 Dec 2025 17:28:54 +0900 Subject: [PATCH] Various refactors to string functions --- datafusion/functions/src/string/ends_with.rs | 9 +-- datafusion/functions/src/string/ltrim.rs | 8 +-- datafusion/functions/src/string/rtrim.rs | 8 +-- datafusion/functions/src/string/split_part.rs | 24 +++---- .../functions/src/string/starts_with.rs | 9 +-- datafusion/functions/src/string/to_hex.rs | 69 +++++++++---------- datafusion/functions/src/string/uuid.rs | 2 +- 7 files changed, 53 insertions(+), 76 deletions(-) diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index e3fa7c92ca62b..20415c6ed479b 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -95,14 +95,7 @@ impl ScalarUDFImpl for EndsWithFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(ends_with, vec![])(&args.args) - } - other => internal_err!( - "Unsupported data type {other:?} for function ends_with. Expected Utf8, LargeUtf8 or Utf8View" - )?, - } + make_scalar_function(ends_with, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 18a61869a8dc2..40f525408f60f 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; use datafusion_common::{Result, exec_err}; use datafusion_expr::function::Hint; @@ -115,11 +115,7 @@ impl ScalarUDFImpl for LtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[0] == DataType::Utf8View { - Ok(DataType::Utf8View) - } else { - utf8_to_str_type(&arg_types[0], "ltrim") - } + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index f0bafc980e324..77a08bf533c20 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::sync::Arc; use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; use datafusion_common::{Result, exec_err}; use datafusion_expr::function::Hint; @@ -115,11 +115,7 @@ impl ScalarUDFImpl for RtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[0] == DataType::Utf8View { - Ok(DataType::Utf8View) - } else { - utf8_to_str_type(&arg_types[0], "rtrim") - } + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 8ac505bf360f6..d29d33a154d79 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -24,8 +24,11 @@ use arrow::array::{AsArray, GenericStringBuilder}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_common::cast::as_int64_array; +use datafusion_common::types::{NativeType, logical_int64, logical_string}; use datafusion_common::{DataFusionError, Result, exec_err}; -use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility, +}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; use std::any::Any; @@ -60,19 +63,16 @@ impl Default for SplitPartFunc { impl SplitPartFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( + signature: Signature::coercible( vec![ - TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), - TypeSignature::Exact(vec![Utf8View, Utf8, Int64]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8, Int64]), - TypeSignature::Exact(vec![Utf8, Utf8View, Int64]), - TypeSignature::Exact(vec![Utf8, Utf8, Int64]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View, Int64]), - TypeSignature::Exact(vec![LargeUtf8, Utf8, Int64]), - TypeSignature::Exact(vec![Utf8, LargeUtf8, Int64]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), ], Volatility::Immutable, ), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 1a60eb91aa621..92220c7698bb1 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -119,14 +119,7 @@ impl ScalarUDFImpl for StartsWithFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(starts_with, vec![])(&args.args) - } - _ => internal_err!( - "Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View" - )?, - } + make_scalar_function(starts_with, vec![])(&args.args) } fn simplify( diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index fb34c96ad83a3..4a0e966f39475 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -21,20 +21,16 @@ use std::sync::Arc; use crate::utils::make_scalar_function; use arrow::array::{ArrayRef, GenericStringBuilder}; -use arrow::datatypes::DataType::{ - Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Utf8, -}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; -use datafusion_common::Result; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{exec_err, plan_err}; - -use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; -use datafusion_expr_common::signature::TypeSignature::Exact; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; /// Converts the number to its equivalent hexadecimal representation. @@ -101,17 +97,8 @@ impl Default for ToHexFunc { impl ToHexFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - Exact(vec![Int8]), - Exact(vec![Int16]), - Exact(vec![Int32]), - Exact(vec![Int64]), - Exact(vec![UInt8]), - Exact(vec![UInt16]), - Exact(vec![UInt32]), - Exact(vec![UInt64]), - ], + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Integer)], Volatility::Immutable, ), } @@ -131,25 +118,37 @@ impl ScalarUDFImpl for ToHexFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Utf8, - _ => { - return plan_err!("The to_hex function can only accept integers."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { match args.args[0].data_type() { - Int64 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt64 => make_scalar_function(to_hex::, vec![])(&args.args), - Int32 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt32 => make_scalar_function(to_hex::, vec![])(&args.args), - Int16 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt16 => make_scalar_function(to_hex::, vec![])(&args.args), - Int8 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt8 => make_scalar_function(to_hex::, vec![])(&args.args), + DataType::Null => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + DataType::Int64 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::UInt64 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::Int32 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::UInt32 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::Int16 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::UInt16 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::Int8 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::UInt8 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 3171eb98fa2bf..3a99412f5ed29 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -56,7 +56,7 @@ impl Default for UuidFunc { impl UuidFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } }