Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 96 additions & 24 deletions datafusion/functions-aggregate/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,25 +533,60 @@ impl SlidingDistinctSumAccumulator {
data_type: data_type.clone(),
})
}

fn update_value(&mut self, value: i64) {
let cnt = self.counts.entry(value).or_insert(0);
if *cnt == 0 {
// first occurrence in window
self.sum = self.sum.wrapping_add(value);
}
*cnt += 1;
}

fn retract_value(&mut self, value: i64) {
if let Some(cnt) = self.counts.get_mut(&value) {
*cnt -= 1;
if *cnt == 0 {
// last copy leaving window
self.sum = self.sum.wrapping_sub(value);
self.counts.remove(&value);
}
}
}

fn apply_valid_values<F>(
&mut self,
arr: &arrow::array::PrimitiveArray<Int64Type>,
mut op: F,
) where
F: FnMut(&mut Self, i64),
{
if arr.null_count() == 0 {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup overall. One small thought: the NULL-aware iteration logic looks very similar between update_batch and retract_batch. Would it make sense to extract a small helper that keeps the current no-null fast path and accepts the value operation, something like apply_valid_values(values, Self::update_value) and apply_valid_values(values, Self::retract_value)?

That would keep the invariant that only valid slots affect counts and sum in one place and help avoid the two paths drifting apart over time.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks you @kosiew

for &value in arr.values() {
op(self, value);
}
} else {
for (idx, &value) in arr.values().iter().enumerate() {
if arr.is_valid(idx) {
op(self, value);
}
}
}
}
}

impl Accumulator for SlidingDistinctSumAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let arr = values[0].as_primitive::<Int64Type>();
for &v in arr.values() {
let cnt = self.counts.entry(v).or_insert(0);
if *cnt == 0 {
// first occurrence in window
self.sum = self.sum.wrapping_add(v);
}
*cnt += 1;
}
self.apply_valid_values(arr, Self::update_value);
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
// O(1) wrap of running sum
Ok(ScalarValue::Int64(Some(self.sum)))
Ok(ScalarValue::Int64(
(!self.counts.is_empty()).then_some(self.sum),
))
}

fn size(&self) -> usize {
Expand Down Expand Up @@ -581,11 +616,7 @@ impl Accumulator for SlidingDistinctSumAccumulator {
if let ScalarValue::Int64(Some(v)) =
ScalarValue::try_from_array(&*maybe_inner, idx)?
{
let cnt = self.counts.entry(v).or_insert(0);
if *cnt == 0 {
self.sum = self.sum.wrapping_add(v);
}
*cnt += 1;
self.update_value(v);
}
}
}
Expand All @@ -594,20 +625,61 @@ impl Accumulator for SlidingDistinctSumAccumulator {

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let arr = values[0].as_primitive::<Int64Type>();
for &v in arr.values() {
if let Some(cnt) = self.counts.get_mut(&v) {
*cnt -= 1;
if *cnt == 0 {
// last copy leaving window
self.sum = self.sum.wrapping_sub(v);
self.counts.remove(&v);
}
}
}
self.apply_valid_values(arr, Self::retract_value);
Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::{
array::Int64Array,
buffer::{NullBuffer, ScalarBuffer},
};
use std::sync::Arc;

#[test]
fn sliding_distinct_sum_ignores_null_slots() -> Result<()> {
let mut acc = SlidingDistinctSumAccumulator::try_new(&DataType::Int64)?;

let values: ArrayRef = Arc::new(Int64Array::new(
ScalarBuffer::from(vec![42, 5, 5]),
Some(NullBuffer::from(vec![false, true, true])),
));
acc.update_batch(&[values])?;
assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(5)));

let retract: ArrayRef = Arc::new(Int64Array::new(
ScalarBuffer::from(vec![42, 5]),
Some(NullBuffer::from(vec![false, true])),
));
acc.retract_batch(&[retract])?;
assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(5)));

let retract_last: ArrayRef =
Arc::new(Int64Array::new(ScalarBuffer::from(vec![5]), None));
acc.retract_batch(&[retract_last])?;
assert_eq!(acc.evaluate()?, ScalarValue::Int64(None));

Ok(())
}

#[test]
fn sliding_distinct_sum_returns_null_for_all_null_frame() -> Result<()> {
let mut acc = SlidingDistinctSumAccumulator::try_new(&DataType::Int64)?;

let values: ArrayRef = Arc::new(Int64Array::new(
ScalarBuffer::from(vec![99]),
Some(NullBuffer::from(vec![false])),
));
acc.update_batch(&[values])?;
assert_eq!(acc.evaluate()?, ScalarValue::Int64(None));

Ok(())
}
}
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/window.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5959,6 +5959,28 @@ physical_plan
07)------------DataSourceExec: partitions=2, partition_sizes=[5, 4]


# SUM(DISTINCT) over sliding frames must skip NULLs and return NULL
# for frames containing no non-null values.
statement ok
CREATE TABLE table_distinct_sum_nulls(ts INT, v BIGINT) AS VALUES
(1, NULL), (2, 3), (3, NULL), (4, NULL), (5, 5);

query II
SELECT
ts,
SUM(DISTINCT v) OVER (
ORDER BY ts
ROWS BETWEEN 1 PRECEDING AND CURRENT ROW
) AS s
FROM table_distinct_sum_nulls;
----
1 NULL
2 3
3 3
4 NULL
5 5


# FILTER clause with window functions

# Verify FILTER clause with non-aggregate window functions fails with a clear message
Expand Down
Loading