From b76539fee01a8b78c0f27e52e07cd0eb7e403024 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Thu, 25 Dec 2025 20:23:12 +0200 Subject: [PATCH 01/15] feat: Improve sort memory resilience --- .../physical-plan/src/aggregates/row_hash.rs | 2 +- .../src/sorts/multi_level_merge.rs | 8 +- datafusion/physical-plan/src/sorts/sort.rs | 988 ++++++++++++++++-- datafusion/physical-plan/src/stream.rs | 252 +++++ 4 files changed, 1163 insertions(+), 87 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index cb22fbf9a06a1..258157b316ab8 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1198,7 +1198,7 @@ impl GroupedHashAggregateStream { // instead. // Spilling to disk and reading back also ensures batch size is consistent // rather than potentially having one significantly larger last batch. - self.spill()?; + self.spill()?; // TODO: use sort_batch_chunked instead // Mark that we're switching to stream merging mode. self.spill_state.is_stream_merging = true; diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 3540f1de3ed10..7545c63bed38d 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -27,10 +27,10 @@ use std::sync::Arc; use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; +use arrow::util::bit_util::round_upto_multiple_of_64; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; -use crate::sorts::sort::get_reserved_byte_for_record_batch_size; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::stream::RecordBatchStreamAdapter; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -360,9 +360,9 @@ impl MultiLevelMergeBuilder { for spill in &self.sorted_spill_files { // For memory pools that are not shared this is good, for other this is not // and there should be some upper limit to memory reservation so we won't starve the system - match reservation.try_grow(get_reserved_byte_for_record_batch_size( - spill.max_record_batch_memory * buffer_len, - )) { + match reservation.try_grow( + round_upto_multiple_of_64(spill.max_record_batch_memory) * buffer_len, + ) { Ok(_) => { number_of_spills_to_read_for_current_phase += 1; } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 18cdcbe9debcc..757ec81d29ffe 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -34,16 +34,15 @@ use crate::filter_pushdown::{ }; use crate::limit::LimitStream; use crate::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, SpillMetrics, - SplitMetrics, + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; use crate::projection::{ProjectionExec, make_with_child, update_ordering}; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; -use crate::stream::BatchSplitStream; use crate::stream::RecordBatchStreamAdapter; +use crate::stream::ReservationStream; use crate::topk::TopK; use crate::topk::TopKDynamicFilters; use crate::{ @@ -55,6 +54,7 @@ use crate::{ use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray}; use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays}; use arrow::datatypes::SchemaRef; +use arrow::util::bit_util::round_upto_multiple_of_64; use datafusion_common::config::SpillCompression; use datafusion_common::{ DataFusionError, Result, assert_or_internal_err, internal_datafusion_err, @@ -75,8 +75,6 @@ struct ExternalSorterMetrics { baseline: BaselineMetrics, spill_metrics: SpillMetrics, - - split_metrics: SplitMetrics, } impl ExternalSorterMetrics { @@ -84,7 +82,6 @@ impl ExternalSorterMetrics { Self { baseline: BaselineMetrics::new(metrics, partition), spill_metrics: SpillMetrics::new(metrics, partition), - split_metrics: SplitMetrics::new(metrics, partition), } } } @@ -545,7 +542,7 @@ impl ExternalSorter { while let Some(batch) = sorted_stream.next().await { let batch = batch?; - let sorted_size = get_reserved_byte_for_record_batch(&batch); + let sorted_size = get_reserved_byte_for_record_batch(&batch)?; if self.reservation.try_grow(sorted_size).is_err() { // Although the reservation is not enough, the batch is // already in memory, so it's okay to combine it with previously @@ -662,7 +659,7 @@ impl ExternalSorter { if self.in_mem_batches.len() == 1 { let batch = self.in_mem_batches.swap_remove(0); let reservation = self.reservation.take(); - return self.sort_batch_stream(batch, metrics, reservation, true); + return self.sort_batch_stream(batch, &metrics, reservation); } // If less than sort_in_place_threshold_bytes, concatenate and sort in place @@ -671,10 +668,10 @@ impl ExternalSorter { let batch = concat_batches(&self.schema, &self.in_mem_batches)?; self.in_mem_batches.clear(); self.reservation - .try_resize(get_reserved_byte_for_record_batch(&batch)) + .try_resize(get_reserved_byte_for_record_batch(&batch)?) .map_err(Self::err_with_oom_context)?; let reservation = self.reservation.take(); - return self.sort_batch_stream(batch, metrics, reservation, true); + return self.sort_batch_stream(batch, &metrics, reservation); } let streams = std::mem::take(&mut self.in_mem_batches) @@ -683,15 +680,8 @@ impl ExternalSorter { let metrics = self.metrics.baseline.intermediate(); let reservation = self .reservation - .split(get_reserved_byte_for_record_batch(&batch)); - let input = self.sort_batch_stream( - batch, - metrics, - reservation, - // Passing false as `StreamingMergeBuilder` will split the - // stream into batches of `self.batch_size` rows. - false, - )?; + .split(get_reserved_byte_for_record_batch(&batch)?); + let input = self.sort_batch_stream(batch, &metrics, reservation)?; Ok(spawn_buffered(input, 1)) }) .collect::>()?; @@ -718,43 +708,68 @@ impl ExternalSorter { fn sort_batch_stream( &self, batch: RecordBatch, - metrics: BaselineMetrics, - reservation: MemoryReservation, - mut split: bool, + metrics: &BaselineMetrics, + mut reservation: MemoryReservation, ) -> Result { assert_eq!( - get_reserved_byte_for_record_batch(&batch), + get_reserved_byte_for_record_batch(&batch)?, reservation.size() ); - split = split && batch.num_rows() > self.batch_size; - let schema = batch.schema(); - let expressions = self.expr.clone(); - let stream = futures::stream::once(async move { - let _timer = metrics.elapsed_compute().timer(); + let batch_size = self.batch_size; + let output_row_metrics = metrics.output_rows().clone(); - let sorted = sort_batch(&batch, &expressions, None)?; + let stream = futures::stream::once(async move { + let schema = batch.schema(); - (&sorted).record_output(&metrics); + // Sort the batch immediately and get all output batches + let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; drop(batch); - drop(reservation); - Ok(sorted) - }); - let mut output: SendableRecordBatchStream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + // Free the old reservation and grow it to match the actual sorted output size + reservation.free(); - if split { - output = Box::pin(BatchSplitStream::new( - output, - self.batch_size, - self.metrics.split_metrics.clone(), - )); - } + Result::<_, DataFusionError>::Ok((schema, sorted_batches, reservation)) + }) + .then({ + move |batches| async move { + match batches { + Ok((schema, sorted_batches, mut reservation)) => { + // Calculate the total size of sorted batches + let total_sorted_size: usize = sorted_batches + .iter() + .map(get_record_batch_memory_size) + .sum(); + reservation + .try_grow(total_sorted_size) + .map_err(Self::err_with_oom_context)?; + + // Wrap in ReservationStream to hold the reservation + Ok(Box::pin(ReservationStream::new( + Arc::clone(&schema), + Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(sorted_batches.into_iter().map(Ok)), + )), + reservation, + )) as SendableRecordBatchStream) + } + Err(e) => Err(e), + } + } + }) + .try_flatten() + .map(move |batch| match batch { + Ok(batch) => { + output_row_metrics.add(batch.num_rows()); + Ok(batch) + } + Err(e) => Err(e), + }); - Ok(output) + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } /// If this sort may spill, pre-allocates @@ -780,7 +795,7 @@ impl ExternalSorter { &mut self, input: &RecordBatch, ) -> Result<()> { - let size = get_reserved_byte_for_record_batch(input); + let size = get_reserved_byte_for_record_batch(input)?; match self.reservation.try_grow(size) { Ok(_) => Ok(()), @@ -813,22 +828,14 @@ impl ExternalSorter { } /// Estimate how much memory is needed to sort a `RecordBatch`. -/// -/// This is used to pre-reserve memory for the sort/merge. The sort/merge process involves -/// creating sorted copies of sorted columns in record batches for speeding up comparison -/// in sorting and merging. The sorted copies are in either row format or array format. -/// Please refer to cursor.rs and stream.rs for more details. No matter what format the -/// sorted copies are, they will use more memory than the original record batch. -pub(crate) fn get_reserved_byte_for_record_batch_size(record_batch_size: usize) -> usize { - // 2x may not be enough for some cases, but it's a good start. - // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes` - // to compensate for the extra memory needed. - record_batch_size * 2 -} - -/// Estimate how much memory is needed to sort a `RecordBatch`. -fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { - get_reserved_byte_for_record_batch_size(get_record_batch_memory_size(batch)) +/// This is calculated by adding the record batch's memory size +/// (which can be much larger than expected for sliced record batches) +/// with the sliced buffer sizes, as that is the amount that will be needed to create the new buffer. +/// The latter is rounded up to the nearest multiple of 64 based on the architecture, +/// as this is how arrow creates buffers. +pub(super) fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> Result { + Ok(get_record_batch_memory_size(batch) + + round_upto_multiple_of_64(batch.get_sliced_size()?)) } impl Debug for ExternalSorter { @@ -853,15 +860,7 @@ pub fn sort_batch( .collect::>>()?; let indices = lexsort_to_indices(&sort_columns, fetch)?; - let mut columns = take_arrays(batch.columns(), &indices, None)?; - - // The columns may be larger than the unsorted columns in `batch` especially for variable length - // data types due to exponential growth when building the sort columns. We shrink the columns - // to prevent memory reservation failures, as well as excessive memory allocation when running - // merges in `SortPreservingMergeStream`. - columns.iter_mut().for_each(|c| { - c.shrink_to_fit(); - }); + let columns = take_arrays(batch.columns(), &indices, None)?; let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); Ok(RecordBatch::try_new_with_options( @@ -871,6 +870,48 @@ pub fn sort_batch( )?) } +/// Sort a batch and return the result as multiple batches of size `batch_size`. +/// This is useful when you want to avoid creating one large sorted batch in memory, +/// and instead want to process the sorted data in smaller chunks. +pub fn sort_batch_chunked( + batch: &RecordBatch, + expressions: &LexOrdering, + batch_size: usize, +) -> Result> { + let sort_columns = expressions + .iter() + .map(|expr| expr.evaluate_to_sort_column(batch)) + .collect::>>()?; + + let indices = lexsort_to_indices(&sort_columns, None)?; + + // Split indices into chunks of batch_size + let num_rows = indices.len(); + let num_chunks = num_rows.div_ceil(batch_size); + + let result_batches = (0..num_chunks) + .map(|chunk_idx| { + let start = chunk_idx * batch_size; + let end = (start + batch_size).min(num_rows); + let chunk_len = end - start; + + // Create a slice of indices for this chunk + let chunk_indices = indices.slice(start, chunk_len); + + // Take the columns using this chunk of indices + let columns = take_arrays(batch.columns(), &chunk_indices, None)?; + + let options = RecordBatchOptions::new().with_row_count(Some(chunk_len)); + let chunk_batch = + RecordBatch::try_new_with_options(batch.schema(), columns, &options)?; + + Ok(chunk_batch) + }) + .collect::>>()?; + + Ok(result_batches) +} + /// Sort execution plan. /// /// Support sorting datasets that are larger than the memory allotted @@ -1173,10 +1214,7 @@ impl ExecutionPlan for SortExec { children: Vec>, ) -> Result> { let mut new_sort = self.cloned(); - assert!( - children.len() == 1, - "SortExec should have exactly one child" - ); + assert_eq!(children.len(), 1, "SortExec should have exactly one child"); new_sort.input = Arc::clone(&children[0]); // Recompute the properties based on the new input since they may have changed let (cache, sort_prefix) = Self::compute_properties( @@ -1623,13 +1661,24 @@ mod tests { #[tokio::test] async fn test_batch_reservation_error() -> Result<()> { // Pick a memory limit and sort_spill_reservation that make the first batch reservation fail. - // These values assume that the ExternalSorter will reserve 800 bytes for the first batch. - let expected_batch_reservation = 800; let merge_reservation: usize = 0; // Set to 0 for simplicity - let memory_limit: usize = expected_batch_reservation + merge_reservation - 1; // Just short of what we need let session_config = SessionConfig::new().with_sort_spill_reservation_bytes(merge_reservation); + + let plan = test::scan_partitioned(1); + + // Read the first record batch to determine the actual memory requirement + let expected_batch_reservation = { + let temp_ctx = Arc::new(TaskContext::default()); + let mut stream = plan.execute(0, Arc::clone(&temp_ctx))?; + let first_batch = stream.next().await.unwrap()?; + get_reserved_byte_for_record_batch(&first_batch)? + }; + + // Set memory limit just short of what we need + let memory_limit: usize = expected_batch_reservation + merge_reservation - 1; + let runtime = RuntimeEnvBuilder::new() .with_memory_limit(memory_limit, 1.0) .build_arc()?; @@ -1639,14 +1688,11 @@ mod tests { .with_runtime(runtime), ); - let plan = test::scan_partitioned(1); - - // Read the first record batch to assert that our memory limit and sort_spill_reservation - // settings trigger the test scenario. + // Verify that our memory limit is insufficient { let mut stream = plan.execute(0, Arc::clone(&task_ctx))?; let first_batch = stream.next().await.unwrap()?; - let batch_reservation = get_reserved_byte_for_record_batch(&first_batch); + let batch_reservation = get_reserved_byte_for_record_batch(&first_batch)?; assert_eq!(batch_reservation, expected_batch_reservation); assert!(memory_limit < (merge_reservation + batch_reservation)); @@ -2288,6 +2334,9 @@ mod tests { .map(|b| b.get_array_memory_size()) .sum::(); + // Use half the batch memory to force spilling + let memory_limit = batches_memory / 2; + TaskContext::default() .with_session_config( SessionConfig::new() @@ -2298,7 +2347,7 @@ mod tests { ) .with_runtime( RuntimeEnvBuilder::default() - .with_memory_limit(batches_memory, 1.0) + .with_memory_limit(memory_limit, 1.0) .build_arc() .unwrap(), ) @@ -2402,4 +2451,779 @@ mod tests { Ok((sorted_batches, metrics)) } + + // ======================================================================== + // Tests for sort_batch_chunked() + // ======================================================================== + + #[tokio::test] + async fn test_sort_batch_chunked_basic() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 1000 rows + let mut values: Vec = (0..1000).collect(); + // Shuffle to make it unsorted + values.reverse(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( + Column::new("a", 0), + ))] + .into(); + + // Sort with batch_size = 250 + let result_batches = sort_batch_chunked(&batch, &expressions, 250)?; + + // Verify 4 batches are returned + assert_eq!(result_batches.len(), 4); + + // Verify each batch has <= 250 rows + let mut total_rows = 0; + for (i, batch) in result_batches.iter().enumerate() { + assert!( + batch.num_rows() <= 250, + "Batch {} has {} rows, expected <= 250", + i, + batch.num_rows() + ); + total_rows += batch.num_rows(); + } + + // Verify total row count matches input + assert_eq!(total_rows, 1000); + + // Verify data is correctly sorted across all chunks + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!( + array.value(i) <= array.value(i + 1), + "Array not sorted at position {}: {} > {}", + i, + array.value(i), + array.value(i + 1) + ); + } + assert_eq!(array.value(0), 0); + assert_eq!(array.value(array.len() - 1), 999); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_smaller_than_batch_size() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 50 rows + let values: Vec = (0..50).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( + Column::new("a", 0), + ))] + .into(); + + // Sort with batch_size = 100 + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Should return exactly 1 batch + assert_eq!(result_batches.len(), 1); + assert_eq!(result_batches[0].num_rows(), 50); + + // Verify it's correctly sorted + let array = as_primitive_array::(result_batches[0].column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + assert_eq!(array.value(0), 0); + assert_eq!(array.value(49), 49); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_exact_multiple() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 1000 rows + let values: Vec = (0..1000).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( + Column::new("a", 0), + ))] + .into(); + + // Sort with batch_size = 100 + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Should return exactly 10 batches of 100 rows each + assert_eq!(result_batches.len(), 10); + for batch in &result_batches { + assert_eq!(batch.num_rows(), 100); + } + + // Verify sorted correctly across all batches + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_with_nulls() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + // Create a batch with nulls + let values = Int32Array::from(vec![ + Some(5), + None, + Some(2), + Some(8), + None, + Some(1), + Some(9), + None, + Some(3), + Some(7), + ]); + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(values)])?; + + // Test with nulls_first = true + { + let expressions: LexOrdering = [PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }] + .into(); + + let result_batches = sort_batch_chunked(&batch, &expressions, 4)?; + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + + // First 3 should be null + assert!(array.is_null(0)); + assert!(array.is_null(1)); + assert!(array.is_null(2)); + // Then sorted values + assert_eq!(array.value(3), 1); + assert_eq!(array.value(4), 2); + } + + // Test with nulls_first = false + { + let expressions: LexOrdering = [PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }] + .into(); + + let result_batches = sort_batch_chunked(&batch, &expressions, 4)?; + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + + // First should be 1 + assert_eq!(array.value(0), 1); + // Last 3 should be null + assert!(array.is_null(7)); + assert!(array.is_null(8)); + assert!(array.is_null(9)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_multi_column() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + // Create a batch with multiple columns + let a_values = Int32Array::from(vec![3, 1, 2, 1, 3, 2, 1, 3, 2, 1]); + let b_values = Int32Array::from(vec![1, 2, 3, 1, 2, 1, 3, 3, 2, 4]); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(a_values), Arc::new(b_values)], + )?; + + let expressions: LexOrdering = [ + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + ] + .into(); + + let result_batches = sort_batch_chunked(&batch, &expressions, 3)?; + let concatenated = concat_batches(&schema, &result_batches)?; + + let a_array = as_primitive_array::(concatenated.column(0))?; + let b_array = as_primitive_array::(concatenated.column(1))?; + + // Verify multi-column sort ordering + for i in 0..a_array.len() - 1 { + let a_curr = a_array.value(i); + let a_next = a_array.value(i + 1); + let b_curr = b_array.value(i); + let b_next = b_array.value(i + 1); + + assert!( + a_curr < a_next || (a_curr == a_next && b_curr <= b_next), + "Not properly sorted at position {}: ({}, {}) -> ({}, {})", + i, + a_curr, + b_curr, + a_next, + b_next + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_empty_batch() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let batch = RecordBatch::new_empty(Arc::clone(&schema)); + + let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( + Column::new("a", 0), + ))] + .into(); + + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Empty input produces no output batches (0 chunks) + assert_eq!(result_batches.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_single_row() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![42]))], + )?; + + let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( + Column::new("a", 0), + ))] + .into(); + + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + assert_eq!(result_batches.len(), 1); + assert_eq!(result_batches[0].num_rows(), 1); + let array = as_primitive_array::(result_batches[0].column(0))?; + assert_eq!(array.value(0), 42); + + Ok(()) + } + + // ======================================================================== + // Tests for get_reserved_byte_for_record_batch() + // ======================================================================== + + #[tokio::test] + async fn test_get_reserved_byte_for_record_batch_normal_batch() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + )?; + + let reserved = get_reserved_byte_for_record_batch(&batch)?; + + // Should be greater than 0 + assert!(reserved > 0); + + Ok(()) + } + + #[tokio::test] + async fn test_get_reserved_byte_for_record_batch_with_sliced_batches( + ) -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a larger batch then slice it + let large_array = Int32Array::from((0..1000).collect::>()); + let sliced_array = large_array.slice(100, 50); // Take 50 elements starting at 100 + + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(sliced_array)])?; + + let reserved = get_reserved_byte_for_record_batch(&batch)?; + + // Reserved should account for the sliced nature + assert!(reserved > 0); + + // The reservation should include memory for the full underlying buffer + // plus the sliced size rounded to 64 + let record_batch_size = get_record_batch_memory_size(&batch); + assert!(reserved >= record_batch_size); + + Ok(()) + } + + #[tokio::test] + async fn test_get_reserved_byte_for_record_batch_rounding() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with a size that's not a multiple of 64 + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + let reserved = get_reserved_byte_for_record_batch(&batch)?; + + // Should be rounded to multiple of 64 + assert!(reserved > 0); + // The rounding is applied to the sliced size component + // Total = record_batch_memory_size + round_upto_multiple_of_64(sliced_size) + + Ok(()) + } + + #[tokio::test] + async fn test_get_reserved_byte_for_record_batch_with_string_view() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Utf8View, + false, + )])); + + let string_array = StringViewArray::from(vec!["hello", "world", "test"]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(string_array)])?; + + let reserved = get_reserved_byte_for_record_batch(&batch)?; + + // Should handle variable-length data correctly + assert!(reserved > 0); + + Ok(()) + } + + // ======================================================================== + // Tests for ReservationStream (in stream.rs, but we test integration here) + // ======================================================================== + + #[tokio::test] + async fn test_sort_batch_stream_memory_tracking() -> Result<()> { + use crate::stream::ReservationStream; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB limit + .build_arc()?; + + let reservation = MemoryConsumer::new("test") + .register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch + let values: Vec = (0..1000).collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let batch_size = get_record_batch_memory_size(&batch); + + // Create a simple stream with one batch + let stream = futures::stream::iter(vec![Ok(batch)]); + let inner = Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + stream, + )) as SendableRecordBatchStream; + + // Create reservation and grow it + let mut reservation = reservation; + reservation.try_grow(batch_size)?; + + let initial_reserved = runtime.memory_pool.reserved(); + assert!(initial_reserved > 0); + + // Create ReservationStream + let mut res_stream = + ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Consume the batch + let result = res_stream.next().await; + assert!(result.is_some()); + + // Memory should be reduced after consuming + let after_consume = runtime.memory_pool.reserved(); + assert!(after_consume < initial_reserved); + + // Consume until end + while res_stream.next().await.is_some() {} + + // Memory should be freed + let final_reserved = runtime.memory_pool.reserved(); + assert_eq!(final_reserved, 0); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_stream_chunked_output() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a large batch (5000 rows) + let values: Vec = (0..5000).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( + Column::new("a", 0), + ))] + .into(); + + let batch_size = 500; + let result_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; + + // Verify multiple output batches + assert_eq!(result_batches.len(), 10); + + // Each batch should be <= batch_size + let mut total_rows = 0; + for batch in &result_batches { + assert!(batch.num_rows() <= batch_size); + total_rows += batch.num_rows(); + } + + assert_eq!(total_rows, 5000); + + // Verify data is sorted + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_sort_no_batch_split_stream_metrics() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let partitions = 2; + let csv = test::scan_partitioned(partitions); + let schema = csv.schema(); + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }] + .into(), + Arc::new(CoalescePartitionsExec::new(csv)), + )); + + let _result = collect(sort_exec.clone(), task_ctx).await?; + + let metrics = sort_exec.metrics().unwrap(); + + // Verify that SplitMetrics are not present + // The metrics should only include baseline and spill metrics + let metrics_str = format!("{:?}", metrics); + + // Should not contain split-related metrics + assert!( + !metrics_str.contains("split_count"), + "Should not have split_count metric" + ); + assert!( + !metrics_str.contains("split_time"), + "Should not have split_time metric" + ); + + // Should still have baseline and spill metrics + assert!(metrics.output_rows().is_some()); + assert!(metrics.elapsed_compute().is_some()); + + Ok(()) + } + + #[tokio::test] + async fn test_external_sorter_with_chunked_batches() -> Result<()> { + // Test with memory limits that trigger spilling + let session_config = SessionConfig::new().with_batch_size(100); + let sort_spill_reservation_bytes = session_config + .options() + .execution + .sort_spill_reservation_bytes; + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(sort_spill_reservation_bytes + 8192, 1.0) + .build_arc()?; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + let partitions = 50; + let input = test::scan_partitioned(partitions); + let schema = input.schema(); + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }] + .into(), + Arc::new(CoalescePartitionsExec::new(input)), + )); + + let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; + + // Verify results are correctly sorted + let concatenated = concat_batches(&schema, &result)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!( + array.value(i) <= array.value(i + 1), + "Not sorted at position {}: {} > {}", + i, + array.value(i), + array.value(i + 1) + ); + } + + // Verify spilling occurred + let metrics = sort_exec.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // Verify no memory leaks + assert_eq!( + task_ctx.runtime_env().memory_pool.reserved(), + 0, + "Memory should be fully released" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_exec_batch_size_respected() -> Result<()> { + let batch_size = 50; + let session_config = SessionConfig::new().with_batch_size(batch_size); + let task_ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create multiple batches with various sizes + let batches = vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from( + (0..200).rev().collect::>(), + ))], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from( + (200..350).rev().collect::>(), + ))], + )?, + ]; + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), + TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?, + )); + + let result = collect(sort_exec, task_ctx).await?; + + // Verify output batches respect batch_size + for (i, batch) in result.iter().enumerate() { + // All batches except possibly the last should have batch_size rows + if i < result.len() - 1 { + assert_eq!( + batch.num_rows(), + batch_size, + "Batch {} should have {} rows", + i, + batch_size + ); + } else { + // Last batch can be smaller + assert!( + batch.num_rows() <= batch_size, + "Last batch should have <= {} rows", + batch_size + ); + } + } + + // Verify total rows + let total_rows: usize = result.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 350); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_exec_with_multiple_partitions_chunked() -> Result<()> { + let batch_size = 100; + let session_config = SessionConfig::new().with_batch_size(batch_size); + let task_ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + + let partitions = 4; + let csv = test::scan_partitioned(partitions); + let schema = csv.schema(); + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }] + .into(), + Arc::new(CoalescePartitionsExec::new(csv)), + )); + + let result = collect(sort_exec.clone(), task_ctx.clone()).await?; + + // Verify correct sorting across partitions + let concatenated = concat_batches(&schema, &result)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + + // Verify batch sizes + for batch in &result { + assert!(batch.num_rows() <= batch_size); + } + + // Verify memory is released + assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_large_batch_memory_handling() -> Result<()> { + // Test with large batches and reasonable memory + let batch_size = 1000; + let session_config = SessionConfig::new() + .with_batch_size(batch_size) + .with_sort_spill_reservation_bytes(200 * 1024); + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500 * 1024, 1.0) // 500KB limit + .build_arc()?; + + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a large batch + let values: Vec = (0..10000).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?, + )); + + let result = collect(sort_exec, Arc::clone(&task_ctx)).await?; + + // Should handle memory pressure correctly without panicking + assert!(!result.is_empty()); + + // Verify data is sorted + let concatenated = concat_batches(&schema, &result)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + + // Verify memory is released + assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_with_fetch_limit_chunked() -> Result<()> { + let batch_size = 50; + let session_config = SessionConfig::new().with_batch_size(batch_size); + let task_ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let values: Vec = (0..1000).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let fetch_limit = 10; + let sort_exec = Arc::new( + SortExec::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?, + ) + .with_fetch(Some(fetch_limit)), + ); + + let result = collect(sort_exec, task_ctx.clone()).await?; + + // Verify correct number of rows returned + let total_rows: usize = result.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, fetch_limit); + + // Verify data is sorted + let concatenated = concat_batches(&schema, &result)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + assert_eq!(array.value(0), 0); + assert_eq!(array.value(array.len() - 1), 9); + + // Verify memory is released + assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 8b2ea1006893e..edd0dd673edb5 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -27,11 +27,13 @@ use super::metrics::ExecutionPlanMetricsSet; use super::metrics::{BaselineMetrics, SplitMetrics}; use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; use crate::displayable; +use crate::spill::get_record_batch_memory_size; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion_common::{Result, exec_err}; use datafusion_common_runtime::JoinSet; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryReservation; use futures::ready; use futures::stream::BoxStream; @@ -699,6 +701,69 @@ impl RecordBatchStream for BatchSplitStream { } } +/// A stream that holds a memory reservation for its lifetime, +/// shrinking the reservation as batches are consumed. +/// The original reservation must have its batch sizes calculated using [`get_record_batch_memory_size`] +pub struct ReservationStream { + schema: SchemaRef, + inner: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl ReservationStream { + pub fn new( + schema: SchemaRef, + inner: SendableRecordBatchStream, + reservation: MemoryReservation, + ) -> Self { + Self { + schema, + inner, + reservation, + } + } +} + +impl Stream for ReservationStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = self.inner.poll_next_unpin(cx); + + match res { + Poll::Ready(res) => { + match res { + Some(Ok(batch)) => { + self.reservation + .shrink(get_record_batch_memory_size(&batch)); + Poll::Ready(Some(Ok(batch))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => { + // Stream is done so free the reservation completely + self.reservation.free(); + Poll::Ready(None) + } + } + } + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl RecordBatchStream for ReservationStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + #[cfg(test)] mod test { use super::*; @@ -927,4 +992,191 @@ mod test { "Should have received exactly one empty batch" ); } + + #[tokio::test] + async fn test_reservation_stream_shrinks_on_poll() { + use arrow::array::Int32Array; + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc() + .unwrap(); + + let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create batches + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + ) + .unwrap(); + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10]))], + ) + .unwrap(); + + let batch1_size = get_record_batch_memory_size(&batch1); + let batch2_size = get_record_batch_memory_size(&batch2); + + // Reserve memory upfront + reservation.try_grow(batch1_size + batch2_size).unwrap(); + let initial_reserved = runtime.memory_pool.reserved(); + assert_eq!(initial_reserved, batch1_size + batch2_size); + + // Create stream with batches + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + let inner = + Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; + + let mut res_stream = + ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Poll first batch + let result1 = res_stream.next().await; + assert!(result1.is_some()); + + // Memory should be reduced by batch1_size + let after_first = runtime.memory_pool.reserved(); + assert_eq!(after_first, batch2_size); + + // Poll second batch + let result2 = res_stream.next().await; + assert!(result2.is_some()); + + // Memory should be reduced by batch2_size + let after_second = runtime.memory_pool.reserved(); + assert_eq!(after_second, 0); + + // Poll None (end of stream) + let result3 = res_stream.next().await; + assert!(result3.is_none()); + + // Memory should still be 0 + assert_eq!(runtime.memory_pool.reserved(), 0); + } + + #[tokio::test] + async fn test_reservation_stream_frees_on_completion() { + use arrow::array::Int32Array; + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc() + .unwrap(); + + let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) + .unwrap(); + + let batch_size = get_record_batch_memory_size(&batch); + reservation.try_grow(batch_size).unwrap(); + + assert!(runtime.memory_pool.reserved() > 0); + + let stream = futures::stream::iter(vec![Ok(batch)]); + let inner = + Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; + + let mut res_stream = + ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Consume all batches + while let Some(_) = res_stream.next().await {} + + // Memory should be fully freed + assert_eq!(runtime.memory_pool.reserved(), 0); + } + + #[tokio::test] + async fn test_reservation_stream_error_handling() { + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc() + .unwrap(); + + let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + reservation.try_grow(1000).unwrap(); + let initial = runtime.memory_pool.reserved(); + assert_eq!(initial, 1000); + + // Create a stream that errors + let stream = futures::stream::iter(vec![exec_err!("Test error")]); + let inner = + Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; + + let mut res_stream = + ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Get the error + let result = res_stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_err()); + + // Stream should be done, but reservation might not be freed yet + // since we didn't consume to None + // This is expected behavior - the reservation is only freed when the stream ends normally + } + + #[tokio::test] + async fn test_reservation_stream_schema_preserved() { + use arrow::array::Int32Array; + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc() + .unwrap(); + + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![4, 5, 6])), + ], + ) + .unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch)]); + let inner = + Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; + + let res_stream = ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Verify schema is preserved + let stream_schema = res_stream.schema(); + assert_eq!(stream_schema.fields().len(), 2); + assert_eq!(stream_schema.field(0).name(), "a"); + assert_eq!(stream_schema.field(1).name(), "b"); + } } From a06760f111e7827169ff1aa3ad356ec4f13d3ee1 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Thu, 25 Dec 2025 20:35:58 +0200 Subject: [PATCH 02/15] fmt --- datafusion/physical-plan/src/sorts/sort.rs | 60 +++++++++------------- datafusion/physical-plan/src/stream.rs | 20 +++----- 2 files changed, 31 insertions(+), 49 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 757ec81d29ffe..16c70ddc208f8 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -2470,10 +2470,8 @@ mod tests { vec![Arc::new(Int32Array::from(values))], )?; - let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( - Column::new("a", 0), - ))] - .into(); + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); // Sort with batch_size = 250 let result_batches = sort_batch_chunked(&batch, &expressions, 250)?; @@ -2525,10 +2523,8 @@ mod tests { vec![Arc::new(Int32Array::from(values))], )?; - let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( - Column::new("a", 0), - ))] - .into(); + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); // Sort with batch_size = 100 let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; @@ -2559,10 +2555,8 @@ mod tests { vec![Arc::new(Int32Array::from(values))], )?; - let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( - Column::new("a", 0), - ))] - .into(); + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); // Sort with batch_size = 100 let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; @@ -2600,8 +2594,7 @@ mod tests { Some(3), Some(7), ]); - let batch = - RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(values)])?; + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(values)])?; // Test with nulls_first = true { @@ -2708,10 +2701,8 @@ mod tests { let batch = RecordBatch::new_empty(Arc::clone(&schema)); - let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( - Column::new("a", 0), - ))] - .into(); + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; @@ -2730,10 +2721,8 @@ mod tests { vec![Arc::new(Int32Array::from(vec![42]))], )?; - let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( - Column::new("a", 0), - ))] - .into(); + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; @@ -2766,8 +2755,7 @@ mod tests { } #[tokio::test] - async fn test_get_reserved_byte_for_record_batch_with_sliced_batches( - ) -> Result<()> { + async fn test_get_reserved_byte_for_record_batch_with_sliced_batches() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); // Create a larger batch then slice it @@ -2841,8 +2829,7 @@ mod tests { .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB limit .build_arc()?; - let reservation = MemoryConsumer::new("test") - .register(&runtime.memory_pool); + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -2857,10 +2844,8 @@ mod tests { // Create a simple stream with one batch let stream = futures::stream::iter(vec![Ok(batch)]); - let inner = Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - stream, - )) as SendableRecordBatchStream; + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; // Create reservation and grow it let mut reservation = reservation; @@ -2902,10 +2887,8 @@ mod tests { vec![Arc::new(Int32Array::from(values))], )?; - let expressions: LexOrdering = [PhysicalSortExpr::new_default(Arc::new( - Column::new("a", 0), - ))] - .into(); + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); let batch_size = 500; let result_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; @@ -3039,7 +3022,8 @@ mod tests { async fn test_sort_exec_batch_size_respected() -> Result<()> { let batch_size = 50; let session_config = SessionConfig::new().with_batch_size(batch_size); - let task_ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -3098,7 +3082,8 @@ mod tests { async fn test_sort_exec_with_multiple_partitions_chunked() -> Result<()> { let batch_size = 100; let session_config = SessionConfig::new().with_batch_size(batch_size); - let task_ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); let partitions = 4; let csv = test::scan_partitioned(partitions); @@ -3187,7 +3172,8 @@ mod tests { async fn test_sort_with_fetch_limit_chunked() -> Result<()> { let batch_size = 50; let session_config = SessionConfig::new().with_batch_size(batch_size); - let task_ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index edd0dd673edb5..7f08a221c25ea 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -1030,9 +1030,8 @@ mod test { // Create stream with batches let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); - let inner = - Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) - as SendableRecordBatchStream; + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; let mut res_stream = ReservationStream::new(Arc::clone(&schema), inner, reservation); @@ -1088,9 +1087,8 @@ mod test { assert!(runtime.memory_pool.reserved() > 0); let stream = futures::stream::iter(vec![Ok(batch)]); - let inner = - Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) - as SendableRecordBatchStream; + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; let mut res_stream = ReservationStream::new(Arc::clone(&schema), inner, reservation); @@ -1122,9 +1120,8 @@ mod test { // Create a stream that errors let stream = futures::stream::iter(vec![exec_err!("Test error")]); - let inner = - Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) - as SendableRecordBatchStream; + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; let mut res_stream = ReservationStream::new(Arc::clone(&schema), inner, reservation); @@ -1167,9 +1164,8 @@ mod test { .unwrap(); let stream = futures::stream::iter(vec![Ok(batch)]); - let inner = - Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) - as SendableRecordBatchStream; + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; let res_stream = ReservationStream::new(Arc::clone(&schema), inner, reservation); From 9ccf67243d61760fabeca0991a8a56786478a46b Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Thu, 25 Dec 2025 20:51:08 +0200 Subject: [PATCH 03/15] fix clippy --- datafusion/physical-plan/src/sorts/sort.rs | 22 +++++++--------------- datafusion/physical-plan/src/stream.rs | 2 +- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 16c70ddc208f8..4025f2557fbeb 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -2683,12 +2683,7 @@ mod tests { assert!( a_curr < a_next || (a_curr == a_next && b_curr <= b_next), - "Not properly sorted at position {}: ({}, {}) -> ({}, {})", - i, - a_curr, - b_curr, - a_next, - b_next + "Not properly sorted at position {i}: ({a_curr}, {b_curr}) -> ({a_next}, {b_next})", ); } @@ -2931,13 +2926,13 @@ mod tests { Arc::new(CoalescePartitionsExec::new(csv)), )); - let _result = collect(sort_exec.clone(), task_ctx).await?; + let _result = collect(Arc::clone(&sort_exec) as _, task_ctx).await?; let metrics = sort_exec.metrics().unwrap(); // Verify that SplitMetrics are not present // The metrics should only include baseline and spill metrics - let metrics_str = format!("{:?}", metrics); + let metrics_str = format!("{metrics:?}"); // Should not contain split-related metrics assert!( @@ -3057,16 +3052,13 @@ mod tests { assert_eq!( batch.num_rows(), batch_size, - "Batch {} should have {} rows", - i, - batch_size + "Batch {i} should have {batch_size} rows", ); } else { // Last batch can be smaller assert!( batch.num_rows() <= batch_size, - "Last batch should have <= {} rows", - batch_size + "Last batch should have <= {batch_size} rows", ); } } @@ -3098,7 +3090,7 @@ mod tests { Arc::new(CoalescePartitionsExec::new(csv)), )); - let result = collect(sort_exec.clone(), task_ctx.clone()).await?; + let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; // Verify correct sorting across partitions let concatenated = concat_batches(&schema, &result)?; @@ -3192,7 +3184,7 @@ mod tests { .with_fetch(Some(fetch_limit)), ); - let result = collect(sort_exec, task_ctx.clone()).await?; + let result = collect(sort_exec, Arc::clone(&task_ctx)).await?; // Verify correct number of rows returned let total_rows: usize = result.iter().map(|b| b.num_rows()).sum(); diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 7f08a221c25ea..b7d5e0bd7e7f2 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -1094,7 +1094,7 @@ mod test { ReservationStream::new(Arc::clone(&schema), inner, reservation); // Consume all batches - while let Some(_) = res_stream.next().await {} + while res_stream.next().await.is_some() {} // Memory should be fully freed assert_eq!(runtime.memory_pool.reserved(), 0); From 5420c3b9dee7cb88edf8ed2fdac9ee0154ab9fe7 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Thu, 25 Dec 2025 20:56:14 +0200 Subject: [PATCH 04/15] update docs --- datafusion/physical-plan/src/sorts/sort.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 4025f2557fbeb..8f0ce0ee6317c 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -699,12 +699,13 @@ impl ExternalSorter { /// Sorts a single `RecordBatch` into a single stream. /// - /// `reservation` accounts for the memory used by this batch and - /// is released when the sort is complete - /// - /// passing `split` true will return a [`BatchSplitStream`] where each batch maximum row count - /// will be `self.batch_size`. - /// If `split` is false, the stream will return a single batch + /// This may output multiple batches depending on the size of the + /// sorted data and the target batch size. + /// For single-batch output cases, `reservation` will be freed immediately after sorting, + /// as the batch will be output and is expected to be reserved by the consumer of the stream. + /// For multi-batch output cases, `reservation` will be grown to match the actual + /// size of sorted output, and as each batch is output, its memory will be freed from the reservation. + /// (This leads to the same behaviour, as futures are only evaluated when polled by the consumer.) fn sort_batch_stream( &self, batch: RecordBatch, From 17523e83b44cd4c7ca3471a6d550bc0c61e027d3 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 12:12:29 +0200 Subject: [PATCH 05/15] address CR --- datafusion/physical-plan/src/aggregates/row_hash.rs | 2 +- datafusion/physical-plan/src/sorts/sort.rs | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 258157b316ab8..1ae7202711112 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1198,7 +1198,7 @@ impl GroupedHashAggregateStream { // instead. // Spilling to disk and reading back also ensures batch size is consistent // rather than potentially having one significantly larger last batch. - self.spill()?; // TODO: use sort_batch_chunked instead + self.spill()?; // TODO: use sort_batch_chunked instead? // Mark that we're switching to stream merging mode. self.spill_state.is_stream_merging = true; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 8f0ce0ee6317c..86193457ec87e 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -829,11 +829,13 @@ impl ExternalSorter { } /// Estimate how much memory is needed to sort a `RecordBatch`. -/// This is calculated by adding the record batch's memory size -/// (which can be much larger than expected for sliced record batches) -/// with the sliced buffer sizes, as that is the amount that will be needed to create the new buffer. -/// The latter is rounded up to the nearest multiple of 64 based on the architecture, -/// as this is how arrow creates buffers. +/// +/// For sliced batches, `get_record_batch_memory_size` returns the size of the +/// underlying shared buffers (which may be larger than the logical data). +/// We add `get_sliced_size()` (the actual logical data size, rounded to 64 bytes) +/// because sorting will create new buffers containing only the referenced data. +/// +/// Total = existing buffer size + new sorted buffer size pub(super) fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> Result { Ok(get_record_batch_memory_size(batch) + round_upto_multiple_of_64(batch.get_sliced_size()?)) From 7b0214cbaddd71145307ae8879e3319e430b9dde Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 14:45:39 +0200 Subject: [PATCH 06/15] remove all sorts of tests --- datafusion/physical-plan/src/sorts/sort.rs | 338 ++------------------- 1 file changed, 30 insertions(+), 308 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 86193457ec87e..6b6441bace0fc 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -2580,75 +2580,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_sort_batch_chunked_with_nulls() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); - - // Create a batch with nulls - let values = Int32Array::from(vec![ - Some(5), - None, - Some(2), - Some(8), - None, - Some(1), - Some(9), - None, - Some(3), - Some(7), - ]); - let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(values)])?; - - // Test with nulls_first = true - { - let expressions: LexOrdering = [PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }] - .into(); - - let result_batches = sort_batch_chunked(&batch, &expressions, 4)?; - let concatenated = concat_batches(&schema, &result_batches)?; - let array = as_primitive_array::(concatenated.column(0))?; - - // First 3 should be null - assert!(array.is_null(0)); - assert!(array.is_null(1)); - assert!(array.is_null(2)); - // Then sorted values - assert_eq!(array.value(3), 1); - assert_eq!(array.value(4), 2); - } - - // Test with nulls_first = false - { - let expressions: LexOrdering = [PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }] - .into(); - - let result_batches = sort_batch_chunked(&batch, &expressions, 4)?; - let concatenated = concat_batches(&schema, &result_batches)?; - let array = as_primitive_array::(concatenated.column(0))?; - - // First should be 1 - assert_eq!(array.value(0), 1); - // Last 3 should be null - assert!(array.is_null(7)); - assert!(array.is_null(8)); - assert!(array.is_null(9)); - } - - Ok(()) - } - #[tokio::test] async fn test_sort_batch_chunked_multi_column() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -2710,28 +2641,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_sort_batch_chunked_single_row() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from(vec![42]))], - )?; - - let expressions: LexOrdering = - [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); - - let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; - - assert_eq!(result_batches.len(), 1); - assert_eq!(result_batches[0].num_rows(), 1); - let array = as_primitive_array::(result_batches[0].column(0))?; - assert_eq!(array.value(0), 42); - - Ok(()) - } - // ======================================================================== // Tests for get_reserved_byte_for_record_batch() // ======================================================================== @@ -2746,8 +2655,14 @@ mod tests { let reserved = get_reserved_byte_for_record_batch(&batch)?; - // Should be greater than 0 - assert!(reserved > 0); + // Calculate expected value: + // Total = existing buffer size + new sorted buffer size (rounded to 64) + let record_batch_size = get_record_batch_memory_size(&batch); + let sliced_size = batch.get_sliced_size()?; + let expected = record_batch_size + round_upto_multiple_of_64(sliced_size); + + assert_eq!(reserved, expected); + assert!(reserved > 0, "Reserved bytes should be greater than 0"); Ok(()) } @@ -2788,168 +2703,32 @@ mod tests { let reserved = get_reserved_byte_for_record_batch(&batch)?; - // Should be rounded to multiple of 64 - assert!(reserved > 0); - // The rounding is applied to the sliced size component - // Total = record_batch_memory_size + round_upto_multiple_of_64(sliced_size) - - Ok(()) - } - - #[tokio::test] - async fn test_get_reserved_byte_for_record_batch_with_string_view() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Utf8View, - false, - )])); - - let string_array = StringViewArray::from(vec!["hello", "world", "test"]); - let batch = RecordBatch::try_new(schema, vec![Arc::new(string_array)])?; - - let reserved = get_reserved_byte_for_record_batch(&batch)?; - - // Should handle variable-length data correctly - assert!(reserved > 0); - - Ok(()) - } - - // ======================================================================== - // Tests for ReservationStream (in stream.rs, but we test integration here) - // ======================================================================== - - #[tokio::test] - async fn test_sort_batch_stream_memory_tracking() -> Result<()> { - use crate::stream::ReservationStream; - - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB limit - .build_arc()?; - - let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); - - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - // Create a batch - let values: Vec = (0..1000).collect(); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from(values))], - )?; - - let batch_size = get_record_batch_memory_size(&batch); - - // Create a simple stream with one batch - let stream = futures::stream::iter(vec![Ok(batch)]); - let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) - as SendableRecordBatchStream; - - // Create reservation and grow it - let mut reservation = reservation; - reservation.try_grow(batch_size)?; - - let initial_reserved = runtime.memory_pool.reserved(); - assert!(initial_reserved > 0); - - // Create ReservationStream - let mut res_stream = - ReservationStream::new(Arc::clone(&schema), inner, reservation); - - // Consume the batch - let result = res_stream.next().await; - assert!(result.is_some()); - - // Memory should be reduced after consuming - let after_consume = runtime.memory_pool.reserved(); - assert!(after_consume < initial_reserved); - - // Consume until end - while res_stream.next().await.is_some() {} - - // Memory should be freed - let final_reserved = runtime.memory_pool.reserved(); - assert_eq!(final_reserved, 0); - - Ok(()) - } - - #[tokio::test] - async fn test_sort_batch_stream_chunked_output() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - // Create a large batch (5000 rows) - let values: Vec = (0..5000).rev().collect(); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from(values))], - )?; - - let expressions: LexOrdering = - [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); - - let batch_size = 500; - let result_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; - - // Verify multiple output batches - assert_eq!(result_batches.len(), 10); - - // Each batch should be <= batch_size - let mut total_rows = 0; - for batch in &result_batches { - assert!(batch.num_rows() <= batch_size); - total_rows += batch.num_rows(); - } - - assert_eq!(total_rows, 5000); - - // Verify data is sorted - let concatenated = concat_batches(&schema, &result_batches)?; - let array = as_primitive_array::(concatenated.column(0))?; - for i in 0..array.len() - 1 { - assert!(array.value(i) <= array.value(i + 1)); - } - - Ok(()) - } - - #[tokio::test] - async fn test_sort_no_batch_split_stream_metrics() -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); - let partitions = 2; - let csv = test::scan_partitioned(partitions); - let schema = csv.schema(); - - let sort_exec = Arc::new(SortExec::new( - [PhysicalSortExpr { - expr: col("i", &schema)?, - options: SortOptions::default(), - }] - .into(), - Arc::new(CoalescePartitionsExec::new(csv)), - )); - - let _result = collect(Arc::clone(&sort_exec) as _, task_ctx).await?; - - let metrics = sort_exec.metrics().unwrap(); + // Calculate expected value + let record_batch_size = get_record_batch_memory_size(&batch); + let sliced_size = batch.get_sliced_size()?; + let rounded_sliced_size = round_upto_multiple_of_64(sliced_size); + let expected = record_batch_size + rounded_sliced_size; - // Verify that SplitMetrics are not present - // The metrics should only include baseline and spill metrics - let metrics_str = format!("{metrics:?}"); + assert_eq!(reserved, expected); - // Should not contain split-related metrics - assert!( - !metrics_str.contains("split_count"), - "Should not have split_count metric" - ); - assert!( - !metrics_str.contains("split_time"), - "Should not have split_time metric" + // Verify that the sliced size component was indeed rounded to multiple of 64 + assert_eq!( + rounded_sliced_size % 64, + 0, + "Rounded sliced size should be a multiple of 64" ); - // Should still have baseline and spill metrics - assert!(metrics.output_rows().is_some()); - assert!(metrics.elapsed_compute().is_some()); + // If sliced_size is not already a multiple of 64, verify rounding occurred + if sliced_size % 64 != 0 { + assert!( + rounded_sliced_size > sliced_size, + "Rounding should have increased the size" + ); + assert!( + rounded_sliced_size - sliced_size < 64, + "Rounding should add less than 64 bytes" + ); + } Ok(()) } @@ -3016,63 +2795,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_sort_exec_batch_size_respected() -> Result<()> { - let batch_size = 50; - let session_config = SessionConfig::new().with_batch_size(batch_size); - let task_ctx = - Arc::new(TaskContext::default().with_session_config(session_config)); - - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - // Create multiple batches with various sizes - let batches = vec![ - RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from( - (0..200).rev().collect::>(), - ))], - )?, - RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from( - (200..350).rev().collect::>(), - ))], - )?, - ]; - - let sort_exec = Arc::new(SortExec::new( - [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), - TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?, - )); - - let result = collect(sort_exec, task_ctx).await?; - - // Verify output batches respect batch_size - for (i, batch) in result.iter().enumerate() { - // All batches except possibly the last should have batch_size rows - if i < result.len() - 1 { - assert_eq!( - batch.num_rows(), - batch_size, - "Batch {i} should have {batch_size} rows", - ); - } else { - // Last batch can be smaller - assert!( - batch.num_rows() <= batch_size, - "Last batch should have <= {batch_size} rows", - ); - } - } - - // Verify total rows - let total_rows: usize = result.iter().map(|b| b.num_rows()).sum(); - assert_eq!(total_rows, 350); - - Ok(()) - } - #[tokio::test] async fn test_sort_exec_with_multiple_partitions_chunked() -> Result<()> { let batch_size = 100; From 66249630abe7590412a4d22399459a490213c4e3 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 14:50:12 +0200 Subject: [PATCH 07/15] remove more bullshit --- datafusion/physical-plan/src/sorts/sort.rs | 8 -- datafusion/physical-plan/src/stream.rs | 100 ++++----------------- 2 files changed, 18 insertions(+), 90 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 6b6441bace0fc..6f0c02b812518 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -2455,10 +2455,6 @@ mod tests { Ok((sorted_batches, metrics)) } - // ======================================================================== - // Tests for sort_batch_chunked() - // ======================================================================== - #[tokio::test] async fn test_sort_batch_chunked_basic() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -2641,10 +2637,6 @@ mod tests { Ok(()) } - // ======================================================================== - // Tests for get_reserved_byte_for_record_batch() - // ======================================================================== - #[tokio::test] async fn test_get_reserved_byte_for_record_batch_normal_batch() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index b7d5e0bd7e7f2..39603e0f24888 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -989,7 +989,7 @@ mod test { assert_eq!( number_of_batches, 2, - "Should have received exactly one empty batch" + "Should have received exactly two empty batches" ); } @@ -1060,46 +1060,6 @@ mod test { assert_eq!(runtime.memory_pool.reserved(), 0); } - #[tokio::test] - async fn test_reservation_stream_frees_on_completion() { - use arrow::array::Int32Array; - use datafusion_execution::memory_pool::MemoryConsumer; - use datafusion_execution::runtime_env::RuntimeEnvBuilder; - - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(10 * 1024 * 1024, 1.0) - .build_arc() - .unwrap(); - - let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); - - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - ) - .unwrap(); - - let batch_size = get_record_batch_memory_size(&batch); - reservation.try_grow(batch_size).unwrap(); - - assert!(runtime.memory_pool.reserved() > 0); - - let stream = futures::stream::iter(vec![Ok(batch)]); - let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) - as SendableRecordBatchStream; - - let mut res_stream = - ReservationStream::new(Arc::clone(&schema), inner, reservation); - - // Consume all batches - while res_stream.next().await.is_some() {} - - // Memory should be fully freed - assert_eq!(runtime.memory_pool.reserved(), 0); - } - #[tokio::test] async fn test_reservation_stream_error_handling() { use datafusion_execution::memory_pool::MemoryConsumer; @@ -1131,48 +1091,24 @@ mod test { assert!(result.is_some()); assert!(result.unwrap().is_err()); - // Stream should be done, but reservation might not be freed yet - // since we didn't consume to None - // This is expected behavior - the reservation is only freed when the stream ends normally - } - - #[tokio::test] - async fn test_reservation_stream_schema_preserved() { - use arrow::array::Int32Array; - use datafusion_execution::memory_pool::MemoryConsumer; - use datafusion_execution::runtime_env::RuntimeEnvBuilder; - - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(10 * 1024 * 1024, 1.0) - .build_arc() - .unwrap(); - - let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); - - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ])); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(Int32Array::from(vec![4, 5, 6])), - ], - ) - .unwrap(); - - let stream = futures::stream::iter(vec![Ok(batch)]); - let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) - as SendableRecordBatchStream; + // Verify reservation is NOT automatically freed on error + // The reservation is only freed when poll_next returns Poll::Ready(None) + // After an error, the stream may continue to hold the reservation + // until it's explicitly dropped or polled to None + let after_error = runtime.memory_pool.reserved(); + assert_eq!( + after_error, 1000, + "Reservation should still be held after error" + ); - let res_stream = ReservationStream::new(Arc::clone(&schema), inner, reservation); + // Drop the stream to free the reservation + drop(res_stream); - // Verify schema is preserved - let stream_schema = res_stream.schema(); - assert_eq!(stream_schema.fields().len(), 2); - assert_eq!(stream_schema.field(0).name(), "a"); - assert_eq!(stream_schema.field(1).name(), "b"); + // Now memory should be freed + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "Memory should be freed when stream is dropped" + ); } } From cba1443982526a684e7d51e8721b01f72282c532 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 14:58:19 +0200 Subject: [PATCH 08/15] Add a comment, change the assert --- datafusion/physical-plan/src/sorts/sort.rs | 204 +-------------------- 1 file changed, 3 insertions(+), 201 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 6f0c02b812518..803a6b08e8672 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -2675,10 +2675,9 @@ mod tests { // Reserved should account for the sliced nature assert!(reserved > 0); - // The reservation should include memory for the full underlying buffer - // plus the sliced size rounded to 64 - let record_batch_size = get_record_batch_memory_size(&batch); - assert!(reserved >= record_batch_size); + // Verify that even if the calculation changes, we still have at least twice the actual slice size, + // otherwise no way we can make the sort. + assert!(reserved >= batch.get_sliced_size()? * 2); Ok(()) } @@ -2724,201 +2723,4 @@ mod tests { Ok(()) } - - #[tokio::test] - async fn test_external_sorter_with_chunked_batches() -> Result<()> { - // Test with memory limits that trigger spilling - let session_config = SessionConfig::new().with_batch_size(100); - let sort_spill_reservation_bytes = session_config - .options() - .execution - .sort_spill_reservation_bytes; - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(sort_spill_reservation_bytes + 8192, 1.0) - .build_arc()?; - let task_ctx = Arc::new( - TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime), - ); - - let partitions = 50; - let input = test::scan_partitioned(partitions); - let schema = input.schema(); - - let sort_exec = Arc::new(SortExec::new( - [PhysicalSortExpr { - expr: col("i", &schema)?, - options: SortOptions::default(), - }] - .into(), - Arc::new(CoalescePartitionsExec::new(input)), - )); - - let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; - - // Verify results are correctly sorted - let concatenated = concat_batches(&schema, &result)?; - let array = as_primitive_array::(concatenated.column(0))?; - for i in 0..array.len() - 1 { - assert!( - array.value(i) <= array.value(i + 1), - "Not sorted at position {}: {} > {}", - i, - array.value(i), - array.value(i + 1) - ); - } - - // Verify spilling occurred - let metrics = sort_exec.metrics().unwrap(); - assert!( - metrics.spill_count().unwrap() > 0, - "Expected spilling to occur" - ); - - // Verify no memory leaks - assert_eq!( - task_ctx.runtime_env().memory_pool.reserved(), - 0, - "Memory should be fully released" - ); - - Ok(()) - } - - #[tokio::test] - async fn test_sort_exec_with_multiple_partitions_chunked() -> Result<()> { - let batch_size = 100; - let session_config = SessionConfig::new().with_batch_size(batch_size); - let task_ctx = - Arc::new(TaskContext::default().with_session_config(session_config)); - - let partitions = 4; - let csv = test::scan_partitioned(partitions); - let schema = csv.schema(); - - let sort_exec = Arc::new(SortExec::new( - [PhysicalSortExpr { - expr: col("i", &schema)?, - options: SortOptions::default(), - }] - .into(), - Arc::new(CoalescePartitionsExec::new(csv)), - )); - - let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; - - // Verify correct sorting across partitions - let concatenated = concat_batches(&schema, &result)?; - let array = as_primitive_array::(concatenated.column(0))?; - for i in 0..array.len() - 1 { - assert!(array.value(i) <= array.value(i + 1)); - } - - // Verify batch sizes - for batch in &result { - assert!(batch.num_rows() <= batch_size); - } - - // Verify memory is released - assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0); - - Ok(()) - } - - #[tokio::test] - async fn test_large_batch_memory_handling() -> Result<()> { - // Test with large batches and reasonable memory - let batch_size = 1000; - let session_config = SessionConfig::new() - .with_batch_size(batch_size) - .with_sort_spill_reservation_bytes(200 * 1024); - - let runtime = RuntimeEnvBuilder::new() - .with_memory_limit(500 * 1024, 1.0) // 500KB limit - .build_arc()?; - - let task_ctx = Arc::new( - TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime), - ); - - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - // Create a large batch - let values: Vec = (0..10000).rev().collect(); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from(values))], - )?; - - let sort_exec = Arc::new(SortExec::new( - [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), - TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?, - )); - - let result = collect(sort_exec, Arc::clone(&task_ctx)).await?; - - // Should handle memory pressure correctly without panicking - assert!(!result.is_empty()); - - // Verify data is sorted - let concatenated = concat_batches(&schema, &result)?; - let array = as_primitive_array::(concatenated.column(0))?; - for i in 0..array.len() - 1 { - assert!(array.value(i) <= array.value(i + 1)); - } - - // Verify memory is released - assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0); - - Ok(()) - } - - #[tokio::test] - async fn test_sort_with_fetch_limit_chunked() -> Result<()> { - let batch_size = 50; - let session_config = SessionConfig::new().with_batch_size(batch_size); - let task_ctx = - Arc::new(TaskContext::default().with_session_config(session_config)); - - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - let values: Vec = (0..1000).rev().collect(); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from(values))], - )?; - - let fetch_limit = 10; - let sort_exec = Arc::new( - SortExec::new( - [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), - TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?, - ) - .with_fetch(Some(fetch_limit)), - ); - - let result = collect(sort_exec, Arc::clone(&task_ctx)).await?; - - // Verify correct number of rows returned - let total_rows: usize = result.iter().map(|b| b.num_rows()).sum(); - assert_eq!(total_rows, fetch_limit); - - // Verify data is sorted - let concatenated = concat_batches(&schema, &result)?; - let array = as_primitive_array::(concatenated.column(0))?; - for i in 0..array.len() - 1 { - assert!(array.value(i) <= array.value(i + 1)); - } - assert_eq!(array.value(0), 0); - assert_eq!(array.value(array.len() - 1), 9); - - // Verify memory is released - assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0); - - Ok(()) - } } From 2a6db02c8d67b96074e6c8ba14da92f7afee4389 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 15:04:10 +0200 Subject: [PATCH 09/15] fix test --- datafusion/physical-plan/src/sorts/sort.rs | 79 ++-------------------- 1 file changed, 6 insertions(+), 73 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 803a6b08e8672..d5cb6b3680029 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -2576,50 +2576,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_sort_batch_chunked_multi_column() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ])); - - // Create a batch with multiple columns - let a_values = Int32Array::from(vec![3, 1, 2, 1, 3, 2, 1, 3, 2, 1]); - let b_values = Int32Array::from(vec![1, 2, 3, 1, 2, 1, 3, 3, 2, 4]); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(a_values), Arc::new(b_values)], - )?; - - let expressions: LexOrdering = [ - PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), - PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), - ] - .into(); - - let result_batches = sort_batch_chunked(&batch, &expressions, 3)?; - let concatenated = concat_batches(&schema, &result_batches)?; - - let a_array = as_primitive_array::(concatenated.column(0))?; - let b_array = as_primitive_array::(concatenated.column(1))?; - - // Verify multi-column sort ordering - for i in 0..a_array.len() - 1 { - let a_curr = a_array.value(i); - let a_next = a_array.value(i + 1); - let b_curr = b_array.value(i); - let b_next = b_array.value(i + 1); - - assert!( - a_curr < a_next || (a_curr == a_next && b_curr <= b_next), - "Not properly sorted at position {i}: ({a_curr}, {b_curr}) -> ({a_next}, {b_next})", - ); - } - - Ok(()) - } - #[tokio::test] async fn test_sort_batch_chunked_empty_batch() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -2637,28 +2593,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_get_reserved_byte_for_record_batch_normal_batch() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - let batch = RecordBatch::try_new( - schema, - vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], - )?; - - let reserved = get_reserved_byte_for_record_batch(&batch)?; - - // Calculate expected value: - // Total = existing buffer size + new sorted buffer size (rounded to 64) - let record_batch_size = get_record_batch_memory_size(&batch); - let sliced_size = batch.get_sliced_size()?; - let expected = record_batch_size + round_upto_multiple_of_64(sliced_size); - - assert_eq!(reserved, expected); - assert!(reserved > 0, "Reserved bytes should be greater than 0"); - - Ok(()) - } - #[tokio::test] async fn test_get_reserved_byte_for_record_batch_with_sliced_batches() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -2667,17 +2601,16 @@ mod tests { let large_array = Int32Array::from((0..1000).collect::>()); let sliced_array = large_array.slice(100, 50); // Take 50 elements starting at 100 - let batch = + let sliced_batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(sliced_array)])?; + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(large_array)])?; + let sliced_reserved = get_reserved_byte_for_record_batch(&batch)?; let reserved = get_reserved_byte_for_record_batch(&batch)?; - // Reserved should account for the sliced nature - assert!(reserved > 0); - - // Verify that even if the calculation changes, we still have at least twice the actual slice size, - // otherwise no way we can make the sort. - assert!(reserved >= batch.get_sliced_size()? * 2); + // The reserved memory for the sliced batch should be less than that of the full batch + assert!(reserved > sliced_reserved); Ok(()) } From 504ed6d397e08566e1f31a223eb3a0ba2bc66dcb Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 15:17:07 +0200 Subject: [PATCH 10/15] add test, add comment --- datafusion/physical-plan/src/sorts/sort.rs | 93 +++++++++++++++++++++- datafusion/physical-plan/src/stream.rs | 1 + 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index d5cb6b3680029..081c436fcbb3c 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -1863,6 +1863,97 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_sort_memory_reduction_per_batch() -> Result<()> { + // This test verifies that memory reservation is reduced for every batch emitted + // during the sort process. This is important to ensure we don't hold onto + // memory longer than necessary. + + // Create a large enough batch that will be split into multiple output batches + let batch_size = 50; // Small batch size to force multiple output batches + let num_rows = 1000; // Create enough data for multiple batches + + let task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(batch_size) + .with_sort_in_place_threshold_bytes(usize::MAX), // Ensure we don't concat batches + ), + ); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create unsorted data + let mut values: Vec = (0..num_rows).collect(); + values.reverse(); + + let input_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let batches = vec![input_batch]; + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }] + .into(), + TestMemoryExec::try_new_exec( + std::slice::from_ref(&batches), + Arc::clone(&schema), + None, + )?, + )); + + let mut stream = sort_exec.execute(0, Arc::clone(&task_ctx))?; + + let mut previous_reserved = task_ctx.runtime_env().memory_pool.reserved(); + let mut batch_count = 0; + + // Collect batches and verify memory is reduced with each batch + while let Some(result) = stream.next().await { + let batch = result?; + batch_count += 1; + + // Verify we got a non-empty batch + assert!(batch.num_rows() > 0, "Batch should not be empty"); + + let current_reserved = task_ctx.runtime_env().memory_pool.reserved(); + + // After the first batch, memory should be reducing or staying the same + // (it should not increase as we emit batches) + if batch_count > 1 { + assert!( + current_reserved <= previous_reserved, + "Memory reservation should decrease or stay same as batches are emitted. \ + Batch {}: previous={}, current={}", + batch_count, + previous_reserved, + current_reserved + ); + } + + previous_reserved = current_reserved; + } + + assert!( + batch_count > 1, + "Expected multiple batches to be emitted, got {}", + batch_count + ); + + // Verify all memory is returned at the end + assert_eq!( + task_ctx.runtime_env().memory_pool.reserved(), + 0, + "All memory should be returned after consuming all batches" + ); + + Ok(()) + } + #[tokio::test] async fn test_sort_metadata() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -2606,7 +2697,7 @@ mod tests { let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(large_array)])?; - let sliced_reserved = get_reserved_byte_for_record_batch(&batch)?; + let sliced_reserved = get_reserved_byte_for_record_batch(&sliced_batch)?; let reserved = get_reserved_byte_for_record_batch(&batch)?; // The reserved memory for the sliced batch should be less than that of the full batch diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 39603e0f24888..7071ccf8811ee 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -704,6 +704,7 @@ impl RecordBatchStream for BatchSplitStream { /// A stream that holds a memory reservation for its lifetime, /// shrinking the reservation as batches are consumed. /// The original reservation must have its batch sizes calculated using [`get_record_batch_memory_size`] +/// On error, the reservation is *NOT* freed, until the stream is dropped. pub struct ReservationStream { schema: SchemaRef, inner: SendableRecordBatchStream, From 98b6202d8da661899cf0cd2b15315c251476d9dc Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 15:18:46 +0200 Subject: [PATCH 11/15] more silly clippy --- datafusion/physical-plan/src/sorts/sort.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 081c436fcbb3c..e22e3508af3e0 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -1928,10 +1928,7 @@ mod tests { assert!( current_reserved <= previous_reserved, "Memory reservation should decrease or stay same as batches are emitted. \ - Batch {}: previous={}, current={}", - batch_count, - previous_reserved, - current_reserved + Batch {batch_count}: previous={previous_reserved}, current={current_reserved}" ); } @@ -1940,8 +1937,7 @@ mod tests { assert!( batch_count > 1, - "Expected multiple batches to be emitted, got {}", - batch_count + "Expected multiple batches to be emitted, got {batch_count}" ); // Verify all memory is returned at the end From de6a47cdad436d340e063cdecf9a8d802a20ebd9 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 15:58:45 +0200 Subject: [PATCH 12/15] fix mem calculation issue --- .../src/sorts/multi_level_merge.rs | 7 ++++-- datafusion/physical-plan/src/sorts/sort.rs | 25 ++++++++++++------- datafusion/physical-plan/src/stream.rs | 4 +-- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 7545c63bed38d..8da75adc3c923 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -27,10 +27,10 @@ use std::sync::Arc; use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; -use arrow::util::bit_util::round_upto_multiple_of_64; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; +use crate::sorts::sort::get_reserved_byte_for_record_batch_size; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::stream::RecordBatchStreamAdapter; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -361,7 +361,10 @@ impl MultiLevelMergeBuilder { // For memory pools that are not shared this is good, for other this is not // and there should be some upper limit to memory reservation so we won't starve the system match reservation.try_grow( - round_upto_multiple_of_64(spill.max_record_batch_memory) * buffer_len, + get_reserved_byte_for_record_batch_size( + spill.max_record_batch_memory, + spill.max_record_batch_memory, + ) * buffer_len, ) { Ok(_) => { number_of_spills_to_read_for_current_phase += 1; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index e22e3508af3e0..39d60e51d75b4 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -828,17 +828,24 @@ impl ExternalSorter { } } +/// Calculate how much memory to reserve for sorting a `RecordBatch` from its size, +/// this can be calculated as the sum of the actual space it takes in memory(which would be larger for a sliced batch), +/// and the size of the actual data, rounded up to 64 bytes, as that is what arrow will use when creating new buffers. +pub(crate) fn get_reserved_byte_for_record_batch_size( + record_batch_size: usize, + sliced_size: usize, +) -> usize { + record_batch_size + round_upto_multiple_of_64(sliced_size) +} + /// Estimate how much memory is needed to sort a `RecordBatch`. -/// -/// For sliced batches, `get_record_batch_memory_size` returns the size of the -/// underlying shared buffers (which may be larger than the logical data). -/// We add `get_sliced_size()` (the actual logical data size, rounded to 64 bytes) -/// because sorting will create new buffers containing only the referenced data. -/// -/// Total = existing buffer size + new sorted buffer size +/// This will just call `get_reserved_byte_for_record_batch_size` with the +/// memory size of the record batch and its sliced size. pub(super) fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> Result { - Ok(get_record_batch_memory_size(batch) - + round_upto_multiple_of_64(batch.get_sliced_size()?)) + Ok(get_reserved_byte_for_record_batch_size( + get_record_batch_memory_size(batch), + batch.get_sliced_size()?, + )) } impl Debug for ExternalSorter { diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 7071ccf8811ee..80c2233d05db6 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -705,14 +705,14 @@ impl RecordBatchStream for BatchSplitStream { /// shrinking the reservation as batches are consumed. /// The original reservation must have its batch sizes calculated using [`get_record_batch_memory_size`] /// On error, the reservation is *NOT* freed, until the stream is dropped. -pub struct ReservationStream { +pub(crate) struct ReservationStream { schema: SchemaRef, inner: SendableRecordBatchStream, reservation: MemoryReservation, } impl ReservationStream { - pub fn new( + pub(crate) fn new( schema: SchemaRef, inner: SendableRecordBatchStream, reservation: MemoryReservation, From 94ad6cf83b867b64ac32205973b959c26505ce11 Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 16:43:23 +0200 Subject: [PATCH 13/15] Remove roundup --- datafusion/physical-plan/src/sorts/sort.rs | 52 ++----------------- .../physical-plan/src/spill/spill_manager.rs | 5 +- 2 files changed, 5 insertions(+), 52 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 39d60e51d75b4..1d7286becc7ab 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -54,7 +54,6 @@ use crate::{ use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray}; use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays}; use arrow::datatypes::SchemaRef; -use arrow::util::bit_util::round_upto_multiple_of_64; use datafusion_common::config::SpillCompression; use datafusion_common::{ DataFusionError, Result, assert_or_internal_err, internal_datafusion_err, @@ -830,12 +829,12 @@ impl ExternalSorter { /// Calculate how much memory to reserve for sorting a `RecordBatch` from its size, /// this can be calculated as the sum of the actual space it takes in memory(which would be larger for a sliced batch), -/// and the size of the actual data, rounded up to 64 bytes, as that is what arrow will use when creating new buffers. +/// and the size of the actual data. pub(crate) fn get_reserved_byte_for_record_batch_size( record_batch_size: usize, sliced_size: usize, ) -> usize { - record_batch_size + round_upto_multiple_of_64(sliced_size) + record_batch_size + sliced_size } /// Estimate how much memory is needed to sort a `RecordBatch`. @@ -2431,9 +2430,6 @@ mod tests { .map(|b| b.get_array_memory_size()) .sum::(); - // Use half the batch memory to force spilling - let memory_limit = batches_memory / 2; - TaskContext::default() .with_session_config( SessionConfig::new() @@ -2444,7 +2440,7 @@ mod tests { ) .with_runtime( RuntimeEnvBuilder::default() - .with_memory_limit(memory_limit, 1.0) + .with_memory_limit(batches_memory, 1.0) .build_arc() .unwrap(), ) @@ -2708,46 +2704,4 @@ mod tests { Ok(()) } - - #[tokio::test] - async fn test_get_reserved_byte_for_record_batch_rounding() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - - // Create a batch with a size that's not a multiple of 64 - let batch = RecordBatch::try_new( - schema, - vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], - )?; - - let reserved = get_reserved_byte_for_record_batch(&batch)?; - - // Calculate expected value - let record_batch_size = get_record_batch_memory_size(&batch); - let sliced_size = batch.get_sliced_size()?; - let rounded_sliced_size = round_upto_multiple_of_64(sliced_size); - let expected = record_batch_size + rounded_sliced_size; - - assert_eq!(reserved, expected); - - // Verify that the sliced size component was indeed rounded to multiple of 64 - assert_eq!( - rounded_sliced_size % 64, - 0, - "Rounded sliced size should be a multiple of 64" - ); - - // If sliced_size is not already a multiple of 64, verify rounding occurred - if sliced_size % 64 != 0 { - assert!( - rounded_sliced_size > sliced_size, - "Rounding should have increased the size" - ); - assert!( - rounded_sliced_size - sliced_size < 64, - "Rounding should add less than 64 bytes" - ); - } - - Ok(()) - } } diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index d4600673394b7..89b0276206774 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -20,12 +20,11 @@ use arrow::array::StringViewArray; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_execution::runtime_env::RuntimeEnv; -use std::sync::Arc; - use datafusion_common::{Result, config::SpillCompression}; use datafusion_execution::SendableRecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::runtime_env::RuntimeEnv; +use std::sync::Arc; use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; use crate::coop::cooperative; From 96b158d3d99dc4b8a44c871e564d2c2d6c4f0f3e Mon Sep 17 00:00:00 2001 From: Emily Matheys Date: Mon, 29 Dec 2025 16:53:58 +0200 Subject: [PATCH 14/15] address verbal CR --- .../src/sorts/multi_level_merge.rs | 5 ++- datafusion/physical-plan/src/sorts/sort.rs | 45 ++++++++++++------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 8da75adc3c923..2e0d668a29559 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -30,7 +30,7 @@ use arrow::datatypes::SchemaRef; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; -use crate::sorts::sort::get_reserved_byte_for_record_batch_size; +use crate::sorts::sort::get_reserved_bytes_for_record_batch_size; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::stream::RecordBatchStreamAdapter; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -361,8 +361,9 @@ impl MultiLevelMergeBuilder { // For memory pools that are not shared this is good, for other this is not // and there should be some upper limit to memory reservation so we won't starve the system match reservation.try_grow( - get_reserved_byte_for_record_batch_size( + get_reserved_bytes_for_record_batch_size( spill.max_record_batch_memory, + // Size will be the same as the sliced size, bc it is a spilled batch. spill.max_record_batch_memory, ) * buffer_len, ) { diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 1d7286becc7ab..d6372cc1be3c1 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -541,7 +541,7 @@ impl ExternalSorter { while let Some(batch) = sorted_stream.next().await { let batch = batch?; - let sorted_size = get_reserved_byte_for_record_batch(&batch)?; + let sorted_size = get_reserved_bytes_for_record_batch(&batch)?; if self.reservation.try_grow(sorted_size).is_err() { // Although the reservation is not enough, the batch is // already in memory, so it's okay to combine it with previously @@ -667,7 +667,7 @@ impl ExternalSorter { let batch = concat_batches(&self.schema, &self.in_mem_batches)?; self.in_mem_batches.clear(); self.reservation - .try_resize(get_reserved_byte_for_record_batch(&batch)?) + .try_resize(get_reserved_bytes_for_record_batch(&batch)?) .map_err(Self::err_with_oom_context)?; let reservation = self.reservation.take(); return self.sort_batch_stream(batch, &metrics, reservation); @@ -679,7 +679,7 @@ impl ExternalSorter { let metrics = self.metrics.baseline.intermediate(); let reservation = self .reservation - .split(get_reserved_byte_for_record_batch(&batch)?); + .split(get_reserved_bytes_for_record_batch(&batch)?); let input = self.sort_batch_stream(batch, &metrics, reservation)?; Ok(spawn_buffered(input, 1)) }) @@ -712,7 +712,7 @@ impl ExternalSorter { mut reservation: MemoryReservation, ) -> Result { assert_eq!( - get_reserved_byte_for_record_batch(&batch)?, + get_reserved_bytes_for_record_batch(&batch)?, reservation.size() ); @@ -795,7 +795,7 @@ impl ExternalSorter { &mut self, input: &RecordBatch, ) -> Result<()> { - let size = get_reserved_byte_for_record_batch(input)?; + let size = get_reserved_bytes_for_record_batch(input)?; match self.reservation.try_grow(size) { Ok(_) => Ok(()), @@ -827,21 +827,31 @@ impl ExternalSorter { } } -/// Calculate how much memory to reserve for sorting a `RecordBatch` from its size, -/// this can be calculated as the sum of the actual space it takes in memory(which would be larger for a sliced batch), -/// and the size of the actual data. -pub(crate) fn get_reserved_byte_for_record_batch_size( +/// Calculate how much memory to reserve for sorting a `RecordBatch` from its size. +/// +/// This is used to pre-reserve memory for the sort/merge. The sort/merge process involves +/// creating sorted copies of sorted columns in record batches for speeding up comparison +/// in sorting and merging. The sorted copies are in either row format or array format. +/// Please refer to cursor.rs and stream.rs for more details. No matter what format the +/// sorted copies are, they will use more memory than the original record batch. +/// +/// This can basically be calculated as the sum of the actual space it takes in +/// memory (which would be larger for a sliced batch), and the size of the actual data. +pub(crate) fn get_reserved_bytes_for_record_batch_size( record_batch_size: usize, sliced_size: usize, ) -> usize { + // Even 2x may not be enough for some cases, but it's a good enough estimation as a baseline. + // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes` + // to compensate for the extra memory needed. record_batch_size + sliced_size } /// Estimate how much memory is needed to sort a `RecordBatch`. -/// This will just call `get_reserved_byte_for_record_batch_size` with the +/// This will just call `get_reserved_bytes_for_record_batch_size` with the /// memory size of the record batch and its sliced size. -pub(super) fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> Result { - Ok(get_reserved_byte_for_record_batch_size( +pub(super) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result { + Ok(get_reserved_bytes_for_record_batch_size( get_record_batch_memory_size(batch), batch.get_sliced_size()?, )) @@ -1682,7 +1692,7 @@ mod tests { let temp_ctx = Arc::new(TaskContext::default()); let mut stream = plan.execute(0, Arc::clone(&temp_ctx))?; let first_batch = stream.next().await.unwrap()?; - get_reserved_byte_for_record_batch(&first_batch)? + get_reserved_bytes_for_record_batch(&first_batch)? }; // Set memory limit just short of what we need @@ -1701,7 +1711,7 @@ mod tests { { let mut stream = plan.execute(0, Arc::clone(&task_ctx))?; let first_batch = stream.next().await.unwrap()?; - let batch_reservation = get_reserved_byte_for_record_batch(&first_batch)?; + let batch_reservation = get_reserved_bytes_for_record_batch(&first_batch)?; assert_eq!(batch_reservation, expected_batch_reservation); assert!(memory_limit < (merge_reservation + batch_reservation)); @@ -2684,7 +2694,8 @@ mod tests { } #[tokio::test] - async fn test_get_reserved_byte_for_record_batch_with_sliced_batches() -> Result<()> { + async fn test_get_reserved_bytes_for_record_batch_with_sliced_batches() -> Result<()> + { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); // Create a larger batch then slice it @@ -2696,8 +2707,8 @@ mod tests { let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(large_array)])?; - let sliced_reserved = get_reserved_byte_for_record_batch(&sliced_batch)?; - let reserved = get_reserved_byte_for_record_batch(&batch)?; + let sliced_reserved = get_reserved_bytes_for_record_batch(&sliced_batch)?; + let reserved = get_reserved_bytes_for_record_batch(&batch)?; // The reserved memory for the sliced batch should be less than that of the full batch assert!(reserved > sliced_reserved); From 13a61a1d94143fc297eead80e57d82bdcbfde7e4 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 29 Dec 2025 17:03:23 +0200 Subject: [PATCH 15/15] revert comment change as it still an estimation --- datafusion/physical-plan/src/sorts/sort.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index d6372cc1be3c1..3e8fdf1f3ed7e 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -827,7 +827,7 @@ impl ExternalSorter { } } -/// Calculate how much memory to reserve for sorting a `RecordBatch` from its size. +/// Estimate how much memory is needed to sort a `RecordBatch`. /// /// This is used to pre-reserve memory for the sort/merge. The sort/merge process involves /// creating sorted copies of sorted columns in record batches for speeding up comparison