diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index ee2d300d9bff8..ad7ec55cf57d4 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -32,7 +32,7 @@ use crate::aggregates::group_values::multi_group_by::{ use arrow::array::{Array, ArrayRef, BooleanBufferBuilder}; use arrow::compute::cast; use arrow::datatypes::{ - BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type, Float32Type, + BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type, Field, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, Schema, SchemaRef, StringViewType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, @@ -272,6 +272,7 @@ impl GroupValuesColumn { /// Create a new instance of GroupValuesColumn if supported for the specified schema pub fn try_new(schema: SchemaRef) -> Result { let map = HashTable::with_capacity(0); + let group_values = Self::build_group_columns(&schema)?; Ok(Self { schema, map, @@ -279,12 +280,27 @@ impl GroupValuesColumn { emit_group_index_list_buffer: Vec::new(), vectorized_operation_buffers: VectorizedOperationBuffers::default(), map_size: 0, - group_values: vec![], + group_values, hashes_buffer: Default::default(), random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } + /// Build one fresh [`GroupColumn`] per field in the schema. + /// + /// Used at construction time (`try_new`) and to repopulate the column + /// vector after operations that drain it (`emit(EmitTo::All)`, + /// `clear_shrink`). Centralising it keeps the post-condition that + /// `self.group_values` always contains exactly one builder per schema + /// field outside of those transient drain points. + fn build_group_columns(schema: &Schema) -> Result>> { + let mut v: Vec> = Vec::with_capacity(schema.fields().len()); + for f in schema.fields().iter() { + v.push(make_group_column(f.as_ref())?); + } + Ok(v) + } + // ======================================================================== // Scalarized intern // ======================================================================== @@ -898,172 +914,127 @@ macro_rules! instantiate_primitive { }; } -impl GroupValues for GroupValuesColumn { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - if self.group_values.is_empty() { - let mut v = Vec::with_capacity(cols.len()); - - for f in self.schema.fields().iter() { - let nullable = f.is_nullable(); - let data_type = f.data_type(); - match data_type { - &DataType::Int8 => { - instantiate_primitive!(v, nullable, Int8Type, data_type) - } - &DataType::Int16 => { - instantiate_primitive!(v, nullable, Int16Type, data_type) - } - &DataType::Int32 => { - instantiate_primitive!(v, nullable, Int32Type, data_type) - } - &DataType::Int64 => { - instantiate_primitive!(v, nullable, Int64Type, data_type) - } - &DataType::UInt8 => { - instantiate_primitive!(v, nullable, UInt8Type, data_type) - } - &DataType::UInt16 => { - instantiate_primitive!(v, nullable, UInt16Type, data_type) - } - &DataType::UInt32 => { - instantiate_primitive!(v, nullable, UInt32Type, data_type) - } - &DataType::UInt64 => { - instantiate_primitive!(v, nullable, UInt64Type, data_type) - } - &DataType::Float32 => { - instantiate_primitive!(v, nullable, Float32Type, data_type) - } - &DataType::Float64 => { - instantiate_primitive!(v, nullable, Float64Type, data_type) - } - &DataType::Date32 => { - instantiate_primitive!(v, nullable, Date32Type, data_type) - } - &DataType::Date64 => { - instantiate_primitive!(v, nullable, Date64Type, data_type) - } - &DataType::Time32(t) => match t { - TimeUnit::Second => { - instantiate_primitive!( - v, - nullable, - Time32SecondType, - data_type - ) - } - TimeUnit::Millisecond => { - instantiate_primitive!( - v, - nullable, - Time32MillisecondType, - data_type - ) - } - _ => {} - }, - &DataType::Time64(t) => match t { - TimeUnit::Microsecond => { - instantiate_primitive!( - v, - nullable, - Time64MicrosecondType, - data_type - ) - } - TimeUnit::Nanosecond => { - instantiate_primitive!( - v, - nullable, - Time64NanosecondType, - data_type - ) - } - _ => {} - }, - &DataType::Timestamp(t, _) => match t { - TimeUnit::Second => { - instantiate_primitive!( - v, - nullable, - TimestampSecondType, - data_type - ) - } - TimeUnit::Millisecond => { - instantiate_primitive!( - v, - nullable, - TimestampMillisecondType, - data_type - ) - } - TimeUnit::Microsecond => { - instantiate_primitive!( - v, - nullable, - TimestampMicrosecondType, - data_type - ) - } - TimeUnit::Nanosecond => { - instantiate_primitive!( - v, - nullable, - TimestampNanosecondType, - data_type - ) - } - }, - &DataType::Decimal128(_, _) => { - instantiate_primitive! { - v, - nullable, - Decimal128Type, - data_type - } - } - &DataType::Utf8 => { - let b = ByteGroupValueBuilder::::new(OutputType::Utf8); - v.push(Box::new(b) as _) - } - &DataType::LargeUtf8 => { - let b = ByteGroupValueBuilder::::new(OutputType::Utf8); - v.push(Box::new(b) as _) - } - &DataType::Binary => { - let b = ByteGroupValueBuilder::::new(OutputType::Binary); - v.push(Box::new(b) as _) - } - &DataType::LargeBinary => { - let b = ByteGroupValueBuilder::::new(OutputType::Binary); - v.push(Box::new(b) as _) - } - &DataType::Utf8View => { - let b = ByteViewGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - &DataType::BinaryView => { - let b = ByteViewGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - &DataType::Boolean => { - if nullable { - let b = BooleanGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } else { - let b = BooleanGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - } - dt => { - return not_impl_err!("{dt} not supported in GroupValuesColumn"); - } - } +/// Build a [`GroupColumn`] for a single schema field. +/// +/// Extracted from the inline match that used to live in +/// [`GroupValuesColumn::intern`] so the per-field dispatch lives in one +/// place. This factory is the single source of truth for which Arrow types +/// map to which builder, and it is the function that future nested-type +/// specializations (e.g. `Struct`, `List`, `LargeList`) plug into without +/// having to enumerate every combination inline. +/// +/// Returns `Err(not_impl_err!(...))` for any type not in the supported set; +/// callers (`GroupValues::intern`) propagate that error so the +/// `GroupValuesRows` fallback can take over upstream of this builder. +fn make_group_column(field: &Field) -> Result> { + let nullable = field.is_nullable(); + let data_type = field.data_type(); + let mut v: Vec> = Vec::with_capacity(1); + match *data_type { + DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type, data_type), + DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type, data_type), + DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type, data_type), + DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type, data_type), + DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type, data_type), + DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type, data_type), + DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type, data_type), + DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type, data_type), + DataType::Float32 => { + instantiate_primitive!(v, nullable, Float32Type, data_type) + } + DataType::Float64 => { + instantiate_primitive!(v, nullable, Float64Type, data_type) + } + DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type, data_type), + DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type, data_type), + DataType::Time32(t) => match t { + TimeUnit::Second => { + instantiate_primitive!(v, nullable, Time32SecondType, data_type) } - self.group_values = v; + TimeUnit::Millisecond => { + instantiate_primitive!(v, nullable, Time32MillisecondType, data_type) + } + // Time32 with Microsecond / Nanosecond is not a valid Arrow type + // combination; reject explicitly so supported_type and this + // dispatcher stay in lockstep (see consistency fuzz below). + _ => return not_impl_err!("{data_type} not supported in GroupValuesColumn"), + }, + DataType::Time64(t) => match t { + TimeUnit::Microsecond => { + instantiate_primitive!(v, nullable, Time64MicrosecondType, data_type) + } + TimeUnit::Nanosecond => { + instantiate_primitive!(v, nullable, Time64NanosecondType, data_type) + } + // Time64 with Second / Millisecond is not a valid Arrow type + // combination; reject explicitly. + _ => return not_impl_err!("{data_type} not supported in GroupValuesColumn"), + }, + DataType::Timestamp(t, _) => match t { + TimeUnit::Second => { + instantiate_primitive!(v, nullable, TimestampSecondType, data_type) + } + TimeUnit::Millisecond => { + instantiate_primitive!(v, nullable, TimestampMillisecondType, data_type) + } + TimeUnit::Microsecond => { + instantiate_primitive!(v, nullable, TimestampMicrosecondType, data_type) + } + TimeUnit::Nanosecond => { + instantiate_primitive!(v, nullable, TimestampNanosecondType, data_type) + } + }, + DataType::Decimal128(_, _) => { + instantiate_primitive!(v, nullable, Decimal128Type, data_type) + } + DataType::Utf8 => { + v.push(Box::new(ByteGroupValueBuilder::::new( + OutputType::Utf8, + ))); + } + DataType::LargeUtf8 => { + v.push(Box::new(ByteGroupValueBuilder::::new( + OutputType::Utf8, + ))); + } + DataType::Binary => { + v.push(Box::new(ByteGroupValueBuilder::::new( + OutputType::Binary, + ))); + } + DataType::LargeBinary => { + v.push(Box::new(ByteGroupValueBuilder::::new( + OutputType::Binary, + ))); } + DataType::Utf8View => { + v.push(Box::new(ByteViewGroupValueBuilder::::new())); + } + DataType::BinaryView => { + v.push(Box::new(ByteViewGroupValueBuilder::::new())); + } + DataType::Boolean => { + if nullable { + v.push(Box::new(BooleanGroupValueBuilder::::new())); + } else { + v.push(Box::new(BooleanGroupValueBuilder::::new())); + } + } + _ => return not_impl_err!("{data_type} not supported in GroupValuesColumn"), + } + debug_assert_eq!( + v.len(), + 1, + "make_group_column must push exactly one builder" + ); + Ok(v.into_iter().next().unwrap()) +} +impl GroupValues for GroupValuesColumn { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + // `try_new` and the reset points in `emit` / `clear_shrink` keep + // `self.group_values` populated with one builder per schema field, + // so no lazy initialization is needed here. if !STREAMING { self.vectorized_intern(cols, groups) } else { @@ -1091,8 +1062,14 @@ impl GroupValues for GroupValuesColumn { fn emit(&mut self, emit_to: EmitTo) -> Result> { let mut output = match emit_to { EmitTo::All => { - let group_values = mem::take(&mut self.group_values); - debug_assert!(self.group_values.is_empty()); + // Replace the column builders with a fresh set so the + // aggregator is immediately reusable after the drain. + // Same `self.schema` was already validated by `try_new`, + // so `build_group_columns` would only error here if some + // out-of-band schema mutation occurred — propagate it as + // a real Result rather than panicking. + let fresh = Self::build_group_columns(&self.schema)?; + let group_values = mem::replace(&mut self.group_values, fresh); group_values .into_iter() @@ -1191,7 +1168,12 @@ impl GroupValues for GroupValuesColumn { } fn clear_shrink(&mut self, num_rows: usize) { - self.group_values.clear(); + // Reset to a fresh column-builder vector. The schema was validated + // in `try_new`, so rebuilding cannot fail unless something else + // mutated the schema out-of-band — surface that as a panic since + // `clear_shrink` is infallible by trait signature. + self.group_values = Self::build_group_columns(&self.schema) + .expect("schema previously validated in try_new"); self.map.clear(); self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); @@ -1219,7 +1201,13 @@ pub fn supported_schema(schema: &Schema) -> bool { /// Returns true if the specified data type is supported by [`GroupValuesColumn`] /// /// In order to be supported, there must be a specialized implementation of -/// [`GroupColumn`] for the data type, instantiated in [`GroupValuesColumn::intern`] +/// [`GroupColumn`] for the data type, instantiated in +/// [`make_group_column`]. This function is the allow-list that gates the +/// `GroupValuesRows` fallback in [`crate::aggregates::group_values::new_group_values`]; +/// it must accept exactly the set of types that [`make_group_column`] +/// constructs a builder for. The +/// `supported_type_and_make_group_column_stay_in_sync` test below pins +/// this biconditional. fn supported_type(data_type: &DataType) -> bool { matches!( *data_type, @@ -1240,7 +1228,15 @@ fn supported_type(data_type: &DataType) -> bool { | DataType::LargeBinary | DataType::Date32 | DataType::Date64 - | DataType::Time32(_) + // Only the semantically valid Time variants per the Arrow spec. + // The dispatcher in `make_group_column` returns NotImpl for the + // other unit combinations, so accepting them here would cause a + // schema to be routed into GroupValuesColumn and then fail at + // intern. Keep these two arms in lockstep with the dispatcher. + | DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) | DataType::Timestamp(_, _) | DataType::Utf8View | DataType::BinaryView @@ -1272,7 +1268,126 @@ mod tests { GroupValues, multi_group_by::GroupValuesColumn, }; - use super::GroupIndexView; + use super::{GroupIndexView, make_group_column, supported_schema, supported_type}; + + /// CRITICAL invariant: if `supported_type(t)` returns true the dispatcher + /// must accept that type at intern time, and conversely if + /// `supported_type(t)` returns false the planner must NOT route it through + /// `GroupValuesColumn`. A divergence here would let the planner select + /// `GroupValuesColumn` for a type whose dispatcher arm is missing, + /// producing a runtime `not_impl_err` after the field reaches the + /// builder factory. + /// + /// This test fuzzes a representative cross-section of types and asserts + /// both directions of the biconditional. When a new specialization is + /// added (`Float16`, `FixedSizeList`, `Struct`, ...) it should be added + /// to the supported_cases vector; when a type is intentionally rejected + /// it should be added to unsupported_cases. + #[test] + fn supported_type_and_make_group_column_stay_in_sync() { + let supported_cases: Vec = vec![ + DataType::Int8, + DataType::Int64, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + DataType::Decimal128(38, 10), + DataType::Utf8, + DataType::LargeUtf8, + DataType::Utf8View, + DataType::Binary, + DataType::LargeBinary, + DataType::BinaryView, + DataType::Boolean, + DataType::Date32, + DataType::Date64, + DataType::Time32(arrow::datatypes::TimeUnit::Second), + DataType::Time32(arrow::datatypes::TimeUnit::Millisecond), + DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), + DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond), + DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None), + ]; + + for dt in &supported_cases { + assert!( + supported_type(dt), + "expected supported_type=true for {dt:?}" + ); + let field = Field::new("col", dt.clone(), true); + make_group_column(&field).unwrap_or_else(|e| { + panic!( + "supported_type accepted {dt:?} but make_group_column rejected: {e}" + ) + }); + } + + let unsupported_cases: Vec = vec![ + DataType::Float16, + DataType::Decimal256(76, 10), + // Invalid Time-unit combinations: Time32 is defined only for + // Second / Millisecond and Time64 only for Microsecond / + // Nanosecond. The TimeUnit enum allows constructing the other + // combinations programmatically, but they are not valid Arrow + // types and must be rejected by both supported_type and the + // dispatcher. + DataType::Time64(arrow::datatypes::TimeUnit::Second), + DataType::Time64(arrow::datatypes::TimeUnit::Millisecond), + DataType::Time32(arrow::datatypes::TimeUnit::Microsecond), + DataType::Time32(arrow::datatypes::TimeUnit::Nanosecond), + ]; + + for dt in &unsupported_cases { + assert!( + !supported_type(dt), + "expected supported_type=false for {dt:?}" + ); + let field = Field::new("col", dt.clone(), true); + assert!( + make_group_column(&field).is_err(), + "supported_type rejected {dt:?} but make_group_column accepted it" + ); + } + } + + #[test] + fn supported_schema_rejects_mix_of_supported_and_unsupported() { + // One Float16 column among supported columns flips the whole + // schema to GroupValuesRows fallback. + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float16, true), + ]); + assert!(!supported_schema(&schema)); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Boolean, true), + ]); + assert!(supported_schema(&schema)); + } + + #[test] + fn try_new_returns_not_impl_for_unsupported_top_level_type() { + // `try_new` now eagerly constructs the per-field GroupColumn + // builders via `make_group_column`, so an unsupported schema is + // rejected at construction time rather than at first `intern`. + // `GroupValuesColumn` doesn't implement `Debug`, so explicit match + // instead of `unwrap_err`. + let schema = + Arc::new(Schema::new(vec![Field::new("x", DataType::Float16, true)])); + match GroupValuesColumn::::try_new(schema) { + Ok(_) => panic!("expected NotImpl error, but try_new succeeded"), + Err(e) => { + let msg = e.to_string(); + assert!( + msg.contains("not supported in GroupValuesColumn"), + "expected NotImpl error from dispatcher, got: {msg}" + ); + } + } + } #[test] fn test_intern_for_vectorized_group_values() { @@ -1344,6 +1459,17 @@ mod tests { let schema = Arc::new(Schema::new_with_metadata(vec![field], HashMap::new())); let mut group_values = GroupValuesColumn::::try_new(schema).unwrap(); + // Seed the column with 12 placeholder rows so the upcoming + // `emit(EmitTo::First(4))` calls can `take_n` without panicking. + // The hashmap entries below reference group indices 0..=11, so the + // single column builder needs at least 12 rows to back them. + let seed: ArrayRef = Arc::new(arrow::array::Int32Array::from(vec![0_i32; 12])); + for row in 0..12 { + group_values.group_values[0] + .append_val(&seed, row) + .expect("seed append"); + } + // Insert group index views and check if success to insert insert_inline_group_index_view(&mut group_values, 0, 0); insert_non_inline_group_index_view(&mut group_values, 1, vec![1, 2]);