@@ -47,6 +47,7 @@ use datafusion_expr::{
4747 LogicalPlanBuilder , OperateFunctionArg , ReturnFieldArgs , ScalarFunctionArgs ,
4848 ScalarUDF , ScalarUDFImpl , Signature , Volatility ,
4949} ;
50+ use datafusion_expr_common:: signature:: TypeSignature ;
5051use datafusion_functions_nested:: range:: range_udf;
5152use parking_lot:: Mutex ;
5253use regex:: Regex ;
@@ -945,6 +946,7 @@ struct ScalarFunctionWrapper {
945946 expr : Expr ,
946947 signature : Signature ,
947948 return_type : DataType ,
949+ defaults : Vec < Option < Expr > > ,
948950}
949951
950952impl ScalarUDFImpl for ScalarFunctionWrapper {
@@ -973,15 +975,19 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
973975 args : Vec < Expr > ,
974976 _info : & dyn SimplifyInfo ,
975977 ) -> Result < ExprSimplifyResult > {
976- let replacement = Self :: replacement ( & self . expr , & args) ?;
978+ let replacement = Self :: replacement ( & self . expr , & args, & self . defaults ) ?;
977979
978980 Ok ( ExprSimplifyResult :: Simplified ( replacement) )
979981 }
980982}
981983
982984impl ScalarFunctionWrapper {
983985 // replaces placeholders with actual arguments
984- fn replacement ( expr : & Expr , args : & [ Expr ] ) -> Result < Expr > {
986+ fn replacement (
987+ expr : & Expr ,
988+ args : & [ Expr ] ,
989+ defaults : & [ Option < Expr > ] ,
990+ ) -> Result < Expr > {
985991 let result = expr. clone ( ) . transform ( |e| {
986992 let r = match e {
987993 Expr :: Placeholder ( placeholder) => {
@@ -990,10 +996,13 @@ impl ScalarFunctionWrapper {
990996 if placeholder_position < args. len ( ) {
991997 Transformed :: yes ( args[ placeholder_position] . clone ( ) )
992998 } else {
993- exec_err ! (
994- "Function argument {} not provided, argument missing!" ,
995- placeholder. id
996- ) ?
999+ match defaults[ placeholder_position] {
1000+ Some ( ref default) => Transformed :: yes ( default. clone ( ) ) ,
1001+ None => exec_err ! (
1002+ "Function argument {} not provided, argument missing!" ,
1003+ placeholder. id
1004+ ) ?,
1005+ }
9971006 }
9981007 }
9991008 _ => Transformed :: no ( e) ,
@@ -1021,6 +1030,32 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
10211030 type Error = DataFusionError ;
10221031
10231032 fn try_from ( definition : CreateFunction ) -> std:: result:: Result < Self , Self :: Error > {
1033+ let args = definition. args . unwrap_or_default ( ) ;
1034+ let defaults: Vec < Option < Expr > > =
1035+ args. iter ( ) . map ( |a| a. default_expr . clone ( ) ) . collect ( ) ;
1036+ let signature: Signature = match defaults. iter ( ) . position ( |v| v. is_some ( ) ) {
1037+ Some ( pos) => {
1038+ let mut type_signatures: Vec < TypeSignature > = vec ! [ ] ;
1039+ // Generate all valid signatures
1040+ for n in pos..defaults. len ( ) + 1 {
1041+ if n == 0 {
1042+ type_signatures. push ( TypeSignature :: Nullary )
1043+ } else {
1044+ type_signatures. push ( TypeSignature :: Exact (
1045+ args. iter ( ) . take ( n) . map ( |a| a. data_type . clone ( ) ) . collect ( ) ,
1046+ ) )
1047+ }
1048+ }
1049+ Signature :: one_of (
1050+ type_signatures,
1051+ definition. params . behavior . unwrap_or ( Volatility :: Volatile ) ,
1052+ )
1053+ }
1054+ None => Signature :: exact (
1055+ args. iter ( ) . map ( |a| a. data_type . clone ( ) ) . collect ( ) ,
1056+ definition. params . behavior . unwrap_or ( Volatility :: Volatile ) ,
1057+ ) ,
1058+ } ;
10241059 Ok ( Self {
10251060 name : definition. name ,
10261061 expr : definition
@@ -1030,15 +1065,8 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
10301065 return_type : definition
10311066 . return_type
10321067 . expect ( "Return type has to be defined!" ) ,
1033- signature : Signature :: exact (
1034- definition
1035- . args
1036- . unwrap_or_default ( )
1037- . into_iter ( )
1038- . map ( |a| a. data_type )
1039- . collect ( ) ,
1040- definition. params . behavior . unwrap_or ( Volatility :: Volatile ) ,
1041- ) ,
1068+ signature,
1069+ defaults,
10421070 } )
10431071 }
10441072}
@@ -1112,6 +1140,115 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
11121140 Ok ( ( ) )
11131141}
11141142
1143+ #[ tokio:: test]
1144+ async fn create_scalar_function_from_sql_statement_named_arguments ( ) -> Result < ( ) > {
1145+ let function_factory = Arc :: new ( CustomFunctionFactory :: default ( ) ) ;
1146+ let ctx = SessionContext :: new ( ) . with_function_factory ( function_factory. clone ( ) ) ;
1147+
1148+ let sql = r#"
1149+ CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
1150+ RETURNS DOUBLE
1151+ RETURN $a + $b
1152+ "# ;
1153+
1154+ assert ! ( ctx. sql( sql) . await . is_ok( ) ) ;
1155+
1156+ let result = ctx
1157+ . sql ( "select better_add(2.0, 2.0)" )
1158+ . await ?
1159+ . collect ( )
1160+ . await ?;
1161+
1162+ assert_batches_eq ! (
1163+ & [
1164+ "+-----------------------------------+" ,
1165+ "| better_add(Float64(2),Float64(2)) |" ,
1166+ "+-----------------------------------+" ,
1167+ "| 4.0 |" ,
1168+ "+-----------------------------------+" ,
1169+ ] ,
1170+ & result
1171+ ) ;
1172+
1173+ // cannot mix named and positional style
1174+ let bad_expression_sql = r#"
1175+ CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
1176+ RETURNS DOUBLE
1177+ RETURN $1 $b
1178+ "# ;
1179+ assert ! ( ctx. sql( bad_expression_sql) . await . is_err( ) ) ;
1180+ Ok ( ( ) )
1181+ }
1182+
1183+ #[ tokio:: test]
1184+ async fn create_scalar_function_from_sql_statement_default_arguments ( ) -> Result < ( ) > {
1185+ let function_factory = Arc :: new ( CustomFunctionFactory :: default ( ) ) ;
1186+ let ctx = SessionContext :: new ( ) . with_function_factory ( function_factory. clone ( ) ) ;
1187+
1188+ let sql = r#"
1189+ CREATE FUNCTION better_add(a DOUBLE DEFAULT 2.0, b DOUBLE DEFAULT 2.0)
1190+ RETURNS DOUBLE
1191+ RETURN $a + $b
1192+ "# ;
1193+
1194+ assert ! ( ctx. sql( sql) . await . is_ok( ) ) ;
1195+
1196+ // Check all function arity supported
1197+ let result = ctx. sql ( "select better_add()" ) . await ?. collect ( ) . await ?;
1198+
1199+ assert_batches_eq ! (
1200+ & [
1201+ "+--------------+" ,
1202+ "| better_add() |" ,
1203+ "+--------------+" ,
1204+ "| 4.0 |" ,
1205+ "+--------------+" ,
1206+ ] ,
1207+ & result
1208+ ) ;
1209+
1210+ let result = ctx. sql ( "select better_add(2.0)" ) . await ?. collect ( ) . await ?;
1211+
1212+ assert_batches_eq ! (
1213+ & [
1214+ "+------------------------+" ,
1215+ "| better_add(Float64(2)) |" ,
1216+ "+------------------------+" ,
1217+ "| 4.0 |" ,
1218+ "+------------------------+" ,
1219+ ] ,
1220+ & result
1221+ ) ;
1222+
1223+ let result = ctx
1224+ . sql ( "select better_add(2.0, 2.0)" )
1225+ . await ?
1226+ . collect ( )
1227+ . await ?;
1228+
1229+ assert_batches_eq ! (
1230+ & [
1231+ "+-----------------------------------+" ,
1232+ "| better_add(Float64(2),Float64(2)) |" ,
1233+ "+-----------------------------------+" ,
1234+ "| 4.0 |" ,
1235+ "+-----------------------------------+" ,
1236+ ] ,
1237+ & result
1238+ ) ;
1239+
1240+ assert ! ( ctx. sql( "select better_add(2.0, 2.0, 2.0)" ) . await . is_err( ) ) ;
1241+
1242+ // non-default argument cannot follow default argument
1243+ let bad_expression_sql = r#"
1244+ CREATE FUNCTION bad_expression_fun(a DOUBLE DEFAULT 2.0, b DOUBLE)
1245+ RETURNS DOUBLE
1246+ RETURN $a $b
1247+ "# ;
1248+ assert ! ( ctx. sql( bad_expression_sql) . await . is_err( ) ) ;
1249+ Ok ( ( ) )
1250+ }
1251+
11151252/// Saves whatever is passed to it as a scalar function
11161253#[ derive( Debug , Default ) ]
11171254struct RecordingFunctionFactory {
0 commit comments