diff --git a/datafusion/spark/src/function/string/ascii.rs b/datafusion/spark/src/function/string/ascii.rs index f14a66d4e484d..50f870ee29b96 100644 --- a/datafusion/spark/src/function/string/ascii.rs +++ b/datafusion/spark/src/function/string/ascii.rs @@ -17,8 +17,12 @@ use arrow::datatypes::DataType; use datafusion_common::Result; +use datafusion_common::types::{NativeType, logical_string}; use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Coercion, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, + Volatility, +}; use datafusion_functions::string::ascii::ascii; use datafusion_functions::utils::make_scalar_function; use std::any::Any; @@ -42,8 +46,17 @@ impl Default for SparkAscii { impl SparkAscii { pub fn new() -> Self { + // Spark's ascii uses ImplicitCastInputTypes with StringType, + // which allows numeric types to be implicitly cast to String. + // See: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala + let string_coercion = Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Numeric], + NativeType::String, + ); + Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible(vec![string_coercion], Volatility::Immutable), } } } @@ -68,8 +81,4 @@ impl ScalarUDFImpl for SparkAscii { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(ascii, vec![])(&args.args) } - - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { - Ok(vec![DataType::Utf8]) - } }