From f7df9a0374e39cfb7043b598a9fa2e54a1154368 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sat, 14 Feb 2026 20:09:42 -0500 Subject: [PATCH] fix: handle Utf8View and LargeUtf8 separators in concat_ws concat_ws only handled Utf8 separators; attempting to pass a Utf8View or LargeUtf8 separator would result in a panic or internal error. --- datafusion/functions/src/string/concat.rs | 11 +- datafusion/functions/src/string/concat_ws.rs | 174 ++++++++++++++----- datafusion/sqllogictest/test_files/expr.slt | 19 ++ 3 files changed, 156 insertions(+), 48 deletions(-) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 9e565342bafbc..c8da67c186726 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -120,13 +120,10 @@ impl ScalarUDFImpl for ConcatFunc { } }); - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); + let array_len = args.iter().find_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); // Scalar if array_len.is_none() { diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index b08799f434aa6..25f7aba56c33e 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -105,7 +105,6 @@ impl ScalarUDFImpl for ConcatWsFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; - // do not accept 0 arguments. if args.len() < 2 { return exec_err!( "concat_ws was called with {} arguments. It requires at least 2.", @@ -113,18 +112,14 @@ impl ScalarUDFImpl for ConcatWsFunc { ); } - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); + let array_len = args.iter().find_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); // Scalar if array_len.is_none() { let ColumnarValue::Scalar(scalar) = &args[0] else { - // loop above checks for all args being scalar unreachable!() }; let sep = match scalar.try_as_str() { @@ -139,7 +134,6 @@ impl ScalarUDFImpl for ConcatWsFunc { let mut values = Vec::with_capacity(args.len() - 1); for arg in &args[1..] { let ColumnarValue::Scalar(scalar) = arg else { - // loop above checks for all args being scalar unreachable!() }; @@ -162,23 +156,55 @@ impl ScalarUDFImpl for ConcatWsFunc { // parse sep let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - data_size += s.len() * len * (args.len() - 2); // estimate - ColumnarValueRef::Scalar(s.as_bytes()) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); - } - ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len() * (args.len() - 2); // estimate - if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(Some(s)) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) } - } - _ => unreachable!("concat ws"), + Some(None) => { + return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null( + len, + )))); + } + None => { + return internal_err!("Expected string separator, got {scalar:?}"); + } + }, + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + } + } + DataType::LargeUtf8 => { + let string_array = as_largestring_array(array); + data_size += string_array.values().len() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + } + } + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + data_size += + string_array.total_buffer_bytes_used() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + } + } + other => { + return plan_err!( + "Input was {other} which is not a supported datatype for concat_ws separator" + ); + } + }, }; let mut columns = Vec::with_capacity(args.len() - 1); @@ -221,11 +247,7 @@ impl ScalarUDFImpl for ConcatWsFunc { DataType::Utf8View => { let string_array = as_string_view_array(array)?; - data_size += string_array - .data_buffers() - .iter() - .map(|buf| buf.len()) - .sum::(); + data_size += string_array.total_buffer_bytes_used(); let column = if array.is_nullable() { ColumnarValueRef::NullableStringViewArray(string_array) } else { @@ -251,18 +273,14 @@ impl ScalarUDFImpl for ConcatWsFunc { continue; } - let mut iter = columns.iter(); - for column in iter.by_ref() { + let mut first = true; + for column in &columns { if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } builder.write::(column, i); - break; - } - } - - for column in iter { - if column.is_valid(i) { - builder.write::(&sep, i); - builder.write::(column, i); + first = false; } } @@ -546,4 +564,78 @@ mod tests { Ok(()) } + + #[test] + fn concat_ws_utf8view_scalar_separator() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_largeutf8_scalar_separator() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 57769941e2a66..c737efca4a6d0 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -504,6 +504,25 @@ abc statement ok drop table foo +# concat_ws with a Utf8View column as separator +statement ok +create table test_concat_ws_sep (sep varchar, val1 varchar, val2 varchar) as values (',', 'foo', 'bar'), ('|', 'a', 'b'); + +query T +SELECT concat_ws(arrow_cast(sep, 'Utf8View'), val1, val2) FROM test_concat_ws_sep ORDER BY val1 +---- +a|b +foo,bar + +query T +SELECT concat_ws(arrow_cast(sep, 'LargeUtf8'), val1, val2) FROM test_concat_ws_sep ORDER BY val1 +---- +a|b +foo,bar + +statement ok +drop table test_concat_ws_sep + query T SELECT initcap('') ----