Skip to content

Commit 93314e6

Browse files
committed
feat: support named arguments, defaults in udfs
1 parent efcc216 commit 93314e6

File tree

5 files changed

+200
-20
lines changed

5 files changed

+200
-20
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 152 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use datafusion_expr::{
4747
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
4848
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
4949
};
50+
use datafusion_expr_common::signature::TypeSignature;
5051
use datafusion_functions_nested::range::range_udf;
5152
use parking_lot::Mutex;
5253
use regex::Regex;
@@ -945,6 +946,7 @@ struct ScalarFunctionWrapper {
945946
expr: Expr,
946947
signature: Signature,
947948
return_type: DataType,
949+
defaults: Vec<Option<Expr>>,
948950
}
949951

950952
impl ScalarUDFImpl for ScalarFunctionWrapper {
@@ -973,15 +975,19 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
973975
args: Vec<Expr>,
974976
_info: &dyn SimplifyInfo,
975977
) -> Result<ExprSimplifyResult> {
976-
let replacement = Self::replacement(&self.expr, &args)?;
978+
let replacement = Self::replacement(&self.expr, &args, &self.defaults)?;
977979

978980
Ok(ExprSimplifyResult::Simplified(replacement))
979981
}
980982
}
981983

982984
impl ScalarFunctionWrapper {
983985
// replaces placeholders with actual arguments
984-
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
986+
fn replacement(
987+
expr: &Expr,
988+
args: &[Expr],
989+
defaults: &[Option<Expr>],
990+
) -> Result<Expr> {
985991
let result = expr.clone().transform(|e| {
986992
let r = match e {
987993
Expr::Placeholder(placeholder) => {
@@ -990,10 +996,13 @@ impl ScalarFunctionWrapper {
990996
if placeholder_position < args.len() {
991997
Transformed::yes(args[placeholder_position].clone())
992998
} else {
993-
exec_err!(
994-
"Function argument {} not provided, argument missing!",
995-
placeholder.id
996-
)?
999+
match defaults[placeholder_position] {
1000+
Some(ref default) => Transformed::yes(default.clone()),
1001+
None => exec_err!(
1002+
"Function argument {} not provided, argument missing!",
1003+
placeholder.id
1004+
)?,
1005+
}
9971006
}
9981007
}
9991008
_ => Transformed::no(e),
@@ -1021,6 +1030,32 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
10211030
type Error = DataFusionError;
10221031

10231032
fn try_from(definition: CreateFunction) -> std::result::Result<Self, Self::Error> {
1033+
let args = definition.args.unwrap_or_default();
1034+
let defaults: Vec<Option<Expr>> =
1035+
args.iter().map(|a| a.default_expr.clone()).collect();
1036+
let signature: Signature = match defaults.iter().position(|v| v.is_some()) {
1037+
Some(pos) => {
1038+
let mut type_signatures: Vec<TypeSignature> = vec![];
1039+
// Generate all valid signatures
1040+
for n in pos..defaults.len() + 1 {
1041+
if n == 0 {
1042+
type_signatures.push(TypeSignature::Nullary)
1043+
} else {
1044+
type_signatures.push(TypeSignature::Exact(
1045+
args.iter().take(n).map(|a| a.data_type.clone()).collect(),
1046+
))
1047+
}
1048+
}
1049+
Signature::one_of(
1050+
type_signatures,
1051+
definition.params.behavior.unwrap_or(Volatility::Volatile),
1052+
)
1053+
}
1054+
None => Signature::exact(
1055+
args.iter().map(|a| a.data_type.clone()).collect(),
1056+
definition.params.behavior.unwrap_or(Volatility::Volatile),
1057+
),
1058+
};
10241059
Ok(Self {
10251060
name: definition.name,
10261061
expr: definition
@@ -1030,15 +1065,8 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
10301065
return_type: definition
10311066
.return_type
10321067
.expect("Return type has to be defined!"),
1033-
signature: Signature::exact(
1034-
definition
1035-
.args
1036-
.unwrap_or_default()
1037-
.into_iter()
1038-
.map(|a| a.data_type)
1039-
.collect(),
1040-
definition.params.behavior.unwrap_or(Volatility::Volatile),
1041-
),
1068+
signature,
1069+
defaults,
10421070
})
10431071
}
10441072
}
@@ -1112,6 +1140,115 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
11121140
Ok(())
11131141
}
11141142

1143+
#[tokio::test]
1144+
async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> {
1145+
let function_factory = Arc::new(CustomFunctionFactory::default());
1146+
let ctx = SessionContext::new().with_function_factory(function_factory.clone());
1147+
1148+
let sql = r#"
1149+
CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
1150+
RETURNS DOUBLE
1151+
RETURN $a + $b
1152+
"#;
1153+
1154+
assert!(ctx.sql(sql).await.is_ok());
1155+
1156+
let result = ctx
1157+
.sql("select better_add(2.0, 2.0)")
1158+
.await?
1159+
.collect()
1160+
.await?;
1161+
1162+
assert_batches_eq!(
1163+
&[
1164+
"+-----------------------------------+",
1165+
"| better_add(Float64(2),Float64(2)) |",
1166+
"+-----------------------------------+",
1167+
"| 4.0 |",
1168+
"+-----------------------------------+",
1169+
],
1170+
&result
1171+
);
1172+
1173+
// cannot mix named and positional style
1174+
let bad_expression_sql = r#"
1175+
CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
1176+
RETURNS DOUBLE
1177+
RETURN $1 $b
1178+
"#;
1179+
assert!(ctx.sql(bad_expression_sql).await.is_err());
1180+
Ok(())
1181+
}
1182+
1183+
#[tokio::test]
1184+
async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> {
1185+
let function_factory = Arc::new(CustomFunctionFactory::default());
1186+
let ctx = SessionContext::new().with_function_factory(function_factory.clone());
1187+
1188+
let sql = r#"
1189+
CREATE FUNCTION better_add(a DOUBLE DEFAULT 2.0, b DOUBLE DEFAULT 2.0)
1190+
RETURNS DOUBLE
1191+
RETURN $a + $b
1192+
"#;
1193+
1194+
assert!(ctx.sql(sql).await.is_ok());
1195+
1196+
// Check all function arity supported
1197+
let result = ctx.sql("select better_add()").await?.collect().await?;
1198+
1199+
assert_batches_eq!(
1200+
&[
1201+
"+--------------+",
1202+
"| better_add() |",
1203+
"+--------------+",
1204+
"| 4.0 |",
1205+
"+--------------+",
1206+
],
1207+
&result
1208+
);
1209+
1210+
let result = ctx.sql("select better_add(2.0)").await?.collect().await?;
1211+
1212+
assert_batches_eq!(
1213+
&[
1214+
"+------------------------+",
1215+
"| better_add(Float64(2)) |",
1216+
"+------------------------+",
1217+
"| 4.0 |",
1218+
"+------------------------+",
1219+
],
1220+
&result
1221+
);
1222+
1223+
let result = ctx
1224+
.sql("select better_add(2.0, 2.0)")
1225+
.await?
1226+
.collect()
1227+
.await?;
1228+
1229+
assert_batches_eq!(
1230+
&[
1231+
"+-----------------------------------+",
1232+
"| better_add(Float64(2),Float64(2)) |",
1233+
"+-----------------------------------+",
1234+
"| 4.0 |",
1235+
"+-----------------------------------+",
1236+
],
1237+
&result
1238+
);
1239+
1240+
assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err());
1241+
1242+
// non-default argument cannot follow default argument
1243+
let bad_expression_sql = r#"
1244+
CREATE FUNCTION bad_expression_fun(a DOUBLE DEFAULT 2.0, b DOUBLE)
1245+
RETURNS DOUBLE
1246+
RETURN $a $b
1247+
"#;
1248+
assert!(ctx.sql(bad_expression_sql).await.is_err());
1249+
Ok(())
1250+
}
1251+
11151252
/// Saves whatever is passed to it as a scalar function
11161253
#[derive(Debug, Default)]
11171254
struct RecordingFunctionFactory {

datafusion/sql/src/expr/value.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,19 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
123123
return if param_data_types.is_empty() {
124124
Ok(Expr::Placeholder(Placeholder::new_with_field(param, None)))
125125
} else {
126-
// when PREPARE Statement, param_data_types length is always 0
127-
plan_err!("Invalid placeholder, not a number: {param}")
126+
// FIXME: This branch is shared by params from PREPARE and CREATE FUNCTION, but
127+
// only CREATE FUNCTION currently supports named params. For now, we rewrite
128+
// these to positional params.
129+
let named_param_pos = param_data_types
130+
.iter()
131+
.position(|v| v.name() == &param[1..]);
132+
match named_param_pos {
133+
Some(pos) => Ok(Expr::Placeholder(Placeholder::new_with_field(
134+
format!("${}", pos + 1),
135+
param_data_types.get(pos).cloned(),
136+
))),
137+
None => plan_err!("Invalid placeholder: {param}"),
138+
}
128139
};
129140
}
130141
};

datafusion/sql/src/statement.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,24 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
12221222
}
12231223
None => None,
12241224
};
1225+
// Validate default arguments
1226+
let first_default = match args.as_ref() {
1227+
Some(arg) => arg.iter().position(|t| t.default_expr.is_some()),
1228+
None => None,
1229+
};
1230+
let last_non_default = match args.as_ref() {
1231+
Some(arg) => arg.iter().rev().position(|t| t.default_expr.is_none()),
1232+
None => None,
1233+
};
1234+
if let (Some(pos_default), Some(pos_non_default)) =
1235+
(first_default, last_non_default)
1236+
{
1237+
if pos_non_default > pos_default {
1238+
return plan_err!(
1239+
"Non-default arguments cannot follow default arguments."
1240+
);
1241+
}
1242+
}
12251243
// At the moment functions can't be qualified `schema.name`
12261244
let name = match &name.0[..] {
12271245
[] => exec_err!("Function should have name")?,
@@ -1233,9 +1251,23 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
12331251
//
12341252
let arg_types = args.as_ref().map(|arg| {
12351253
arg.iter()
1236-
.map(|t| Arc::new(Field::new("", t.data_type.clone(), true)))
1254+
.map(|t| {
1255+
let name = match t.name.clone() {
1256+
Some(name) => name.value,
1257+
None => "".to_string(),
1258+
};
1259+
Arc::new(Field::new(name, t.data_type.clone(), true))
1260+
})
12371261
.collect::<Vec<_>>()
12381262
});
1263+
// Validate parameter style
1264+
if let Some(ref fields) = arg_types {
1265+
let count_positional =
1266+
fields.iter().filter(|f| f.name() == "").count();
1267+
if !(count_positional == 0 || count_positional == fields.len()) {
1268+
return plan_err!("All function arguments must use either named or positional style.");
1269+
}
1270+
}
12391271
let mut planner_context = PlannerContext::new()
12401272
.with_prepare_param_data_types(arg_types.unwrap_or_default());
12411273

datafusion/sql/tests/cases/params.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ fn test_prepare_statement_to_plan_panic_param_format() {
105105
assert_snapshot!(
106106
logical_plan(sql).unwrap_err().strip_backtrace(),
107107
@r###"
108-
Error during planning: Invalid placeholder, not a number: $foo
108+
Error during planning: Invalid placeholder: $foo
109109
"###
110110
);
111111
}

datafusion/sqllogictest/test_files/prepare.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ statement error DataFusion error: SQL error: ParserError
3434
PREPARE AS SELECT id, age FROM person WHERE age = $foo;
3535

3636
# param following a non-number, $foo, not supported
37-
statement error Invalid placeholder, not a number: \$foo
37+
statement error Invalid placeholder: \$foo
3838
PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo;
3939

4040
# not specify table hence cannot specify columns

0 commit comments

Comments
 (0)