diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c8b825d576e02..5a2080990e386 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -3791,6 +3791,90 @@ mod tests { Ok(()) } + /// When `skip_partial_aggregation_probe_ratio_threshold` is set to 1.0, + /// the feature must be effectively disabled: even with 100% cardinality + /// (every row is a unique group), no rows should be skipped. + #[tokio::test] + async fn test_skip_aggregation_disabled_at_threshold_one() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int32, true), + Field::new("val", DataType::Int32, true), + ])); + + let group_by = + PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); + + let aggr_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) + .schema(Arc::clone(&schema)) + .alias(String::from("COUNT(val)")) + .build() + .map(Arc::new)?, + ]; + + // Two batches are required: batch 1 triggers the probe threshold so the + // skip decision is evaluated; batch 2 is what would be skipped on main + // (where >= caused threshold=1.0 to still skip at 100% cardinality). + // All rows have unique keys => ratio = 1.0 (100% cardinality). + let input_data = vec![ + // Batch 1: fires the probe check (ratio = 5/5 = 1.0) + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0])), + ], + ) + .unwrap(), + // Batch 2: would be skipped if threshold=1.0 did not disable the feature + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10])), + Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0])), + ], + ) + .unwrap(), + ]; + + let input = + TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?; + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by, + aggr_expr, + vec![None], + Arc::clone(&input) as Arc, + schema, + )?); + + let session_config = SessionConfig::default() + .set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::Int64(Some(1)), + ) + .set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(1.0)), + ); + + let ctx = TaskContext::default().with_session_config(session_config); + collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + + let metrics = aggregate_exec.metrics().unwrap(); + let skipped_rows = metrics + .sum_by_name("skipped_aggregation_rows") + .map(|m| m.as_usize()) + .unwrap_or(0); + + assert_eq!( + skipped_rows, 0, + "threshold=1.0 should disable skip aggregation, but {skipped_rows} rows were skipped" + ); + + Ok(()) + } + #[test] fn group_exprs_nullable() -> Result<()> { let input_schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1164fb37b384a..c3f73976c721a 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -195,7 +195,7 @@ impl SkipAggregationProbe { self.num_groups = num_groups; if self.input_rows >= self.probe_rows_threshold { self.should_skip = self.num_groups as f64 / self.input_rows as f64 - >= self.probe_ratio_threshold; + > self.probe_ratio_threshold; // Set is_locked to true only if we have decided to skip, otherwise we can try to skip // during processing the next record_batch. self.is_locked = self.should_skip; @@ -644,14 +644,20 @@ impl GroupedHashAggregateStream { options.skip_partial_aggregation_probe_rows_threshold; let probe_ratio_threshold = options.skip_partial_aggregation_probe_ratio_threshold; - let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) - .with_category(MetricCategory::Rows) - .counter("skipped_aggregation_rows", partition); - Some(SkipAggregationProbe::new( - probe_rows_threshold, - probe_ratio_threshold, - skipped_aggregation_rows, - )) + // A threshold >= 1.0 means the ratio (num_groups / input_rows) can + // never exceed it, so the feature is effectively disabled. + if probe_ratio_threshold >= 1.0 { + None + } else { + let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) + .with_category(MetricCategory::Rows) + .counter("skipped_aggregation_rows", partition); + Some(SkipAggregationProbe::new( + probe_rows_threshold, + probe_ratio_threshold, + skipped_aggregation_rows, + )) + } } else { None }; @@ -1630,11 +1636,11 @@ mod tests { ], )?; - // Batch 2: 350 rows with 350 unique NEW groups (starting from group 10) - // After batch 2, total: 450 rows, 360 groups - // Ratio: 360/450 = 0.8 (80%) >= 0.8 -> SHOULD decide to skip - let batch2_rows = 350; - let batch2_groups = 350; + // Batch 2: 360 rows with 360 unique NEW groups (starting from group 10) + // After batch 2, total: 460 rows, 370 groups + // Ratio: 370/460 ≈ 0.804 (80.4%) > 0.8 -> SHOULD decide to skip + let batch2_rows = 360; + let batch2_groups = 360; let group_ids_batch2: Vec = (batch1_groups..(batch1_groups + batch2_groups)) .map(|x| x as i32) .collect(); @@ -1817,4 +1823,25 @@ mod tests { Ok(()) } + + #[test] + fn test_skip_aggregation_probe_equality_does_not_skip() { + // When num_groups / input_rows == probe_ratio_threshold, the `>` boundary + // means we must NOT skip — equality is not sufficient to trigger skip. + let threshold_ratio = 0.5_f64; + let threshold_rows = 10_usize; + let mut probe = SkipAggregationProbe::new( + threshold_rows, + threshold_ratio, + metrics::Count::new(), + ); + + // 10 rows, 5 groups → ratio = 5/10 = 0.5 exactly equals threshold + probe.update_state(10, 5); + + assert!( + !probe.should_skip(), + "ratio == threshold should not trigger skip (boundary is exclusive)" + ); + } }