diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 40da98c3eb3a2..cc42b6c22bdbe 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -23,11 +23,10 @@ use arrow::array::{ GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ - ArrowPrimitiveType, Date32Type, Date64Type, FieldRef, Int8Type, Int16Type, Int32Type, - Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, - UInt64Type, + ArrowPrimitiveType, Date32Type, Date64Type, FieldRef, Int32Type, Int64Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt32Type, UInt64Type, }; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::ScalarValue; @@ -40,6 +39,10 @@ use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; +use datafusion_functions_aggregate_common::aggregate::count_distinct::{ + Bitmap65536DistinctCountAccumulator, Bitmap65536DistinctCountAccumulatorI16, + BoolArray256DistinctCountAccumulator, BoolArray256DistinctCountAccumulatorI8, +}; use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator; use datafusion_macros::user_doc; use std::fmt::{Debug, Formatter}; @@ -84,6 +87,36 @@ impl TryFrom<&ScalarValue> for HyperLogLog { } } +#[derive(Debug)] +struct ApproxDistinctBitmapWrapper { + inner: A, +} + +impl Accumulator for ApproxDistinctBitmapWrapper { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.inner.update_batch(values) + } + + fn evaluate(&mut self) -> Result { + match self.inner.evaluate()? { + ScalarValue::Int64(Some(v)) => Ok(ScalarValue::UInt64(Some(v as u64))), + other => internal_err!("unexpected: {other}"), + } + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> Result> { + self.inner.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.inner.merge_batch(states) + } +} + #[derive(Debug)] struct NumericHLLAccumulator where @@ -302,6 +335,39 @@ impl ApproxDistinct { } } +#[cold] +fn get_small_int_approx_accumulator( + data_type: &DataType, +) -> Result> { + match data_type { + DataType::UInt8 => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: BoolArray256DistinctCountAccumulator::new(), + })), + DataType::Int8 => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: BoolArray256DistinctCountAccumulatorI8::new(), + })), + DataType::UInt16 => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: Bitmap65536DistinctCountAccumulator::new(), + })), + DataType::Int16 => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: Bitmap65536DistinctCountAccumulatorI16::new(), + })), + _ => internal_err!("unsupported small int type: {}", data_type), + } +} + +#[cold] +fn get_small_int_state_field(name: &str, data_type: &DataType) -> Result> { + Ok(vec![ + Field::new_list( + format_state_name(name, "approx_distinct"), + Field::new_list_field(data_type.clone(), true), + false, + ) + .into(), + ]) +} + impl AggregateUDFImpl for ApproxDistinct { fn name(&self) -> &str { "approx_distinct" @@ -316,24 +382,27 @@ impl AggregateUDFImpl for ApproxDistinct { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - if args.input_fields[0].data_type().is_null() { - Ok(vec![ + let data_type = args.input_fields[0].data_type(); + match data_type { + DataType::Null => Ok(vec![ Field::new( format_state_name(args.name, self.name()), DataType::Null, true, ) .into(), - ]) - } else { - Ok(vec![ + ]), + DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => { + get_small_int_state_field(args.name, data_type) + } + _ => Ok(vec![ Field::new( format_state_name(args.name, "hll_registers"), DataType::Binary, false, ) .into(), - ]) + ]), } } @@ -341,15 +410,11 @@ impl AggregateUDFImpl for ApproxDistinct { let data_type = acc_args.expr_fields[0].data_type(); let accumulator: Box = match data_type { - // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL - // TODO support for boolean (trivial case) - // https://github.com/apache/datafusion/issues/1109 - DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => { + return get_small_int_approx_accumulator(data_type); + } DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), DataType::Date32 => Box::new(NumericHLLAccumulator::::new()),