From 6bafabe43666c5a25ef8d2eb81ea4cb627cea1bd Mon Sep 17 00:00:00 2001 From: blaginin Date: Sun, 29 Mar 2026 22:21:10 +0100 Subject: [PATCH 01/10] `Count` and `Mean` aggregates Signed-off-by: blaginin Co-authored-by: Claude --- .../src/aggregate_fn/accumulator_grouped.rs | 4 +- .../src/aggregate_fn/fns/count/mod.rs | 267 ++++++++++ vortex-array/src/aggregate_fn/fns/mean/mod.rs | 472 ++++++++++++++++++ vortex-array/src/aggregate_fn/fns/mod.rs | 2 + 4 files changed, 743 insertions(+), 2 deletions(-) create mode 100644 vortex-array/src/aggregate_fn/fns/count/mod.rs create mode 100644 vortex-array/src/aggregate_fn/fns/mean/mod.rs diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index a4d9c38b60e..15b0fbc9326 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -237,7 +237,7 @@ impl GroupedAccumulator { if validity.value(offset) { let group = elements.slice(offset..offset + size)?; accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.finish()?)?; + states.append_scalar(&accumulator.flush()?)?; } else { states.append_null() } @@ -309,7 +309,7 @@ impl GroupedAccumulator { if validity.value(i) { let group = elements.slice(offset..offset + size)?; accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.finish()?)?; + states.append_scalar(&accumulator.flush()?)?; } else { states.append_null() } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs new file mode 100644 index 00000000000..3c483372186 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexExpect; +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar::Scalar; + +/// Return the count of non-null elements in an array. +/// +/// See [`Count`] for details. +pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + let result = acc.finish()?; + + Ok(result + .as_primitive() + .typed_value::() + .vortex_expect("count result should not be null")) +} + +/// Count the number of non-null elements in an array. +/// +/// Applies to all types. Returns a `u64` count. +/// The identity value is zero. +#[derive(Clone, Debug)] +pub struct Count; + +impl AggregateFnVTable for Count { + type Options = EmptyOptions; + type Partial = u64; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.count") + } + + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + _metadata: &[u8], + _session: &vortex_session::VortexSession, + ) -> VortexResult { + Ok(EmptyOptions) + } + + fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option { + Some(DType::Primitive(PType::U64, Nullability::NonNullable)) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + self.return_dtype(options, input_dtype) + } + + fn empty_partial( + &self, + _options: &Self::Options, + _input_dtype: &DType, + ) -> VortexResult { + Ok(0u64) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + let val = other + .as_primitive() + .typed_value::() + .vortex_expect("count partial should not be null"); + *partial += val; + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + Ok(Scalar::primitive(*partial, Nullability::NonNullable)) + } + + fn reset(&self, partial: &mut Self::Partial) { + *partial = 0; + } + + #[inline] + fn is_saturated(&self, _partial: &Self::Partial) -> bool { + false + } + + fn accumulate( + &self, + partial: &mut Self::Partial, + batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + match batch { + Columnar::Constant(c) => { + if !c.scalar().is_null() { + *partial += c.len() as u64; + } + } + Columnar::Canonical(c) => { + let valid = c.as_ref().valid_count()?; + *partial += valid as u64; + } + } + Ok(()) + } + + fn finalize(&self, partials: ArrayRef) -> VortexResult { + Ok(partials) + } + + fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { + self.to_scalar(partial) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::count::Count; + use crate::aggregate_fn::fns::count::count; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::PrimitiveArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::scalar::Scalar; + use crate::validity::Validity; + + #[test] + fn count_all_valid() -> VortexResult<()> { + let array = + PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 5); + Ok(()) + } + + #[test] + fn count_with_nulls() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)]) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 3); + Ok(()) + } + + #[test] + fn count_all_null() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 0); + Ok(()) + } + + #[test] + fn count_empty() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn count_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(3)); + Ok(()) + } + + #[test] + fn count_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + let result1 = acc.finish()?; + assert_eq!(result1.as_primitive().typed_value::(), Some(1)); + + let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + let result2 = acc.finish()?; + assert_eq!(result2.as_primitive().typed_value::(), Some(2)); + Ok(()) + } + + #[test] + fn count_state_merge() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut state = Count.empty_partial(&EmptyOptions, &dtype)?; + + let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable); + Count.combine_partials(&mut state, scalar1)?; + + let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable); + Count.combine_partials(&mut state, scalar2)?; + + let result = Count.to_scalar(&state)?; + Count.reset(&mut state); + assert_eq!(result.as_primitive().typed_value::(), Some(8)); + Ok(()) + } + + #[test] + fn count_constant_non_null() -> VortexResult<()> { + let array = ConstantArray::new(42i32, 10); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array.into_array(), &mut ctx)?, 10); + Ok(()) + } + + #[test] + fn count_constant_null() -> VortexResult<()> { + let array = ConstantArray::new( + Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), + 10, + ); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array.into_array(), &mut ctx)?, 0); + Ok(()) + } + + #[test] + fn count_chunked() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]); + let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs new file mode 100644 index 00000000000..dfa981b59af --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -0,0 +1,472 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::ToPrimitive; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::Canonical; +use crate::Columnar; +use crate::DynArray; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::arrays::PrimitiveArray; +use crate::canonical::ToCanonical; +use crate::dtype::DType; +use crate::dtype::FieldName; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::dtype::StructFields; +use crate::match_each_native_ptype; +use crate::scalar::Scalar; +use crate::validity::Validity; + +/// Compute the arithmetic mean of an array. +/// +/// See [`Mean`] for details. +pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let mut acc = Accumulator::try_new(Mean, EmptyOptions, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + acc.finish() +} + +/// Compute the arithmetic mean of an array, returning `f64`. +/// +/// Applies to boolean and primitive numeric types. Returns a nullable `f64`. +/// Internally tracks sum and count, returning `sum / count` on finalize. +/// If there are no valid elements, returns null. +/// +/// The partial state is a struct `{sum: f64, count: u64}` so that partials from +/// different accumulators can be correctly combined via weighted addition. +#[derive(Clone, Debug)] +pub struct Mean; + +/// Internal accumulation state for [`Mean`]. +pub struct MeanPartial { + sum: f64, + count: u64, +} + +fn partial_struct_dtype() -> DType { + DType::Struct( + StructFields::new( + [FieldName::from("sum"), FieldName::from("count")].into(), + vec![ + DType::Primitive(PType::F64, Nullability::NonNullable), + DType::Primitive(PType::U64, Nullability::NonNullable), + ], + ), + Nullability::Nullable, + ) +} + +impl AggregateFnVTable for Mean { + type Options = EmptyOptions; + type Partial = MeanPartial; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.mean") + } + + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + _metadata: &[u8], + _session: &vortex_session::VortexSession, + ) -> VortexResult { + Ok(EmptyOptions) + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) | DType::Primitive(..) => { + Some(DType::Primitive(PType::F64, Nullability::Nullable)) + } + _ => None, + } + } + + fn partial_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) | DType::Primitive(..) => Some(partial_struct_dtype()), + _ => None, + } + } + + fn empty_partial( + &self, + _options: &Self::Options, + _input_dtype: &DType, + ) -> VortexResult { + Ok(MeanPartial { sum: 0.0, count: 0 }) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + if other.is_null() { + return Ok(()); + } + let s = other.as_struct(); + let sum_scalar = s + .field("sum") + .vortex_expect("mean partial must have sum field"); + let count_scalar = s + .field("count") + .vortex_expect("mean partial must have count field"); + + partial.sum += sum_scalar + .as_primitive() + .typed_value::() + .vortex_expect("sum field should not be null"); + partial.count += count_scalar + .as_primitive() + .typed_value::() + .vortex_expect("count field should not be null"); + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + if partial.count == 0 { + Ok(Scalar::null(partial_struct_dtype())) + } else { + Ok(Scalar::struct_( + partial_struct_dtype(), + vec![ + Scalar::primitive(partial.sum, Nullability::NonNullable), + Scalar::primitive(partial.count, Nullability::NonNullable), + ], + )) + } + } + + fn reset(&self, partial: &mut Self::Partial) { + partial.sum = 0.0; + partial.count = 0; + } + + #[inline] + fn is_saturated(&self, _partial: &Self::Partial) -> bool { + false + } + + fn accumulate( + &self, + partial: &mut Self::Partial, + batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + match batch { + Columnar::Constant(c) => { + if !c.scalar().is_null() { + let val = scalar_to_f64(c.scalar())?; + partial.sum += val * c.len() as f64; + partial.count += c.len() as u64; + } + } + Columnar::Canonical(canonical) => match canonical { + Canonical::Primitive(prim) => { + let mask = prim.validity_mask()?; + match_each_native_ptype!(prim.ptype(), |T| { + accumulate_values(partial, prim.as_slice::(), &mask); + }); + } + Canonical::Bool(bool_arr) => { + let mask = bool_arr.validity_mask()?; + let bits = bool_arr.to_bit_buffer(); + match &mask { + Mask::AllTrue(_) => { + partial.sum += bits.true_count() as f64; + partial.count += bool_arr.len() as u64; + } + Mask::AllFalse(_) => {} + Mask::Values(validity) => { + let valid_count = validity.true_count(); + let valid_and_true = (&bits & validity.bit_buffer()).true_count(); + partial.sum += valid_and_true as f64; + partial.count += valid_count as u64; + } + } + } + _ => vortex_bail!("Unsupported canonical type for mean: {}", batch.dtype()), + }, + } + Ok(()) + } + + fn finalize(&self, partials: ArrayRef) -> VortexResult { + let struct_arr = partials.to_struct(); + let sums = struct_arr.unmasked_field_by_name("sum")?; + let counts = struct_arr.unmasked_field_by_name("count")?; + let validity_mask = struct_arr.validity_mask()?; + + let sum_prim = sums.to_primitive(); + let count_prim = counts.to_primitive(); + let sum_values = sum_prim.as_slice::(); + let count_values = count_prim.as_slice::(); + + let means: vortex_buffer::Buffer = sum_values + .iter() + .zip(count_values.iter()) + .map(|(s, c)| if *c == 0 { 0.0 } else { s / *c as f64 }) + .collect(); + + let validity = match validity_mask { + Mask::AllTrue(_) => { + let valid_bits: Vec = count_values.iter().map(|c| *c > 0).collect(); + if valid_bits.iter().all(|v| *v) { + Validity::AllValid + } else { + Validity::from_iter(valid_bits) + } + } + Mask::AllFalse(_) => Validity::AllInvalid, + Mask::Values(v) => { + let valid_bits: Vec = count_values + .iter() + .zip(v.bit_buffer().iter()) + .map(|(c, group_valid)| group_valid && *c > 0) + .collect(); + Validity::from_iter(valid_bits) + } + }; + + Ok(PrimitiveArray::new(means, validity).into_array()) + } + + fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { + if partial.count == 0 { + Ok(Scalar::null(DType::Primitive( + PType::F64, + Nullability::Nullable, + ))) + } else { + Ok(Scalar::primitive( + partial.sum / partial.count as f64, + Nullability::Nullable, + )) + } + } +} + +fn scalar_to_f64(scalar: &Scalar) -> VortexResult { + match scalar.dtype() { + DType::Bool(_) => { + let v = scalar.as_bool().value().vortex_expect("checked non-null"); + Ok(if v { 1.0 } else { 0.0 }) + } + DType::Primitive(..) => f64::try_from(scalar), + _ => vortex_bail!("Cannot convert {} to f64 for mean", scalar.dtype()), + } +} + +fn accumulate_values(partial: &mut MeanPartial, values: &[T], mask: &Mask) { + match mask { + Mask::AllTrue(_) => { + partial.count += values.len() as u64; + for v in values { + partial.sum += v.to_f64().unwrap_or(0.0); + } + } + Mask::AllFalse(_) => {} + Mask::Values(v) => { + for (val, valid) in values.iter().zip(v.bit_buffer().iter()) { + if valid { + partial.count += 1; + partial.sum += val.to_f64().unwrap_or(0.0); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::mean::Mean; + use crate::aggregate_fn::fns::mean::mean; + use crate::aggregate_fn::fns::mean::partial_struct_dtype; + use crate::arrays::BoolArray; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::PrimitiveArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::scalar::Scalar; + use crate::validity::Validity; + + #[test] + fn mean_all_valid() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_with_nulls() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_all_null() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn mean_empty() -> VortexResult<()> { + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn mean_integers() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(20.0)); + Ok(()) + } + + #[test] + fn mean_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; + + let batch1 = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::new(buffer![2.0f64, 4.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + let result1 = acc.finish()?; + assert_eq!(result1.as_primitive().as_::(), Some(3.0)); + + let batch2 = + PrimitiveArray::new(buffer![10.0f64, 20.0, 30.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + let result2 = acc.finish()?; + assert_eq!(result2.as_primitive().as_::(), Some(20.0)); + Ok(()) + } + + #[test] + fn mean_state_merge() -> VortexResult<()> { + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut state = Mean.empty_partial(&EmptyOptions, &dtype)?; + + // Partition 1: mean of [2, 4] → sum=6, count=2 + let partial1 = Scalar::struct_( + partial_struct_dtype(), + vec![ + Scalar::primitive(6.0f64, Nullability::NonNullable), + Scalar::primitive(2u64, Nullability::NonNullable), + ], + ); + Mean.combine_partials(&mut state, partial1)?; + + // Partition 2: mean of [10, 20, 30] → sum=60, count=3 + let partial2 = Scalar::struct_( + partial_struct_dtype(), + vec![ + Scalar::primitive(60.0f64, Nullability::NonNullable), + Scalar::primitive(3u64, Nullability::NonNullable), + ], + ); + Mean.combine_partials(&mut state, partial2)?; + + // Combined: (6 + 60) / (2 + 3) = 13.2 + let result = Mean.finalize_scalar(&state)?; + assert_eq!(result.as_primitive().as_::(), Some(13.2)); + Ok(()) + } + + #[test] + fn mean_constant_non_null() -> VortexResult<()> { + let array = ConstantArray::new(5.0f64, 4); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(5.0)); + Ok(()) + } + + #[test] + fn mean_constant_null() -> VortexResult<()> { + let array = ConstantArray::new( + Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable)), + 10, + ); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn mean_bool() -> VortexResult<()> { + let array: BoolArray = [true, false, true, true].into_iter().collect(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(0.75)); + Ok(()) + } + + #[test] + fn mean_chunked() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]); + let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&chunked.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index 4c233ba4d27..4e6df22299a 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -1,8 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +pub mod count; pub mod is_constant; pub mod is_sorted; +pub mod mean; pub mod min_max; pub mod nan_count; pub mod sum; From ee67d8025f2b62b8b7314c56bfab222459bd44fb Mon Sep 17 00:00:00 2001 From: blaginin Date: Sun, 29 Mar 2026 22:24:22 +0100 Subject: [PATCH 02/10] validity cleanup Co-authored-by: Claude Signed-off-by: blaginin --- vortex-array/src/aggregate_fn/fns/mean/mod.rs | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs index dfa981b59af..024fb4ea16b 100644 --- a/vortex-array/src/aggregate_fn/fns/mean/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -220,25 +220,14 @@ impl AggregateFnVTable for Mean { .map(|(s, c)| if *c == 0 { 0.0 } else { s / *c as f64 }) .collect(); - let validity = match validity_mask { - Mask::AllTrue(_) => { - let valid_bits: Vec = count_values.iter().map(|c| *c > 0).collect(); - if valid_bits.iter().all(|v| *v) { - Validity::AllValid - } else { - Validity::from_iter(valid_bits) - } - } - Mask::AllFalse(_) => Validity::AllInvalid, - Mask::Values(v) => { - let valid_bits: Vec = count_values - .iter() - .zip(v.bit_buffer().iter()) - .map(|(c, group_valid)| group_valid && *c > 0) - .collect(); - Validity::from_iter(valid_bits) - } - }; + // A mean is valid when the group itself was valid AND had at least one + // non-null element (count > 0). + let validity = Validity::from_iter( + count_values + .iter() + .enumerate() + .map(|(i, c)| validity_mask.value(i) && *c > 0), + ); Ok(PrimitiveArray::new(means, validity).into_array()) } From 4e16945e3e07313917a96058848f9ee7e6001467 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 2 Apr 2026 16:13:10 +0100 Subject: [PATCH 03/10] generic kernels Signed-off-by: blaginin --- vortex-array/src/aggregate_fn/accumulator.rs | 25 +++++++++++++++++- .../src/aggregate_fn/fns/count/mod.rs | 26 +++++++++++++++++++ vortex-array/src/aggregate_fn/session.rs | 19 ++++++++++++++ 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index 95d5268b969..9aa15149ff7 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -101,7 +101,9 @@ impl DynAccumulator for Accumulator { ); let session = ctx.session().clone(); - let kernels = &session.aggregate_fns().kernels; + let agg_fns = session.aggregate_fns(); + let kernels = &agg_fns.kernels; + let generic_kernels = &agg_fns.generic_kernels; let mut batch = batch.clone(); for _ in 0..*MAX_ITERATIONS { @@ -131,6 +133,27 @@ impl DynAccumulator for Accumulator { return Ok(()); } + // Try encoding-agnostic kernels before decompressing. + let generic_r = generic_kernels.read(); + if let Some(result) = generic_r + .get(&self.aggregate_fn.id()) + .and_then(|kernel| { + kernel + .aggregate(&self.aggregate_fn, &batch, ctx) + .transpose() + }) + .transpose()? + { + vortex_ensure!( + result.dtype() == &self.partial_dtype, + "Aggregate kernel returned {}, expected {}", + result.dtype(), + self.partial_dtype, + ); + self.vtable.combine_partials(&mut self.partial, result)?; + return Ok(()); + } + // Execute one step and try again batch = batch.execute(ctx)?; } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index 3c483372186..cf398ce6274 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -6,12 +6,15 @@ use vortex_error::VortexResult; use crate::ArrayRef; use crate::Columnar; +use crate::DynArray; use crate::ExecutionCtx; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::kernels::DynAggregateKernel; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -31,6 +34,29 @@ pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { .vortex_expect("count result should not be null")) } +/// Encoding-agnostic count kernel. +/// +/// Count-non-null only depends on validity, not data values. Since every encoding +/// exposes validity independently, this avoids decompressing the data. +#[derive(Debug)] +pub(crate) struct CountKernel; + +impl DynAggregateKernel for CountKernel { + fn aggregate( + &self, + aggregate_fn: &AggregateFnRef, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + if !aggregate_fn.is::() { + return Ok(None); + } + + let count = batch.valid_count()? as u64; + Ok(Some(Scalar::primitive(count, Nullability::NonNullable))) + } +} + /// Count the number of non-null elements in an array. /// /// Applies to all types. Returns a `u64` count. diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 6e85ee97b6b..02fdc66243b 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -12,6 +12,8 @@ use vortex_utils::aliases::hash_map::HashMap; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnPluginRef; use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::fns::count::Count; +use crate::aggregate_fn::fns::count::CountKernel; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; use crate::aggregate_fn::fns::min_max::MinMax; @@ -36,6 +38,9 @@ pub struct AggregateFnSession { registry: AggregateFnRegistry, pub(super) kernels: RwLock>, + /// Encoding-agnostic kernels keyed by aggregate function ID. + /// These are checked as a fallback after encoding-specific kernels. + pub(super) generic_kernels: RwLock>, pub(super) grouped_kernels: RwLock>, } @@ -46,10 +51,12 @@ impl Default for AggregateFnSession { let this = Self { registry: AggregateFnRegistry::default(), kernels: RwLock::new(HashMap::default()), + generic_kernels: RwLock::new(HashMap::default()), grouped_kernels: RwLock::new(HashMap::default()), }; // Register the built-in aggregate functions + this.register(Count); this.register(IsConstant); this.register(IsSorted); this.register(MinMax); @@ -62,6 +69,9 @@ impl Default for AggregateFnSession { this.register_aggregate_kernel(Dict::ID, Some(IsConstant.id()), &DictIsConstantKernel); this.register_aggregate_kernel(Dict::ID, Some(IsSorted.id()), &DictIsSortedKernel); + // Register encoding-agnostic kernels. + this.register_generic_kernel(Count.id(), &CountKernel); + this } } @@ -88,6 +98,15 @@ impl AggregateFnSession { ) { self.kernels.write().insert((array_id, agg_fn_id), kernel); } + + /// Register an encoding-agnostic aggregate kernel for a specific aggregate function. + pub fn register_generic_kernel( + &self, + agg_fn_id: AggregateFnId, + kernel: &'static dyn DynAggregateKernel, + ) { + self.generic_kernels.write().insert(agg_fn_id, kernel); + } } /// Extension trait for accessing aggregate function session data. From 575672ad697c71e28991e496b5676c03053672ca Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 2 Apr 2026 16:16:27 +0100 Subject: [PATCH 04/10] `try_accumulate` feels cleaner Signed-off-by: blaginin --- vortex-array/src/aggregate_fn/accumulator.rs | 30 +++-------- .../src/aggregate_fn/fns/count/mod.rs | 52 +++++-------------- vortex-array/src/aggregate_fn/session.rs | 19 ------- vortex-array/src/aggregate_fn/vtable.rs | 16 ++++++ 4 files changed, 35 insertions(+), 82 deletions(-) diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index 9aa15149ff7..d49820f6dd1 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -100,10 +100,13 @@ impl DynAccumulator for Accumulator { batch.dtype() ); + // Allow the vtable to short-circuit on the raw array before decompression. + if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? { + return Ok(()); + } + let session = ctx.session().clone(); - let agg_fns = session.aggregate_fns(); - let kernels = &agg_fns.kernels; - let generic_kernels = &agg_fns.generic_kernels; + let kernels = &session.aggregate_fns().kernels; let mut batch = batch.clone(); for _ in 0..*MAX_ITERATIONS { @@ -133,27 +136,6 @@ impl DynAccumulator for Accumulator { return Ok(()); } - // Try encoding-agnostic kernels before decompressing. - let generic_r = generic_kernels.read(); - if let Some(result) = generic_r - .get(&self.aggregate_fn.id()) - .and_then(|kernel| { - kernel - .aggregate(&self.aggregate_fn, &batch, ctx) - .transpose() - }) - .transpose()? - { - vortex_ensure!( - result.dtype() == &self.partial_dtype, - "Aggregate kernel returned {}, expected {}", - result.dtype(), - self.partial_dtype, - ); - self.vtable.combine_partials(&mut self.partial, result)?; - return Ok(()); - } - // Execute one step and try again batch = batch.execute(ctx)?; } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index cf398ce6274..546315b066c 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -10,11 +10,9 @@ use crate::DynArray; use crate::ExecutionCtx; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnId; -use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; -use crate::aggregate_fn::kernels::DynAggregateKernel; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -34,29 +32,6 @@ pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { .vortex_expect("count result should not be null")) } -/// Encoding-agnostic count kernel. -/// -/// Count-non-null only depends on validity, not data values. Since every encoding -/// exposes validity independently, this avoids decompressing the data. -#[derive(Debug)] -pub(crate) struct CountKernel; - -impl DynAggregateKernel for CountKernel { - fn aggregate( - &self, - aggregate_fn: &AggregateFnRef, - batch: &ArrayRef, - _ctx: &mut ExecutionCtx, - ) -> VortexResult> { - if !aggregate_fn.is::() { - return Ok(None); - } - - let count = batch.valid_count()? as u64; - Ok(Some(Scalar::primitive(count, Nullability::NonNullable))) - } -} - /// Count the number of non-null elements in an array. /// /// Applies to all types. Returns a `u64` count. @@ -122,24 +97,23 @@ impl AggregateFnVTable for Count { false } + fn try_accumulate( + &self, + state: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + *state += batch.valid_count()? as u64; + Ok(true) + } + fn accumulate( &self, - partial: &mut Self::Partial, - batch: &Columnar, + _partial: &mut Self::Partial, + _batch: &Columnar, _ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - match batch { - Columnar::Constant(c) => { - if !c.scalar().is_null() { - *partial += c.len() as u64; - } - } - Columnar::Canonical(c) => { - let valid = c.as_ref().valid_count()?; - *partial += valid as u64; - } - } - Ok(()) + unreachable!("Count::try_accumulate handles all arrays") } fn finalize(&self, partials: ArrayRef) -> VortexResult { diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index 02fdc66243b..6e85ee97b6b 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -12,8 +12,6 @@ use vortex_utils::aliases::hash_map::HashMap; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnPluginRef; use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::fns::count::Count; -use crate::aggregate_fn::fns::count::CountKernel; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; use crate::aggregate_fn::fns::min_max::MinMax; @@ -38,9 +36,6 @@ pub struct AggregateFnSession { registry: AggregateFnRegistry, pub(super) kernels: RwLock>, - /// Encoding-agnostic kernels keyed by aggregate function ID. - /// These are checked as a fallback after encoding-specific kernels. - pub(super) generic_kernels: RwLock>, pub(super) grouped_kernels: RwLock>, } @@ -51,12 +46,10 @@ impl Default for AggregateFnSession { let this = Self { registry: AggregateFnRegistry::default(), kernels: RwLock::new(HashMap::default()), - generic_kernels: RwLock::new(HashMap::default()), grouped_kernels: RwLock::new(HashMap::default()), }; // Register the built-in aggregate functions - this.register(Count); this.register(IsConstant); this.register(IsSorted); this.register(MinMax); @@ -69,9 +62,6 @@ impl Default for AggregateFnSession { this.register_aggregate_kernel(Dict::ID, Some(IsConstant.id()), &DictIsConstantKernel); this.register_aggregate_kernel(Dict::ID, Some(IsSorted.id()), &DictIsSortedKernel); - // Register encoding-agnostic kernels. - this.register_generic_kernel(Count.id(), &CountKernel); - this } } @@ -98,15 +88,6 @@ impl AggregateFnSession { ) { self.kernels.write().insert((array_id, agg_fn_id), kernel); } - - /// Register an encoding-agnostic aggregate kernel for a specific aggregate function. - pub fn register_generic_kernel( - &self, - agg_fn_id: AggregateFnId, - kernel: &'static dyn DynAggregateKernel, - ) { - self.generic_kernels.write().insert(agg_fn_id, kernel); - } } /// Extension trait for accessing aggregate function session data. diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index feaebe61f56..e2bdf7ab602 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -105,6 +105,22 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { /// final result is fully determined. fn is_saturated(&self, state: &Self::Partial) -> bool; + /// Try to accumulate the raw array before decompression. + /// + /// Returns `true` if the array was handled, `false` to fall through to + /// the default kernel dispatch and canonicalization path. + /// + /// This is useful for aggregates that only depend on array metadata (e.g., validity) + /// rather than the encoded data, avoiding unnecessary decompression. + fn try_accumulate( + &self, + _state: &mut Self::Partial, + _batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(false) + } + /// Accumulate a new canonical array into the accumulator state. fn accumulate( &self, From 3c8b12b8e8f0d7b9fbe844c6b637ae75476faee8 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 2 Apr 2026 17:43:36 +0100 Subject: [PATCH 05/10] only mean; mean is not serializable Signed-off-by: blaginin --- vortex-array/src/aggregate_fn/accumulator.rs | 5 - .../src/aggregate_fn/fns/count/mod.rs | 267 ------------------ vortex-array/src/aggregate_fn/fns/mean/mod.rs | 7 +- vortex-array/src/aggregate_fn/fns/mod.rs | 1 - vortex-array/src/aggregate_fn/vtable.rs | 16 -- 5 files changed, 5 insertions(+), 291 deletions(-) delete mode 100644 vortex-array/src/aggregate_fn/fns/count/mod.rs diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index d49820f6dd1..95d5268b969 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -100,11 +100,6 @@ impl DynAccumulator for Accumulator { batch.dtype() ); - // Allow the vtable to short-circuit on the raw array before decompression. - if self.vtable.try_accumulate(&mut self.partial, batch, ctx)? { - return Ok(()); - } - let session = ctx.session().clone(); let kernels = &session.aggregate_fns().kernels; diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs deleted file mode 100644 index 546315b066c..00000000000 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ /dev/null @@ -1,267 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_error::VortexExpect; -use vortex_error::VortexResult; - -use crate::ArrayRef; -use crate::Columnar; -use crate::DynArray; -use crate::ExecutionCtx; -use crate::aggregate_fn::Accumulator; -use crate::aggregate_fn::AggregateFnId; -use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::DynAccumulator; -use crate::aggregate_fn::EmptyOptions; -use crate::dtype::DType; -use crate::dtype::Nullability; -use crate::dtype::PType; -use crate::scalar::Scalar; - -/// Return the count of non-null elements in an array. -/// -/// See [`Count`] for details. -pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?; - acc.accumulate(array, ctx)?; - let result = acc.finish()?; - - Ok(result - .as_primitive() - .typed_value::() - .vortex_expect("count result should not be null")) -} - -/// Count the number of non-null elements in an array. -/// -/// Applies to all types. Returns a `u64` count. -/// The identity value is zero. -#[derive(Clone, Debug)] -pub struct Count; - -impl AggregateFnVTable for Count { - type Options = EmptyOptions; - type Partial = u64; - - fn id(&self) -> AggregateFnId { - AggregateFnId::new_ref("vortex.count") - } - - fn serialize(&self, _options: &Self::Options) -> VortexResult>> { - Ok(Some(vec![])) - } - - fn deserialize( - &self, - _metadata: &[u8], - _session: &vortex_session::VortexSession, - ) -> VortexResult { - Ok(EmptyOptions) - } - - fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option { - Some(DType::Primitive(PType::U64, Nullability::NonNullable)) - } - - fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { - self.return_dtype(options, input_dtype) - } - - fn empty_partial( - &self, - _options: &Self::Options, - _input_dtype: &DType, - ) -> VortexResult { - Ok(0u64) - } - - fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { - let val = other - .as_primitive() - .typed_value::() - .vortex_expect("count partial should not be null"); - *partial += val; - Ok(()) - } - - fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { - Ok(Scalar::primitive(*partial, Nullability::NonNullable)) - } - - fn reset(&self, partial: &mut Self::Partial) { - *partial = 0; - } - - #[inline] - fn is_saturated(&self, _partial: &Self::Partial) -> bool { - false - } - - fn try_accumulate( - &self, - state: &mut Self::Partial, - batch: &ArrayRef, - _ctx: &mut ExecutionCtx, - ) -> VortexResult { - *state += batch.valid_count()? as u64; - Ok(true) - } - - fn accumulate( - &self, - _partial: &mut Self::Partial, - _batch: &Columnar, - _ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - unreachable!("Count::try_accumulate handles all arrays") - } - - fn finalize(&self, partials: ArrayRef) -> VortexResult { - Ok(partials) - } - - fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { - self.to_scalar(partial) - } -} - -#[cfg(test)] -mod tests { - use vortex_buffer::buffer; - use vortex_error::VortexResult; - - use crate::IntoArray; - use crate::LEGACY_SESSION; - use crate::VortexSessionExecute; - use crate::aggregate_fn::Accumulator; - use crate::aggregate_fn::AggregateFnVTable; - use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; - use crate::aggregate_fn::fns::count::Count; - use crate::aggregate_fn::fns::count::count; - use crate::arrays::ChunkedArray; - use crate::arrays::ConstantArray; - use crate::arrays::PrimitiveArray; - use crate::dtype::DType; - use crate::dtype::Nullability; - use crate::dtype::PType; - use crate::scalar::Scalar; - use crate::validity::Validity; - - #[test] - fn count_all_valid() -> VortexResult<()> { - let array = - PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(count(&array, &mut ctx)?, 5); - Ok(()) - } - - #[test] - fn count_with_nulls() -> VortexResult<()> { - let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)]) - .into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(count(&array, &mut ctx)?, 3); - Ok(()) - } - - #[test] - fn count_all_null() -> VortexResult<()> { - let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(count(&array, &mut ctx)?, 0); - Ok(()) - } - - #[test] - fn count_empty() -> VortexResult<()> { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; - let result = acc.finish()?; - assert_eq!(result.as_primitive().typed_value::(), Some(0)); - Ok(()) - } - - #[test] - fn count_multi_batch() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; - - let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(); - acc.accumulate(&batch1, &mut ctx)?; - - let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array(); - acc.accumulate(&batch2, &mut ctx)?; - - let result = acc.finish()?; - assert_eq!(result.as_primitive().typed_value::(), Some(3)); - Ok(()) - } - - #[test] - fn count_finish_resets_state() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; - - let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array(); - acc.accumulate(&batch1, &mut ctx)?; - let result1 = acc.finish()?; - assert_eq!(result1.as_primitive().typed_value::(), Some(1)); - - let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array(); - acc.accumulate(&batch2, &mut ctx)?; - let result2 = acc.finish()?; - assert_eq!(result2.as_primitive().typed_value::(), Some(2)); - Ok(()) - } - - #[test] - fn count_state_merge() -> VortexResult<()> { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let mut state = Count.empty_partial(&EmptyOptions, &dtype)?; - - let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable); - Count.combine_partials(&mut state, scalar1)?; - - let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable); - Count.combine_partials(&mut state, scalar2)?; - - let result = Count.to_scalar(&state)?; - Count.reset(&mut state); - assert_eq!(result.as_primitive().typed_value::(), Some(8)); - Ok(()) - } - - #[test] - fn count_constant_non_null() -> VortexResult<()> { - let array = ConstantArray::new(42i32, 10); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(count(&array.into_array(), &mut ctx)?, 10); - Ok(()) - } - - #[test] - fn count_constant_null() -> VortexResult<()> { - let array = ConstantArray::new( - Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), - 10, - ); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(count(&array.into_array(), &mut ctx)?, 0); - Ok(()) - } - - #[test] - fn count_chunked() -> VortexResult<()> { - let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]); - let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]); - let dtype = chunk1.dtype().clone(); - let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3); - Ok(()) - } -} diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs index 024fb4ea16b..6eed9b0aa6f 100644 --- a/vortex-array/src/aggregate_fn/fns/mean/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -77,7 +77,10 @@ impl AggregateFnVTable for Mean { } fn serialize(&self, _options: &Self::Options) -> VortexResult>> { - Ok(Some(vec![])) + // This function is not serializable until: + // - we decide on algo for compilation (and hence what should be the intermediate state) + // - we decide on return type (should mean(decimals) be a decimal?) + Ok(None) } fn deserialize( @@ -85,7 +88,7 @@ impl AggregateFnVTable for Mean { _metadata: &[u8], _session: &vortex_session::VortexSession, ) -> VortexResult { - Ok(EmptyOptions) + unimplemented!("Mean is not deserializable") } fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index 4e6df22299a..13afde10f6c 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -pub mod count; pub mod is_constant; pub mod is_sorted; pub mod mean; diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index e2bdf7ab602..feaebe61f56 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -105,22 +105,6 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { /// final result is fully determined. fn is_saturated(&self, state: &Self::Partial) -> bool; - /// Try to accumulate the raw array before decompression. - /// - /// Returns `true` if the array was handled, `false` to fall through to - /// the default kernel dispatch and canonicalization path. - /// - /// This is useful for aggregates that only depend on array metadata (e.g., validity) - /// rather than the encoded data, avoiding unnecessary decompression. - fn try_accumulate( - &self, - _state: &mut Self::Partial, - _batch: &ArrayRef, - _ctx: &mut ExecutionCtx, - ) -> VortexResult { - Ok(false) - } - /// Accumulate a new canonical array into the accumulator state. fn accumulate( &self, From 46db9fab01a5f9831f9b10872b8b9dd10e70eec1 Mon Sep 17 00:00:00 2001 From: blaginin Date: Thu, 2 Apr 2026 17:56:04 +0100 Subject: [PATCH 06/10] lock file update Signed-off-by: blaginin --- vortex-array/public-api.lock | 84 ++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index c193f10da92..0c6c1ff06cd 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -186,6 +186,56 @@ pub fn vortex_array::aggregate_fn::fns::is_sorted::is_strict_sorted(array: &vort pub fn vortex_array::aggregate_fn::fns::is_sorted::make_is_sorted_partial_dtype(element_dtype: &vortex_array::dtype::DType) -> vortex_array::dtype::DType +pub mod vortex_array::aggregate_fn::fns::mean + +pub struct vortex_array::aggregate_fn::fns::mean::Mean + +impl core::clone::Clone for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::clone(&self) -> vortex_array::aggregate_fn::fns::mean::Mean + +impl core::fmt::Debug for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Options = vortex_array::aggregate_fn::EmptyOptions + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Partial = vortex_array::aggregate_fn::fns::mean::MeanPartial + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::empty_partial(&self, _options: &Self::Options, _input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::is_saturated(&self, _partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::partial_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub struct vortex_array::aggregate_fn::fns::mean::MeanPartial + +pub fn vortex_array::aggregate_fn::fns::mean::mean(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::fns::min_max pub struct vortex_array::aggregate_fn::fns::min_max::MinMax @@ -684,6 +734,40 @@ pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::serialize(&self, op pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Options = vortex_array::aggregate_fn::EmptyOptions + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Partial = vortex_array::aggregate_fn::fns::mean::MeanPartial + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::empty_partial(&self, _options: &Self::Options, _input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::is_saturated(&self, _partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::partial_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::min_max::MinMax pub type vortex_array::aggregate_fn::fns::min_max::MinMax::Options = vortex_array::aggregate_fn::EmptyOptions From 0515378dd20d2a1bbba73517efa6f8d2875edf6b Mon Sep 17 00:00:00 2001 From: blaginin Date: Wed, 8 Apr 2026 16:05:31 +0100 Subject: [PATCH 07/10] Combined Signed-off-by: blaginin --- vortex-array/src/aggregate_fn/combined.rs | 254 ++++++++++++++++++ vortex-array/src/aggregate_fn/fns/mean/mod.rs | 2 + vortex-array/src/aggregate_fn/mod.rs | 1 + 3 files changed, 257 insertions(+) create mode 100644 vortex-array/src/aggregate_fn/combined.rs diff --git a/vortex-array/src/aggregate_fn/combined.rs b/vortex-array/src/aggregate_fn/combined.rs new file mode 100644 index 00000000000..aa0fa8e2d30 --- /dev/null +++ b/vortex-array/src/aggregate_fn/combined.rs @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Generic adapter for aggregates whose result is computed from two child +//! aggregate functions, e.g. `Mean = Sum / Count`. + +use std::fmt::{self, Debug, Display, Formatter}; +use std::hash::Hash; + +use vortex_error::{VortexResult, vortex_bail, vortex_err}; +use vortex_session::VortexSession; + +use crate::aggregate_fn::{AggregateFnId, AggregateFnVTable}; +use crate::builtins::ArrayBuiltins; +use crate::dtype::{DType, FieldName, FieldNames, Nullability, StructFields}; +use crate::scalar::Scalar; +use crate::{ArrayRef, Columnar, ExecutionCtx}; + +/// Pair of options for the two children of a [`BinaryCombined`] aggregate. +/// +/// Wrapper around `(L, R)` because the [`AggregateFnVTable::Options`] bound +/// requires `Display`, which tuples don't implement. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct PairOptions(pub L, pub R); + +impl Display for PairOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "({}, {})", self.0, self.1) + } +} + +// Convenience aliases so signatures stay readable. +type LeftOptions = <::Left as AggregateFnVTable>::Options; +type RightOptions = <::Right as AggregateFnVTable>::Options; +type LeftPartial = <::Left as AggregateFnVTable>::Partial; +type RightPartial = <::Right as AggregateFnVTable>::Partial; +/// Combined options for a [`BinaryCombined`] aggregate. +pub type CombinedOptions = PairOptions, RightOptions>; + +/// Declare an aggregate function in terms of two child aggregates. +pub trait BinaryCombined: 'static + Send + Sync + Clone { + /// The left child aggregate vtable. + type Left: AggregateFnVTable; + /// The right child aggregate vtable. + type Right: AggregateFnVTable; + + /// Stable identifier for the combined aggregate. + fn id(&self) -> AggregateFnId; + + /// Construct the left child vtable. + fn left(&self) -> Self::Left; + + /// Construct the right child vtable. + fn right(&self) -> Self::Right; + + /// Field name for the left child in the partial struct dtype. + fn left_name(&self) -> &'static str { + "left" + } + + /// Field name for the right child in the partial struct dtype. + fn right_name(&self) -> &'static str { + "right" + } + + /// Return type of the combined aggregate. + fn return_dtype(&self, input_dtype: &DType) -> Option; + + /// Combine the finalized left and right results into the final aggregate. + fn finalize(&self, left: ArrayRef, right: ArrayRef) -> VortexResult; + + /// Serialize the options for this combined aggregate. Default: not serializable. + fn serialize(&self, options: &CombinedOptions) -> VortexResult>> { + let _ = options; + Ok(None) + } + + /// Deserialize the options for this combined aggregate. Default: bails. + fn deserialize( + &self, + metadata: &[u8], + session: &VortexSession, + ) -> VortexResult> { + let _ = (metadata, session); + vortex_bail!( + "Combined aggregate function {} is not deserializable", + BinaryCombined::id(self) + ); + } + + /// Coerce the input type. Default: chains `right.coerce_args(left.coerce_args(input))`. + fn coerce_args( + &self, + options: &CombinedOptions, + input_dtype: &DType, + ) -> VortexResult { + let left_coerced = self.left().coerce_args(&options.0, input_dtype)?; + self.right().coerce_args(&options.1, &left_coerced) + } +} + +/// Adapter that exposes any [`BinaryCombined`] as an [`AggregateFnVTable`]. +#[derive(Clone, Debug)] +pub struct Combined(pub T); + +impl Combined { + /// Construct a new combined aggregate vtable. + pub fn new(inner: T) -> Self { + Self(inner) + } +} + +impl AggregateFnVTable for Combined { + type Options = CombinedOptions; + type Partial = (LeftPartial, RightPartial); + + fn id(&self) -> AggregateFnId { + self.0.id() + } + + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + BinaryCombined::serialize(&self.0, options) + } + + fn deserialize( + &self, + metadata: &[u8], + session: &VortexSession, + ) -> VortexResult { + BinaryCombined::deserialize(&self.0, metadata, session) + } + + fn coerce_args(&self, options: &Self::Options, input_dtype: &DType) -> VortexResult { + BinaryCombined::coerce_args(&self.0, options, input_dtype) + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + BinaryCombined::return_dtype(&self.0, input_dtype) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + let l = self.0.left().partial_dtype(&options.0, input_dtype)?; + let r = self.0.right().partial_dtype(&options.1, input_dtype)?; + Some(struct_dtype(self.0.left_name(), self.0.right_name(), l, r)) + } + + fn empty_partial( + &self, + options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult { + Ok(( + self.0.left().empty_partial(&options.0, input_dtype)?, + self.0.right().empty_partial(&options.1, input_dtype)?, + )) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + if other.is_null() { + return Ok(()); + } + let s = other.as_struct(); + let lname = self.0.left_name(); + let rname = self.0.right_name(); + let l_field = s + .field(lname) + .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", lname))?; + let r_field = s + .field(rname) + .ok_or_else(|| vortex_err!("BinaryCombined partial missing `{}` field", rname))?; + self.0.left().combine_partials(&mut partial.0, l_field)?; + self.0.right().combine_partials(&mut partial.1, r_field)?; + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + let l_scalar = self.0.left().to_scalar(&partial.0)?; + let r_scalar = self.0.right().to_scalar(&partial.1)?; + let dtype = struct_dtype( + self.0.left_name(), + self.0.right_name(), + l_scalar.dtype().clone(), + r_scalar.dtype().clone(), + ); + Ok(Scalar::struct_(dtype, vec![l_scalar, r_scalar])) + } + + fn reset(&self, partial: &mut Self::Partial) { + self.0.left().reset(&mut partial.0); + self.0.right().reset(&mut partial.1); + } + + fn is_saturated(&self, partial: &Self::Partial) -> bool { + self.0.left().is_saturated(&partial.0) && self.0.right().is_saturated(&partial.1) + } + + /// Fans out to each child's `try_accumulate`, falling back to `accumulate` + /// against a lazily-canonicalized batch. We always claim to handle the + /// batch ourselves so [`Self::accumulate`] is unreachable — this is the + /// same trick `Count` uses to opt out of the canonicalization path. + fn try_accumulate( + &self, + state: &mut Self::Partial, + batch: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let mut canonical: Option = None; + if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? { + canonical = Some(batch.clone().execute::(ctx)?); + self.0 + .left() + .accumulate(&mut state.0, canonical.as_ref().expect("just set"), ctx)?; + } + if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? { + if canonical.is_none() { + canonical = Some(batch.clone().execute::(ctx)?); + } + self.0 + .right() + .accumulate(&mut state.1, canonical.as_ref().expect("just set"), ctx)?; + } + Ok(true) + } + + fn accumulate( + &self, + _state: &mut Self::Partial, + _batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + unreachable!("Combined::try_accumulate handles all batches") + } + + fn finalize(&self, states: ArrayRef) -> VortexResult { + let l_field = states.get_item(FieldName::from(self.0.left_name()))?; + let r_field = states.get_item(FieldName::from(self.0.right_name()))?; + let l_finalized = self.0.left().finalize(l_field)?; + let r_finalized = self.0.right().finalize(r_field)?; + BinaryCombined::finalize(&self.0, l_finalized, r_finalized) + } +} + +fn struct_dtype(left_name: &str, right_name: &str, left: DType, right: DType) -> DType { + DType::Struct( + StructFields::new( + FieldNames::from_iter([ + FieldName::from(left_name), + FieldName::from(right_name), + ]), + vec![left, right], + ), + Nullability::NonNullable, + ) +} diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs index 6eed9b0aa6f..1e1b7d2eb4c 100644 --- a/vortex-array/src/aggregate_fn/fns/mean/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -18,7 +18,9 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; +use crate::arrays::bool::BoolArrayExt; use crate::arrays::PrimitiveArray; +use crate::arrays::struct_::StructArrayExt; use crate::canonical::ToCanonical; use crate::dtype::DType; use crate::dtype::FieldName; diff --git a/vortex-array/src/aggregate_fn/mod.rs b/vortex-array/src/aggregate_fn/mod.rs index 0dad56f3222..de4d2633959 100644 --- a/vortex-array/src/aggregate_fn/mod.rs +++ b/vortex-array/src/aggregate_fn/mod.rs @@ -33,6 +33,7 @@ pub mod fns; pub mod kernels; pub mod proto; pub mod session; +pub mod combined; /// A unique identifier for an aggregate function. pub type AggregateFnId = ArcRef; From ce2e52ce388ce676b22d1ce1a639424bb18786a5 Mon Sep 17 00:00:00 2001 From: blaginin Date: Wed, 8 Apr 2026 16:13:18 +0100 Subject: [PATCH 08/10] new Mean Co-authored-by: Claude Signed-off-by: blaginin --- vortex-array/src/aggregate_fn/fns/mean/mod.rs | 455 ++++-------------- 1 file changed, 100 insertions(+), 355 deletions(-) diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs index 1e1b7d2eb4c..766ea2ee147 100644 --- a/vortex-array/src/aggregate_fn/fns/mean/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -1,285 +1,118 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use num_traits::ToPrimitive; -use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_mask::Mask; use crate::ArrayRef; -use crate::Canonical; -use crate::Columnar; -use crate::DynArray; use crate::ExecutionCtx; -use crate::IntoArray; -use crate::aggregate_fn::Accumulator; -use crate::aggregate_fn::AggregateFnId; -use crate::aggregate_fn::AggregateFnVTable; -use crate::aggregate_fn::DynAccumulator; -use crate::aggregate_fn::EmptyOptions; -use crate::arrays::bool::BoolArrayExt; -use crate::arrays::PrimitiveArray; -use crate::arrays::struct_::StructArrayExt; -use crate::canonical::ToCanonical; -use crate::dtype::DType; -use crate::dtype::FieldName; -use crate::dtype::Nullability; -use crate::dtype::PType; -use crate::dtype::StructFields; -use crate::match_each_native_ptype; +use crate::aggregate_fn::combined::{BinaryCombined, Combined, PairOptions}; +use crate::aggregate_fn::fns::count::Count; +use crate::aggregate_fn::fns::sum::Sum; +use crate::aggregate_fn::{ + Accumulator, AggregateFnId, AggregateFnVTable, DynAccumulator, EmptyOptions, +}; +use crate::builtins::ArrayBuiltins; +use crate::dtype::{DType, DecimalDType, MAX_PRECISION, MAX_SCALE, Nullability, PType}; use crate::scalar::Scalar; -use crate::validity::Validity; +use crate::scalar_fn::fns::operators::Operator; /// Compute the arithmetic mean of an array. /// /// See [`Mean`] for details. pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let mut acc = Accumulator::try_new(Mean, EmptyOptions, array.dtype().clone())?; - acc.accumulate(array, ctx)?; + let options = PairOptions(EmptyOptions, EmptyOptions); + let vtable = Mean::combined(); + + let coerced_dtype = vtable.coerce_args(&options, array.dtype())?; + let coerced = array.cast(coerced_dtype.clone())?; + + let mut acc = Accumulator::try_new(vtable, options, coerced_dtype)?; + acc.accumulate(&coerced, ctx)?; acc.finish() } -/// Compute the arithmetic mean of an array, returning `f64`. +/// Compute the arithmetic mean of an array. /// -/// Applies to boolean and primitive numeric types. Returns a nullable `f64`. -/// Internally tracks sum and count, returning `sum / count` on finalize. -/// If there are no valid elements, returns null. +/// Implemented as `Sum / Count` via [`BinaryCombined`]. /// -/// The partial state is a struct `{sum: f64, count: u64}` so that partials from -/// different accumulators can be correctly combined via weighted addition. +/// Coercion / return type: +/// - Booleans and primitive numeric types are coerced to `f64` and the result +/// is a nullable `f64`. +/// - Decimals are kept as decimals with widened precision and scale +/// (`+4` each, capped at [`MAX_PRECISION`] / [`MAX_SCALE`]), matching +/// DataFusion's `coerce_avg_type`. #[derive(Clone, Debug)] pub struct Mean; -/// Internal accumulation state for [`Mean`]. -pub struct MeanPartial { - sum: f64, - count: u64, -} - -fn partial_struct_dtype() -> DType { - DType::Struct( - StructFields::new( - [FieldName::from("sum"), FieldName::from("count")].into(), - vec![ - DType::Primitive(PType::F64, Nullability::NonNullable), - DType::Primitive(PType::U64, Nullability::NonNullable), - ], - ), - Nullability::Nullable, - ) +impl Mean { + pub fn combined() -> Combined { + Combined(Mean) + } } -impl AggregateFnVTable for Mean { - type Options = EmptyOptions; - type Partial = MeanPartial; +impl BinaryCombined for Mean { + type Left = Sum; + type Right = Count; fn id(&self) -> AggregateFnId { AggregateFnId::new_ref("vortex.mean") } - fn serialize(&self, _options: &Self::Options) -> VortexResult>> { - // This function is not serializable until: - // - we decide on algo for compilation (and hence what should be the intermediate state) - // - we decide on return type (should mean(decimals) be a decimal?) - Ok(None) - } - - fn deserialize( - &self, - _metadata: &[u8], - _session: &vortex_session::VortexSession, - ) -> VortexResult { - unimplemented!("Mean is not deserializable") - } - - fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { - match input_dtype { - DType::Bool(_) | DType::Primitive(..) => { - Some(DType::Primitive(PType::F64, Nullability::Nullable)) - } - _ => None, - } - } - - fn partial_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { - match input_dtype { - DType::Bool(_) | DType::Primitive(..) => Some(partial_struct_dtype()), - _ => None, - } + fn left(&self) -> Sum { + Sum } - fn empty_partial( - &self, - _options: &Self::Options, - _input_dtype: &DType, - ) -> VortexResult { - Ok(MeanPartial { sum: 0.0, count: 0 }) + fn right(&self) -> Count { + Count } - fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { - if other.is_null() { - return Ok(()); - } - let s = other.as_struct(); - let sum_scalar = s - .field("sum") - .vortex_expect("mean partial must have sum field"); - let count_scalar = s - .field("count") - .vortex_expect("mean partial must have count field"); - - partial.sum += sum_scalar - .as_primitive() - .typed_value::() - .vortex_expect("sum field should not be null"); - partial.count += count_scalar - .as_primitive() - .typed_value::() - .vortex_expect("count field should not be null"); - Ok(()) + fn left_name(&self) -> &'static str { + "sum" } - fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { - if partial.count == 0 { - Ok(Scalar::null(partial_struct_dtype())) - } else { - Ok(Scalar::struct_( - partial_struct_dtype(), - vec![ - Scalar::primitive(partial.sum, Nullability::NonNullable), - Scalar::primitive(partial.count, Nullability::NonNullable), - ], - )) - } + fn right_name(&self) -> &'static str { + "count" } - fn reset(&self, partial: &mut Self::Partial) { - partial.sum = 0.0; - partial.count = 0; + fn return_dtype(&self, input_dtype: &DType) -> Option { + let coerced = coerced_dtype(input_dtype)?; + // Mean is always nullable: an empty / all-null group returns null. + Some(coerced.with_nullability(Nullability::Nullable)) } - #[inline] - fn is_saturated(&self, _partial: &Self::Partial) -> bool { - false + fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult { + let count_cast = count.cast(sum.dtype().clone())?; + sum.binary(count_cast, Operator::Div) } - fn accumulate( + fn coerce_args( &self, - partial: &mut Self::Partial, - batch: &Columnar, - _ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - match batch { - Columnar::Constant(c) => { - if !c.scalar().is_null() { - let val = scalar_to_f64(c.scalar())?; - partial.sum += val * c.len() as f64; - partial.count += c.len() as u64; - } - } - Columnar::Canonical(canonical) => match canonical { - Canonical::Primitive(prim) => { - let mask = prim.validity_mask()?; - match_each_native_ptype!(prim.ptype(), |T| { - accumulate_values(partial, prim.as_slice::(), &mask); - }); - } - Canonical::Bool(bool_arr) => { - let mask = bool_arr.validity_mask()?; - let bits = bool_arr.to_bit_buffer(); - match &mask { - Mask::AllTrue(_) => { - partial.sum += bits.true_count() as f64; - partial.count += bool_arr.len() as u64; - } - Mask::AllFalse(_) => {} - Mask::Values(validity) => { - let valid_count = validity.true_count(); - let valid_and_true = (&bits & validity.bit_buffer()).true_count(); - partial.sum += valid_and_true as f64; - partial.count += valid_count as u64; - } - } - } - _ => vortex_bail!("Unsupported canonical type for mean: {}", batch.dtype()), - }, - } - Ok(()) - } - - fn finalize(&self, partials: ArrayRef) -> VortexResult { - let struct_arr = partials.to_struct(); - let sums = struct_arr.unmasked_field_by_name("sum")?; - let counts = struct_arr.unmasked_field_by_name("count")?; - let validity_mask = struct_arr.validity_mask()?; - - let sum_prim = sums.to_primitive(); - let count_prim = counts.to_primitive(); - let sum_values = sum_prim.as_slice::(); - let count_values = count_prim.as_slice::(); - - let means: vortex_buffer::Buffer = sum_values - .iter() - .zip(count_values.iter()) - .map(|(s, c)| if *c == 0 { 0.0 } else { s / *c as f64 }) - .collect(); - - // A mean is valid when the group itself was valid AND had at least one - // non-null element (count > 0). - let validity = Validity::from_iter( - count_values - .iter() - .enumerate() - .map(|(i, c)| validity_mask.value(i) && *c > 0), - ); - - Ok(PrimitiveArray::new(means, validity).into_array()) - } - - fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { - if partial.count == 0 { - Ok(Scalar::null(DType::Primitive( - PType::F64, - Nullability::Nullable, - ))) - } else { - Ok(Scalar::primitive( - partial.sum / partial.count as f64, - Nullability::Nullable, - )) - } - } -} - -fn scalar_to_f64(scalar: &Scalar) -> VortexResult { - match scalar.dtype() { - DType::Bool(_) => { - let v = scalar.as_bool().value().vortex_expect("checked non-null"); - Ok(if v { 1.0 } else { 0.0 }) - } - DType::Primitive(..) => f64::try_from(scalar), - _ => vortex_bail!("Cannot convert {} to f64 for mean", scalar.dtype()), + _options: &PairOptions<::Options, ::Options>, + input_dtype: &DType, + ) -> VortexResult { + Ok(coerced_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone())) } } -fn accumulate_values(partial: &mut MeanPartial, values: &[T], mask: &Mask) { - match mask { - Mask::AllTrue(_) => { - partial.count += values.len() as u64; - for v in values { - partial.sum += v.to_f64().unwrap_or(0.0); - } +/// Decide what to coerce the input dtype to before feeding it to `Sum` and `Count`. +/// +/// Returns `None` for unsupported input types so callers can fall through. +fn coerced_dtype(input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(n) | DType::Primitive(_, n) => { + Some(DType::Primitive(PType::F64, *n)) } - Mask::AllFalse(_) => {} - Mask::Values(v) => { - for (val, valid) in values.iter().zip(v.bit_buffer().iter()) { - if valid { - partial.count += 1; - partial.sum += val.to_f64().unwrap_or(0.0); - } - } + DType::Decimal(d, n) => { + // Mirrors DataFusion's `coerce_avg_type`: precision and scale each + // grow by 4, capped at the maximum allowed. + let new_precision = u8::min(MAX_PRECISION, d.precision().saturating_add(4)); + let new_scale = i8::min(MAX_SCALE, d.scale().saturating_add(4)); + Some(DType::Decimal( + DecimalDType::new(new_precision, new_scale), + *n, + )) } + _ => None, } } @@ -288,30 +121,18 @@ mod tests { use vortex_buffer::buffer; use vortex_error::VortexResult; + use super::*; use crate::IntoArray; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; - use crate::aggregate_fn::Accumulator; - use crate::aggregate_fn::AggregateFnVTable; - use crate::aggregate_fn::DynAccumulator; - use crate::aggregate_fn::EmptyOptions; - use crate::aggregate_fn::fns::mean::Mean; - use crate::aggregate_fn::fns::mean::mean; - use crate::aggregate_fn::fns::mean::partial_struct_dtype; - use crate::arrays::BoolArray; - use crate::arrays::ChunkedArray; - use crate::arrays::ConstantArray; - use crate::arrays::PrimitiveArray; - use crate::dtype::DType; - use crate::dtype::Nullability; - use crate::dtype::PType; - use crate::scalar::Scalar; + use crate::arrays::{BoolArray, ChunkedArray, ConstantArray, PrimitiveArray}; use crate::validity::Validity; #[test] fn mean_all_valid() -> VortexResult<()> { - let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable) - .into_array(); + let array = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable) + .into_array(); let mut ctx = LEGACY_SESSION.create_execution_ctx(); let result = mean(&array, &mut ctx)?; assert_eq!(result.as_primitive().as_::(), Some(3.0)); @@ -327,24 +148,6 @@ mod tests { Ok(()) } - #[test] - fn mean_all_null() -> VortexResult<()> { - let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let result = mean(&array, &mut ctx)?; - assert!(result.is_null()); - Ok(()) - } - - #[test] - fn mean_empty() -> VortexResult<()> { - let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; - let result = acc.finish()?; - assert!(result.is_null()); - Ok(()) - } - #[test] fn mean_integers() -> VortexResult<()> { let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array(); @@ -355,70 +158,11 @@ mod tests { } #[test] - fn mean_multi_batch() -> VortexResult<()> { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; - - let batch1 = - PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); - acc.accumulate(&batch1, &mut ctx)?; - - let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array(); - acc.accumulate(&batch2, &mut ctx)?; - - let result = acc.finish()?; - assert_eq!(result.as_primitive().as_::(), Some(3.0)); - Ok(()) - } - - #[test] - fn mean_finish_resets_state() -> VortexResult<()> { + fn mean_bool() -> VortexResult<()> { + let array: BoolArray = [true, false, true, true].into_iter().collect(); let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; - - let batch1 = PrimitiveArray::new(buffer![2.0f64, 4.0], Validity::NonNullable).into_array(); - acc.accumulate(&batch1, &mut ctx)?; - let result1 = acc.finish()?; - assert_eq!(result1.as_primitive().as_::(), Some(3.0)); - - let batch2 = - PrimitiveArray::new(buffer![10.0f64, 20.0, 30.0], Validity::NonNullable).into_array(); - acc.accumulate(&batch2, &mut ctx)?; - let result2 = acc.finish()?; - assert_eq!(result2.as_primitive().as_::(), Some(20.0)); - Ok(()) - } - - #[test] - fn mean_state_merge() -> VortexResult<()> { - let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let mut state = Mean.empty_partial(&EmptyOptions, &dtype)?; - - // Partition 1: mean of [2, 4] → sum=6, count=2 - let partial1 = Scalar::struct_( - partial_struct_dtype(), - vec![ - Scalar::primitive(6.0f64, Nullability::NonNullable), - Scalar::primitive(2u64, Nullability::NonNullable), - ], - ); - Mean.combine_partials(&mut state, partial1)?; - - // Partition 2: mean of [10, 20, 30] → sum=60, count=3 - let partial2 = Scalar::struct_( - partial_struct_dtype(), - vec![ - Scalar::primitive(60.0f64, Nullability::NonNullable), - Scalar::primitive(3u64, Nullability::NonNullable), - ], - ); - Mean.combine_partials(&mut state, partial2)?; - - // Combined: (6 + 60) / (2 + 3) = 13.2 - let result = Mean.finalize_scalar(&state)?; - assert_eq!(result.as_primitive().as_::(), Some(13.2)); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(0.75)); Ok(()) } @@ -431,27 +175,6 @@ mod tests { Ok(()) } - #[test] - fn mean_constant_null() -> VortexResult<()> { - let array = ConstantArray::new( - Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable)), - 10, - ); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let result = mean(&array.into_array(), &mut ctx)?; - assert!(result.is_null()); - Ok(()) - } - - #[test] - fn mean_bool() -> VortexResult<()> { - let array: BoolArray = [true, false, true, true].into_iter().collect(); - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let result = mean(&array.into_array(), &mut ctx)?; - assert_eq!(result.as_primitive().as_::(), Some(0.75)); - Ok(()) - } - #[test] fn mean_chunked() -> VortexResult<()> { let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]); @@ -463,4 +186,26 @@ mod tests { assert_eq!(result.as_primitive().as_::(), Some(3.0)); Ok(()) } + + #[test] + fn mean_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new( + Mean::combined(), + PairOptions(EmptyOptions, EmptyOptions), + dtype, + )?; + + let batch1 = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } } From 0e56b5cbacbec08bc4a453ab38aad002240b1595 Mon Sep 17 00:00:00 2001 From: blaginin Date: Wed, 8 Apr 2026 16:23:22 +0100 Subject: [PATCH 09/10] decimals Signed-off-by: blaginin --- vortex-array/public-api.lock | 212 ++++++++++++++++++ vortex-array/src/aggregate_fn/combined.rs | 50 +++-- vortex-array/src/aggregate_fn/fns/mean/mod.rs | 119 +++++++--- vortex-array/src/aggregate_fn/mod.rs | 2 +- 4 files changed, 326 insertions(+), 57 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 687b9d33fef..5f41a34d6b8 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -28,6 +28,138 @@ pub fn vortex_array::arrays::PrimitiveArray::with_iterator(&self, f: F) -> pub mod vortex_array::aggregate_fn +pub mod vortex_array::aggregate_fn::combined + +pub struct vortex_array::aggregate_fn::combined::Combined(pub T) + +impl vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::new(inner: T) -> Self + +impl core::clone::Clone for vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::clone(&self) -> vortex_array::aggregate_fn::combined::Combined + +impl core::fmt::Debug for vortex_array::aggregate_fn::combined::Combined + +pub fn vortex_array::aggregate_fn::combined::Combined::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::combined::Combined + +pub type vortex_array::aggregate_fn::combined::Combined::Options = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + +pub type vortex_array::aggregate_fn::combined::Combined::Partial = (<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Partial, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::Combined::is_saturated(&self, partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::combined::Combined::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::combined::Combined::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub struct vortex_array::aggregate_fn::combined::PairOptions(pub L, pub R) + +impl core::marker::StructuralPartialEq for vortex_array::aggregate_fn::combined::PairOptions + +impl core::clone::Clone for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::clone(&self) -> vortex_array::aggregate_fn::combined::PairOptions + +impl core::cmp::Eq for vortex_array::aggregate_fn::combined::PairOptions + +impl core::cmp::PartialEq for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::eq(&self, other: &vortex_array::aggregate_fn::combined::PairOptions) -> bool + +impl core::fmt::Debug for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_array::aggregate_fn::combined::PairOptions + +pub fn vortex_array::aggregate_fn::combined::PairOptions::hash<__H: core::hash::Hasher>(&self, state: &mut __H) + +pub trait vortex_array::aggregate_fn::combined::BinaryCombined: 'static + core::marker::Send + core::marker::Sync + core::clone::Clone + +pub type vortex_array::aggregate_fn::combined::BinaryCombined::Left: vortex_array::aggregate_fn::AggregateFnVTable + +pub type vortex_array::aggregate_fn::combined::BinaryCombined::Right: vortex_array::aggregate_fn::AggregateFnVTable + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::coerce_args(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::finalize(&self, left: vortex_array::ArrayRef, right: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::left(&self) -> Self::Left + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::right(&self) -> Self::Right + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::combined::BinaryCombined::serialize(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +impl vortex_array::aggregate_fn::combined::BinaryCombined for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Left = vortex_array::aggregate_fn::fns::sum::Sum + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Right = vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::coerce_args(&self, _options: &vortex_array::aggregate_fn::combined::PairOptions<::Options, ::Options>, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize(&self, sum: vortex_array::ArrayRef, count: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left(&self) -> vortex_array::aggregate_fn::fns::sum::Sum + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right(&self) -> vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +pub type vortex_array::aggregate_fn::combined::CombinedOptions = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + pub mod vortex_array::aggregate_fn::fns pub mod vortex_array::aggregate_fn::fns::count @@ -342,6 +474,50 @@ pub struct vortex_array::aggregate_fn::fns::last::LastPartial pub fn vortex_array::aggregate_fn::fns::last::last(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub mod vortex_array::aggregate_fn::fns::mean + +pub struct vortex_array::aggregate_fn::fns::mean::Mean + +impl vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::combined() -> vortex_array::aggregate_fn::combined::Combined + +impl core::clone::Clone for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::clone(&self) -> vortex_array::aggregate_fn::fns::mean::Mean + +impl core::fmt::Debug for vortex_array::aggregate_fn::fns::mean::Mean + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::combined::BinaryCombined for vortex_array::aggregate_fn::fns::mean::Mean + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Left = vortex_array::aggregate_fn::fns::sum::Sum + +pub type vortex_array::aggregate_fn::fns::mean::Mean::Right = vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::coerce_args(&self, _options: &vortex_array::aggregate_fn::combined::PairOptions<::Options, ::Options>, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::finalize(&self, sum: vortex_array::ArrayRef, count: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left(&self) -> vortex_array::aggregate_fn::fns::sum::Sum + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::left_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::return_dtype(&self, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right(&self) -> vortex_array::aggregate_fn::fns::count::Count + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::right_name(&self) -> &'static str + +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::mean::mean(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::fns::min_max pub struct vortex_array::aggregate_fn::fns::min_max::MinMax @@ -1068,6 +1244,42 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::to_scalar(&self, partial: &Sel pub fn vortex_array::aggregate_fn::fns::sum::Sum::try_accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::combined::Combined + +pub type vortex_array::aggregate_fn::combined::Combined::Options = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> + +pub type vortex_array::aggregate_fn::combined::Combined::Partial = (<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Partial, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::combined::Combined::deserialize(&self, metadata: &[u8], session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize(&self, states: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::combined::Combined::is_saturated(&self, partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::combined::Combined::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::combined::Combined::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::combined::Combined::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::combined::Combined::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::combined::Combined::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub trait vortex_array::aggregate_fn::AggregateFnVTableExt: vortex_array::aggregate_fn::AggregateFnVTable pub fn vortex_array::aggregate_fn::AggregateFnVTableExt::bind(&self, options: Self::Options) -> vortex_array::aggregate_fn::AggregateFnRef diff --git a/vortex-array/src/aggregate_fn/combined.rs b/vortex-array/src/aggregate_fn/combined.rs index aa0fa8e2d30..9ecc0dcc8d1 100644 --- a/vortex-array/src/aggregate_fn/combined.rs +++ b/vortex-array/src/aggregate_fn/combined.rs @@ -4,17 +4,29 @@ //! Generic adapter for aggregates whose result is computed from two child //! aggregate functions, e.g. `Mean = Sum / Count`. -use std::fmt::{self, Debug, Display, Formatter}; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::fmt::{self}; use std::hash::Hash; -use vortex_error::{VortexResult, vortex_bail, vortex_err}; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_session::VortexSession; -use crate::aggregate_fn::{AggregateFnId, AggregateFnVTable}; +use crate::ArrayRef; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; use crate::builtins::ArrayBuiltins; -use crate::dtype::{DType, FieldName, FieldNames, Nullability, StructFields}; +use crate::dtype::DType; +use crate::dtype::FieldName; +use crate::dtype::FieldNames; +use crate::dtype::Nullability; +use crate::dtype::StructFields; use crate::scalar::Scalar; -use crate::{ArrayRef, Columnar, ExecutionCtx}; /// Pair of options for the two children of a [`BinaryCombined`] aggregate. /// @@ -122,11 +134,7 @@ impl AggregateFnVTable for Combined { BinaryCombined::serialize(&self.0, options) } - fn deserialize( - &self, - metadata: &[u8], - session: &VortexSession, - ) -> VortexResult { + fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult { BinaryCombined::deserialize(&self.0, metadata, session) } @@ -206,18 +214,15 @@ impl AggregateFnVTable for Combined { ) -> VortexResult { let mut canonical: Option = None; if !self.0.left().try_accumulate(&mut state.0, batch, ctx)? { - canonical = Some(batch.clone().execute::(ctx)?); - self.0 - .left() - .accumulate(&mut state.0, canonical.as_ref().expect("just set"), ctx)?; + let c = canonical.insert(batch.clone().execute::(ctx)?); + self.0.left().accumulate(&mut state.0, c, ctx)?; } if !self.0.right().try_accumulate(&mut state.1, batch, ctx)? { - if canonical.is_none() { - canonical = Some(batch.clone().execute::(ctx)?); - } - self.0 - .right() - .accumulate(&mut state.1, canonical.as_ref().expect("just set"), ctx)?; + let c = match canonical.as_ref() { + Some(c) => c, + None => canonical.insert(batch.clone().execute::(ctx)?), + }; + self.0.right().accumulate(&mut state.1, c, ctx)?; } Ok(true) } @@ -243,10 +248,7 @@ impl AggregateFnVTable for Combined { fn struct_dtype(left_name: &str, right_name: &str, left: DType, right: DType) -> DType { DType::Struct( StructFields::new( - FieldNames::from_iter([ - FieldName::from(left_name), - FieldName::from(right_name), - ]), + FieldNames::from_iter([FieldName::from(left_name), FieldName::from(right_name)]), vec![left, right], ), Nullability::NonNullable, diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs index 766ea2ee147..f6f6950f9e0 100644 --- a/vortex-array/src/aggregate_fn/fns/mean/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -5,14 +5,23 @@ use vortex_error::VortexResult; use crate::ArrayRef; use crate::ExecutionCtx; -use crate::aggregate_fn::combined::{BinaryCombined, Combined, PairOptions}; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::aggregate_fn::combined::BinaryCombined; +use crate::aggregate_fn::combined::Combined; +use crate::aggregate_fn::combined::PairOptions; use crate::aggregate_fn::fns::count::Count; use crate::aggregate_fn::fns::sum::Sum; -use crate::aggregate_fn::{ - Accumulator, AggregateFnId, AggregateFnVTable, DynAccumulator, EmptyOptions, -}; use crate::builtins::ArrayBuiltins; -use crate::dtype::{DType, DecimalDType, MAX_PRECISION, MAX_SCALE, Nullability, PType}; +use crate::dtype::DType; +use crate::dtype::DecimalDType; +use crate::dtype::MAX_PRECISION; +use crate::dtype::MAX_SCALE; +use crate::dtype::Nullability; +use crate::dtype::PType; use crate::scalar::Scalar; use crate::scalar_fn::fns::operators::Operator; @@ -20,14 +29,12 @@ use crate::scalar_fn::fns::operators::Operator; /// /// See [`Mean`] for details. pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let options = PairOptions(EmptyOptions, EmptyOptions); - let vtable = Mean::combined(); - - let coerced_dtype = vtable.coerce_args(&options, array.dtype())?; - let coerced = array.cast(coerced_dtype.clone())?; - - let mut acc = Accumulator::try_new(vtable, options, coerced_dtype)?; - acc.accumulate(&coerced, ctx)?; + let mut acc = Accumulator::try_new( + Mean::combined(), + PairOptions(EmptyOptions, EmptyOptions), + array.dtype().clone(), + )?; + acc.accumulate(array, ctx)?; acc.finish() } @@ -75,36 +82,45 @@ impl BinaryCombined for Mean { } fn return_dtype(&self, input_dtype: &DType) -> Option { - let coerced = coerced_dtype(input_dtype)?; - // Mean is always nullable: an empty / all-null group returns null. - Some(coerced.with_nullability(Nullability::Nullable)) + Some(mean_output_dtype(input_dtype)?.with_nullability(Nullability::Nullable)) } fn finalize(&self, sum: ArrayRef, count: ArrayRef) -> VortexResult { - let count_cast = count.cast(sum.dtype().clone())?; - sum.binary(count_cast, Operator::Div) + let target = match sum.dtype() { + DType::Decimal(..) => sum.dtype().with_nullability(Nullability::Nullable), + _ => DType::Primitive(PType::F64, Nullability::Nullable), + }; + let sum_cast = sum.cast(target.clone())?; + let count_cast = count.cast(target)?; + sum_cast.binary(count_cast, Operator::Div) } fn coerce_args( &self, - _options: &PairOptions<::Options, ::Options>, + _options: &PairOptions< + ::Options, + ::Options, + >, input_dtype: &DType, ) -> VortexResult { - Ok(coerced_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone())) + // Advisory hint for query planners: where possible, cast input to the + // type we're going to compute the mean in. + Ok(coerced_input_dtype(input_dtype).unwrap_or_else(|| input_dtype.clone())) } } -/// Decide what to coerce the input dtype to before feeding it to `Sum` and `Count`. +/// Hint for callers: what to cast the input to before accumulation. /// -/// Returns `None` for unsupported input types so callers can fall through. -fn coerced_dtype(input_dtype: &DType) -> Option { +/// - Bool stays as bool — `Sum` has a native bool path and bool → f64 isn't +/// currently a direct cast in vortex. +/// - Primitive numerics → `f64` so the sum and finalize work without overflow. +/// - Decimals → decimal with widened precision and scale (`+4` each, capped), +/// matching DataFusion's `coerce_avg_type`. +fn coerced_input_dtype(input_dtype: &DType) -> Option { match input_dtype { - DType::Bool(n) | DType::Primitive(_, n) => { - Some(DType::Primitive(PType::F64, *n)) - } + DType::Bool(_) => Some(input_dtype.clone()), + DType::Primitive(_, n) => Some(DType::Primitive(PType::F64, *n)), DType::Decimal(d, n) => { - // Mirrors DataFusion's `coerce_avg_type`: precision and scale each - // grow by 4, capped at the maximum allowed. let new_precision = u8::min(MAX_PRECISION, d.precision().saturating_add(4)); let new_scale = i8::min(MAX_SCALE, d.scale().saturating_add(4)); Some(DType::Decimal( @@ -116,6 +132,22 @@ fn coerced_dtype(input_dtype: &DType) -> Option { } } +fn mean_output_dtype(input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) | DType::Primitive(..) => { + Some(DType::Primitive(PType::F64, Nullability::Nullable)) + } + DType::Decimal(d, _) => { + let new_precision = u8::min(MAX_PRECISION, d.precision().saturating_add(10)); + Some(DType::Decimal( + DecimalDType::new(new_precision, d.scale()), + Nullability::Nullable, + )) + } + _ => None, + } +} + #[cfg(test)] mod tests { use vortex_buffer::buffer; @@ -125,14 +157,18 @@ mod tests { use crate::IntoArray; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; - use crate::arrays::{BoolArray, ChunkedArray, ConstantArray, PrimitiveArray}; + use crate::arrays::BoolArray; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::DecimalArray; + use crate::arrays::PrimitiveArray; + use crate::scalar::DecimalValue; use crate::validity::Validity; #[test] fn mean_all_valid() -> VortexResult<()> { - let array = - PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable) - .into_array(); + let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable) + .into_array(); let mut ctx = LEGACY_SESSION.create_execution_ctx(); let result = mean(&array, &mut ctx)?; assert_eq!(result.as_primitive().as_::(), Some(3.0)); @@ -187,6 +223,25 @@ mod tests { Ok(()) } + // TODO: vortex's cast kernel doesn't currently support `u64 → decimal`, + #[test] + #[ignore = "u64 → decimal cast not yet supported"] + fn mean_decimal() -> VortexResult<()> { + // 1.00, 2.00, 3.00 in decimal(5, 2). Mean = 2.00. + let values = buffer![100i128, 200i128, 300i128]; + let dt = DecimalDType::new(5, 2); + let array = DecimalArray::new(values, dt, Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + // `Sum` widens precision by +10, so the result lives in decimal(15, 2). + // 2.00 in scale=2 is the integer 200. + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I128(200)) + ); + Ok(()) + } + #[test] fn mean_multi_batch() -> VortexResult<()> { let mut ctx = LEGACY_SESSION.create_execution_ctx(); diff --git a/vortex-array/src/aggregate_fn/mod.rs b/vortex-array/src/aggregate_fn/mod.rs index de4d2633959..af32e94bb15 100644 --- a/vortex-array/src/aggregate_fn/mod.rs +++ b/vortex-array/src/aggregate_fn/mod.rs @@ -29,11 +29,11 @@ pub use erased::*; mod options; pub use options::*; +pub mod combined; pub mod fns; pub mod kernels; pub mod proto; pub mod session; -pub mod combined; /// A unique identifier for an aggregate function. pub type AggregateFnId = ArcRef; From 1e0d743a9fc29834210ba0f37b236912000fc3bf Mon Sep 17 00:00:00 2001 From: blaginin Date: Wed, 8 Apr 2026 16:25:04 +0100 Subject: [PATCH 10/10] block serialization Signed-off-by: blaginin --- vortex-array/public-api.lock | 4 ++-- vortex-array/src/aggregate_fn/fns/mean/mod.rs | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 5f41a34d6b8..8d0c45badd9 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -156,7 +156,7 @@ pub fn vortex_array::aggregate_fn::fns::mean::Mean::right(&self) -> vortex_array pub fn vortex_array::aggregate_fn::fns::mean::Mean::right_name(&self) -> &'static str -pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, _options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> pub type vortex_array::aggregate_fn::combined::CombinedOptions = vortex_array::aggregate_fn::combined::PairOptions<<::Left as vortex_array::aggregate_fn::AggregateFnVTable>::Options, <::Right as vortex_array::aggregate_fn::AggregateFnVTable>::Options> @@ -514,7 +514,7 @@ pub fn vortex_array::aggregate_fn::fns::mean::Mean::right(&self) -> vortex_array pub fn vortex_array::aggregate_fn::fns::mean::Mean::right_name(&self) -> &'static str -pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> +pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, _options: &vortex_array::aggregate_fn::combined::CombinedOptions) -> vortex_error::VortexResult>> pub fn vortex_array::aggregate_fn::fns::mean::mean(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs index f6f6950f9e0..fe4829b82f2 100644 --- a/vortex-array/src/aggregate_fn/fns/mean/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -10,7 +10,7 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; -use crate::aggregate_fn::combined::BinaryCombined; +use crate::aggregate_fn::combined::{BinaryCombined, CombinedOptions}; use crate::aggregate_fn::combined::Combined; use crate::aggregate_fn::combined::PairOptions; use crate::aggregate_fn::fns::count::Count; @@ -95,6 +95,10 @@ impl BinaryCombined for Mean { sum_cast.binary(count_cast, Operator::Div) } + fn serialize(&self, _options: &CombinedOptions) -> VortexResult>> { + unimplemented!("Mean is not yet serializable"); + } + fn coerce_args( &self, _options: &PairOptions<