Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions datafusion/spark/src/function/string/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::sync::Arc;

use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::types::{NativeType, logical_string};
use datafusion_expr::ColumnarValue;
use datafusion_common::{Result, internal_err};
use datafusion_expr::{
Coercion, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass,
Volatility,
Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
Signature, TypeSignatureClass, Volatility,
};
use datafusion_functions::string::ascii::ascii;
use datafusion_functions::utils::make_scalar_function;
Expand Down Expand Up @@ -75,10 +76,61 @@ impl ScalarUDFImpl for SparkAscii {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
internal_err!("return_field_from_args should be used instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
// ascii returns an Int32 value
// The result is nullable only if any of the input arguments is nullable
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
Ok(Arc::new(Field::new("ascii", DataType::Int32, nullable)))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(ascii, vec![])(&args.args)
}
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr::ReturnFieldArgs;

#[test]
fn test_return_field_nullable_input() {
let ascii_func = SparkAscii::new();
let nullable_field = Arc::new(Field::new("input", DataType::Utf8, true));

let result = ascii_func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[nullable_field],
scalar_arguments: &[],
})
.unwrap();

assert_eq!(result.data_type(), &DataType::Int32);
assert!(
result.is_nullable(),
"Output should be nullable when input is nullable"
);
}

#[test]
fn test_return_field_non_nullable_input() {
let ascii_func = SparkAscii::new();
let non_nullable_field = Arc::new(Field::new("input", DataType::Utf8, false));

let result = ascii_func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[non_nullable_field],
scalar_arguments: &[],
})
.unwrap();

assert_eq!(result.data_type(), &DataType::Int32);
assert!(
!result.is_nullable(),
"Output should not be nullable when input is not nullable"
);
}
}