diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 2621fcf0bf3c7..5a95cfe8320fc 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -281,6 +281,10 @@ impl Accumulator for CorrelationAccumulator { self.stddev2.retract_batch(&values[1..2])?; Ok(()) } + + fn supports_retract_batch(&self) -> bool { + true + } } #[derive(Default)] diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 18d602ab33940..bd7c8a039076a 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -305,6 +305,14 @@ impl Accumulator for CovarianceAccumulator { _ => continue, }; + if self.count <= 1 { + self.count = 0; + self.mean1 = 0.0; + self.mean2 = 0.0; + self.algo_const = 0.0; + continue; + } + let new_count = self.count - 1; let delta1 = self.mean1 - value1; let new_mean1 = delta1 / new_count as f64 + self.mean1; @@ -373,4 +381,8 @@ impl Accumulator for CovarianceAccumulator { fn size(&self) -> usize { size_of_val(self) } + + fn supports_retract_batch(&self) -> bool { + true + } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index ce3e00b9ffd91..d5fddf01f2d52 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -348,6 +348,13 @@ impl Accumulator for VarianceAccumulator { fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let arr = as_float64_array(&values[0])?; for value in arr.iter().flatten() { + if self.count <= 1 { + self.count = 0; + self.mean = 0.0; + self.m2 = 0.0; + continue; + } + let new_count = self.count - 1; let delta1 = self.mean - value; let new_mean = delta1 / new_count as f64 + self.mean; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index bc2f1bfcbc73f..f811646b0853b 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -6604,6 +6604,142 @@ ORDER BY i; 3 1 4 NULL +# Covariance/correlation sliding-window regression test. Verifies correct +# results across row removals and a NULL-gap empty-frame transition. +query IRRR +SELECT + column1, + covar_pop(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ), + covar_samp(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ), + corr(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ) +FROM ( + VALUES + (1, 10.0, 5.0), + (2, NULL, NULL), + (3, NULL, NULL), + (4, 30.0, 10.0), + (5, 40.0, 20.0), + (6, 50.0, 10.0) +); +---- +1 0 NULL NULL +2 0 NULL NULL +3 NULL NULL NULL +4 0 NULL NULL +5 25 50 1 +6 -25 -50 -1 + +# Multi-row covariance/correlation sliding-window regression test. Verifies +# correct accumulation when valid rows enter the frame after a reset. +query IRRR +SELECT + column1, + covar_pop(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 2 PRECEDING AND CURRENT ROW + ), + covar_samp(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 2 PRECEDING AND CURRENT ROW + ), + corr(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 2 PRECEDING AND CURRENT ROW + ) +FROM ( + VALUES + (1, 10.0, 5.0), + (2, NULL, NULL), + (3, NULL, NULL), + (4, 30.0, 10.0), + (5, 40.0, 20.0), + (6, 50.0, 10.0) +); +---- +1 0 NULL NULL +2 0 NULL NULL +3 0 NULL NULL +4 0 NULL NULL +5 25 50 1 +6 0 0 0 + +# Covariance/correlation sliding-window regression test. Rows with NULL in +# either input column must not contribute to the aggregate state. +query IRRR +SELECT + column1, + covar_pop(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ), + covar_samp(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ), + corr(column2, column3) OVER ( + ORDER BY column1 + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ) +FROM ( + VALUES + (1, 10.0, 5.0), + (2, 20.0, NULL), + (3, NULL, 15.0), + (4, 30.0, 10.0), + (5, 40.0, 20.0) +); +---- +1 0 NULL NULL +2 0 NULL NULL +3 0 NULL NULL +4 25 50 1 +5 25 50 1 + +# Variance/stddev sliding-window regression test. Verifies that retracting +# the last valid row resets the aggregate state. +query IRRRR +SELECT + column1, + var_pop(column2) OVER ( + ORDER BY column1 + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ), + var_samp(column2) OVER ( + ORDER BY column1 + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ), + stddev_pop(column2) OVER ( + ORDER BY column1 + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ), + stddev_samp(column2) OVER ( + ORDER BY column1 + ROWS BETWEEN 1 PRECEDING AND CURRENT ROW + ) +FROM ( + VALUES + (1, 10.0), + (2, NULL), + (3, NULL), + (4, 30.0), + (5, 40.0) +); +---- +1 0 NULL 0 NULL +2 0 NULL 0 NULL +3 NULL NULL NULL NULL +4 0 NULL 0 NULL +5 25 50 5 7.071067811865 + # Decimal variant — the integer-division path would otherwise panic on an # empty frame. query IR