From 3f37e417d3018fd3996b956fd6b633c493bb41d9 Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 16:39:46 +0800 Subject: [PATCH 1/6] [AURON #2067] Implement native function of instr --- .../apache/spark/sql/AuronInstrSuite.scala | 130 +++++++++++++++++ native-engine/auron/src/exec.rs | 9 +- .../datafusion-ext-functions/src/lib.rs | 2 + .../src/spark_instr.rs | 136 ++++++++++++++++++ .../spark/sql/auron/NativeConverters.scala | 2 + 5 files changed, 278 insertions(+), 1 deletion(-) create mode 100644 auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala create mode 100644 native-engine/datafusion-ext-functions/src/spark_instr.rs diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala new file mode 100644 index 000000000..cc62e5b68 --- /dev/null +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +class AuronInstrSuite extends QueryTest with SparkQueryTestsBase { + + test("test instr function - basic functionality") { + val data = Seq( + ("hello world", "world"), + ("hello world", "hello"), + ("hello world", "o"), + ("hello world", "z"), + (null, "test"), + ("test", null) + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 7, "instr('hello world', 'world') should return 7") + assert(result(1) == 1, "instr('hello world', 'hello') should return 1") + assert(result(2) == 5, "instr('hello world', 'o') should return 5") + assert(result(3) == 0, "instr('hello world', 'z') should return 0") + assert(result(4) == 0, "instr(null, 'test') should return null") + assert(result(5) == 0, "instr('test', null) should return null") + } + + test("test instr function - multiple occurrences") { + val data = Seq( + ("banana", "a"), + ("testtesttest", "test"), + ("abcabcabc", "abc") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 2, "instr('banana', 'a') should return 2") + assert(result(1) == 1, "instr('testtesttest', 'test') should return 1") + assert(result(2) == 1, "instr('abcabcabc', 'abc') should return 1") + } + + test("test instr function - case sensitive") { + val data = Seq( + ("Hello", "hello"), + ("HELLO", "hello"), + ("Hello", "Hello"), + ("hElLo", "hello") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) + + assert(result(0) == 0, "instr('Hello', 'hello') should return 0 (case sensitive)") + assert(result(1) == 0, "instr('HELLO', 'hello') should return 0 (case sensitive)") + assert(result(2) == 1, "instr('Hello', 'Hello') should return 1") + assert(result(3) == 0, "instr('hElLo', 'hello') should return 0 (case sensitive)") + } + + test("test instr function - with filter") { + val data = Seq( + ("hello world", "world", 1), + ("hello", "world", 0), + ("hello", "hello", 1), + ("test", "abc", 0) + ) + + val df = spark.createDataFrame(data).toDF("str", "substr", "expected") + val result = df + .filter("instr(str, substr) > 0") + .select("str") + .collect() + .map(_.getString(0)) + + assert(result.length == 2, "Should find 2 matching strings") + assert(result.contains("hello world")) + assert(result.contains("hello")) + } + + test("test instr function - in group by") { + val data = Seq( + ("test1", "test"), + ("test2", "test"), + ("hello", "world"), + ("testing", "test") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df + .groupBy("substr") + .count() + .filter("count > 0") + .orderBy("substr") + .collect() + + assert(result.length >= 1) + } + + test("test instr function - in where clause") { + val data = Seq( + ("hello world", "world"), + ("hello", "world"), + ("testing", "test"), + ("abc", "def") + ) + + val df = spark.createDataFrame(data).toDF("str", "substr") + val result = df + .filter("instr(str, substr) = 1") + .select("str") + .collect() + .map(_.getString(0)) + + assert(result.length >= 1) + } +} diff --git a/native-engine/auron/src/exec.rs b/native-engine/auron/src/exec.rs index fa4fec4af..ee51eba05 100644 --- a/native-engine/auron/src/exec.rs +++ b/native-engine/auron/src/exec.rs @@ -141,9 +141,16 @@ pub extern "system" fn Java_org_apache_auron_jni_JniBridge_finalizeNative( #[allow(non_snake_case)] #[unsafe(no_mangle)] -pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(_: JNIEnv, _: JClass) { +pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(env: JNIEnv, _: JClass) { log::info!("exiting native environment"); if MemManager::initialized() { MemManager::get().dump_status(); } + // Clear Java-side resources to prevent memory leaks + let _ = env.call_static_method( + jni_bridge::JavaClasses::get().cJniBridge.class, + "clearResources", + "()V", + &[] + ); } diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index a65dc0d44..2eeb8d36b 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -26,6 +26,7 @@ mod spark_dates; pub mod spark_get_json_object; mod spark_hash; mod spark_initcap; +mod spark_instr; mod spark_isnan; mod spark_make_array; mod spark_make_decimal; @@ -85,6 +86,7 @@ pub fn create_auron_ext_function( Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero) } "Spark_IsNaN" => Arc::new(spark_isnan::spark_isnan), + "Spark_Instr" => Arc::new(spark_instr::spark_instr), _ => df_unimplemented_err!("spark ext function not implemented: {name}")?, }) } diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs new file mode 100644 index 000000000..db1a9b96b --- /dev/null +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Int32Array, StringArray}; +use datafusion::{ + common::{Result, ScalarValue, cast::as_string_array}, + physical_plan::ColumnarValue, +}; +use datafusion_ext_commons::df_execution_err; + +/// instr(str, substr) - Returns the (1-based) index of the first occurrence of +/// substr in str Compatible with Spark's instr function +/// Returns 0 if substr is not found or if either argument is null +pub fn spark_instr(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + df_execution_err!("instr requires exactly 2 arguments")?; + } + + let string_array = args[0].clone().into_array(1)?; + let substr = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(substr))) if !substr.is_empty() => substr, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); + } + _ => df_execution_err!("instr substring only supports non-empty literal string")?, + }; + + let result_array: ArrayRef = Arc::new(Int32Array::from_iter( + as_string_array(&string_array)? + .into_iter() + .map(|s| s.map(|s| s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))), + )); + + Ok(ColumnarValue::Array(result_array)) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::StringArray; + use datafusion::{ + common::{Result, ScalarValue, cast::as_int32_array}, + physical_plan::ColumnarValue, + }; + + use super::spark_instr; + + #[test] + fn test_spark_instr() -> Result<()> { + // Test basic functionality + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello world".to_string()), + Some("abc".to_string()), + Some("abcabc".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from("world")), + ])?; + let s = r.into_array(4)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(7), Some(0), Some(0), None,] + ); + + // Test with empty substring should return 0 + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "hello".to_string(), + )]))), + ColumnarValue::Scalar(ScalarValue::from("")), + ]); + assert!(r.is_err()); + + // Test with null substring + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "hello".to_string(), + )]))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ])?; + if !matches!(r, ColumnarValue::Scalar(ScalarValue::Int32(None))) { + return datafusion::common::internal_err!("Expected null Int32 scalar"); + } + Ok(()) + } + + #[test] + fn test_spark_instr_multiple_matches() -> Result<()> { + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("banana".to_string()), + Some("testtesttest".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("test")), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(1),] + ); + Ok(()) + } + + #[test] + fn test_spark_instr_case_sensitive() -> Result<()> { + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("Hello".to_string()), + Some("HELLO".to_string()), + ]))), + ColumnarValue::Scalar(ScalarValue::from("hello")), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(0),] + ); + Ok(()) + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 7a3bde2c8..11ad3797f 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -922,6 +922,8 @@ object NativeConverters extends Logging { case e: Levenshtein => buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children, e.dataType) + case e: StringInstr => + buildExtScalarFunction("Spark_Instr", e.children, e.dataType) case e: Hour if datetimeExtractEnabled => buildTimePartExt("Spark_Hour", e.children.head, isPruningExpr, fallback) From fe51cadc9eb0f617dec980a53fb6659f2c4e60e2 Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 16:42:36 +0800 Subject: [PATCH 2/6] [AURON #2067] Implement native function of instr --- native-engine/auron/src/exec.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/native-engine/auron/src/exec.rs b/native-engine/auron/src/exec.rs index ee51eba05..fa4fec4af 100644 --- a/native-engine/auron/src/exec.rs +++ b/native-engine/auron/src/exec.rs @@ -141,16 +141,9 @@ pub extern "system" fn Java_org_apache_auron_jni_JniBridge_finalizeNative( #[allow(non_snake_case)] #[unsafe(no_mangle)] -pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(env: JNIEnv, _: JClass) { +pub extern "system" fn Java_org_apache_auron_jni_JniBridge_onExit(_: JNIEnv, _: JClass) { log::info!("exiting native environment"); if MemManager::initialized() { MemManager::get().dump_status(); } - // Clear Java-side resources to prevent memory leaks - let _ = env.call_static_method( - jni_bridge::JavaClasses::get().cJniBridge.class, - "clearResources", - "()V", - &[] - ); } From d23615f1acc172d96e059918aab066e2e97ed783 Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 17:23:48 +0800 Subject: [PATCH 3/6] fix tests --- .../src/spark_instr.rs | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index db1a9b96b..34bf6ae0f 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -23,8 +23,9 @@ use datafusion::{ use datafusion_ext_commons::df_execution_err; /// instr(str, substr) - Returns the (1-based) index of the first occurrence of -/// substr in str Compatible with Spark's instr function -/// Returns 0 if substr is not found or if either argument is null +/// substr in str. Compatible with Spark's instr function. +/// Returns 0 if substr is not found or if substr is empty. +/// Returns null if str is null. pub fn spark_instr(args: &[ColumnarValue]) -> Result { if args.len() != 2 { df_execution_err!("instr requires exactly 2 arguments")?; @@ -32,18 +33,27 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { let string_array = args[0].clone().into_array(1)?; let substr = match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(substr))) if !substr.is_empty() => substr, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(substr))) => substr, ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); } - _ => df_execution_err!("instr substring only supports non-empty literal string")?, + _ => df_execution_err!("instr substring only supports literal string")?, }; - let result_array: ArrayRef = Arc::new(Int32Array::from_iter( - as_string_array(&string_array)? - .into_iter() - .map(|s| s.map(|s| s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))), - )); + // If substr is empty, return 0 for all non-null strings + let result_array: ArrayRef = if substr.is_empty() { + Arc::new(Int32Array::from_iter( + as_string_array(&string_array)? + .into_iter() + .map(|s| s.map(|_| 0)), + )) + } else { + Arc::new(Int32Array::from_iter( + as_string_array(&string_array)? + .into_iter() + .map(|s| s.map(|s| s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))), + )) + }; Ok(ColumnarValue::Array(result_array)) } @@ -80,12 +90,18 @@ mod test { // Test with empty substring should return 0 let r = spark_instr(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( - "hello".to_string(), - )]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello".to_string()), + Some("world".to_string()), + None, + ]))), ColumnarValue::Scalar(ScalarValue::from("")), - ]); - assert!(r.is_err()); + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(0), Some(0), None,] + ); // Test with null substring let r = spark_instr(&vec![ From 2b4134d6610f3481df5c3c8155f4fa1bb5995f4b Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 18:01:47 +0800 Subject: [PATCH 4/6] fix tests --- .../src/spark_instr.rs | 110 +++++++++++++----- 1 file changed, 83 insertions(+), 27 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index 34bf6ae0f..fbdc4a01f 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -17,7 +17,10 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int32Array, StringArray}; use datafusion::{ - common::{Result, ScalarValue, cast::as_string_array}, + common::{ + Result, ScalarValue, + cast::{as_int32_array, as_string_array}, + }, physical_plan::ColumnarValue, }; use datafusion_ext_commons::df_execution_err; @@ -25,37 +28,58 @@ use datafusion_ext_commons::df_execution_err; /// instr(str, substr) - Returns the (1-based) index of the first occurrence of /// substr in str. Compatible with Spark's instr function. /// Returns 0 if substr is not found or if substr is empty. -/// Returns null if str is null. +/// Returns null if str is null or substr is null. pub fn spark_instr(args: &[ColumnarValue]) -> Result { if args.len() != 2 { df_execution_err!("instr requires exactly 2 arguments")?; } - let string_array = args[0].clone().into_array(1)?; - let substr = match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(substr))) => substr, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); - } - _ => df_execution_err!("instr substring only supports literal string")?, + // Ensure both arguments are arrays for element-wise comparison + let left: ArrayRef = match &args[0] { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, + ColumnarValue::Array(array) => array.clone(), }; - // If substr is empty, return 0 for all non-null strings - let result_array: ArrayRef = if substr.is_empty() { - Arc::new(Int32Array::from_iter( - as_string_array(&string_array)? - .into_iter() - .map(|s| s.map(|_| 0)), - )) - } else { - Arc::new(Int32Array::from_iter( - as_string_array(&string_array)? - .into_iter() - .map(|s| s.map(|s| s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0))), - )) + let right: ArrayRef = match &args[1] { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(left.len())?, + ColumnarValue::Array(array) => array.clone(), }; - Ok(ColumnarValue::Array(result_array)) + let str_array = as_string_array(&left)?; + let substr_array = as_string_array(&right)?; + + // Perform element-wise instr operation + let result_array: ArrayRef = Arc::new(Int32Array::from_iter( + str_array + .into_iter() + .zip(substr_array.into_iter()) + .map(|(s, substr)| { + match (s, substr) { + (Some(_), None) => None, // substr is null + (None, _) => None, // str is null + (Some(s), Some(substr)) => { + // Empty substr returns 0 + if substr.is_empty() { + Some(0) + } else { + Some(s.find(&substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) + } + } + } + }), + )); + + // If both inputs were scalars, return a scalar + if matches!(args[0], ColumnarValue::Scalar(_)) && matches!(args[1], ColumnarValue::Scalar(_)) { + let scalar = as_int32_array(&result_array)?.value(0); + Ok(ColumnarValue::Scalar(if result_array.is_null(0) { + ScalarValue::Int32(None) + } else { + ScalarValue::Int32(Some(scalar)) + })) + } else { + Ok(ColumnarValue::Array(result_array)) + } } #[cfg(test)] @@ -72,7 +96,7 @@ mod test { #[test] fn test_spark_instr() -> Result<()> { - // Test basic functionality + // Test basic functionality with scalar substring let r = spark_instr(&vec![ ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ Some("hello world".to_string()), @@ -110,9 +134,41 @@ mod test { )]))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ])?; - if !matches!(r, ColumnarValue::Scalar(ScalarValue::Int32(None))) { - return datafusion::common::internal_err!("Expected null Int32 scalar"); - } + let s = r.into_array(1)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![None,] + ); + + // Test with array substring (element-wise) + let r = spark_instr(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("hello world".to_string()), + Some("hello".to_string()), + Some("test".to_string()), + ]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("world".to_string()), + Some("test".to_string()), + Some("test".to_string()), + ]))), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(7), Some(0), Some(1),] + ); + + // Test with both scalars + let r = spark_instr(&vec![ + ColumnarValue::Scalar(ScalarValue::from("hello world")), + ColumnarValue::Scalar(ScalarValue::from("world")), + ])?; + assert!(matches!( + r, + ColumnarValue::Scalar(ScalarValue::Int32(Some(7))) + )); + Ok(()) } From 152e20583c45214ca30249946df7d03a47525fbb Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 19:07:08 +0800 Subject: [PATCH 5/6] fix styles --- .../src/spark_instr.rs | 62 ++++++++++--------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index fbdc4a01f..69a74b59b 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -34,43 +34,49 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { df_execution_err!("instr requires exactly 2 arguments")?; } - // Ensure both arguments are arrays for element-wise comparison - let left: ArrayRef = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1)?, - ColumnarValue::Array(array) => array.clone(), - }; - - let right: ArrayRef = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(left.len())?, - ColumnarValue::Array(array) => array.clone(), - }; - - let str_array = as_string_array(&left)?; - let substr_array = as_string_array(&right)?; + let is_scalar = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let len = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }) + .max() + .unwrap_or(0); + + let arrays = args + .iter() + .map(|arg| { + Ok(match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len)?, + }) + }) + .collect::>>()?; + + let str_array = as_string_array(&arrays[0])?; + let substr_array = as_string_array(&arrays[1])?; - // Perform element-wise instr operation let result_array: ArrayRef = Arc::new(Int32Array::from_iter( str_array .into_iter() .zip(substr_array.into_iter()) - .map(|(s, substr)| { - match (s, substr) { - (Some(_), None) => None, // substr is null - (None, _) => None, // str is null - (Some(s), Some(substr)) => { - // Empty substr returns 0 - if substr.is_empty() { - Some(0) - } else { - Some(s.find(&substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) - } + .map(|(s, substr)| match (s, substr) { + (Some(_), None) => None, // substr is null + (None, _) => None, // str is null + (Some(s), Some(substr)) => { + if substr.is_empty() { + Some(0) + } else { + Some(s.find(&substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) } } }), )); - // If both inputs were scalars, return a scalar - if matches!(args[0], ColumnarValue::Scalar(_)) && matches!(args[1], ColumnarValue::Scalar(_)) { + if is_scalar { let scalar = as_int32_array(&result_array)?.value(0); Ok(ColumnarValue::Scalar(if result_array.is_null(0) { ScalarValue::Int32(None) @@ -86,7 +92,7 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { mod test { use std::sync::Arc; - use arrow::array::StringArray; + use arrow::array::{ArrayRef, Int32Array, StringArray}; use datafusion::{ common::{Result, ScalarValue, cast::as_int32_array}, physical_plan::ColumnarValue, From 91d05c77d92abb77e47658764e4f2e33e7a6015b Mon Sep 17 00:00:00 2001 From: xuzifu666 <1206332514@qq.com> Date: Wed, 11 Mar 2026 19:11:31 +0800 Subject: [PATCH 6/6] fix styles --- native-engine/datafusion-ext-functions/src/spark_instr.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_instr.rs b/native-engine/datafusion-ext-functions/src/spark_instr.rs index 69a74b59b..b970cc38f 100644 --- a/native-engine/datafusion-ext-functions/src/spark_instr.rs +++ b/native-engine/datafusion-ext-functions/src/spark_instr.rs @@ -61,8 +61,8 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { let result_array: ArrayRef = Arc::new(Int32Array::from_iter( str_array - .into_iter() - .zip(substr_array.into_iter()) + .iter() + .zip(substr_array.iter()) .map(|(s, substr)| match (s, substr) { (Some(_), None) => None, // substr is null (None, _) => None, // str is null @@ -70,7 +70,7 @@ pub fn spark_instr(args: &[ColumnarValue]) -> Result { if substr.is_empty() { Some(0) } else { - Some(s.find(&substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) + Some(s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) } } }),