diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 81efea1df22b1..c3c2e5e0b9677 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -533,25 +533,60 @@ impl SlidingDistinctSumAccumulator { data_type: data_type.clone(), }) } + + fn update_value(&mut self, value: i64) { + let cnt = self.counts.entry(value).or_insert(0); + if *cnt == 0 { + // first occurrence in window + self.sum = self.sum.wrapping_add(value); + } + *cnt += 1; + } + + fn retract_value(&mut self, value: i64) { + if let Some(cnt) = self.counts.get_mut(&value) { + *cnt -= 1; + if *cnt == 0 { + // last copy leaving window + self.sum = self.sum.wrapping_sub(value); + self.counts.remove(&value); + } + } + } + + fn apply_valid_values( + &mut self, + arr: &arrow::array::PrimitiveArray, + mut op: F, + ) where + F: FnMut(&mut Self, i64), + { + if arr.null_count() == 0 { + for &value in arr.values() { + op(self, value); + } + } else { + for (idx, &value) in arr.values().iter().enumerate() { + if arr.is_valid(idx) { + op(self, value); + } + } + } + } } impl Accumulator for SlidingDistinctSumAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let arr = values[0].as_primitive::(); - for &v in arr.values() { - let cnt = self.counts.entry(v).or_insert(0); - if *cnt == 0 { - // first occurrence in window - self.sum = self.sum.wrapping_add(v); - } - *cnt += 1; - } + self.apply_valid_values(arr, Self::update_value); Ok(()) } fn evaluate(&mut self) -> Result { // O(1) wrap of running sum - Ok(ScalarValue::Int64(Some(self.sum))) + Ok(ScalarValue::Int64( + (!self.counts.is_empty()).then_some(self.sum), + )) } fn size(&self) -> usize { @@ -581,11 +616,7 @@ impl Accumulator for SlidingDistinctSumAccumulator { if let ScalarValue::Int64(Some(v)) = ScalarValue::try_from_array(&*maybe_inner, idx)? { - let cnt = self.counts.entry(v).or_insert(0); - if *cnt == 0 { - self.sum = self.sum.wrapping_add(v); - } - *cnt += 1; + self.update_value(v); } } } @@ -594,16 +625,7 @@ impl Accumulator for SlidingDistinctSumAccumulator { fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let arr = values[0].as_primitive::(); - for &v in arr.values() { - if let Some(cnt) = self.counts.get_mut(&v) { - *cnt -= 1; - if *cnt == 0 { - // last copy leaving window - self.sum = self.sum.wrapping_sub(v); - self.counts.remove(&v); - } - } - } + self.apply_valid_values(arr, Self::retract_value); Ok(()) } @@ -611,3 +633,53 @@ impl Accumulator for SlidingDistinctSumAccumulator { true } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::Int64Array, + buffer::{NullBuffer, ScalarBuffer}, + }; + use std::sync::Arc; + + #[test] + fn sliding_distinct_sum_ignores_null_slots() -> Result<()> { + let mut acc = SlidingDistinctSumAccumulator::try_new(&DataType::Int64)?; + + let values: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![42, 5, 5]), + Some(NullBuffer::from(vec![false, true, true])), + )); + acc.update_batch(&[values])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(5))); + + let retract: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![42, 5]), + Some(NullBuffer::from(vec![false, true])), + )); + acc.retract_batch(&[retract])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(5))); + + let retract_last: ArrayRef = + Arc::new(Int64Array::new(ScalarBuffer::from(vec![5]), None)); + acc.retract_batch(&[retract_last])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(None)); + + Ok(()) + } + + #[test] + fn sliding_distinct_sum_returns_null_for_all_null_frame() -> Result<()> { + let mut acc = SlidingDistinctSumAccumulator::try_new(&DataType::Int64)?; + + let values: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![99]), + Some(NullBuffer::from(vec![false])), + )); + acc.update_batch(&[values])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(None)); + + Ok(()) + } +} diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index bc2f1bfcbc73f..1b51950a70e1b 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5959,6 +5959,28 @@ physical_plan 07)------------DataSourceExec: partitions=2, partition_sizes=[5, 4] +# SUM(DISTINCT) over sliding frames must skip NULLs and return NULL +# for frames containing no non-null values. +statement ok +CREATE TABLE table_distinct_sum_nulls(ts INT, v BIGINT) AS VALUES + (1, NULL), (2, 3), (3, NULL), (4, NULL), (5, 5); + +query II +SELECT + ts, + SUM(DISTINCT v) OVER ( + ORDER BY ts + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ) AS s +FROM table_distinct_sum_nulls; +---- +1 NULL +2 3 +3 3 +4 NULL +5 5 + + # FILTER clause with window functions # Verify FILTER clause with non-aggregate window functions fails with a clear message