diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index cb22fbf9a06a1..1ae7202711112 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1198,7 +1198,7 @@ impl GroupedHashAggregateStream { // instead. // Spilling to disk and reading back also ensures batch size is consistent // rather than potentially having one significantly larger last batch. - self.spill()?; + self.spill()?; // TODO: use sort_batch_chunked instead? // Mark that we're switching to stream merging mode. self.spill_state.is_stream_merging = true; diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 3540f1de3ed10..2e0d668a29559 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -30,7 +30,7 @@ use arrow::datatypes::SchemaRef; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; -use crate::sorts::sort::get_reserved_byte_for_record_batch_size; +use crate::sorts::sort::get_reserved_bytes_for_record_batch_size; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::stream::RecordBatchStreamAdapter; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -360,9 +360,13 @@ impl MultiLevelMergeBuilder { for spill in &self.sorted_spill_files { // For memory pools that are not shared this is good, for other this is not // and there should be some upper limit to memory reservation so we won't starve the system - match reservation.try_grow(get_reserved_byte_for_record_batch_size( - spill.max_record_batch_memory * buffer_len, - )) { + match reservation.try_grow( + get_reserved_bytes_for_record_batch_size( + spill.max_record_batch_memory, + // Size will be the same as the sliced size, bc it is a spilled batch. + spill.max_record_batch_memory, + ) * buffer_len, + ) { Ok(_) => { number_of_spills_to_read_for_current_phase += 1; } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 18cdcbe9debcc..3e8fdf1f3ed7e 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -34,16 +34,15 @@ use crate::filter_pushdown::{ }; use crate::limit::LimitStream; use crate::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, SpillMetrics, - SplitMetrics, + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; use crate::projection::{ProjectionExec, make_with_child, update_ordering}; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; -use crate::stream::BatchSplitStream; use crate::stream::RecordBatchStreamAdapter; +use crate::stream::ReservationStream; use crate::topk::TopK; use crate::topk::TopKDynamicFilters; use crate::{ @@ -75,8 +74,6 @@ struct ExternalSorterMetrics { baseline: BaselineMetrics, spill_metrics: SpillMetrics, - - split_metrics: SplitMetrics, } impl ExternalSorterMetrics { @@ -84,7 +81,6 @@ impl ExternalSorterMetrics { Self { baseline: BaselineMetrics::new(metrics, partition), spill_metrics: SpillMetrics::new(metrics, partition), - split_metrics: SplitMetrics::new(metrics, partition), } } } @@ -545,7 +541,7 @@ impl ExternalSorter { while let Some(batch) = sorted_stream.next().await { let batch = batch?; - let sorted_size = get_reserved_byte_for_record_batch(&batch); + let sorted_size = get_reserved_bytes_for_record_batch(&batch)?; if self.reservation.try_grow(sorted_size).is_err() { // Although the reservation is not enough, the batch is // already in memory, so it's okay to combine it with previously @@ -662,7 +658,7 @@ impl ExternalSorter { if self.in_mem_batches.len() == 1 { let batch = self.in_mem_batches.swap_remove(0); let reservation = self.reservation.take(); - return self.sort_batch_stream(batch, metrics, reservation, true); + return self.sort_batch_stream(batch, &metrics, reservation); } // If less than sort_in_place_threshold_bytes, concatenate and sort in place @@ -671,10 +667,10 @@ impl ExternalSorter { let batch = concat_batches(&self.schema, &self.in_mem_batches)?; self.in_mem_batches.clear(); self.reservation - .try_resize(get_reserved_byte_for_record_batch(&batch)) + .try_resize(get_reserved_bytes_for_record_batch(&batch)?) .map_err(Self::err_with_oom_context)?; let reservation = self.reservation.take(); - return self.sort_batch_stream(batch, metrics, reservation, true); + return self.sort_batch_stream(batch, &metrics, reservation); } let streams = std::mem::take(&mut self.in_mem_batches) @@ -683,15 +679,8 @@ impl ExternalSorter { let metrics = self.metrics.baseline.intermediate(); let reservation = self .reservation - .split(get_reserved_byte_for_record_batch(&batch)); - let input = self.sort_batch_stream( - batch, - metrics, - reservation, - // Passing false as `StreamingMergeBuilder` will split the - // stream into batches of `self.batch_size` rows. - false, - )?; + .split(get_reserved_bytes_for_record_batch(&batch)?); + let input = self.sort_batch_stream(batch, &metrics, reservation)?; Ok(spawn_buffered(input, 1)) }) .collect::>()?; @@ -709,52 +698,78 @@ impl ExternalSorter { /// Sorts a single `RecordBatch` into a single stream. /// - /// `reservation` accounts for the memory used by this batch and - /// is released when the sort is complete - /// - /// passing `split` true will return a [`BatchSplitStream`] where each batch maximum row count - /// will be `self.batch_size`. - /// If `split` is false, the stream will return a single batch + /// This may output multiple batches depending on the size of the + /// sorted data and the target batch size. + /// For single-batch output cases, `reservation` will be freed immediately after sorting, + /// as the batch will be output and is expected to be reserved by the consumer of the stream. + /// For multi-batch output cases, `reservation` will be grown to match the actual + /// size of sorted output, and as each batch is output, its memory will be freed from the reservation. + /// (This leads to the same behaviour, as futures are only evaluated when polled by the consumer.) fn sort_batch_stream( &self, batch: RecordBatch, - metrics: BaselineMetrics, - reservation: MemoryReservation, - mut split: bool, + metrics: &BaselineMetrics, + mut reservation: MemoryReservation, ) -> Result { assert_eq!( - get_reserved_byte_for_record_batch(&batch), + get_reserved_bytes_for_record_batch(&batch)?, reservation.size() ); - split = split && batch.num_rows() > self.batch_size; - let schema = batch.schema(); - let expressions = self.expr.clone(); - let stream = futures::stream::once(async move { - let _timer = metrics.elapsed_compute().timer(); + let batch_size = self.batch_size; + let output_row_metrics = metrics.output_rows().clone(); - let sorted = sort_batch(&batch, &expressions, None)?; + let stream = futures::stream::once(async move { + let schema = batch.schema(); - (&sorted).record_output(&metrics); + // Sort the batch immediately and get all output batches + let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; drop(batch); - drop(reservation); - Ok(sorted) - }); - let mut output: SendableRecordBatchStream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + // Free the old reservation and grow it to match the actual sorted output size + reservation.free(); - if split { - output = Box::pin(BatchSplitStream::new( - output, - self.batch_size, - self.metrics.split_metrics.clone(), - )); - } + Result::<_, DataFusionError>::Ok((schema, sorted_batches, reservation)) + }) + .then({ + move |batches| async move { + match batches { + Ok((schema, sorted_batches, mut reservation)) => { + // Calculate the total size of sorted batches + let total_sorted_size: usize = sorted_batches + .iter() + .map(get_record_batch_memory_size) + .sum(); + reservation + .try_grow(total_sorted_size) + .map_err(Self::err_with_oom_context)?; + + // Wrap in ReservationStream to hold the reservation + Ok(Box::pin(ReservationStream::new( + Arc::clone(&schema), + Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(sorted_batches.into_iter().map(Ok)), + )), + reservation, + )) as SendableRecordBatchStream) + } + Err(e) => Err(e), + } + } + }) + .try_flatten() + .map(move |batch| match batch { + Ok(batch) => { + output_row_metrics.add(batch.num_rows()); + Ok(batch) + } + Err(e) => Err(e), + }); - Ok(output) + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } /// If this sort may spill, pre-allocates @@ -780,7 +795,7 @@ impl ExternalSorter { &mut self, input: &RecordBatch, ) -> Result<()> { - let size = get_reserved_byte_for_record_batch(input); + let size = get_reserved_bytes_for_record_batch(input)?; match self.reservation.try_grow(size) { Ok(_) => Ok(()), @@ -819,16 +834,27 @@ impl ExternalSorter { /// in sorting and merging. The sorted copies are in either row format or array format. /// Please refer to cursor.rs and stream.rs for more details. No matter what format the /// sorted copies are, they will use more memory than the original record batch. -pub(crate) fn get_reserved_byte_for_record_batch_size(record_batch_size: usize) -> usize { - // 2x may not be enough for some cases, but it's a good start. +/// +/// This can basically be calculated as the sum of the actual space it takes in +/// memory (which would be larger for a sliced batch), and the size of the actual data. +pub(crate) fn get_reserved_bytes_for_record_batch_size( + record_batch_size: usize, + sliced_size: usize, +) -> usize { + // Even 2x may not be enough for some cases, but it's a good enough estimation as a baseline. // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes` // to compensate for the extra memory needed. - record_batch_size * 2 + record_batch_size + sliced_size } /// Estimate how much memory is needed to sort a `RecordBatch`. -fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { - get_reserved_byte_for_record_batch_size(get_record_batch_memory_size(batch)) +/// This will just call `get_reserved_bytes_for_record_batch_size` with the +/// memory size of the record batch and its sliced size. +pub(super) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result { + Ok(get_reserved_bytes_for_record_batch_size( + get_record_batch_memory_size(batch), + batch.get_sliced_size()?, + )) } impl Debug for ExternalSorter { @@ -853,15 +879,7 @@ pub fn sort_batch( .collect::>>()?; let indices = lexsort_to_indices(&sort_columns, fetch)?; - let mut columns = take_arrays(batch.columns(), &indices, None)?; - - // The columns may be larger than the unsorted columns in `batch` especially for variable length - // data types due to exponential growth when building the sort columns. We shrink the columns - // to prevent memory reservation failures, as well as excessive memory allocation when running - // merges in `SortPreservingMergeStream`. - columns.iter_mut().for_each(|c| { - c.shrink_to_fit(); - }); + let columns = take_arrays(batch.columns(), &indices, None)?; let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); Ok(RecordBatch::try_new_with_options( @@ -871,6 +889,48 @@ pub fn sort_batch( )?) } +/// Sort a batch and return the result as multiple batches of size `batch_size`. +/// This is useful when you want to avoid creating one large sorted batch in memory, +/// and instead want to process the sorted data in smaller chunks. +pub fn sort_batch_chunked( + batch: &RecordBatch, + expressions: &LexOrdering, + batch_size: usize, +) -> Result> { + let sort_columns = expressions + .iter() + .map(|expr| expr.evaluate_to_sort_column(batch)) + .collect::>>()?; + + let indices = lexsort_to_indices(&sort_columns, None)?; + + // Split indices into chunks of batch_size + let num_rows = indices.len(); + let num_chunks = num_rows.div_ceil(batch_size); + + let result_batches = (0..num_chunks) + .map(|chunk_idx| { + let start = chunk_idx * batch_size; + let end = (start + batch_size).min(num_rows); + let chunk_len = end - start; + + // Create a slice of indices for this chunk + let chunk_indices = indices.slice(start, chunk_len); + + // Take the columns using this chunk of indices + let columns = take_arrays(batch.columns(), &chunk_indices, None)?; + + let options = RecordBatchOptions::new().with_row_count(Some(chunk_len)); + let chunk_batch = + RecordBatch::try_new_with_options(batch.schema(), columns, &options)?; + + Ok(chunk_batch) + }) + .collect::>>()?; + + Ok(result_batches) +} + /// Sort execution plan. /// /// Support sorting datasets that are larger than the memory allotted @@ -1173,10 +1233,7 @@ impl ExecutionPlan for SortExec { children: Vec>, ) -> Result> { let mut new_sort = self.cloned(); - assert!( - children.len() == 1, - "SortExec should have exactly one child" - ); + assert_eq!(children.len(), 1, "SortExec should have exactly one child"); new_sort.input = Arc::clone(&children[0]); // Recompute the properties based on the new input since they may have changed let (cache, sort_prefix) = Self::compute_properties( @@ -1623,13 +1680,24 @@ mod tests { #[tokio::test] async fn test_batch_reservation_error() -> Result<()> { // Pick a memory limit and sort_spill_reservation that make the first batch reservation fail. - // These values assume that the ExternalSorter will reserve 800 bytes for the first batch. - let expected_batch_reservation = 800; let merge_reservation: usize = 0; // Set to 0 for simplicity - let memory_limit: usize = expected_batch_reservation + merge_reservation - 1; // Just short of what we need let session_config = SessionConfig::new().with_sort_spill_reservation_bytes(merge_reservation); + + let plan = test::scan_partitioned(1); + + // Read the first record batch to determine the actual memory requirement + let expected_batch_reservation = { + let temp_ctx = Arc::new(TaskContext::default()); + let mut stream = plan.execute(0, Arc::clone(&temp_ctx))?; + let first_batch = stream.next().await.unwrap()?; + get_reserved_bytes_for_record_batch(&first_batch)? + }; + + // Set memory limit just short of what we need + let memory_limit: usize = expected_batch_reservation + merge_reservation - 1; + let runtime = RuntimeEnvBuilder::new() .with_memory_limit(memory_limit, 1.0) .build_arc()?; @@ -1639,14 +1707,11 @@ mod tests { .with_runtime(runtime), ); - let plan = test::scan_partitioned(1); - - // Read the first record batch to assert that our memory limit and sort_spill_reservation - // settings trigger the test scenario. + // Verify that our memory limit is insufficient { let mut stream = plan.execute(0, Arc::clone(&task_ctx))?; let first_batch = stream.next().await.unwrap()?; - let batch_reservation = get_reserved_byte_for_record_batch(&first_batch); + let batch_reservation = get_reserved_bytes_for_record_batch(&first_batch)?; assert_eq!(batch_reservation, expected_batch_reservation); assert!(memory_limit < (merge_reservation + batch_reservation)); @@ -1814,6 +1879,93 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_sort_memory_reduction_per_batch() -> Result<()> { + // This test verifies that memory reservation is reduced for every batch emitted + // during the sort process. This is important to ensure we don't hold onto + // memory longer than necessary. + + // Create a large enough batch that will be split into multiple output batches + let batch_size = 50; // Small batch size to force multiple output batches + let num_rows = 1000; // Create enough data for multiple batches + + let task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(batch_size) + .with_sort_in_place_threshold_bytes(usize::MAX), // Ensure we don't concat batches + ), + ); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create unsorted data + let mut values: Vec = (0..num_rows).collect(); + values.reverse(); + + let input_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let batches = vec![input_batch]; + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }] + .into(), + TestMemoryExec::try_new_exec( + std::slice::from_ref(&batches), + Arc::clone(&schema), + None, + )?, + )); + + let mut stream = sort_exec.execute(0, Arc::clone(&task_ctx))?; + + let mut previous_reserved = task_ctx.runtime_env().memory_pool.reserved(); + let mut batch_count = 0; + + // Collect batches and verify memory is reduced with each batch + while let Some(result) = stream.next().await { + let batch = result?; + batch_count += 1; + + // Verify we got a non-empty batch + assert!(batch.num_rows() > 0, "Batch should not be empty"); + + let current_reserved = task_ctx.runtime_env().memory_pool.reserved(); + + // After the first batch, memory should be reducing or staying the same + // (it should not increase as we emit batches) + if batch_count > 1 { + assert!( + current_reserved <= previous_reserved, + "Memory reservation should decrease or stay same as batches are emitted. \ + Batch {batch_count}: previous={previous_reserved}, current={current_reserved}" + ); + } + + previous_reserved = current_reserved; + } + + assert!( + batch_count > 1, + "Expected multiple batches to be emitted, got {batch_count}" + ); + + // Verify all memory is returned at the end + assert_eq!( + task_ctx.runtime_env().memory_pool.reserved(), + 0, + "All memory should be returned after consuming all batches" + ); + + Ok(()) + } + #[tokio::test] async fn test_sort_metadata() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -2402,4 +2554,165 @@ mod tests { Ok((sorted_batches, metrics)) } + + #[tokio::test] + async fn test_sort_batch_chunked_basic() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 1000 rows + let mut values: Vec = (0..1000).collect(); + // Shuffle to make it unsorted + values.reverse(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); + + // Sort with batch_size = 250 + let result_batches = sort_batch_chunked(&batch, &expressions, 250)?; + + // Verify 4 batches are returned + assert_eq!(result_batches.len(), 4); + + // Verify each batch has <= 250 rows + let mut total_rows = 0; + for (i, batch) in result_batches.iter().enumerate() { + assert!( + batch.num_rows() <= 250, + "Batch {} has {} rows, expected <= 250", + i, + batch.num_rows() + ); + total_rows += batch.num_rows(); + } + + // Verify total row count matches input + assert_eq!(total_rows, 1000); + + // Verify data is correctly sorted across all chunks + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!( + array.value(i) <= array.value(i + 1), + "Array not sorted at position {}: {} > {}", + i, + array.value(i), + array.value(i + 1) + ); + } + assert_eq!(array.value(0), 0); + assert_eq!(array.value(array.len() - 1), 999); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_smaller_than_batch_size() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 50 rows + let values: Vec = (0..50).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); + + // Sort with batch_size = 100 + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Should return exactly 1 batch + assert_eq!(result_batches.len(), 1); + assert_eq!(result_batches[0].num_rows(), 50); + + // Verify it's correctly sorted + let array = as_primitive_array::(result_batches[0].column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + assert_eq!(array.value(0), 0); + assert_eq!(array.value(49), 49); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_exact_multiple() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 1000 rows + let values: Vec = (0..1000).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); + + // Sort with batch_size = 100 + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Should return exactly 10 batches of 100 rows each + assert_eq!(result_batches.len(), 10); + for batch in &result_batches { + assert_eq!(batch.num_rows(), 100); + } + + // Verify sorted correctly across all batches + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_empty_batch() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let batch = RecordBatch::new_empty(Arc::clone(&schema)); + + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); + + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Empty input produces no output batches (0 chunks) + assert_eq!(result_batches.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_get_reserved_bytes_for_record_batch_with_sliced_batches() -> Result<()> + { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a larger batch then slice it + let large_array = Int32Array::from((0..1000).collect::>()); + let sliced_array = large_array.slice(100, 50); // Take 50 elements starting at 100 + + let sliced_batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(sliced_array)])?; + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(large_array)])?; + + let sliced_reserved = get_reserved_bytes_for_record_batch(&sliced_batch)?; + let reserved = get_reserved_bytes_for_record_batch(&batch)?; + + // The reserved memory for the sliced batch should be less than that of the full batch + assert!(reserved > sliced_reserved); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index d4600673394b7..89b0276206774 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -20,12 +20,11 @@ use arrow::array::StringViewArray; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_execution::runtime_env::RuntimeEnv; -use std::sync::Arc; - use datafusion_common::{Result, config::SpillCompression}; use datafusion_execution::SendableRecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::runtime_env::RuntimeEnv; +use std::sync::Arc; use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; use crate::coop::cooperative; diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 8b2ea1006893e..80c2233d05db6 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -27,11 +27,13 @@ use super::metrics::ExecutionPlanMetricsSet; use super::metrics::{BaselineMetrics, SplitMetrics}; use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; use crate::displayable; +use crate::spill::get_record_batch_memory_size; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion_common::{Result, exec_err}; use datafusion_common_runtime::JoinSet; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryReservation; use futures::ready; use futures::stream::BoxStream; @@ -699,6 +701,70 @@ impl RecordBatchStream for BatchSplitStream { } } +/// A stream that holds a memory reservation for its lifetime, +/// shrinking the reservation as batches are consumed. +/// The original reservation must have its batch sizes calculated using [`get_record_batch_memory_size`] +/// On error, the reservation is *NOT* freed, until the stream is dropped. +pub(crate) struct ReservationStream { + schema: SchemaRef, + inner: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl ReservationStream { + pub(crate) fn new( + schema: SchemaRef, + inner: SendableRecordBatchStream, + reservation: MemoryReservation, + ) -> Self { + Self { + schema, + inner, + reservation, + } + } +} + +impl Stream for ReservationStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = self.inner.poll_next_unpin(cx); + + match res { + Poll::Ready(res) => { + match res { + Some(Ok(batch)) => { + self.reservation + .shrink(get_record_batch_memory_size(&batch)); + Poll::Ready(Some(Ok(batch))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => { + // Stream is done so free the reservation completely + self.reservation.free(); + Poll::Ready(None) + } + } + } + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl RecordBatchStream for ReservationStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + #[cfg(test)] mod test { use super::*; @@ -924,7 +990,126 @@ mod test { assert_eq!( number_of_batches, 2, - "Should have received exactly one empty batch" + "Should have received exactly two empty batches" + ); + } + + #[tokio::test] + async fn test_reservation_stream_shrinks_on_poll() { + use arrow::array::Int32Array; + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc() + .unwrap(); + + let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create batches + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + ) + .unwrap(); + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10]))], + ) + .unwrap(); + + let batch1_size = get_record_batch_memory_size(&batch1); + let batch2_size = get_record_batch_memory_size(&batch2); + + // Reserve memory upfront + reservation.try_grow(batch1_size + batch2_size).unwrap(); + let initial_reserved = runtime.memory_pool.reserved(); + assert_eq!(initial_reserved, batch1_size + batch2_size); + + // Create stream with batches + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; + + let mut res_stream = + ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Poll first batch + let result1 = res_stream.next().await; + assert!(result1.is_some()); + + // Memory should be reduced by batch1_size + let after_first = runtime.memory_pool.reserved(); + assert_eq!(after_first, batch2_size); + + // Poll second batch + let result2 = res_stream.next().await; + assert!(result2.is_some()); + + // Memory should be reduced by batch2_size + let after_second = runtime.memory_pool.reserved(); + assert_eq!(after_second, 0); + + // Poll None (end of stream) + let result3 = res_stream.next().await; + assert!(result3.is_none()); + + // Memory should still be 0 + assert_eq!(runtime.memory_pool.reserved(), 0); + } + + #[tokio::test] + async fn test_reservation_stream_error_handling() { + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc() + .unwrap(); + + let mut reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + reservation.try_grow(1000).unwrap(); + let initial = runtime.memory_pool.reserved(); + assert_eq!(initial, 1000); + + // Create a stream that errors + let stream = futures::stream::iter(vec![exec_err!("Test error")]); + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; + + let mut res_stream = + ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Get the error + let result = res_stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_err()); + + // Verify reservation is NOT automatically freed on error + // The reservation is only freed when poll_next returns Poll::Ready(None) + // After an error, the stream may continue to hold the reservation + // until it's explicitly dropped or polled to None + let after_error = runtime.memory_pool.reserved(); + assert_eq!( + after_error, 1000, + "Reservation should still be held after error" + ); + + // Drop the stream to free the reservation + drop(res_stream); + + // Now memory should be freed + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "Memory should be freed when stream is dropped" ); } }