Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,10 @@ impl ScalarUDFImpl for ConcatFunc {
}
});

let array_len = args
.iter()
.filter_map(|x| match x {
ColumnarValue::Array(array) => Some(array.len()),
_ => None,
})
.next();
let array_len = args.iter().find_map(|x| match x {
ColumnarValue::Array(array) => Some(array.len()),
_ => None,
});

// Scalar
if array_len.is_none() {
Expand Down
174 changes: 133 additions & 41 deletions datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,21 @@ impl ScalarUDFImpl for ConcatWsFunc {
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;

// do not accept 0 arguments.
if args.len() < 2 {
return exec_err!(
"concat_ws was called with {} arguments. It requires at least 2.",
args.len()
);
}

let array_len = args
.iter()
.filter_map(|x| match x {
ColumnarValue::Array(array) => Some(array.len()),
_ => None,
})
.next();
let array_len = args.iter().find_map(|x| match x {
ColumnarValue::Array(array) => Some(array.len()),
_ => None,
});

// Scalar
if array_len.is_none() {
let ColumnarValue::Scalar(scalar) = &args[0] else {
// loop above checks for all args being scalar
unreachable!()
};
let sep = match scalar.try_as_str() {
Expand All @@ -139,7 +134,6 @@ impl ScalarUDFImpl for ConcatWsFunc {
let mut values = Vec::with_capacity(args.len() - 1);
for arg in &args[1..] {
let ColumnarValue::Scalar(scalar) = arg else {
// loop above checks for all args being scalar
unreachable!()
};

Expand All @@ -162,23 +156,55 @@ impl ScalarUDFImpl for ConcatWsFunc {

// parse sep
let sep = match &args[0] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
data_size += s.len() * len * (args.len() - 2); // estimate
ColumnarValueRef::Scalar(s.as_bytes())
}
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len))));
}
ColumnarValue::Array(array) => {
let string_array = as_string_array(array)?;
data_size += string_array.values().len() * (args.len() - 2); // estimate
if array.is_nullable() {
ColumnarValueRef::NullableArray(string_array)
} else {
ColumnarValueRef::NonNullableArray(string_array)
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(Some(s)) => {
data_size += s.len() * len * (args.len() - 2); // estimate
ColumnarValueRef::Scalar(s.as_bytes())
}
}
_ => unreachable!("concat ws"),
Some(None) => {
return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(
len,
))));
}
None => {
return internal_err!("Expected string separator, got {scalar:?}");
}
},
ColumnarValue::Array(array) => match array.data_type() {
DataType::Utf8 => {
let string_array = as_string_array(array)?;
data_size += string_array.values().len() * (args.len() - 2);
if array.is_nullable() {
ColumnarValueRef::NullableArray(string_array)
} else {
ColumnarValueRef::NonNullableArray(string_array)
}
}
DataType::LargeUtf8 => {
let string_array = as_largestring_array(array);
data_size += string_array.values().len() * (args.len() - 2);
if array.is_nullable() {
ColumnarValueRef::NullableLargeStringArray(string_array)
} else {
ColumnarValueRef::NonNullableLargeStringArray(string_array)
}
}
DataType::Utf8View => {
let string_array = as_string_view_array(array)?;
data_size +=
string_array.total_buffer_bytes_used() * (args.len() - 2);
if array.is_nullable() {
ColumnarValueRef::NullableStringViewArray(string_array)
} else {
ColumnarValueRef::NonNullableStringViewArray(string_array)
}
}
other => {
return plan_err!(
"Input was {other} which is not a supported datatype for concat_ws separator"
);
}
},
};

let mut columns = Vec::with_capacity(args.len() - 1);
Expand Down Expand Up @@ -221,11 +247,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
DataType::Utf8View => {
let string_array = as_string_view_array(array)?;

data_size += string_array
.data_buffers()
.iter()
.map(|buf| buf.len())
.sum::<usize>();
data_size += string_array.total_buffer_bytes_used();
let column = if array.is_nullable() {
ColumnarValueRef::NullableStringViewArray(string_array)
} else {
Expand All @@ -251,18 +273,14 @@ impl ScalarUDFImpl for ConcatWsFunc {
continue;
}

let mut iter = columns.iter();
for column in iter.by_ref() {
let mut first = true;
for column in &columns {
if column.is_valid(i) {
if !first {
builder.write::<false>(&sep, i);
}
builder.write::<false>(column, i);
break;
}
}

for column in iter {
if column.is_valid(i) {
builder.write::<false>(&sep, i);
builder.write::<false>(column, i);
first = false;
}
}

Expand Down Expand Up @@ -546,4 +564,78 @@ mod tests {

Ok(())
}

#[test]
fn concat_ws_utf8view_scalar_separator() -> Result<()> {
let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
let c1 =
ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
Some("x"),
None,
Some("z"),
])));

let arg_fields = vec![
Field::new("a", Utf8, true).into(),
Field::new("a", Utf8, true).into(),
Field::new("a", Utf8, true).into(),
];
let args = ScalarFunctionArgs {
args: vec![c0, c1, c2],
arg_fields,
number_rows: 3,
return_field: Field::new("f", Utf8, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};

let result = ConcatWsFunc::new().invoke_with_args(args)?;
let expected =
Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
match &result {
ColumnarValue::Array(array) => {
assert_eq!(&expected, array);
}
_ => panic!("Expected array result"),
}

Ok(())
}

#[test]
fn concat_ws_largeutf8_scalar_separator() -> Result<()> {
let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string())));
let c1 =
ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
Some("x"),
None,
Some("z"),
])));

let arg_fields = vec![
Field::new("a", Utf8, true).into(),
Field::new("a", Utf8, true).into(),
Field::new("a", Utf8, true).into(),
];
let args = ScalarFunctionArgs {
args: vec![c0, c1, c2],
arg_fields,
number_rows: 3,
return_field: Field::new("f", Utf8, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};

let result = ConcatWsFunc::new().invoke_with_args(args)?;
let expected =
Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
match &result {
ColumnarValue::Array(array) => {
assert_eq!(&expected, array);
}
_ => panic!("Expected array result"),
}

Ok(())
}
}
19 changes: 19 additions & 0 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,25 @@ abc
statement ok
drop table foo

# concat_ws with a Utf8View column as separator
statement ok
create table test_concat_ws_sep (sep varchar, val1 varchar, val2 varchar) as values (',', 'foo', 'bar'), ('|', 'a', 'b');

query T
SELECT concat_ws(arrow_cast(sep, 'Utf8View'), val1, val2) FROM test_concat_ws_sep ORDER BY val1
----
a|b
foo,bar

query T
SELECT concat_ws(arrow_cast(sep, 'LargeUtf8'), val1, val2) FROM test_concat_ws_sep ORDER BY val1
----
a|b
foo,bar

statement ok
drop table test_concat_ws_sep

query T
SELECT initcap('')
----
Expand Down