Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ async fn window_using_aggregates() -> Result<()> {
| -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 |
| -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 |
| -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 |
| -85 | -56 | 2 | -70 | 57 | -56 | -85 | 1 | -25 |
| -85 | -56 | 2 | -70 | -70 | -56 | -85 | 1 | -25 |
| -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 |
| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |
| -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 |
Expand Down
22 changes: 19 additions & 3 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,25 @@ fn calculate_median<T: ArrowNumericType>(values: &mut [T::Native]) -> Option<T::
let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp);
// Get the maximum of the low (left side after bi-partitioning)
let left_max = slice_max::<T>(low);
let median = left_max
.add_wrapping(*high)
.div_wrapping(T::Native::usize_as(2));
// Calculate median as the average of the two middle values.
// Use checked arithmetic to detect overflow and fall back to safe formula.
let two = T::Native::usize_as(2);
let median = match left_max.add_checked(*high) {
Ok(sum) => sum.div_wrapping(two),
Err(_) => {
// Overflow detected - use safe midpoint formula:
// a/2 + b/2 + ((a%2 + b%2) / 2)
// This avoids overflow by dividing before adding.
let half_left = left_max.div_wrapping(two);
let half_right = (*high).div_wrapping(two);
let rem_left = left_max.mod_wrapping(two);
let rem_right = (*high).mod_wrapping(two);
// The sum of remainders (0, 1, or 2 for unsigned; -2 to 2 for signed)
// divided by 2 gives the correction factor (0 or 1 for unsigned; -1, 0, or 1 for signed)
let correction = rem_left.add_wrapping(rem_right).div_wrapping(two);
half_left.add_wrapping(half_right).add_wrapping(correction)
}
};
Some(median)
} else {
let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp);
Expand Down
56 changes: 56 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,62 @@ SELECT approx_median(col_f64_nan) FROM median_table
----
NaN


# median_i8_overflow_negative
query I
SELECT median(v) FROM (VALUES (arrow_cast(-85, 'Int8')), (arrow_cast(-56, 'Int8'))) AS t(v);
----
-70

# median_i8_overflow_positive
# Test overflow with positive values: 100 + 120 = 220 > 127 (max i8)
query I
SELECT median(v) FROM (VALUES (arrow_cast(100, 'Int8')), (arrow_cast(120, 'Int8'))) AS t(v);
----
110

# median_u8_overflow
# Test unsigned overflow: 200 + 250 = 450 > 255 (max u8)
query I
SELECT median(v) FROM (VALUES (arrow_cast(200, 'UInt8')), (arrow_cast(250, 'UInt8'))) AS t(v);
----
225

# median_i8_no_overflow_normal_case
# Normal case that doesn't overflow for comparison
query I
SELECT median(v) FROM (VALUES (arrow_cast(4, 'Int8')), (arrow_cast(5, 'Int8'))) AS t(v);
----
4

# median_i8_max_values
# Test with both i8::MAX values: 127 + 127 = 254 > 127, overflow
query I
SELECT median(v) FROM (VALUES (arrow_cast(127, 'Int8')), (arrow_cast(127, 'Int8'))) AS t(v);
----
127

# median_i8_min_values
# Test with both i8::MIN values: -128 + -128 = -256 < -128, underflow
query I
SELECT median(v) FROM (VALUES (arrow_cast(-128, 'Int8')), (arrow_cast(-128, 'Int8'))) AS t(v);
----
-128

# median_i8_min_max_values
# Test with i8::MIN and i8::MAX: -128 + 127 = -1, no overflow, median = 0 (truncated from -0.5)
query I
SELECT median(v) FROM (VALUES (arrow_cast(-128, 'Int8')), (arrow_cast(127, 'Int8'))) AS t(v);
----
0

# median_u8_max_values
# Test with both u8::MAX values: 255 + 255 = 510 > 255, overflow
query I
SELECT median(v) FROM (VALUES (arrow_cast(255, 'UInt8')), (arrow_cast(255, 'UInt8'))) AS t(v);
----
255

# median_sliding_window
statement ok
CREATE TABLE median_window_test (
Expand Down