diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index d970eccc43a54..80b701516c470 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -28,7 +28,6 @@ use datafusion_expr::{ TypeSignature::Exact, TypeSignature::Uniform, Volatility, }; use datafusion_macros::user_doc; -use itertools::izip; use regex::Regex; use std::collections::HashMap; use std::sync::Arc; @@ -196,8 +195,8 @@ fn regexp_count( match (values.data_type(), regex_array.data_type(), flags_array) { (Utf8, Utf8, None) => regexp_count_inner( - &values.as_string::(), - ®ex_array.as_string::(), + values.as_string::(), + regex_array.as_string::(), is_regex_scalar, start_array.map(|start| start.as_primitive::()), is_start_scalar, @@ -205,17 +204,17 @@ fn regexp_count( is_flags_scalar, ), (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner( - &values.as_string::(), - ®ex_array.as_string::(), + values.as_string::(), + regex_array.as_string::(), is_regex_scalar, start_array.map(|start| start.as_primitive::()), is_start_scalar, - Some(&flags_array.as_string::()), + Some(flags_array.as_string::()), is_flags_scalar, ), (LargeUtf8, LargeUtf8, None) => regexp_count_inner( - &values.as_string::(), - ®ex_array.as_string::(), + values.as_string::(), + regex_array.as_string::(), is_regex_scalar, start_array.map(|start| start.as_primitive::()), is_start_scalar, @@ -223,17 +222,17 @@ fn regexp_count( is_flags_scalar, ), (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner( - &values.as_string::(), - ®ex_array.as_string::(), + values.as_string::(), + regex_array.as_string::(), is_regex_scalar, start_array.map(|start| start.as_primitive::()), is_start_scalar, - Some(&flags_array.as_string::()), + Some(flags_array.as_string::()), is_flags_scalar, ), (Utf8View, Utf8View, None) => regexp_count_inner( - &values.as_string_view(), - ®ex_array.as_string_view(), + values.as_string_view(), + regex_array.as_string_view(), is_regex_scalar, start_array.map(|start| start.as_primitive::()), is_start_scalar, @@ -241,12 +240,12 @@ fn regexp_count( is_flags_scalar, ), (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner( - &values.as_string_view(), - ®ex_array.as_string_view(), + values.as_string_view(), + regex_array.as_string_view(), is_regex_scalar, start_array.map(|start| start.as_primitive::()), is_start_scalar, - Some(&flags_array.as_string_view()), + Some(flags_array.as_string_view()), is_flags_scalar, ), _ => Err(ArrowError::ComputeError( @@ -256,298 +255,183 @@ fn regexp_count( } fn regexp_count_inner<'a, S>( - values: &S, - regex_array: &S, + values: S, + regex_array: S, is_regex_scalar: bool, - start_array: Option<&Int64Array>, + start_array: Option<&'a Int64Array>, is_start_scalar: bool, - flags_array: Option<&S>, + flags_array: Option, is_flags_scalar: bool, ) -> Result where - S: StringArrayType<'a>, + S: StringArrayType<'a> + Copy, { - let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { - ( - (!regex_array.is_null(0)).then(|| regex_array.value(0)), - true, - ) + let values_len = values.len(); + let regex = StringValueSource::regex_arg(regex_array, is_regex_scalar); + + // Preserve existing short-circuit behavior: a scalar NULL regex produces zeros + // for every row without validating start or flags. + if regex.is_null_scalar() { + return Ok(Arc::new(Int64Array::from(vec![0; values_len]))); + } + + let start = StartValueSource::new(start_array, is_start_scalar); + let flags = StringValueSource::flags_arg(flags_array, is_flags_scalar); + + regex.validate_len("regex_array", values_len)?; + + let scalar_pattern = if regex.is_array() { + start.validate_len(values_len)?; + flags.validate_len("flags_array", values_len)?; + None + } else if flags.is_array() { + flags.validate_len("flags_array", values_len)?; + start.validate_len(values_len)?; + None } else { - (None, false) + let scalar_pattern = compile_scalar_pattern(®ex, &flags)?; + start.validate_len(values_len)?; + scalar_pattern }; - let (start_array, start_scalar, is_start_scalar) = - if let Some(start_array) = start_array { - if is_start_scalar || start_array.len() == 1 { - (None, Some(start_array.value(0)), true) - } else { - (Some(start_array), None, false) - } - } else { - (None, Some(1), true) - }; + let mut regex_cache = HashMap::new(); + let counts = (0..values_len) + .map(|row| { + let regex = match regex.value(row) { + None => return Ok(0), + Some(regex) => regex, + }; + let start = start.value(row); + let flags = flags.value(row); + let value = string_value_opt(&values, row); - let (flags_array, flags_scalar, is_flags_scalar) = - if let Some(flags_array) = flags_array { - if is_flags_scalar || flags_array.len() == 1 { - (None, Some(flags_array.value(0)), true) + if let Some(pattern) = &scalar_pattern { + count_matches(value, pattern, start) } else { - (Some(flags_array), None, false) + let pattern = compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, pattern, start) } - } else { - (None, None, true) - }; + }) + .collect::>()?; - let mut regex_cache = HashMap::new(); + Ok(Arc::new(counts)) +} - match (is_regex_scalar, is_start_scalar, is_flags_scalar) { - (true, true, true) => { - let regex = match regex_scalar { - None => { - return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); - } - Some(regex) => regex, - }; +fn string_value_opt<'a, S>(array: &S, row: usize) -> Option<&'a str> +where + S: StringArrayType<'a>, +{ + (!array.is_null(row)).then(|| array.value(row)) +} - let pattern = compile_regex(regex, flags_scalar)?; +enum StringValueSource<'a, S> { + Scalar(Option<&'a str>), + Array(S), +} - Ok(Arc::new( - values - .iter() - .map(|value| count_matches(value, &pattern, start_scalar)) - .collect::>()?, - )) +impl<'a, S> StringValueSource<'a, S> +where + S: StringArrayType<'a> + Copy, +{ + fn regex_arg(array: S, is_scalar: bool) -> Self { + if is_scalar || array.len() == 1 { + Self::Scalar(string_value_opt(&array, 0)) + } else { + Self::Array(array) } - (true, true, false) => { - let regex = match regex_scalar { - None => { - return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); - } - Some(regex) => regex, - }; + } - let flags_array = flags_array.unwrap(); - if values.len() != flags_array.len() { - return Err(ArrowError::ComputeError(format!( - "flags_array must be the same length as values array; got {} and {}", - flags_array.len(), - values.len(), - ))); + fn flags_arg(array: Option, is_scalar: bool) -> Self { + match array { + // Preserve prior behavior: scalar flags use value(0), not a null-aware + // lookup, before compile_regex handles the resulting flag value. + Some(array) if is_scalar || array.len() == 1 => { + Self::Scalar(Some(array.value(0))) } - - Ok(Arc::new( - values - .iter() - .zip(flags_array.iter()) - .map(|(value, flags)| { - let pattern = - compile_and_cache_regex(regex, flags, &mut regex_cache)?; - count_matches(value, pattern, start_scalar) - }) - .collect::>()?, - )) + Some(array) => Self::Array(array), + None => Self::Scalar(None), } - (true, false, true) => { - let regex = match regex_scalar { - None => { - return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); - } - Some(regex) => regex, - }; + } - let pattern = compile_regex(regex, flags_scalar)?; + fn is_null_scalar(&self) -> bool { + matches!(self, Self::Scalar(None)) + } - let start_array = start_array.unwrap(); + fn is_array(&self) -> bool { + matches!(self, Self::Array(_)) + } - Ok(Arc::new( - values - .iter() - .zip(start_array.iter()) - .map(|(value, start)| count_matches(value, &pattern, start)) - .collect::>()?, - )) + fn value(&self, row: usize) -> Option<&'a str> { + match self { + Self::Scalar(value) => *value, + Self::Array(array) => string_value_opt(array, row), } - (true, false, false) => { - let regex = match regex_scalar { - None => { - return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); - } - Some(regex) => regex, - }; - - let flags_array = flags_array.unwrap(); - if values.len() != flags_array.len() { - return Err(ArrowError::ComputeError(format!( - "flags_array must be the same length as values array; got {} and {}", - flags_array.len(), - values.len(), - ))); - } + } - Ok(Arc::new( - izip!( - values.iter(), - start_array.unwrap().iter(), - flags_array.iter() - ) - .map(|(value, start, flags)| { - let pattern = - compile_and_cache_regex(regex, flags, &mut regex_cache)?; - - count_matches(value, pattern, start) - }) - .collect::>()?, - )) + fn validate_len(&self, name: &str, values_len: usize) -> Result<(), ArrowError> { + if let Self::Array(array) = self { + validate_array_len(name, array.len(), values_len)?; } - (false, true, true) => { - if values.len() != regex_array.len() { - return Err(ArrowError::ComputeError(format!( - "regex_array must be the same length as values array; got {} and {}", - regex_array.len(), - values.len(), - ))); - } + Ok(()) + } +} - Ok(Arc::new( - values - .iter() - .zip(regex_array.iter()) - .map(|(value, regex)| { - let regex = match regex { - None => return Ok(0), - Some(regex) => regex, - }; - - let pattern = compile_and_cache_regex( - regex, - flags_scalar, - &mut regex_cache, - )?; - count_matches(value, pattern, start_scalar) - }) - .collect::>()?, - )) +fn compile_scalar_pattern<'a, S>( + regex: &StringValueSource<'a, S>, + flags: &StringValueSource<'a, S>, +) -> Result, ArrowError> { + match (regex, flags) { + (StringValueSource::Scalar(Some(regex)), StringValueSource::Scalar(flags)) => { + compile_regex(regex, *flags).map(Some) } - (false, true, false) => { - if values.len() != regex_array.len() { - return Err(ArrowError::ComputeError(format!( - "regex_array must be the same length as values array; got {} and {}", - regex_array.len(), - values.len(), - ))); - } + _ => Ok(None), + } +} - let flags_array = flags_array.unwrap(); - if values.len() != flags_array.len() { - return Err(ArrowError::ComputeError(format!( - "flags_array must be the same length as values array; got {} and {}", - flags_array.len(), - values.len(), - ))); - } +enum StartValueSource<'a> { + Scalar(i64), + Array(&'a Int64Array), +} - Ok(Arc::new( - izip!(values.iter(), regex_array.iter(), flags_array.iter()) - .map(|(value, regex, flags)| { - let regex = match regex { - None => return Ok(0), - Some(regex) => regex, - }; - - let pattern = - compile_and_cache_regex(regex, flags, &mut regex_cache)?; - - count_matches(value, pattern, start_scalar) - }) - .collect::>()?, - )) +impl<'a> StartValueSource<'a> { + fn new(array: Option<&'a Int64Array>, is_scalar: bool) -> Self { + match array { + // Preserve prior behavior: scalar start uses value(0), not a null-aware + // lookup, before count_matches validates the resulting start value. + Some(array) if is_scalar || array.len() == 1 => Self::Scalar(array.value(0)), + Some(array) => Self::Array(array), + None => Self::Scalar(1), } - (false, false, true) => { - if values.len() != regex_array.len() { - return Err(ArrowError::ComputeError(format!( - "regex_array must be the same length as values array; got {} and {}", - regex_array.len(), - values.len(), - ))); - } - - let start_array = start_array.unwrap(); - if values.len() != start_array.len() { - return Err(ArrowError::ComputeError(format!( - "start_array must be the same length as values array; got {} and {}", - start_array.len(), - values.len(), - ))); - } + } - Ok(Arc::new( - izip!(values.iter(), regex_array.iter(), start_array.iter()) - .map(|(value, regex, start)| { - let regex = match regex { - None => return Ok(0), - Some(regex) => regex, - }; - - let pattern = compile_and_cache_regex( - regex, - flags_scalar, - &mut regex_cache, - )?; - count_matches(value, pattern, start) - }) - .collect::>()?, - )) + fn value(&self, row: usize) -> Option { + match self { + Self::Scalar(value) => Some(*value), + Self::Array(array) => (!array.is_null(row)).then(|| array.value(row)), } - (false, false, false) => { - if values.len() != regex_array.len() { - return Err(ArrowError::ComputeError(format!( - "regex_array must be the same length as values array; got {} and {}", - regex_array.len(), - values.len(), - ))); - } - - let start_array = start_array.unwrap(); - if values.len() != start_array.len() { - return Err(ArrowError::ComputeError(format!( - "start_array must be the same length as values array; got {} and {}", - start_array.len(), - values.len(), - ))); - } - - let flags_array = flags_array.unwrap(); - if values.len() != flags_array.len() { - return Err(ArrowError::ComputeError(format!( - "flags_array must be the same length as values array; got {} and {}", - flags_array.len(), - values.len(), - ))); - } + } - Ok(Arc::new( - izip!( - values.iter(), - regex_array.iter(), - start_array.iter(), - flags_array.iter() - ) - .map(|(value, regex, start, flags)| { - let regex = match regex { - None => return Ok(0), - Some(regex) => regex, - }; - - let pattern = - compile_and_cache_regex(regex, flags, &mut regex_cache)?; - count_matches(value, pattern, start) - }) - .collect::>()?, - )) + fn validate_len(&self, values_len: usize) -> Result<(), ArrowError> { + if let Self::Array(array) = self { + validate_array_len("start_array", array.len(), values_len)?; } + Ok(()) } } +fn validate_array_len( + array_name: &str, + array_len: usize, + values_len: usize, +) -> Result<(), ArrowError> { + if values_len != array_len { + return Err(ArrowError::ComputeError(format!( + "{array_name} must be the same length as values array; got {array_len} and {values_len}", + ))); + } + Ok(()) +} + fn count_matches( value: Option<&str>, pattern: &Regex, @@ -625,6 +509,8 @@ mod tests { test_case_sensitive_regexp_count_array_complex::(); test_case_regexp_count_cache_check::>(); + test_regexp_count_error_order_invalid_scalar_regex_before_start_len(); + test_regexp_count_error_order_flags_len_before_start_len(); } fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result { @@ -851,6 +737,37 @@ mod tests { }); } + fn assert_regexp_count_error_contains(args: &[ArrayRef], expected: &str) { + let err = regexp_count_func(args).unwrap_err().to_string(); + assert!( + err.contains(expected), + "expected error to contain {expected:?}, got {err:?}" + ); + } + + fn test_regexp_count_error_order_invalid_scalar_regex_before_start_len() { + let values = Arc::new(GenericStringArray::::from(vec!["a", "b"])); + let regex = Arc::new(GenericStringArray::::from(vec!["["])); + let start = Arc::new(Int64Array::from(vec![1, 1, 1])); + + assert_regexp_count_error_contains( + &[values, regex, start], + "Regular expression did not compile", + ); + } + + fn test_regexp_count_error_order_flags_len_before_start_len() { + let values = Arc::new(GenericStringArray::::from(vec!["a", "b"])); + let regex = Arc::new(GenericStringArray::::from(vec!["a"])); + let start = Arc::new(Int64Array::from(vec![1, 1, 1])); + let flags = Arc::new(GenericStringArray::::from(vec!["", "", ""])); + + assert_regexp_count_error_contains( + &[values, regex, start, flags], + "flags_array must be the same length as values array; got 3 and 2", + ); + } + fn test_case_sensitive_regexp_count_array() where A: From> + Array + 'static,