diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs index 4f26d8683902f..83cc5cded8361 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs @@ -17,11 +17,13 @@ mod bytes; mod dict; +mod groups; mod native; pub use bytes::BytesDistinctCountAccumulator; pub use bytes::BytesViewDistinctCountAccumulator; pub use dict::DictionaryCountAccumulator; +pub use groups::PrimitiveDistinctCountGroupsAccumulator; pub use native::Bitmap65536DistinctCountAccumulator; pub use native::Bitmap65536DistinctCountAccumulatorI16; pub use native::BoolArray256DistinctCountAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/groups.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/groups.rs new file mode 100644 index 0000000000000..d370d59c90012 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/groups.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, PrimitiveArray, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{ArrowPrimitiveType, Field}; +use datafusion_common::HashSet; +use datafusion_common::hash_utils::RandomState; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +use std::hash::Hash; +use std::mem::size_of; +use std::sync::Arc; + +use crate::aggregate::groups_accumulator::accumulate::accumulate; + +pub struct PrimitiveDistinctCountGroupsAccumulator +where + T::Native: Eq + Hash, +{ + seen: HashSet<(usize, T::Native), RandomState>, + counts: Vec, +} + +impl PrimitiveDistinctCountGroupsAccumulator +where + T::Native: Eq + Hash, +{ + pub fn new() -> Self { + Self { + seen: HashSet::default(), + counts: Vec::new(), + } + } +} + +impl Default for PrimitiveDistinctCountGroupsAccumulator +where + T::Native: Eq + Hash, +{ + fn default() -> Self { + Self::new() + } +} + +impl GroupsAccumulator + for PrimitiveDistinctCountGroupsAccumulator +where + T::Native: Eq + Hash, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> datafusion_common::Result<()> { + debug_assert_eq!(values.len(), 1); + self.counts.resize(total_num_groups, 0); + let arr = values[0].as_primitive::(); + accumulate(group_indices, arr, opt_filter, |group_idx, value| { + if self.seen.insert((group_idx, value)) { + self.counts[group_idx] += 1; + } + }); + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result { + let counts = emit_to.take_needed(&mut self.counts); + + match emit_to { + EmitTo::All => { + self.seen.clear(); + } + EmitTo::First(n) => { + let mut remaining = HashSet::default(); + for (group_idx, value) in self.seen.drain() { + if group_idx >= n { + remaining.insert((group_idx - n, value)); + } + } + self.seen = remaining; + } + } + + Ok(Arc::new(Int64Array::from(counts))) + } + + fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + let num_emitted = match emit_to { + EmitTo::All => self.counts.len(), + EmitTo::First(n) => n, + }; + + // Prefix-sum counts[..num_emitted] into offsets + let mut offsets = Vec::with_capacity(num_emitted + 1); + offsets.push(0i32); + let mut total = 0i32; + for &c in &self.counts[..num_emitted] { + total += c as i32; + offsets.push(total); + } + + let mut all_values = vec![T::Native::default(); total as usize]; + let mut cursors: Vec = offsets[..num_emitted].to_vec(); + + if matches!(emit_to, EmitTo::All) { + for (group_idx, value) in self.seen.drain() { + let pos = cursors[group_idx] as usize; + all_values[pos] = value; + cursors[group_idx] += 1; + } + self.counts.clear(); + } else { + let mut remaining = HashSet::default(); + for (group_idx, value) in self.seen.drain() { + if group_idx < num_emitted { + let pos = cursors[group_idx] as usize; + all_values[pos] = value; + cursors[group_idx] += 1; + } else { + remaining.insert((group_idx - num_emitted, value)); + } + } + self.seen = remaining; + let _ = emit_to.take_needed(&mut self.counts); + } + + let values_array = Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(all_values), + None, + )); + let list_array = ListArray::new( + Arc::new(Field::new_list_field(T::DATA_TYPE, true)), + OffsetBuffer::new(offsets.into()), + values_array, + None, + ); + + Ok(vec![Arc::new(list_array)]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> datafusion_common::Result<()> { + debug_assert_eq!(values.len(), 1); + self.counts.resize(total_num_groups, 0); + let list_array = values[0].as_list::(); + let inner = list_array.values().as_primitive::(); + let inner_values = inner.values(); + let offsets = list_array.offsets(); + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + let start = offsets[row_idx] as usize; + let end = offsets[row_idx + 1] as usize; + for &value in &inner_values[start..end] { + if self.seen.insert((group_idx, value)) { + self.counts[group_idx] += 1; + } + } + } + + Ok(()) + } + + fn size(&self) -> usize { + size_of::() + + self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::()) + + self.counts.capacity() * size_of::() + } +} diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 75a7018f32016..eab36d4951a9c 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -21,11 +21,11 @@ use arrow::{ compute, datatypes::{ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, - FieldRef, Float16Type, Float32Type, Float64Type, Int32Type, Int64Type, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, + Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, - UInt32Type, UInt64Type, + UInt8Type, UInt16Type, UInt32Type, UInt64Type, }, }; use datafusion_common::hash_utils::RandomState; @@ -41,6 +41,7 @@ use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, }; +use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator; use datafusion_functions_aggregate_common::aggregate::{ count_distinct::Bitmap65536DistinctCountAccumulator, count_distinct::Bitmap65536DistinctCountAccumulatorI16, @@ -344,20 +345,33 @@ impl AggregateUDFImpl for Count { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - // groups accumulator only supports `COUNT(c1)`, not - // `COUNT(c1, c2)`, etc - if args.is_distinct { + if args.exprs.len() != 1 { return false; } - args.exprs.len() == 1 + if !args.is_distinct { + return true; + } + matches!( + args.expr_fields[0].data_type(), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) } fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { - // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) + if !args.is_distinct { + return Ok(Box::new(CountGroupsAccumulator::new())); + } + create_distinct_count_groups_accumulator(&args) } fn reverse_expr(&self) -> ReversedUDAF { @@ -430,6 +444,43 @@ impl AggregateUDFImpl for Count { } } +#[cold] +fn create_distinct_count_groups_accumulator( + args: &AccumulatorArgs, +) -> Result> { + let data_type = args.expr_fields[0].data_type(); + match data_type { + DataType::Int8 => Ok(Box::new( + PrimitiveDistinctCountGroupsAccumulator::::new(), + )), + DataType::Int16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int16Type, + >::new())), + DataType::Int32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int32Type, + >::new())), + DataType::Int64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int64Type, + >::new())), + DataType::UInt8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt8Type, + >::new())), + DataType::UInt16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt16Type, + >::new())), + DataType::UInt32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt32Type, + >::new())), + DataType::UInt64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt64Type, + >::new())), + _ => not_impl_err!( + "GroupsAccumulator not supported for COUNT(DISTINCT) with {}", + data_type + ), + } +} + // DistinctCountAccumulator does not support retract_batch and sliding window // this is a specialized accumulator for distinct count that supports retract_batch // and sliding window.