diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 046ccf0b1c..2a45dae563 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -434,11 +434,11 @@ object CometConf extends ShimCometConf { conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") .category(CATEGORY_SHUFFLE) .doc( - "The codec of Comet native shuffle used to compress shuffle data. lz4, zstd, and " + - "snappy are supported. Compression can be disabled by setting " + + "The codec of Comet native shuffle used to compress shuffle data. " + + "Supported codecs: lz4, zstd. Compression can be disabled by setting " + "spark.shuffle.compress=false.") .stringConf - .checkValues(Set("zstd", "lz4", "snappy")) + .checkValues(Set("zstd", "lz4")) .createWithDefault("lz4") val COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL: ConfigEntry[Int] = @@ -523,6 +523,18 @@ object CometConf extends ShimCometConf { "Should not be larger than batch size `spark.comet.batchSize`") .createWithDefault(8192) + val COMET_SHUFFLE_PARTITIONER_MODE: ConfigEntry[String] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.partitionerMode") + .category(CATEGORY_SHUFFLE) + .doc( + "The partitioner mode used by the native shuffle writer. " + + "'immediate' writes partitioned IPC blocks immediately as batches arrive, " + + "reducing memory usage. 'buffered' buffers all rows before writing, which may " + + "improve performance for small datasets but uses more memory.") + .stringConf + .checkValues(Set("immediate", "buffered")) + .createWithDefault("immediate") + val COMET_SHUFFLE_WRITE_BUFFER_SIZE: ConfigEntry[Long] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.writeBufferSize") .category(CATEGORY_SHUFFLE) diff --git a/docs/source/contributor-guide/native_shuffle.md b/docs/source/contributor-guide/native_shuffle.md index 18e80a90c8..9e1a17d349 100644 --- a/docs/source/contributor-guide/native_shuffle.md +++ b/docs/source/contributor-guide/native_shuffle.md @@ -81,10 +81,18 @@ Native shuffle (`CometExchange`) is selected when all of the following condition └─────────────────────────────────────────────────────────────────────────────┘ │ │ ▼ ▼ -┌───────────────────────────────────┐ ┌───────────────────────────────────┐ -│ MultiPartitionShuffleRepartitioner │ │ SinglePartitionShufflePartitioner │ -│ (hash/range partitioning) │ │ (single partition case) │ -└───────────────────────────────────┘ └───────────────────────────────────┘ +┌───────────────────────────────────────────────────────────────────────┐ +│ Partitioner Selection │ +│ Controlled by spark.comet.exec.shuffle.partitionerMode │ +├───────────────────────────┬───────────────────────────────────────────┤ +│ immediate (default) │ buffered │ +│ ImmediateModePartitioner │ MultiPartitionShuffleRepartitioner │ +│ (hash/range/round-robin) │ (hash/range/round-robin) │ +│ Writes IPC blocks as │ Buffers all rows in memory │ +│ batches arrive │ before writing │ +├───────────────────────────┴───────────────────────────────────────────┤ +│ SinglePartitionShufflePartitioner (single partition case) │ +└───────────────────────────────────────────────────────────────────────┘ │ ▼ ┌───────────────────────────────────┐ @@ -113,11 +121,13 @@ Native shuffle (`CometExchange`) is selected when all of the following condition ### Rust Side -| File | Location | Description | -| ----------------------- | ------------------------------------ | ------------------------------------------------------------------------------------ | -| `shuffle_writer.rs` | `native/core/src/execution/shuffle/` | `ShuffleWriterExec` plan and partitioners. Main shuffle logic. | -| `codec.rs` | `native/core/src/execution/shuffle/` | `ShuffleBlockWriter` for Arrow IPC encoding with compression. Also handles decoding. | -| `comet_partitioning.rs` | `native/core/src/execution/shuffle/` | `CometPartitioning` enum defining partition schemes (Hash, Range, Single). | +| File | Location | Description | +| ----------------------- | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------- | +| `shuffle_writer.rs` | `native/shuffle/src/` | `ShuffleWriterExec` plan. Selects partitioner based on `immediate_mode` flag. | +| `immediate_mode.rs` | `native/shuffle/src/partitioners/` | `ImmediateModePartitioner`. Scatter-writes rows into per-partition Arrow builders and flushes IPC blocks to in-memory buffers eagerly. | +| `multi_partition.rs` | `native/shuffle/src/partitioners/` | `MultiPartitionShuffleRepartitioner`. Buffers all rows in memory, then writes partitions. | +| `codec.rs` | `native/shuffle/src/` | `ShuffleBlockWriter` for Arrow IPC encoding with compression. Also handles decoding. | +| `comet_partitioning.rs` | `native/shuffle/src/` | `CometPartitioning` enum defining partition schemes (Hash, Range, Single). | ## Data Flow @@ -129,23 +139,33 @@ Native shuffle (`CometExchange`) is selected when all of the following condition 2. **Native execution**: `CometExec.getCometIterator()` executes the plan in Rust. -3. **Partitioning**: `ShuffleWriterExec` receives batches and routes to the appropriate partitioner: - - `MultiPartitionShuffleRepartitioner`: For hash/range/round-robin partitioning - - `SinglePartitionShufflePartitioner`: For single partition (simpler path) +3. **Partitioning**: `ShuffleWriterExec` receives batches and routes to the appropriate partitioner + based on the `partitionerMode` configuration: + - **Immediate mode** (`ImmediateModePartitioner`): For hash/range/round-robin partitioning. + As each batch arrives, rows are scattered into per-partition Arrow array builders. When a + partition's builder reaches the target batch size, it is flushed as a compressed Arrow IPC + block to an in-memory buffer. Under memory pressure, these buffers are spilled to + per-partition temporary files. This keeps memory usage much lower than buffered mode since + data is encoded into compact IPC format eagerly rather than held as raw Arrow arrays. -4. **Buffering and spilling**: The partitioner buffers rows per partition. When memory pressure - exceeds the threshold, partitions spill to temporary files. + - **Buffered mode** (`MultiPartitionShuffleRepartitioner`): For hash/range/round-robin + partitioning. Buffers all input `RecordBatch`es in memory, then partitions and writes + them in a single pass. When memory pressure exceeds the threshold, partitions spill to + temporary files. -5. **Encoding**: `ShuffleBlockWriter` encodes each partition's data as compressed Arrow IPC: + - `SinglePartitionShufflePartitioner`: For single partition (simpler path, used regardless + of partitioner mode). + +4. **Encoding**: `ShuffleBlockWriter` encodes each partition's data as compressed Arrow IPC: - Writes compression type header - Writes field count header - Writes compressed IPC stream -6. **Output files**: Two files are produced: +5. **Output files**: Two files are produced: - **Data file**: Concatenated partition data - **Index file**: Array of 8-byte little-endian offsets marking partition boundaries -7. **Commit**: Back in JVM, `CometNativeShuffleWriter` reads the index file to get partition +6. **Commit**: Back in JVM, `CometNativeShuffleWriter` reads the index file to get partition lengths and commits via Spark's `IndexShuffleBlockResolver`. ### Read Path @@ -201,10 +221,31 @@ sizes. ## Memory Management -Native shuffle uses DataFusion's memory management with spilling support: +Native shuffle uses DataFusion's memory management. The memory characteristics differ +between the two partitioner modes: + +### Immediate Mode + +Immediate mode keeps memory usage low by partitioning and encoding data eagerly as it arrives, +rather than buffering all input rows before writing: + +- **Per-partition builders**: Each partition has a set of Arrow array builders sized to the + target batch size. When a builder fills up, it is flushed as a compressed IPC block to an + in-memory buffer. +- **Memory footprint**: Proportional to `num_partitions × batch_size` for the builders, plus + the accumulated IPC buffers. This is typically much smaller than buffered mode since IPC + encoding is more compact than raw Arrow arrays. +- **Spilling**: When memory pressure is detected via DataFusion's `MemoryConsumer` trait, + partition builders are flushed and all IPC buffers are drained to per-partition temporary + files on disk. + +### Buffered Mode + +Buffered mode holds all input data in memory before writing: -- **Memory pool**: Tracks memory usage across the shuffle operation. -- **Spill threshold**: When buffered data exceeds the threshold, partitions spill to disk. +- **Buffered batches**: All incoming `RecordBatch`es are accumulated in a `Vec`. +- **Spill threshold**: When buffered data exceeds the memory threshold, partitions spill to + temporary files on disk. - **Per-partition spilling**: Each partition has its own spill file. Multiple spills for a partition are concatenated when writing the final output. - **Scratch space**: Reusable buffers for partition ID computation to reduce allocations. @@ -232,14 +273,15 @@ independently compressed, allowing parallel decompression during reads. ## Configuration -| Config | Default | Description | -| ------------------------------------------------- | ------- | ---------------------------------------- | -| `spark.comet.exec.shuffle.enabled` | `true` | Enable Comet shuffle | -| `spark.comet.exec.shuffle.mode` | `auto` | Shuffle mode: `native`, `jvm`, or `auto` | -| `spark.comet.exec.shuffle.compression.codec` | `zstd` | Compression codec | -| `spark.comet.exec.shuffle.compression.zstd.level` | `1` | Zstd compression level | -| `spark.comet.shuffle.write.buffer.size` | `1MB` | Write buffer size | -| `spark.comet.columnar.shuffle.batch.size` | `8192` | Target rows per batch | +| Config | Default | Description | +| ------------------------------------------------- | ----------- | ------------------------------------------- | +| `spark.comet.exec.shuffle.enabled` | `true` | Enable Comet shuffle | +| `spark.comet.exec.shuffle.mode` | `auto` | Shuffle mode: `native`, `jvm`, or `auto` | +| `spark.comet.exec.shuffle.partitionerMode` | `immediate` | Partitioner mode: `immediate` or `buffered` | +| `spark.comet.exec.shuffle.compression.codec` | `zstd` | Compression codec | +| `spark.comet.exec.shuffle.compression.zstd.level` | `1` | Zstd compression level | +| `spark.comet.shuffle.write.buffer.size` | `1MB` | Write buffer size | +| `spark.comet.columnar.shuffle.batch.size` | `8192` | Target rows per batch | ## Comparison with JVM Shuffle diff --git a/docs/source/user-guide/latest/tuning.md b/docs/source/user-guide/latest/tuning.md index 5939e89ef3..c47fe0a644 100644 --- a/docs/source/user-guide/latest/tuning.md +++ b/docs/source/user-guide/latest/tuning.md @@ -144,6 +144,17 @@ Comet provides a fully native shuffle implementation, which generally provides t supports `HashPartitioning`, `RangePartitioning` and `SinglePartitioning` but currently only supports primitive type partitioning keys. Columns that are not partitioning keys may contain complex types like maps, structs, and arrays. +Native shuffle has two partitioner modes, configured via +`spark.comet.exec.shuffle.partitionerMode`: + +- **`immediate`** (default): Writes partitioned Arrow IPC blocks to disk immediately as each batch + arrives. This mode uses less memory because it does not need to buffer the entire input before + writing. It is recommended for most workloads, especially large datasets. + +- **`buffered`**: Buffers all input rows in memory before partitioning and writing to disk. This + may improve performance for small datasets that fit in memory, but uses significantly more + memory. + #### Columnar (JVM) Shuffle Comet Columnar shuffle is JVM-based and supports `HashPartitioning`, `RoundRobinPartitioning`, `RangePartitioning`, and diff --git a/native/Cargo.lock b/native/Cargo.lock index 0cf1f20318..067b511fe1 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -96,12 +96,56 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.60.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.60.2", +] + [[package]] name = "anyhow" version = "1.0.102" @@ -288,7 +332,8 @@ dependencies = [ "arrow-schema", "arrow-select", "flatbuffers", - "lz4_flex 0.12.1", + "lz4_flex", + "zstd", ] [[package]] @@ -1331,6 +1376,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" dependencies = [ "clap_builder", + "clap_derive", ] [[package]] @@ -1339,8 +1385,22 @@ version = "4.5.60" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" dependencies = [ + "anstream", "anstyle", "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.117", ] [[package]] @@ -1358,6 +1418,12 @@ dependencies = [ "cc", ] +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + [[package]] name = "combine" version = "4.6.7" @@ -1953,6 +2019,7 @@ dependencies = [ "arrow", "async-trait", "bytes", + "clap", "crc32c", "crc32fast", "criterion", @@ -1964,12 +2031,10 @@ dependencies = [ "itertools 0.14.0", "jni", "log", - "lz4_flex 0.13.0", + "parquet", "simd-adler32", - "snap", "tempfile", "tokio", - "zstd", ] [[package]] @@ -3632,6 +3697,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itertools" version = "0.13.0" @@ -4007,15 +4078,6 @@ dependencies = [ "twox-hash", ] -[[package]] -name = "lz4_flex" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db9a0d582c2874f68138a16ce1867e0ffde6c0bb0a0df85e1f36d04146db488a" -dependencies = [ - "twox-hash", -] - [[package]] name = "md-5" version = "0.10.6" @@ -4312,6 +4374,12 @@ version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "oorandom" version = "11.1.5" @@ -4439,7 +4507,7 @@ dependencies = [ "futures", "half", "hashbrown 0.16.1", - "lz4_flex 0.12.1", + "lz4_flex", "num-bigint", "num-integer", "num-traits", @@ -6362,6 +6430,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.23.0" diff --git a/native/Cargo.toml b/native/Cargo.toml index c626743be1..3fb087e443 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -34,7 +34,7 @@ edition = "2021" rust-version = "1.88" [workspace.dependencies] -arrow = { version = "57.3.0", features = ["prettyprint", "ffi", "chrono-tz"] } +arrow = { version = "57.3.0", features = ["prettyprint", "ffi", "chrono-tz", "ipc_compression"] } async-trait = { version = "0.1" } bytes = { version = "1.11.1" } parquet = { version = "57.3.0", default-features = false, features = ["experimental"] } diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e0a395ebbf..2ac89b2669 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -61,7 +61,6 @@ use datafusion_spark::function::string::space::SparkSpace; use futures::poll; use futures::stream::StreamExt; use futures::FutureExt; -use jni::objects::JByteBuffer; use jni::sys::{jlongArray, JNI_FALSE}; use jni::{ errors::Result as JNIResult, @@ -83,7 +82,7 @@ use crate::execution::memory_pools::{ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig, }; use crate::execution::operators::{ScanExec, ShuffleScanExec}; -use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; +use crate::execution::shuffle::{CompressionCodec, ShuffleStreamReader}; use crate::execution::spark_plan::SparkPlan; use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace}; @@ -809,7 +808,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative let compression_codec = match compression_codec.as_str() { "zstd" => CompressionCodec::Zstd(compression_level), "lz4" => CompressionCodec::Lz4Frame, - "snappy" => CompressionCodec::Snappy, _ => CompressionCodec::Lz4Frame, }; @@ -876,29 +874,73 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( } #[no_mangle] -/// Used by Comet native shuffle reader +/// Open a shuffle stream reader over a JVM InputStream. +/// Returns an opaque handle (pointer) to a `ShuffleStreamReader`. /// # Safety /// This function is inherently unsafe since it deals with raw pointers passed from JNI. -pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( +pub unsafe extern "system" fn Java_org_apache_comet_Native_openShuffleStream( e: JNIEnv, _class: JClass, - byte_buffer: JByteBuffer, - length: jint, + input_stream: JObject, +) -> jlong { + try_unwrap_or_throw(&e, |mut env| { + let reader = + ShuffleStreamReader::new(&mut env, &input_stream).map_err(CometError::Internal)?; + let handle = Box::into_raw(Box::new(reader)); + Ok(handle as jlong) + }) +} + +#[no_mangle] +/// Read the next batch from a shuffle stream, exporting via Arrow FFI. +/// Returns the row count, or -1 if the stream is exhausted. +/// # Safety +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. +pub unsafe extern "system" fn Java_org_apache_comet_Native_nextShuffleStreamBatch( + e: JNIEnv, + _class: JClass, + handle: jlong, array_addrs: JLongArray, schema_addrs: JLongArray, - tracing_enabled: jboolean, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { - with_trace("decodeShuffleBlock", tracing_enabled != JNI_FALSE, || { - let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; - let length = length as usize; - let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; - let batch = read_ipc_compressed(slice)?; - prepare_output(&mut env, array_addrs, schema_addrs, batch, false) - }) + let reader = unsafe { &mut *(handle as *mut ShuffleStreamReader) }; + match reader.next_batch().map_err(CometError::Internal)? { + Some(batch) => prepare_output(&mut env, array_addrs, schema_addrs, batch, false), + None => Ok(-1_i64), + } }) } +#[no_mangle] +/// Get the number of fields in the shuffle stream's schema. +/// # Safety +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. +pub unsafe extern "system" fn Java_org_apache_comet_Native_shuffleStreamNumFields( + _e: JNIEnv, + _class: JClass, + handle: jlong, +) -> jlong { + let reader = unsafe { &*(handle as *mut ShuffleStreamReader) }; + reader.num_fields() as jlong +} + +#[no_mangle] +/// Close and drop a shuffle stream reader. +/// # Safety +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. +pub unsafe extern "system" fn Java_org_apache_comet_Native_closeShuffleStream( + _e: JNIEnv, + _class: JClass, + handle: jlong, +) { + if handle != 0 { + unsafe { + let _ = Box::from_raw(handle as *mut ShuffleStreamReader); + } + } +} + #[no_mangle] /// # Safety /// This function is inherently unsafe since it deals with raw pointers passed from JNI. diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index a1ad52310c..9a5d41af79 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -18,9 +18,9 @@ use crate::{ errors::CometError, execution::{ - operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, shuffle::ipc::read_ipc_compressed, + operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, shuffle::ShuffleStreamReader, }, - jvm_bridge::{jni_call, JVMClasses}, + jvm_bridge::JVMClasses, }; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -35,7 +35,7 @@ use datafusion::{ physical_plan::{ExecutionPlan, *}, }; use futures::Stream; -use jni::objects::{GlobalRef, JByteBuffer, JObject}; +use jni::objects::GlobalRef; use std::{ any::Any, pin::Pin, @@ -45,14 +45,13 @@ use std::{ use super::scan::InputBatch; -/// ShuffleScanExec reads compressed shuffle blocks from JVM via JNI and decodes them natively. -/// Unlike ScanExec which receives Arrow arrays via FFI, ShuffleScanExec receives raw compressed -/// bytes from CometShuffleBlockIterator and decodes them using read_ipc_compressed(). -#[derive(Debug, Clone)] +/// ShuffleScanExec reads Arrow IPC streams from JVM via JNI and decodes them natively. +/// Unlike ScanExec which receives Arrow arrays via FFI, ShuffleScanExec receives a raw +/// InputStream from JVM and reads Arrow IPC streams using ShuffleStreamReader. pub struct ShuffleScanExec { /// The ID of the execution context that owns this subquery. pub exec_context_id: i64, - /// The input source: a global reference to a JVM CometShuffleBlockIterator object. + /// The input source: a global reference to a JVM InputStream object. pub input_source: Option>, /// The data types of columns in the shuffle output. pub data_types: Vec, @@ -60,16 +59,48 @@ pub struct ShuffleScanExec { pub schema: SchemaRef, /// The current input batch, populated by get_next_batch() before poll_next(). pub batch: Arc>>, + /// Cached ShuffleStreamReader, created lazily on first get_next call. + stream_reader: Option, /// Cache of plan properties. cache: PlanProperties, /// Metrics collector. metrics: ExecutionPlanMetricsSet, /// Baseline metrics. baseline_metrics: BaselineMetrics, - /// Time spent decoding compressed shuffle blocks. + /// Time spent decoding shuffle batches. decode_time: Time, } +impl std::fmt::Debug for ShuffleScanExec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ShuffleScanExec") + .field("exec_context_id", &self.exec_context_id) + .field("data_types", &self.data_types) + .field("schema", &self.schema) + .field("stream_reader", &self.stream_reader.is_some()) + .finish() + } +} + +impl Clone for ShuffleScanExec { + fn clone(&self) -> Self { + Self { + exec_context_id: self.exec_context_id, + input_source: self.input_source.clone(), + data_types: self.data_types.clone(), + schema: Arc::clone(&self.schema), + batch: Arc::clone(&self.batch), + // stream_reader is not cloneable; cloned instances start without one + // and will lazily create their own if needed. + stream_reader: None, + cache: self.cache.clone(), + metrics: self.metrics.clone(), + baseline_metrics: self.baseline_metrics.clone(), + decode_time: self.decode_time.clone(), + } + } +} + impl ShuffleScanExec { pub fn new( exec_context_id: i64, @@ -94,6 +125,7 @@ impl ShuffleScanExec { input_source, data_types, batch: Arc::new(Mutex::new(None)), + stream_reader: None, cache, metrics: metrics_set, baseline_metrics, @@ -114,90 +146,86 @@ impl ShuffleScanExec { // Unit test mode - no JNI calls needed. return Ok(()); } - let mut timer = self.baseline_metrics.elapsed_compute().timer(); - let mut current_batch = self.batch.try_lock().unwrap(); - if current_batch.is_none() { - let next_batch = Self::get_next( - self.exec_context_id, - self.input_source.as_ref().unwrap().as_obj(), - &self.data_types, - &self.decode_time, - )?; + // Check if a batch is already pending without holding the lock during get_next + let needs_batch = { + let current_batch = self.batch.try_lock().unwrap(); + current_batch.is_none() + }; + + if needs_batch { + let start = std::time::Instant::now(); + let next_batch = self.get_next()?; + self.baseline_metrics + .elapsed_compute() + .add_duration(start.elapsed()); + let mut current_batch = self.batch.try_lock().unwrap(); *current_batch = Some(next_batch); } - timer.stop(); - Ok(()) } - /// Invokes JNI calls to get the next compressed shuffle block and decode it. - fn get_next( - exec_context_id: i64, - iter: &JObject, - data_types: &[DataType], - decode_time: &Time, - ) -> Result { - if exec_context_id == TEST_EXEC_CONTEXT_ID { + /// Reads the next batch from the ShuffleStreamReader, creating it lazily on first call. + fn get_next(&mut self) -> Result { + if self.exec_context_id == TEST_EXEC_CONTEXT_ID { return Ok(InputBatch::EOF); } - if iter.is_null() { - return Err(CometError::from(ExecutionError::GeneralError(format!( - "Null shuffle block iterator object. Plan id: {exec_context_id}" - )))); + // Lazily create the ShuffleStreamReader on first call + if self.stream_reader.is_none() { + let input_source = self.input_source.as_ref().ok_or_else(|| { + CometError::from(ExecutionError::GeneralError(format!( + "Null shuffle input source. Plan id: {}", + self.exec_context_id + ))) + })?; + let mut env = JVMClasses::get_env()?; + let reader = + ShuffleStreamReader::new(&mut env, input_source.as_obj()).map_err(|e| { + CometError::from(ExecutionError::GeneralError(format!( + "Failed to create ShuffleStreamReader: {e}" + ))) + })?; + self.stream_reader = Some(reader); } - let mut env = JVMClasses::get_env()?; + let reader = self.stream_reader.as_mut().unwrap(); + + let mut decode_timer = self.decode_time.timer(); + let batch_opt = reader.next_batch().map_err(|e| { + CometError::from(ExecutionError::GeneralError(format!( + "Failed to read shuffle batch: {e}" + ))) + })?; + decode_timer.stop(); + + match batch_opt { + None => Ok(InputBatch::EOF), + Some(batch) => { + let num_rows = batch.num_rows(); + + // Extract column arrays, unpacking any dictionary-encoded columns. + // Native shuffle may dictionary-encode string/binary columns for efficiency, + // but downstream DataFusion operators expect the value types declared in the + // schema (e.g. Utf8, not Dictionary). + let columns: Vec = batch + .columns() + .iter() + .map(|col| unpack_dictionary(col)) + .collect(); - // has_next() reads the next block and returns its length, or -1 if EOF - let block_length: i32 = unsafe { - jni_call!(&mut env, - comet_shuffle_block_iterator(iter).has_next() -> i32)? - }; + debug_assert_eq!( + columns.len(), + self.data_types.len(), + "Shuffle block column count mismatch: got {} but expected {}", + columns.len(), + self.data_types.len() + ); - if block_length == -1 { - return Ok(InputBatch::EOF); + Ok(InputBatch::new(columns, Some(num_rows))) + } } - - // Get the DirectByteBuffer containing the compressed shuffle block - let buffer: JObject = unsafe { - jni_call!(&mut env, - comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? - }; - - let byte_buffer = JByteBuffer::from(buffer); - let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; - let length = block_length as usize; - let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; - - // Decode the compressed IPC data - let mut timer = decode_time.timer(); - let batch = read_ipc_compressed(slice)?; - timer.stop(); - - let num_rows = batch.num_rows(); - - // Extract column arrays, unpacking any dictionary-encoded columns. - // Native shuffle may dictionary-encode string/binary columns for efficiency, - // but downstream DataFusion operators expect the value types declared in the - // schema (e.g. Utf8, not Dictionary). - let columns: Vec = batch - .columns() - .iter() - .map(|col| unpack_dictionary(col)) - .collect(); - - debug_assert_eq!( - columns.len(), - data_types.len(), - "Shuffle block column count mismatch: got {} but expected {}", - columns.len(), - data_types.len() - ); - - Ok(InputBatch::new(columns, Some(num_rows))) } } @@ -351,16 +379,15 @@ impl RecordBatchStream for ShuffleScanStream { #[cfg(test)] mod tests { - use crate::execution::shuffle::{CompressionCodec, ShuffleBlockWriter}; + use crate::execution::shuffle::CompressionCodec; use arrow::array::{Int32Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; + use arrow::ipc::reader::StreamReader; + use arrow::ipc::writer::StreamWriter; use arrow::record_batch::RecordBatch; - use datafusion::physical_plan::metrics::Time; use std::io::Cursor; use std::sync::Arc; - use crate::execution::shuffle::ipc::read_ipc_compressed; - #[test] #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd) fn test_read_compressed_ipc_block() { @@ -377,18 +404,18 @@ mod tests { ) .unwrap(); - // Write as compressed IPC - let writer = - ShuffleBlockWriter::try_new(&batch.schema(), CompressionCodec::Zstd(1)).unwrap(); - let mut buf = Cursor::new(Vec::new()); - let ipc_time = Time::new(); - writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); - - // Read back (skip 16-byte header: 8 compressed_length + 8 field_count) - let bytes = buf.into_inner(); - let body = &bytes[16..]; - - let decoded = read_ipc_compressed(body).unwrap(); + // Write as Arrow IPC stream with compression + let write_options = CompressionCodec::Zstd(1).ipc_write_options().unwrap(); + let mut buf = Vec::new(); + let mut writer = + StreamWriter::try_new_with_options(&mut buf, &batch.schema(), write_options).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + + // Read back using standard StreamReader + let cursor = Cursor::new(&buf); + let mut reader = StreamReader::try_new(cursor, None).unwrap(); + let decoded = reader.next().unwrap().unwrap(); assert_eq!(decoded.num_rows(), 3); assert_eq!(decoded.num_columns(), 2); @@ -404,9 +431,6 @@ mod tests { } /// Tests that ShuffleScanExec correctly unpacks dictionary-encoded columns. - /// Native shuffle may dictionary-encode string/binary columns, but the schema - /// declares value types (e.g. Utf8). Without unpacking, RecordBatch creation - /// fails with a schema mismatch. #[test] #[cfg_attr(miri, ignore)] fn test_dictionary_encoded_shuffle_block_is_unpacked() { @@ -416,15 +440,12 @@ mod tests { use datafusion::physical_plan::ExecutionPlan; use futures::StreamExt; - // Build a batch with a dictionary-encoded string column (simulating what - // the native shuffle writer produces for string columns). let mut dict_builder = StringDictionaryBuilder::::new(); dict_builder.append_value("hello"); dict_builder.append_value("world"); - dict_builder.append_value("hello"); // repeated value, good for dictionary + dict_builder.append_value("hello"); let dict_array = dict_builder.finish(); - // The IPC schema includes the dictionary type let dict_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new( @@ -442,19 +463,19 @@ mod tests { ) .unwrap(); - // Write as compressed IPC (preserves dictionary encoding) - let writer = - ShuffleBlockWriter::try_new(&dict_batch.schema(), CompressionCodec::Zstd(1)).unwrap(); - let mut buf = Cursor::new(Vec::new()); - let ipc_time = Time::new(); - writer - .write_batch(&dict_batch, &mut buf, &ipc_time) - .unwrap(); - let bytes = buf.into_inner(); - let body = &bytes[16..]; - - // Confirm that read_ipc_compressed returns dictionary-encoded arrays - let decoded = read_ipc_compressed(body).unwrap(); + // Write as Arrow IPC stream with compression + let write_options = CompressionCodec::Zstd(1).ipc_write_options().unwrap(); + let mut buf = Vec::new(); + let mut writer = + StreamWriter::try_new_with_options(&mut buf, &dict_batch.schema(), write_options) + .unwrap(); + writer.write(&dict_batch).unwrap(); + writer.finish().unwrap(); + + // Read back using standard StreamReader + let cursor = Cursor::new(&buf); + let mut reader = StreamReader::try_new(cursor, None).unwrap(); + let decoded = reader.next().unwrap().unwrap(); assert!( matches!(decoded.column(1).data_type(), DataType::Dictionary(_, _)), "Expected dictionary-encoded column from IPC, got {:?}", diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0f96c829e7..c150fb480a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1354,7 +1354,6 @@ impl PhysicalPlanner { let codec = match writer.codec.try_into() { Ok(SparkCompressionCodec::None) => Ok(CompressionCodec::None), - Ok(SparkCompressionCodec::Snappy) => Ok(CompressionCodec::Snappy), Ok(SparkCompressionCodec::Zstd) => { Ok(CompressionCodec::Zstd(writer.compression_level)) } @@ -1374,6 +1373,7 @@ impl PhysicalPlanner { writer.output_index_file.clone(), writer.tracing_enabled, write_buffer_size, + writer.immediate_mode, )?); Ok(( diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index a2e25c3e2f..fdc69c527e 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -181,12 +181,9 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; -mod shuffle_block_iterator; - use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; -use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. #[allow(dead_code)] // we need to keep references to Java items to prevent GC @@ -212,8 +209,6 @@ pub struct JVMClasses<'a> { pub comet_exec: CometExec<'a>, /// The CometBatchIterator class. Used for iterating over the batches. pub comet_batch_iterator: CometBatchIterator<'a>, - /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. - pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, @@ -267,7 +262,6 @@ impl JVMClasses<'_> { comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), comet_batch_iterator: CometBatchIterator::new(env).unwrap(), - comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), } }); diff --git a/native/jni-bridge/src/shuffle_block_iterator.rs b/native/jni-bridge/src/shuffle_block_iterator.rs deleted file mode 100644 index c3bb5af5fb..0000000000 --- a/native/jni-bridge/src/shuffle_block_iterator.rs +++ /dev/null @@ -1,62 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use jni::signature::Primitive; -use jni::{ - errors::Result as JniResult, - objects::{JClass, JMethodID}, - signature::ReturnType, - JNIEnv, -}; - -/// A struct that holds all the JNI methods and fields for JVM `CometShuffleBlockIterator` class. -#[allow(dead_code)] // we need to keep references to Java items to prevent GC -pub struct CometShuffleBlockIterator<'a> { - pub class: JClass<'a>, - pub method_has_next: JMethodID, - pub method_has_next_ret: ReturnType, - pub method_get_buffer: JMethodID, - pub method_get_buffer_ret: ReturnType, - pub method_get_current_block_length: JMethodID, - pub method_get_current_block_length_ret: ReturnType, -} - -impl<'a> CometShuffleBlockIterator<'a> { - pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; - - pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { - let class = env.find_class(Self::JVM_CLASS)?; - - Ok(CometShuffleBlockIterator { - class, - method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, - method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_get_buffer: env.get_method_id( - Self::JVM_CLASS, - "getBuffer", - "()Ljava/nio/ByteBuffer;", - )?, - method_get_buffer_ret: ReturnType::Object, - method_get_current_block_length: env.get_method_id( - Self::JVM_CLASS, - "getCurrentBlockLength", - "()I", - )?, - method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), - }) - } -} diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 344b9f0f21..5726484c0f 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -294,6 +294,10 @@ message ShuffleWriter { // Size of the write buffer in bytes used when writing shuffle data to disk. // Larger values may improve write performance but use more memory. int32 write_buffer_size = 8; + // Whether to use immediate mode partitioner. When true, partitioned IPC blocks + // are written immediately as batches arrive. When false, rows are buffered + // before writing (the original behavior). + bool immediate_mode = 9; } message ParquetWriter { diff --git a/native/shuffle/Cargo.toml b/native/shuffle/Cargo.toml index 5cd7cd43ef..9528a66727 100644 --- a/native/shuffle/Cargo.toml +++ b/native/shuffle/Cargo.toml @@ -32,6 +32,7 @@ publish = false arrow = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } +clap = { version = "4", features = ["derive"], optional = true } crc32c = "0.6.8" crc32fast = "1.3.2" datafusion = { workspace = true } @@ -42,11 +43,10 @@ futures = { workspace = true } itertools = "0.14.0" jni = "0.21" log = "0.4" -lz4_flex = { version = "0.13.0", default-features = false, features = ["frame"] } +# parquet is only used by the shuffle_bench binary (shuffle-bench feature) +parquet = { workspace = true, optional = true } simd-adler32 = "0.3.9" -snap = "1.1" tokio = { version = "1", features = ["rt-multi-thread"] } -zstd = "0.13.3" [dev-dependencies] criterion = { version = "0.7", features = ["async", "async_tokio", "async_std"] } @@ -54,10 +54,18 @@ datafusion = { workspace = true, features = ["parquet_encryption", "sql"] } itertools = "0.14.0" tempfile = "3.26.0" +[features] +shuffle-bench = ["clap", "parquet"] + [lib] name = "datafusion_comet_shuffle" path = "src/lib.rs" +[[bin]] +name = "shuffle_bench" +path = "src/bin/shuffle_bench.rs" +required-features = ["shuffle-bench"] + [[bench]] name = "shuffle_writer" harness = false diff --git a/native/shuffle/README.md b/native/shuffle/README.md index 8fba6b0323..74b8dbe656 100644 --- a/native/shuffle/README.md +++ b/native/shuffle/README.md @@ -23,3 +23,46 @@ This crate provides the shuffle writer and reader implementation for Apache Data of the [Apache DataFusion Comet] subproject. [Apache DataFusion Comet]: https://github.com/apache/datafusion-comet/ + +## Shuffle Benchmark Tool + +A standalone benchmark binary (`shuffle_bench`) is included for profiling shuffle write and read +performance outside of Spark. It streams input data directly from Parquet files. + +### Basic usage + +```sh +cargo run --release --features shuffle-bench --bin shuffle_bench -- \ + --input /data/tpch-sf100/lineitem/ \ + --partitions 200 \ + --codec zstd --zstd-level 1 \ + --hash-columns 0,3 +``` + +### Options + +| Option | Default | Description | +| ------------------------ | -------------------------- | ------------------------------------------------------------ | +| `--input` | _(required)_ | Path to a Parquet file or directory of Parquet files | +| `--partitions` | `200` | Number of output shuffle partitions | +| `--partitioning` | `hash` | Partitioning scheme: `hash`, `single`, `round-robin` | +| `--hash-columns` | `0` | Comma-separated column indices to hash on (e.g. `0,3`) | +| `--codec` | `zstd` | Compression codec: `none`, `lz4`, `zstd`, `snappy` | +| `--zstd-level` | `1` | Zstd compression level (1–22) | +| `--batch-size` | `8192` | Batch size for reading Parquet data | +| `--memory-limit` | _(none)_ | Memory limit in bytes; triggers spilling when exceeded | +| `--max-buffered-batches` | `0` | Max batches to buffer before spilling (0 = memory-pool-only) | +| `--write-buffer-size` | `1048576` | Write buffer size in bytes | +| `--limit` | `0` | Limit rows processed per iteration (0 = no limit) | +| `--iterations` | `1` | Number of timed iterations | +| `--warmup` | `0` | Number of warmup iterations before timing | +| `--read-back` | `false` | Also benchmark reading back the shuffle output | +| `--output-dir` | `/tmp/comet_shuffle_bench` | Directory for temporary shuffle output files | + +### Profiling with flamegraph + +```sh +cargo flamegraph --release --features shuffle-bench --bin shuffle_bench -- \ + --input /data/tpch-sf100/lineitem/ \ + --partitions 200 --codec zstd --zstd-level 1 +``` diff --git a/native/shuffle/benches/shuffle_writer.rs b/native/shuffle/benches/shuffle_writer.rs index 27abd919fa..e71f83f387 100644 --- a/native/shuffle/benches/shuffle_writer.rs +++ b/native/shuffle/benches/shuffle_writer.rs @@ -18,22 +18,19 @@ use arrow::array::builder::{Date32Builder, Decimal128Builder, Int32Builder}; use arrow::array::{builder::StringBuilder, Array, Int32Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; +use arrow::ipc::writer::StreamWriter; use arrow::row::{RowConverter, SortField}; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::physical_expr::expressions::{col, Column}; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion::physical_plan::metrics::Time; use datafusion::{ physical_plan::{common::collect, ExecutionPlan}, prelude::SessionContext, }; -use datafusion_comet_shuffle::{ - CometPartitioning, CompressionCodec, ShuffleBlockWriter, ShuffleWriterExec, -}; +use datafusion_comet_shuffle::{CometPartitioning, CompressionCodec, ShuffleWriterExec}; use itertools::Itertools; -use std::io::Cursor; use std::sync::Arc; use tokio::runtime::Runtime; @@ -43,20 +40,22 @@ fn criterion_benchmark(c: &mut Criterion) { for compression_codec in &[ CompressionCodec::None, CompressionCodec::Lz4Frame, - CompressionCodec::Snappy, CompressionCodec::Zstd(1), CompressionCodec::Zstd(6), ] { let name = format!("shuffle_writer: write encoded (compression={compression_codec:?})"); group.bench_function(name, |b| { - let mut buffer = vec![]; - let ipc_time = Time::default(); - let w = - ShuffleBlockWriter::try_new(&batch.schema(), compression_codec.clone()).unwrap(); + let write_options = compression_codec.ipc_write_options().unwrap(); b.iter(|| { - buffer.clear(); - let mut cursor = Cursor::new(&mut buffer); - w.write_batch(&batch, &mut cursor, &ipc_time).unwrap(); + let mut buffer = Vec::new(); + let mut writer = StreamWriter::try_new_with_options( + &mut buffer, + &batch.schema(), + write_options.clone(), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); }); }); } @@ -64,7 +63,6 @@ fn criterion_benchmark(c: &mut Criterion) { for compression_codec in [ CompressionCodec::None, CompressionCodec::Lz4Frame, - CompressionCodec::Snappy, CompressionCodec::Zstd(1), CompressionCodec::Zstd(6), ] { @@ -153,6 +151,7 @@ fn create_shuffle_writer_exec( "/tmp/index.out".to_string(), false, 1024 * 1024, + false, // immediate_mode ) .unwrap() } diff --git a/native/shuffle/src/bin/shuffle_bench.rs b/native/shuffle/src/bin/shuffle_bench.rs new file mode 100644 index 0000000000..78d4072d22 --- /dev/null +++ b/native/shuffle/src/bin/shuffle_bench.rs @@ -0,0 +1,780 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Standalone shuffle benchmark tool for profiling Comet shuffle write and read +//! outside of Spark. Streams input directly from Parquet files. +//! +//! # Usage +//! +//! ```sh +//! cargo run --release --bin shuffle_bench -- \ +//! --input /data/tpch-sf100/lineitem/ \ +//! --partitions 200 \ +//! --codec zstd --zstd-level 1 \ +//! --hash-columns 0,3 \ +//! --read-back +//! ``` +//! +//! Profile with flamegraph: +//! ```sh +//! cargo flamegraph --release --bin shuffle_bench -- \ +//! --input /data/tpch-sf100/lineitem/ \ +//! --partitions 200 --codec zstd --zstd-level 1 +//! ``` + +use arrow::datatypes::{DataType, SchemaRef}; +use arrow::ipc::reader::StreamReader; +use clap::Parser; +use datafusion::execution::config::SessionConfig; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::common::collect; +use datafusion::physical_plan::metrics::{MetricValue, MetricsSet}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_comet_shuffle::{CometPartitioning, CompressionCodec, ShuffleWriterExec}; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Instant; + +#[derive(Parser, Debug)] +#[command( + name = "shuffle_bench", + about = "Standalone benchmark for Comet shuffle write and read performance" +)] +struct Args { + /// Path to input Parquet file or directory of Parquet files + #[arg(long)] + input: PathBuf, + + /// Batch size for reading Parquet data + #[arg(long, default_value_t = 8192)] + batch_size: usize, + + /// Number of output shuffle partitions + #[arg(long, default_value_t = 200)] + partitions: usize, + + /// Partitioning scheme: hash, single, round-robin + #[arg(long, default_value = "hash")] + partitioning: String, + + /// Column indices to hash on (comma-separated, e.g. "0,3") + #[arg(long, default_value = "0")] + hash_columns: String, + + /// Compression codec: none, lz4, zstd + #[arg(long, default_value = "zstd")] + codec: String, + + /// Zstd compression level (1-22) + #[arg(long, default_value_t = 1)] + zstd_level: i32, + + /// Memory limit in bytes (triggers spilling when exceeded) + #[arg(long)] + memory_limit: Option, + + /// Also benchmark reading back the shuffle output + #[arg(long, default_value_t = false)] + read_back: bool, + + /// Number of iterations to run + #[arg(long, default_value_t = 1)] + iterations: usize, + + /// Number of warmup iterations before timing + #[arg(long, default_value_t = 0)] + warmup: usize, + + /// Output directory for shuffle data/index files + #[arg(long, default_value = "/tmp/comet_shuffle_bench")] + output_dir: PathBuf, + + /// Write buffer size in bytes + #[arg(long, default_value_t = 1048576)] + write_buffer_size: usize, + + /// Limit rows processed per iteration (0 = no limit) + #[arg(long, default_value_t = 0)] + limit: usize, + + /// Number of concurrent shuffle tasks to simulate executor parallelism. + /// Each task reads the same input and writes to its own output files. + #[arg(long, default_value_t = 1)] + concurrent_tasks: usize, + + /// Shuffle mode: 'immediate' writes IPC blocks per batch as they arrive, + /// 'buffered' buffers all rows before writing (original behavior). + #[arg(long, default_value = "immediate")] + mode: String, +} + +fn main() { + let args = Args::parse(); + + // Create output directory + fs::create_dir_all(&args.output_dir).expect("Failed to create output directory"); + let data_file = args.output_dir.join("data.out"); + let index_file = args.output_dir.join("index.out"); + + let (schema, total_rows) = read_parquet_metadata(&args.input, args.limit); + + let codec = parse_codec(&args.codec, args.zstd_level); + let hash_col_indices = parse_hash_columns(&args.hash_columns); + + println!("=== Shuffle Benchmark ==="); + println!("Input: {}", args.input.display()); + println!( + "Schema: {} columns ({})", + schema.fields().len(), + describe_schema(&schema) + ); + println!("Total rows: {}", format_number(total_rows as usize)); + println!("Batch size: {}", format_number(args.batch_size)); + println!("Partitioning: {}", args.partitioning); + println!("Partitions: {}", args.partitions); + println!("Codec: {:?}", codec); + println!("Mode: {}", args.mode); + println!("Hash columns: {:?}", hash_col_indices); + if let Some(mem_limit) = args.memory_limit { + println!("Memory limit: {}", format_bytes(mem_limit)); + } + if args.concurrent_tasks > 1 { + println!("Concurrent: {} tasks", args.concurrent_tasks); + } + println!( + "Iterations: {} (warmup: {})", + args.iterations, args.warmup + ); + println!(); + + let total_iters = args.warmup + args.iterations; + let mut write_times = Vec::with_capacity(args.iterations); + let mut read_times = Vec::with_capacity(args.iterations); + let mut data_file_sizes = Vec::with_capacity(args.iterations); + let mut last_metrics: Option = None; + let mut last_input_metrics: Option = None; + + for i in 0..total_iters { + let is_warmup = i < args.warmup; + let label = if is_warmup { + format!("warmup {}/{}", i + 1, args.warmup) + } else { + format!("iter {}/{}", i - args.warmup + 1, args.iterations) + }; + + let (write_elapsed, metrics, input_metrics) = if args.concurrent_tasks > 1 { + let elapsed = run_concurrent_shuffle_writes( + &args.input, + &schema, + &codec, + &hash_col_indices, + &args, + ); + (elapsed, None, None) + } else { + run_shuffle_write( + &args.input, + &schema, + &codec, + &hash_col_indices, + &args, + data_file.to_str().unwrap(), + index_file.to_str().unwrap(), + ) + }; + let data_size = fs::metadata(&data_file).map(|m| m.len()).unwrap_or(0); + + if !is_warmup { + write_times.push(write_elapsed); + data_file_sizes.push(data_size); + last_metrics = metrics; + last_input_metrics = input_metrics; + } + + print!(" [{label}] write: {:.3}s", write_elapsed); + if args.concurrent_tasks <= 1 { + print!(" output: {}", format_bytes(data_size as usize)); + } + + if args.read_back && args.concurrent_tasks <= 1 { + let read_elapsed = run_shuffle_read( + data_file.to_str().unwrap(), + index_file.to_str().unwrap(), + args.partitions, + ); + if !is_warmup { + read_times.push(read_elapsed); + } + print!(" read: {:.3}s", read_elapsed); + } + println!(); + + // Remove output files after each iteration to avoid filling disk + let _ = fs::remove_file(&data_file); + let _ = fs::remove_file(&index_file); + } + + if args.iterations > 0 { + println!(); + println!("=== Results ==="); + + let avg_write = write_times.iter().sum::() / write_times.len() as f64; + let write_throughput_rows = (total_rows as f64 * args.concurrent_tasks as f64) / avg_write; + + println!("Write:"); + println!(" avg time: {:.3}s", avg_write); + if write_times.len() > 1 { + let min = write_times.iter().cloned().fold(f64::INFINITY, f64::min); + let max = write_times + .iter() + .cloned() + .fold(f64::NEG_INFINITY, f64::max); + println!(" min/max: {:.3}s / {:.3}s", min, max); + } + println!( + " throughput: {} rows/s (total across {} tasks)", + format_number(write_throughput_rows as usize), + args.concurrent_tasks + ); + if args.concurrent_tasks <= 1 { + let avg_data_size = data_file_sizes.iter().sum::() / data_file_sizes.len() as u64; + println!( + " output size: {}", + format_bytes(avg_data_size as usize) + ); + } + + if !read_times.is_empty() { + let avg_data_size = data_file_sizes.iter().sum::() / data_file_sizes.len() as u64; + let avg_read = read_times.iter().sum::() / read_times.len() as f64; + let read_throughput_bytes = avg_data_size as f64 / avg_read; + + println!("Read:"); + println!(" avg time: {:.3}s", avg_read); + if read_times.len() > 1 { + let min = read_times.iter().cloned().fold(f64::INFINITY, f64::min); + let max = read_times.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + println!(" min/max: {:.3}s / {:.3}s", min, max); + } + println!( + " throughput: {}/s (from compressed)", + format_bytes(read_throughput_bytes as usize) + ); + } + + if let Some(ref metrics) = last_input_metrics { + println!(); + println!("Input Metrics (last iteration):"); + print_input_metrics(metrics); + } + + if let Some(ref metrics) = last_metrics { + println!(); + println!("Shuffle Metrics (last iteration):"); + print_shuffle_metrics(metrics, avg_write); + } + } + + let _ = fs::remove_file(&data_file); + let _ = fs::remove_file(&index_file); +} + +fn print_shuffle_metrics(metrics: &MetricsSet, total_wall_time_secs: f64) { + let get_metric = |name: &str| -> Option { + metrics + .iter() + .find(|m| m.value().name() == name) + .map(|m| m.value().as_usize()) + }; + + let total_ns = (total_wall_time_secs * 1e9) as u64; + let fmt_time = |nanos: usize| -> String { + let secs = nanos as f64 / 1e9; + let pct = if total_ns > 0 { + (nanos as f64 / total_ns as f64) * 100.0 + } else { + 0.0 + }; + format!("{:.3}s ({:.1}%)", secs, pct) + }; + + if let Some(input_batches) = get_metric("input_batches") { + println!(" input batches: {}", format_number(input_batches)); + } + if let Some(nanos) = get_metric("repart_time") { + println!(" repart time: {}", fmt_time(nanos)); + } + if let Some(nanos) = get_metric("encode_time") { + println!(" encode time: {}", fmt_time(nanos)); + } + if let Some(nanos) = get_metric("write_time") { + println!(" write time: {}", fmt_time(nanos)); + } + if let Some(nanos) = get_metric("interleave_time") { + println!(" interleave time: {}", fmt_time(nanos)); + } + if let Some(nanos) = get_metric("coalesce_time") { + println!(" coalesce time: {}", fmt_time(nanos)); + } + if let Some(nanos) = get_metric("memcopy_time") { + println!(" memcopy time: {}", fmt_time(nanos)); + } + + if let Some(spill_count) = get_metric("spill_count") { + if spill_count > 0 { + println!(" spill count: {}", format_number(spill_count)); + } + } + if let Some(spilled_bytes) = get_metric("spilled_bytes") { + if spilled_bytes > 0 { + println!(" spilled bytes: {}", format_bytes(spilled_bytes)); + } + } + if let Some(data_size) = get_metric("data_size") { + if data_size > 0 { + println!(" data size: {}", format_bytes(data_size)); + } + } +} + +fn print_input_metrics(metrics: &MetricsSet) { + let aggregated = metrics.aggregate_by_name(); + for m in aggregated.iter() { + let value = m.value(); + let name = value.name(); + let v = value.as_usize(); + if v == 0 { + continue; + } + // Format time metrics as seconds, everything else as a number + // Skip start/end timestamps — not useful in benchmark output + if matches!( + value, + MetricValue::StartTimestamp(_) | MetricValue::EndTimestamp(_) + ) { + continue; + } + let is_time = matches!( + value, + MetricValue::ElapsedCompute(_) | MetricValue::Time { .. } + ); + if is_time { + println!(" {name}: {:.3}s", v as f64 / 1e9); + } else if name.contains("bytes") || name.contains("size") { + println!(" {name}: {}", format_bytes(v)); + } else { + println!(" {name}: {}", format_number(v)); + } + } +} + +/// Read schema and total row count from Parquet metadata without loading any data. +fn read_parquet_metadata(path: &Path, limit: usize) -> (SchemaRef, u64) { + let paths = collect_parquet_paths(path); + let mut schema = None; + let mut total_rows = 0u64; + + for file_path in &paths { + let file = fs::File::open(file_path) + .unwrap_or_else(|e| panic!("Failed to open {}: {}", file_path.display(), e)); + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap_or_else(|e| { + panic!( + "Failed to read Parquet metadata from {}: {}", + file_path.display(), + e + ) + }); + if schema.is_none() { + schema = Some(Arc::clone(builder.schema())); + } + total_rows += builder.metadata().file_metadata().num_rows() as u64; + if limit > 0 && total_rows >= limit as u64 { + total_rows = total_rows.min(limit as u64); + break; + } + } + + (schema.expect("No parquet files found"), total_rows) +} + +fn collect_parquet_paths(path: &Path) -> Vec { + if path.is_dir() { + let mut files: Vec = fs::read_dir(path) + .unwrap_or_else(|e| panic!("Failed to read directory {}: {}", path.display(), e)) + .filter_map(|entry| { + let p = entry.ok()?.path(); + if p.extension().and_then(|e| e.to_str()) == Some("parquet") { + Some(p) + } else { + None + } + }) + .collect(); + files.sort(); + if files.is_empty() { + panic!("No .parquet files found in {}", path.display()); + } + files + } else { + vec![path.to_path_buf()] + } +} + +fn run_shuffle_write( + input_path: &Path, + schema: &SchemaRef, + codec: &CompressionCodec, + hash_col_indices: &[usize], + args: &Args, + data_file: &str, + index_file: &str, +) -> (f64, Option, Option) { + let partitioning = build_partitioning( + &args.partitioning, + args.partitions, + hash_col_indices, + schema, + ); + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let start = Instant::now(); + let (shuffle_metrics, input_metrics) = execute_shuffle_write( + input_path.to_str().unwrap(), + codec.clone(), + partitioning, + args.batch_size, + args.memory_limit, + args.write_buffer_size, + args.limit, + data_file.to_string(), + index_file.to_string(), + args.mode == "immediate", + ) + .await + .unwrap(); + ( + start.elapsed().as_secs_f64(), + Some(shuffle_metrics), + Some(input_metrics), + ) + }) +} + +/// Core async shuffle write logic shared by single and concurrent paths. +#[allow(clippy::too_many_arguments)] +async fn execute_shuffle_write( + input_path: &str, + codec: CompressionCodec, + partitioning: CometPartitioning, + batch_size: usize, + memory_limit: Option, + write_buffer_size: usize, + limit: usize, + data_file: String, + index_file: String, + immediate_mode: bool, +) -> datafusion::common::Result<(MetricsSet, MetricsSet)> { + let config = SessionConfig::new().with_batch_size(batch_size); + let mut runtime_builder = RuntimeEnvBuilder::new(); + if let Some(mem_limit) = memory_limit { + runtime_builder = runtime_builder.with_memory_limit(mem_limit, 1.0); + } + let runtime_env = Arc::new(runtime_builder.build().unwrap()); + let ctx = SessionContext::new_with_config_rt(config, runtime_env); + + let mut df = ctx + .read_parquet(input_path, ParquetReadOptions::default()) + .await + .expect("Failed to create Parquet scan"); + if limit > 0 { + df = df.limit(0, Some(limit)).unwrap(); + } + + let parquet_plan = df + .create_physical_plan() + .await + .expect("Failed to create physical plan"); + + let input: Arc = if parquet_plan + .properties() + .output_partitioning() + .partition_count() + > 1 + { + Arc::new(CoalescePartitionsExec::new(parquet_plan)) + } else { + parquet_plan + }; + + let exec = ShuffleWriterExec::try_new( + input, + partitioning, + codec, + data_file, + index_file, + false, + write_buffer_size, + immediate_mode, + ) + .expect("Failed to create ShuffleWriterExec"); + + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).unwrap(); + collect(stream).await.unwrap(); + + // Collect metrics from the input plan (Parquet scan + optional coalesce) + let input_metrics = collect_input_metrics(&exec); + + Ok((exec.metrics().unwrap_or_default(), input_metrics)) +} + +/// Walk the plan tree and aggregate all metrics from input plans (everything below shuffle writer). +fn collect_input_metrics(exec: &ShuffleWriterExec) -> MetricsSet { + let mut all_metrics = MetricsSet::new(); + fn gather(plan: &dyn ExecutionPlan, out: &mut MetricsSet) { + if let Some(metrics) = plan.metrics() { + for m in metrics.iter() { + out.push(Arc::clone(m)); + } + } + for child in plan.children() { + gather(child.as_ref(), out); + } + } + for child in exec.children() { + gather(child.as_ref(), &mut all_metrics); + } + all_metrics +} + +/// Run N concurrent shuffle writes to simulate executor parallelism. +/// Returns wall-clock time for all tasks to complete. +fn run_concurrent_shuffle_writes( + input_path: &Path, + schema: &SchemaRef, + codec: &CompressionCodec, + hash_col_indices: &[usize], + args: &Args, +) -> f64 { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let start = Instant::now(); + + let mut handles = Vec::with_capacity(args.concurrent_tasks); + for task_id in 0..args.concurrent_tasks { + let task_dir = args.output_dir.join(format!("task_{task_id}")); + fs::create_dir_all(&task_dir).expect("Failed to create task output directory"); + let data_file = task_dir.join("data.out").to_str().unwrap().to_string(); + let index_file = task_dir.join("index.out").to_str().unwrap().to_string(); + + let input_str = input_path.to_str().unwrap().to_string(); + let codec = codec.clone(); + let partitioning = build_partitioning( + &args.partitioning, + args.partitions, + hash_col_indices, + schema, + ); + let batch_size = args.batch_size; + let memory_limit = args.memory_limit; + let write_buffer_size = args.write_buffer_size; + let limit = args.limit; + let immediate_mode = args.mode == "immediate"; + + handles.push(tokio::spawn(async move { + execute_shuffle_write( + &input_str, + codec, + partitioning, + batch_size, + memory_limit, + write_buffer_size, + limit, + data_file, + index_file, + immediate_mode, + ) + .await + .unwrap() + })); + } + + for handle in handles { + handle.await.expect("Task panicked"); + } + + for task_id in 0..args.concurrent_tasks { + let task_dir = args.output_dir.join(format!("task_{task_id}")); + let _ = fs::remove_dir_all(&task_dir); + } + + start.elapsed().as_secs_f64() + }) +} + +fn run_shuffle_read(data_file: &str, index_file: &str, num_partitions: usize) -> f64 { + let start = Instant::now(); + + let index_bytes = fs::read(index_file).expect("Failed to read index file"); + let num_offsets = index_bytes.len() / 8; + let offsets: Vec = (0..num_offsets) + .map(|i| { + let bytes: [u8; 8] = index_bytes[i * 8..(i + 1) * 8].try_into().unwrap(); + i64::from_le_bytes(bytes) + }) + .collect(); + + let data_bytes = fs::read(data_file).expect("Failed to read data file"); + + let mut total_rows = 0usize; + let mut total_batches = 0usize; + + for p in 0..num_partitions.min(offsets.len().saturating_sub(1)) { + let start_offset = offsets[p] as usize; + let end_offset = offsets[p + 1] as usize; + + if start_offset >= end_offset { + continue; + } + + // Each partition's data contains one or more complete IPC streams + let partition_data = &data_bytes[start_offset..end_offset]; + let mut reader = + StreamReader::try_new(partition_data, None).expect("Failed to open IPC stream"); + while let Some(batch_result) = reader.next() { + let batch = batch_result.expect("Failed to decode record batch"); + total_rows += batch.num_rows(); + total_batches += 1; + } + } + + let elapsed = start.elapsed().as_secs_f64(); + eprintln!( + " read back {} rows in {} batches from {} partitions", + format_number(total_rows), + total_batches, + num_partitions + ); + elapsed +} + +fn build_partitioning( + scheme: &str, + num_partitions: usize, + hash_col_indices: &[usize], + schema: &SchemaRef, +) -> CometPartitioning { + match scheme { + "single" => CometPartitioning::SinglePartition, + "round-robin" => CometPartitioning::RoundRobin(num_partitions, 0), + "hash" => { + let exprs: Vec> = hash_col_indices + .iter() + .map(|&idx| { + let field = schema.field(idx); + Arc::new(Column::new(field.name(), idx)) + as Arc + }) + .collect(); + CometPartitioning::Hash(exprs, num_partitions) + } + other => { + eprintln!("Unknown partitioning scheme: {other}. Using hash."); + build_partitioning("hash", num_partitions, hash_col_indices, schema) + } + } +} + +fn parse_codec(codec: &str, zstd_level: i32) -> CompressionCodec { + match codec.to_lowercase().as_str() { + "none" => CompressionCodec::None, + "lz4" => CompressionCodec::Lz4Frame, + "zstd" => CompressionCodec::Zstd(zstd_level), + other => { + eprintln!("Unknown codec: {other}. Using zstd."); + CompressionCodec::Zstd(zstd_level) + } + } +} + +fn parse_hash_columns(s: &str) -> Vec { + s.split(',') + .filter(|s| !s.is_empty()) + .map(|s| s.trim().parse::().expect("Invalid column index")) + .collect() +} + +fn describe_schema(schema: &arrow::datatypes::Schema) -> String { + let mut counts = std::collections::HashMap::new(); + for field in schema.fields() { + let type_name = match field.data_type() { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => "int", + DataType::Float16 | DataType::Float32 | DataType::Float64 => "float", + DataType::Utf8 | DataType::LargeUtf8 => "string", + DataType::Boolean => "bool", + DataType::Date32 | DataType::Date64 => "date", + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => "decimal", + DataType::Timestamp(_, _) => "timestamp", + DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => "binary", + _ => "other", + }; + *counts.entry(type_name).or_insert(0) += 1; + } + let mut parts: Vec = counts + .into_iter() + .map(|(k, v)| format!("{}x{}", v, k)) + .collect(); + parts.sort(); + parts.join(", ") +} + +fn format_number(n: usize) -> String { + let s = n.to_string(); + let mut result = String::new(); + for (i, c) in s.chars().rev().enumerate() { + if i > 0 && i % 3 == 0 { + result.push(','); + } + result.push(c); + } + result.chars().rev().collect() +} + +fn format_bytes(bytes: usize) -> String { + if bytes >= 1024 * 1024 * 1024 { + format!("{:.2} GiB", bytes as f64 / (1024.0 * 1024.0 * 1024.0)) + } else if bytes >= 1024 * 1024 { + format!("{:.2} MiB", bytes as f64 / (1024.0 * 1024.0)) + } else if bytes >= 1024 { + format!("{:.2} KiB", bytes as f64 / 1024.0) + } else { + format!("{bytes} B") + } +} diff --git a/native/shuffle/src/ipc.rs b/native/shuffle/src/ipc.rs index 81ee41332a..bd41deb4d8 100644 --- a/native/shuffle/src/ipc.rs +++ b/native/shuffle/src/ipc.rs @@ -17,36 +17,255 @@ use arrow::array::RecordBatch; use arrow::ipc::reader::StreamReader; -use datafusion::common::DataFusionError; -use datafusion::error::Result; - -pub fn read_ipc_compressed(bytes: &[u8]) -> Result { - match &bytes[0..4] { - b"SNAP" => { - let decoder = snap::read::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) +use jni::objects::{GlobalRef, JObject, JValue}; +use jni::JavaVM; +use std::io::Read; + +/// Size of the internal read-ahead buffer (64 KB). +const READ_AHEAD_BUF_SIZE: usize = 64 * 1024; + +/// A Rust `Read` implementation that pulls bytes from a JVM `java.io.InputStream` +/// via JNI callbacks, using an internal read-ahead buffer to minimize JNI crossings. +pub struct JniInputStream { + /// Handle to the JVM for attaching threads. + vm: JavaVM, + /// Global reference to the JVM InputStream object. + input_stream: GlobalRef, + /// Global reference to the JVM byte[] used for bulk reads. + jbuf: GlobalRef, + /// Internal Rust-side buffer holding bytes read from JVM. + buf: Vec, + /// Current read position within `buf`. + pos: usize, + /// Number of valid bytes in `buf`. + len: usize, +} + +impl JniInputStream { + /// Create a new `JniInputStream` wrapping a JVM InputStream. + pub fn new(env: &mut jni::JNIEnv, input_stream: &JObject) -> jni::errors::Result { + let vm = env.get_java_vm()?; + let input_stream = env.new_global_ref(input_stream)?; + let jbuf_local = env.new_byte_array(READ_AHEAD_BUF_SIZE as i32)?; + let jbuf = env.new_global_ref(&jbuf_local)?; + Ok(Self { + vm, + input_stream, + jbuf, + buf: vec![0u8; READ_AHEAD_BUF_SIZE], + pos: 0, + len: 0, + }) + } + + /// Refill the internal buffer by calling `InputStream.read(byte[], 0, len)` via JNI. + fn refill(&mut self) -> std::io::Result { + let mut env = self + .vm + .attach_current_thread_as_daemon() + .map_err(|e| std::io::Error::other(e.to_string()))?; + + // Get a local reference from the global ref for the byte array + let jbuf_local = env + .new_local_ref(self.jbuf.as_obj()) + .map_err(|e| std::io::Error::other(e.to_string()))?; + + let n = env + .call_method( + &self.input_stream, + "read", + "([BII)I", + &[ + JValue::Object(&jbuf_local), + JValue::Int(0), + JValue::Int(READ_AHEAD_BUF_SIZE as i32), + ], + ) + .map_err(|e| std::io::Error::other(e.to_string()))? + .i() + .map_err(|e| std::io::Error::other(e.to_string()))?; + + if n <= 0 { + // -1 means end of stream + self.pos = 0; + self.len = 0; + return Ok(0); } - b"LZ4_" => { - let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) + + let n = n as usize; + + // Copy bytes from JVM byte[] into our Rust buffer. + // jbyte is i8; we read into a temporary i8 slice then reinterpret as u8. + let mut i8_buf = vec![0i8; n]; + let jbuf_array = unsafe { jni::objects::JByteArray::from_raw(jbuf_local.as_raw()) }; + env.get_byte_array_region(&jbuf_array, 0, &mut i8_buf) + .map_err(|e| std::io::Error::other(e.to_string()))?; + + // Don't let the JByteArray drop free the local ref — it was created from + // a local ref that we don't own (it came from new_local_ref). + // Actually, JByteArray::from_raw takes ownership conceptually, but the local + // ref table manages it. We need to forget it so the underlying JObject local + // ref doesn't get deleted twice. The new_local_ref created it, and from_raw + // wrapped it. We should not drop jbuf_array since that would call + // DeleteLocalRef on the same raw jobject that jbuf_local already points to. + // However, JByteArray doesn't impl Drop with DeleteLocalRef — jni objects + // are plain wrappers. So this is fine. + + let src = unsafe { std::slice::from_raw_parts(i8_buf.as_ptr() as *const u8, n) }; + self.buf[..n].copy_from_slice(src); + self.pos = 0; + self.len = n; + + Ok(n) + } +} + +impl Read for JniInputStream { + fn read(&mut self, out: &mut [u8]) -> std::io::Result { + if self.pos >= self.len { + // Buffer is empty, refill + let filled = self.refill()?; + if filled == 0 { + return Ok(0); // EOF + } } - b"ZSTD" => { - let decoder = zstd::Decoder::new(&bytes[4..])?; - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) + + let available = self.len - self.pos; + let to_copy = available.min(out.len()); + out[..to_copy].copy_from_slice(&self.buf[self.pos..self.pos + to_copy]); + self.pos += to_copy; + Ok(to_copy) + } +} + +/// A wrapper around `JniInputStream` that allows `StreamReader` to borrow +/// it while still being able to create new `StreamReader` instances for +/// concatenated IPC streams. +/// +/// Uses a raw pointer to the `JniInputStream` stored in a `Box` so that +/// the `StreamReader` can take a `Read` impl without taking ownership. +struct SharedJniStream { + inner: *mut JniInputStream, +} + +impl SharedJniStream { + fn new(stream: JniInputStream) -> Self { + Self { + inner: Box::into_raw(Box::new(stream)), } - b"NONE" => { - let mut reader = - unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) + } + + /// Create a Read adapter that delegates to the inner stream. + fn reader(&self) -> StreamReadAdapter { + StreamReadAdapter { inner: self.inner } + } +} + +impl Drop for SharedJniStream { + fn drop(&mut self) { + unsafe { drop(Box::from_raw(self.inner)) }; + } +} + +// SAFETY: SharedJniStream owns the JniInputStream exclusively via a raw pointer. +// It is only accessed from a single thread at a time (the JNI thread that calls +// get_next_batch). The raw pointer is used to allow multiple sequential StreamReader +// instances to borrow the same underlying stream. +unsafe impl Send for SharedJniStream {} +unsafe impl Sync for SharedJniStream {} + +// SAFETY: StreamReadAdapter borrows from the same raw pointer as SharedJniStream. +// Same single-threaded access guarantees apply. +unsafe impl Send for StreamReadAdapter {} +unsafe impl Sync for StreamReadAdapter {} + +/// A Read adapter that delegates to a raw pointer to JniInputStream. +/// Multiple StreamReader instances can be created from this adapter +/// (sequentially, not concurrently). +struct StreamReadAdapter { + inner: *mut JniInputStream, +} + +impl Read for StreamReadAdapter { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + unsafe { (*self.inner).read(buf) } + } +} + +/// Manages reading potentially concatenated Arrow IPC streams from a JVM +/// InputStream. A single partition's data may contain multiple IPC streams +/// (e.g., from spills), so when one stream reaches EOS we attempt to open +/// the next one from the same underlying InputStream. +pub struct ShuffleStreamReader { + /// Shared ownership of the JniInputStream. + jni_stream: SharedJniStream, + /// Current Arrow IPC stream reader. `None` when all streams are exhausted. + reader: Option>, + num_fields: usize, +} + +impl ShuffleStreamReader { + /// Create a new `ShuffleStreamReader` over a JVM InputStream. + /// Returns a reader that yields no batches if the stream is empty. + pub fn new(env: &mut jni::JNIEnv, input_stream: &JObject) -> Result { + let jni_stream = SharedJniStream::new( + JniInputStream::new(env, input_stream).map_err(|e| format!("JNI error: {e}"))?, + ); + match StreamReader::try_new(jni_stream.reader(), None) { + Ok(reader) => { + let reader = unsafe { reader.with_skip_validation(true) }; + let num_fields = reader.schema().fields().len(); + Ok(Self { + jni_stream, + reader: Some(reader), + num_fields, + }) + } + Err(_) => { + // Empty stream — no data for this partition + Ok(Self { + jni_stream, + reader: None, + num_fields: 0, + }) + } } - other => Err(DataFusionError::Execution(format!( - "Failed to decode batch: invalid compression codec: {other:?}" - ))), + } + + /// Read the next batch from the stream. Returns `None` when all + /// concatenated IPC streams are exhausted. + pub fn next_batch(&mut self) -> Result, String> { + loop { + let reader = match &mut self.reader { + Some(r) => r, + None => return Ok(None), + }; + + match reader.next() { + Some(Ok(batch)) => return Ok(Some(batch)), + Some(Err(e)) => return Err(format!("Arrow IPC read error: {e}")), + None => { + // Current IPC stream exhausted. Drop the old reader and try + // to open the next concatenated stream. + self.reader = None; + + match StreamReader::try_new(self.jni_stream.reader(), None) { + Ok(new_reader) => { + self.reader = Some(unsafe { new_reader.with_skip_validation(true) }); + // Loop back to read from the new reader + } + Err(_) => { + // No more streams — the InputStream is exhausted + return Ok(None); + } + } + } + } + } + } + + /// Return the number of fields in the stream's schema. + pub fn num_fields(&self) -> usize { + self.num_fields } } diff --git a/native/shuffle/src/lib.rs b/native/shuffle/src/lib.rs index dd3b900272..1c31bda5ef 100644 --- a/native/shuffle/src/lib.rs +++ b/native/shuffle/src/lib.rs @@ -25,6 +25,6 @@ pub mod spark_unsafe; pub(crate) mod writers; pub use comet_partitioning::CometPartitioning; -pub use ipc::read_ipc_compressed; +pub use ipc::{JniInputStream, ShuffleStreamReader}; pub use shuffle_writer::ShuffleWriterExec; -pub use writers::{CompressionCodec, ShuffleBlockWriter}; +pub use writers::CompressionCodec; diff --git a/native/shuffle/src/partitioners/immediate_mode.rs b/native/shuffle/src/partitioners/immediate_mode.rs new file mode 100644 index 0000000000..c9dc24b754 --- /dev/null +++ b/native/shuffle/src/partitioners/immediate_mode.rs @@ -0,0 +1,1070 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::ShufflePartitionerMetrics; +use crate::partitioners::ShufflePartitioner; +use crate::{comet_partitioning, CometPartitioning, CompressionCodec}; +use arrow::array::builder::{ + make_builder, ArrayBuilder, BinaryBuilder, BinaryViewBuilder, BooleanBuilder, + LargeBinaryBuilder, LargeStringBuilder, NullBuilder, PrimitiveBuilder, StringBuilder, + StringViewBuilder, +}; +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryViewArray, RecordBatch, StringViewArray, UInt32Array, +}; +use arrow::compute::take; +use arrow::datatypes::{ + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, + UInt64Type, UInt8Type, +}; +use arrow::ipc::writer::{IpcWriteOptions, StreamWriter}; +use datafusion::common::{DataFusionError, Result}; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryLimit, MemoryReservation}; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion_comet_spark_expr::murmur3::create_murmur3_hashes; +use std::fs::{File, OpenOptions}; +use std::io::{BufWriter, Seek, Write}; +use std::sync::Arc; +use tokio::time::Instant; + +macro_rules! scatter_byte_array { + ($builder:expr, $source:expr, $indices:expr, $offset_type:ty, $builder_type:ty, $cast:ident) => {{ + let src = $source.$cast::<$offset_type>(); + let dst = $builder + .as_any_mut() + .downcast_mut::<$builder_type>() + .expect("builder type mismatch"); + if src.null_count() == 0 { + for &idx in $indices { + dst.append_value(src.value(idx)); + } + } else { + for &idx in $indices { + dst.append_option(src.is_valid(idx).then(|| src.value(idx))); + } + } + }}; +} + +macro_rules! scatter_byte_view { + ($builder:expr, $source:expr, $indices:expr, $array_type:ty, $builder_type:ty) => {{ + let src = $source + .as_any() + .downcast_ref::<$array_type>() + .expect("array type mismatch"); + let dst = $builder + .as_any_mut() + .downcast_mut::<$builder_type>() + .expect("builder type mismatch"); + if src.null_count() == 0 { + for &idx in $indices { + dst.append_value(src.value(idx)); + } + } else { + for &idx in $indices { + dst.append_option(src.is_valid(idx).then(|| src.value(idx))); + } + } + }}; +} + +macro_rules! scatter_primitive { + ($builder:expr, $source:expr, $indices:expr, $arrow_type:ty) => {{ + let src = $source.as_primitive::<$arrow_type>(); + let dst = $builder + .as_any_mut() + .downcast_mut::>() + .expect("builder type mismatch"); + if src.null_count() == 0 { + for &idx in $indices { + dst.append_value(src.value(idx)); + } + } else { + for &idx in $indices { + dst.append_option(src.is_valid(idx).then(|| src.value(idx))); + } + } + }}; +} + +/// Scatter-append selected rows from `source` into `builder`. +fn scatter_append( + builder: &mut dyn ArrayBuilder, + source: &dyn Array, + indices: &[usize], +) -> Result<()> { + use DataType::*; + match source.data_type() { + Boolean => { + let src = source.as_boolean(); + let dst = builder + .as_any_mut() + .downcast_mut::() + .unwrap(); + if src.null_count() == 0 { + for &idx in indices { + dst.append_value(src.value(idx)); + } + } else { + for &idx in indices { + dst.append_option(src.is_valid(idx).then(|| src.value(idx))); + } + } + } + Int8 => scatter_primitive!(builder, source, indices, Int8Type), + Int16 => scatter_primitive!(builder, source, indices, Int16Type), + Int32 => scatter_primitive!(builder, source, indices, Int32Type), + Int64 => scatter_primitive!(builder, source, indices, Int64Type), + UInt8 => scatter_primitive!(builder, source, indices, UInt8Type), + UInt16 => scatter_primitive!(builder, source, indices, UInt16Type), + UInt32 => scatter_primitive!(builder, source, indices, UInt32Type), + UInt64 => scatter_primitive!(builder, source, indices, UInt64Type), + Float32 => scatter_primitive!(builder, source, indices, Float32Type), + Float64 => scatter_primitive!(builder, source, indices, Float64Type), + Date32 => scatter_primitive!(builder, source, indices, Date32Type), + Date64 => scatter_primitive!(builder, source, indices, Date64Type), + Timestamp(TimeUnit::Second, _) => { + scatter_primitive!(builder, source, indices, TimestampSecondType) + } + Timestamp(TimeUnit::Millisecond, _) => { + scatter_primitive!(builder, source, indices, TimestampMillisecondType) + } + Timestamp(TimeUnit::Microsecond, _) => { + scatter_primitive!(builder, source, indices, TimestampMicrosecondType) + } + Timestamp(TimeUnit::Nanosecond, _) => { + scatter_primitive!(builder, source, indices, TimestampNanosecondType) + } + Decimal128(_, _) => scatter_primitive!(builder, source, indices, Decimal128Type), + Decimal256(_, _) => scatter_primitive!(builder, source, indices, Decimal256Type), + Utf8 => scatter_byte_array!(builder, source, indices, i32, StringBuilder, as_string), + LargeUtf8 => { + scatter_byte_array!(builder, source, indices, i64, LargeStringBuilder, as_string) + } + Binary => scatter_byte_array!(builder, source, indices, i32, BinaryBuilder, as_binary), + LargeBinary => { + scatter_byte_array!(builder, source, indices, i64, LargeBinaryBuilder, as_binary) + } + Utf8View => { + scatter_byte_view!(builder, source, indices, StringViewArray, StringViewBuilder) + } + BinaryView => { + scatter_byte_view!(builder, source, indices, BinaryViewArray, BinaryViewBuilder) + } + Null => { + let dst = builder.as_any_mut().downcast_mut::().unwrap(); + dst.append_nulls(indices.len()); + } + dt => { + return Err(DataFusionError::NotImplemented(format!( + "Scatter append not implemented for {dt}" + ))); + } + } + Ok(()) +} + +/// Per-column strategy: scatter-write via builder for primitive/string types, +/// or accumulate taken sub-arrays for complex types (List, Map, Struct, etc.). +enum ColumnBuffer { + /// Fast path: direct scatter into a pre-allocated builder. + Builder(Box), + /// Fallback for complex types: accumulate `take`-produced sub-arrays, + /// concatenate at flush time. + Accumulator(Vec), +} + +/// Returns true if `scatter_append` can handle this data type directly. +fn has_scatter_support(dt: &DataType) -> bool { + use DataType::*; + matches!( + dt, + Boolean + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Date32 + | Date64 + | Timestamp(_, _) + | Decimal128(_, _) + | Decimal256(_, _) + | Utf8 + | LargeUtf8 + | Binary + | LargeBinary + | Utf8View + | BinaryView + | Null + ) +} + +struct PartitionBuffer { + columns: Vec, + schema: SchemaRef, + num_rows: usize, + target_batch_size: usize, +} + +impl PartitionBuffer { + fn new(schema: &SchemaRef, target_batch_size: usize) -> Self { + let columns = schema + .fields() + .iter() + .map(|f| { + if has_scatter_support(f.data_type()) { + ColumnBuffer::Builder(make_builder(f.data_type(), target_batch_size)) + } else { + ColumnBuffer::Accumulator(Vec::new()) + } + }) + .collect(); + Self { + columns, + schema: Arc::clone(schema), + num_rows: 0, + target_batch_size, + } + } + + fn is_full(&self) -> bool { + self.num_rows >= self.target_batch_size + } + + /// Finish all columns into a RecordBatch. Builders are reset (retaining + /// capacity); accumulators are concatenated and cleared. + fn flush(&mut self) -> Result { + let arrays: Vec = self + .columns + .iter_mut() + .map(|col| match col { + ColumnBuffer::Builder(b) => b.finish(), + ColumnBuffer::Accumulator(chunks) => { + let refs: Vec<&dyn Array> = chunks.iter().map(|a| a.as_ref()).collect(); + let result = arrow::compute::concat(&refs) + .expect("concat failed for accumulated arrays"); + chunks.clear(); + result + } + }) + .collect(); + let batch = RecordBatch::try_new(Arc::clone(&self.schema), arrays) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + self.num_rows = 0; + Ok(batch) + } + + fn has_data(&self) -> bool { + self.num_rows > 0 + } +} + +/// Per-partition output stream that writes batches into a persistent Arrow IPC +/// `StreamWriter>`. The schema is written once when the writer is lazily +/// created. Arrow IPC body compression handles LZ4/ZSTD internally via `IpcWriteOptions`. +pub(crate) struct PartitionOutputStream { + schema: SchemaRef, + write_options: IpcWriteOptions, + /// Lazily created IPC stream writer over an in-memory buffer + writer: Option>>, + /// Accumulated spill data (bytes from finished IPC streams that were drained) + spilled_bytes: Vec, +} + +impl PartitionOutputStream { + pub(crate) fn try_new(schema: SchemaRef, write_options: IpcWriteOptions) -> Result { + Ok(Self { + schema, + write_options, + writer: None, + spilled_bytes: Vec::new(), + }) + } + + /// Ensure the writer exists (lazy creation), write the batch, and return bytes written. + fn write_batch(&mut self, batch: &RecordBatch) -> Result { + let before = self.current_buffer_len(); + let writer = match &mut self.writer { + Some(w) => w, + None => { + let w = StreamWriter::try_new_with_options( + Vec::new(), + &self.schema, + self.write_options.clone(), + )?; + self.writer = Some(w); + self.writer.as_mut().unwrap() + } + }; + writer.write(batch)?; + let after = self.current_buffer_len(); + Ok(after.saturating_sub(before)) + } + + /// Finish the current IPC stream (if any), return all accumulated bytes + /// (spilled + current stream), and reset the writer to None. + fn drain_buffer(&mut self) -> Result> { + if let Some(mut writer) = self.writer.take() { + writer.finish()?; + let buf = writer.into_inner()?; + self.spilled_bytes.extend_from_slice(&buf); + } + Ok(std::mem::take(&mut self.spilled_bytes)) + } + + /// Finish the current IPC stream and move its bytes into spilled_bytes, + /// resetting the writer to None so a new stream can be started later. + fn finish_current_stream(&mut self) -> Result<()> { + if let Some(mut writer) = self.writer.take() { + writer.finish()?; + let buf = writer.into_inner()?; + self.spilled_bytes.extend_from_slice(&buf); + } + Ok(()) + } + + fn current_buffer_len(&self) -> usize { + let writer_len = self.writer.as_ref().map(|w| w.get_ref().len()).unwrap_or(0); + self.spilled_bytes.len() + writer_len + } +} + +struct SpillFile { + _temp_file: datafusion::execution::disk_manager::RefCountedTempFile, + file: File, +} + +/// A partitioner that scatter-writes incoming rows directly into pre-allocated +/// per-partition column builders. When a partition's builders reach +/// `target_batch_size`, the batch is flushed to a compressed IPC block. +/// No intermediate sub-batches or coalescers are created. +pub(crate) struct ImmediateModePartitioner { + output_data_file: String, + output_index_file: String, + partition_buffers: Vec, + streams: Vec, + spill_files: Vec>, + partitioning: CometPartitioning, + runtime: Arc, + reservation: MemoryReservation, + metrics: ShufflePartitionerMetrics, + hashes_buf: Vec, + partition_ids: Vec, + /// Reusable per-partition row index scratch space. + partition_row_indices: Vec>, + /// Maximum bytes this partitioner will reserve from the memory pool. + /// Computed as memory_pool_size * memory_fraction at construction. + memory_limit: usize, +} + +impl ImmediateModePartitioner { + #[allow(clippy::too_many_arguments)] + pub(crate) fn try_new( + partition: usize, + output_data_file: String, + output_index_file: String, + schema: SchemaRef, + partitioning: CometPartitioning, + metrics: ShufflePartitionerMetrics, + runtime: Arc, + batch_size: usize, + codec: CompressionCodec, + ) -> Result { + let num_output_partitions = partitioning.partition_count(); + let write_options = codec.ipc_write_options()?; + + let partition_buffers = (0..num_output_partitions) + .map(|_| PartitionBuffer::new(&schema, batch_size)) + .collect(); + + let streams = (0..num_output_partitions) + .map(|_| PartitionOutputStream::try_new(Arc::clone(&schema), write_options.clone())) + .collect::>>()?; + + let spill_files: Vec> = + (0..num_output_partitions).map(|_| None).collect(); + + let hashes_buf = match &partitioning { + CometPartitioning::Hash(_, _) | CometPartitioning::RoundRobin(_, _) => { + vec![0u32; batch_size] + } + _ => vec![], + }; + + let memory_limit = match runtime.memory_pool.memory_limit() { + MemoryLimit::Finite(pool_size) => pool_size, + _ => usize::MAX, + }; + + let reservation = MemoryConsumer::new(format!("ImmediateModePartitioner[{partition}]")) + .with_can_spill(true) + .register(&runtime.memory_pool); + + let partition_row_indices = (0..num_output_partitions).map(|_| Vec::new()).collect(); + + Ok(Self { + output_data_file, + output_index_file, + partition_buffers, + streams, + spill_files, + partitioning, + runtime, + reservation, + metrics, + hashes_buf, + partition_ids: vec![0u32; batch_size], + partition_row_indices, + memory_limit, + }) + } + + fn compute_partition_ids(&mut self, batch: &RecordBatch) -> Result { + let num_rows = batch.num_rows(); + + // Ensure scratch buffers are large enough for this batch + if self.hashes_buf.len() < num_rows { + self.hashes_buf.resize(num_rows, 0); + } + if self.partition_ids.len() < num_rows { + self.partition_ids.resize(num_rows, 0); + } + + match &self.partitioning { + CometPartitioning::Hash(exprs, num_output_partitions) => { + let num_output_partitions = *num_output_partitions; + let arrays = exprs + .iter() + .map(|expr| expr.evaluate(batch)?.into_array(num_rows)) + .collect::>>()?; + let hashes_buf = &mut self.hashes_buf[..num_rows]; + hashes_buf.fill(42_u32); + create_murmur3_hashes(&arrays, hashes_buf)?; + let partition_ids = &mut self.partition_ids[..num_rows]; + for (idx, hash) in hashes_buf.iter().enumerate() { + partition_ids[idx] = + comet_partitioning::pmod(*hash, num_output_partitions) as u32; + } + Ok(num_output_partitions) + } + CometPartitioning::RoundRobin(num_output_partitions, max_hash_columns) => { + let num_output_partitions = *num_output_partitions; + let max_hash_columns = *max_hash_columns; + let num_columns_to_hash = if max_hash_columns == 0 { + batch.num_columns() + } else { + max_hash_columns.min(batch.num_columns()) + }; + let columns_to_hash: Vec = (0..num_columns_to_hash) + .map(|i| Arc::clone(batch.column(i))) + .collect(); + let hashes_buf = &mut self.hashes_buf[..num_rows]; + hashes_buf.fill(42_u32); + create_murmur3_hashes(&columns_to_hash, hashes_buf)?; + let partition_ids = &mut self.partition_ids[..num_rows]; + for (idx, hash) in hashes_buf.iter().enumerate() { + partition_ids[idx] = + comet_partitioning::pmod(*hash, num_output_partitions) as u32; + } + Ok(num_output_partitions) + } + CometPartitioning::RangePartitioning( + lex_ordering, + num_output_partitions, + row_converter, + bounds, + ) => { + let num_output_partitions = *num_output_partitions; + let arrays = lex_ordering + .iter() + .map(|expr| expr.expr.evaluate(batch)?.into_array(num_rows)) + .collect::>>()?; + let row_batch = row_converter.convert_columns(arrays.as_slice())?; + let partition_ids = &mut self.partition_ids[..num_rows]; + for (row_idx, row) in row_batch.iter().enumerate() { + partition_ids[row_idx] = bounds + .as_slice() + .partition_point(|bound| bound.row() <= row) + as u32; + } + Ok(num_output_partitions) + } + other => Err(DataFusionError::NotImplemented(format!( + "Unsupported shuffle partitioning scheme {other:?}" + ))), + } + } + + /// Scatter-write rows from batch into per-partition builders, flushing + /// any partition that reaches target_batch_size. Returns + /// `(flushed_builder_bytes, ipc_bytes_written)`. + /// + /// Uses column-first iteration so each column's type dispatch happens once + /// per batch (num_columns times) rather than once per partition per column + /// (num_columns × num_partitions times). + fn repartition_batch(&mut self, batch: &RecordBatch) -> Result<(usize, usize)> { + let num_partitions = self.partition_buffers.len(); + let num_rows = batch.num_rows(); + + // Build per-partition row indices, reusing scratch vecs + for indices in self.partition_row_indices.iter_mut() { + indices.clear(); + } + for row_idx in 0..num_rows { + let pid = self.partition_ids[row_idx] as usize; + self.partition_row_indices[pid].push(row_idx); + } + + // Column-first scatter: resolve each column's type once, then + // scatter across all partitions with the same typed path. + for col_idx in 0..batch.num_columns() { + let source = batch.column(col_idx); + for pid in 0..num_partitions { + let indices = &self.partition_row_indices[pid]; + if indices.is_empty() { + continue; + } + match &mut self.partition_buffers[pid].columns[col_idx] { + ColumnBuffer::Builder(builder) => { + scatter_append(builder.as_mut(), source.as_ref(), indices)?; + } + ColumnBuffer::Accumulator(chunks) => { + let idx_array = + UInt32Array::from_iter_values(indices.iter().map(|&i| i as u32)); + let taken = take(source.as_ref(), &idx_array, None) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + chunks.push(taken); + } + } + } + } + + // Update row counts and flush full partitions + let mut flushed_builder_bytes = 0usize; + let mut ipc_bytes = 0usize; + for pid in 0..num_partitions { + let added = self.partition_row_indices[pid].len(); + if added == 0 { + continue; + } + self.partition_buffers[pid].num_rows += added; + if self.partition_buffers[pid].is_full() { + let (builder_bytes, written) = self.flush_partition(pid)?; + flushed_builder_bytes += builder_bytes; + ipc_bytes += written; + } + } + + Ok((flushed_builder_bytes, ipc_bytes)) + } + + /// Flush a partition's builders to the IPC stream in its output stream. + /// Returns `(flushed_batch_memory, ipc_bytes_written)`. + fn flush_partition(&mut self, pid: usize) -> Result<(usize, usize)> { + let output_batch = self.partition_buffers[pid].flush()?; + let batch_mem = output_batch.get_array_memory_size(); + let mut encode_timer = self.metrics.encode_time.timer(); + let ipc_bytes = self.streams[pid].write_batch(&output_batch)?; + encode_timer.stop(); + Ok((batch_mem, ipc_bytes)) + } + + /// Spill all partition IPC buffers to per-partition temp files. + fn spill_all(&mut self) -> Result<()> { + let mut spilled_bytes = 0usize; + + // Flush any partially-filled partition builders + for pid in 0..self.partition_buffers.len() { + if self.partition_buffers[pid].has_data() { + self.flush_partition(pid)?; + } + } + + // Finish current IPC streams and drain buffers to disk + for pid in 0..self.streams.len() { + // Finish the current IPC stream so it can be read back later + self.streams[pid].finish_current_stream()?; + + let buf = self.streams[pid].drain_buffer()?; + if buf.is_empty() { + continue; + } + + if self.spill_files[pid].is_none() { + let temp_file = self + .runtime + .disk_manager + .create_tmp_file(&format!("imm_shuffle_p{pid}"))?; + let path = temp_file.path().to_owned(); + let file = OpenOptions::new().append(true).open(&path).map_err(|e| { + DataFusionError::Execution(format!("Failed to open spill file: {e}")) + })?; + self.spill_files[pid] = Some(SpillFile { + _temp_file: temp_file, + file, + }); + } + + if let Some(spill) = &mut self.spill_files[pid] { + spill.file.write_all(&buf).map_err(|e| { + DataFusionError::Execution(format!("Failed to write spill: {e}")) + })?; + spilled_bytes += buf.len(); + } + } + + for spill in self.spill_files.iter_mut().flatten() { + spill.file.flush()?; + } + + self.reservation.free(); + if spilled_bytes > 0 { + self.metrics.spill_count.add(1); + self.metrics.spilled_bytes.add(spilled_bytes); + } + + Ok(()) + } +} + +#[async_trait::async_trait] +impl ShufflePartitioner for ImmediateModePartitioner { + async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> { + if batch.num_rows() == 0 { + return Ok(()); + } + + let start_time = Instant::now(); + + let batch_mem = batch.get_array_memory_size(); + self.metrics.data_size.add(batch_mem); + self.metrics.baseline.record_output(batch.num_rows()); + + let repart_start = Instant::now(); + self.compute_partition_ids(&batch)?; + self.metrics + .repart_time + .add_duration(repart_start.elapsed()); + + let (flushed_builder_bytes, ipc_growth) = self.repartition_batch(&batch)?; + let builder_growth = batch_mem; + + // Net memory change: data entered builders, some was flushed to IPC + let net_growth = (builder_growth + ipc_growth).saturating_sub(flushed_builder_bytes); + + if net_growth > 0 { + // Use our own memory limit rather than relying solely on the pool, + // since the pool doesn't see builder allocations directly. + if self.reservation.size() + net_growth > self.memory_limit + || self.reservation.try_grow(net_growth).is_err() + { + self.spill_all()?; + } + } + + self.metrics.input_batches.add(1); + self.metrics + .baseline + .elapsed_compute() + .add_duration(start_time.elapsed()); + + Ok(()) + } + + fn shuffle_write(&mut self) -> Result<()> { + let start_time = Instant::now(); + let num_output_partitions = self.streams.len(); + let mut offsets = vec![0i64; num_output_partitions + 1]; + + let mut output_data = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&self.output_data_file) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?; + + #[allow(clippy::needless_range_loop)] + for pid in 0..num_output_partitions { + offsets[pid] = output_data.stream_position()? as i64; + + if let Some(spill) = &self.spill_files[pid] { + let path = spill._temp_file.path().to_owned(); + let spill_reader = File::open(&path).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to open spill file for reading: {e}" + )) + })?; + let mut write_timer = self.metrics.write_time.timer(); + std::io::copy(&mut &spill_reader, &mut output_data)?; + write_timer.stop(); + } + + if self.partition_buffers[pid].has_data() { + self.flush_partition(pid)?; + } + + let buf = self.streams[pid].drain_buffer()?; + if !buf.is_empty() { + let mut write_timer = self.metrics.write_time.timer(); + output_data.write_all(&buf)?; + write_timer.stop(); + } + } + + for spill in self.spill_files.iter_mut() { + *spill = None; + } + + offsets[num_output_partitions] = output_data.stream_position()? as i64; + + let mut write_timer = self.metrics.write_time.timer(); + let mut output_index = BufWriter::new( + File::create(&self.output_index_file) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?, + ); + for offset in &offsets { + output_index.write_all(&offset.to_le_bytes())?; + } + output_index.flush()?; + write_timer.stop(); + + self.metrics + .baseline + .elapsed_compute() + .add_duration(start_time.elapsed()); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::ipc::reader::StreamReader; + use datafusion::execution::memory_pool::GreedyMemoryPool; + use datafusion::execution::runtime_env::RuntimeEnvBuilder; + use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; + + fn make_test_batch(values: &[i32]) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let array = Int32Array::from(values.to_vec()); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() + } + + #[test] + fn test_scatter_append_primitives() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30, 40, 50])); + let mut builder = make_builder(&DataType::Int32, 8); + scatter_append(builder.as_mut(), array.as_ref(), &[0, 2, 4]).unwrap(); + let result = builder.finish(); + let result = result.as_primitive::(); + assert_eq!(result.values().as_ref(), &[10, 30, 50]); + } + + #[test] + fn test_scatter_append_strings() { + let array: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + let mut builder = make_builder(&DataType::Utf8, 4); + scatter_append(builder.as_mut(), array.as_ref(), &[1, 3]).unwrap(); + let result = builder.finish(); + let result = result.as_string::(); + assert_eq!(result.value(0), "b"); + assert_eq!(result.value(1), "d"); + } + + #[test] + fn test_scatter_append_nulls() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + let mut builder = make_builder(&DataType::Int32, 4); + scatter_append(builder.as_mut(), array.as_ref(), &[0, 1, 2]).unwrap(); + let result = builder.finish(); + let result = result.as_primitive::(); + assert!(result.is_valid(0)); + assert!(result.is_null(1)); + assert!(result.is_valid(2)); + } + + #[test] + fn test_partition_buffer_flush_reuse() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let batch = make_test_batch(&[1, 2, 3, 4, 5]); + + let mut buf = PartitionBuffer::new(&schema, 3); + match &mut buf.columns[0] { + ColumnBuffer::Builder(b) => { + scatter_append(b.as_mut(), batch.column(0).as_ref(), &[0, 1, 2]).unwrap() + } + _ => panic!("expected Builder"), + } + buf.num_rows += 3; + assert!(buf.is_full()); + + let flushed = buf.flush().unwrap(); + assert_eq!(flushed.num_rows(), 3); + assert_eq!(buf.num_rows, 0); + + // Builders are reused after flush + match &mut buf.columns[0] { + ColumnBuffer::Builder(b) => { + scatter_append(b.as_mut(), batch.column(0).as_ref(), &[3, 4]).unwrap() + } + _ => panic!("expected Builder"), + } + buf.num_rows += 2; + assert_eq!(buf.num_rows, 2); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_partition_output_stream_write_and_read() { + let batch = make_test_batch(&[1, 2, 3, 4, 5]); + let schema = batch.schema(); + + for codec in [ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Zstd(1), + ] { + let write_options = codec.ipc_write_options().unwrap(); + let mut stream = + PartitionOutputStream::try_new(Arc::clone(&schema), write_options).unwrap(); + stream.write_batch(&batch).unwrap(); + + let buf = stream.drain_buffer().unwrap(); + assert!(!buf.is_empty()); + + // Read back using standard Arrow StreamReader + let mut reader = StreamReader::try_new(&buf[..], None).unwrap(); + let batch2 = reader.next().unwrap().unwrap(); + assert_eq!(batch2.num_rows(), 5); + } + } + + fn make_hash_partitioning(col_name: &str, num_partitions: usize) -> CometPartitioning { + use datafusion::physical_expr::expressions::Column; + let expr: Arc = + Arc::new(Column::new(col_name, 0)); + CometPartitioning::Hash(vec![expr], num_partitions) + } + + #[tokio::test] + async fn test_immediate_mode_partitioner_hash() { + let batch = make_test_batch(&[1, 2, 3, 4, 5, 6, 7, 8]); + let schema = batch.schema(); + let dir = tempfile::tempdir().unwrap(); + let data_path = dir.path().join("data").to_str().unwrap().to_string(); + let index_path = dir.path().join("index").to_str().unwrap().to_string(); + + let metrics = ShufflePartitionerMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let runtime = Arc::new(RuntimeEnvBuilder::new().build().unwrap()); + + let mut partitioner = ImmediateModePartitioner::try_new( + 0, + data_path, + index_path, + schema, + make_hash_partitioning("a", 4), + metrics, + runtime, + 8192, + CompressionCodec::None, + ) + .unwrap(); + + partitioner.insert_batch(batch).await.unwrap(); + + let total_rows: usize = partitioner + .partition_buffers + .iter() + .map(|b| b.num_rows) + .sum(); + assert_eq!(total_rows, 8); + } + + #[tokio::test] + async fn test_immediate_mode_shuffle_write() { + let batch1 = make_test_batch(&[1, 2, 3, 4, 5, 6]); + let batch2 = make_test_batch(&[7, 8, 9, 10, 11, 12]); + let schema = batch1.schema(); + let dir = tempfile::tempdir().unwrap(); + let data_path = dir.path().join("data").to_str().unwrap().to_string(); + let index_path = dir.path().join("index").to_str().unwrap().to_string(); + + let num_partitions = 3; + let metrics = ShufflePartitionerMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let runtime = Arc::new(RuntimeEnvBuilder::new().build().unwrap()); + + let mut partitioner = ImmediateModePartitioner::try_new( + 0, + data_path.clone(), + index_path.clone(), + schema, + make_hash_partitioning("a", num_partitions), + metrics, + runtime, + 8192, + CompressionCodec::None, + ) + .unwrap(); + + partitioner.insert_batch(batch1).await.unwrap(); + partitioner.insert_batch(batch2).await.unwrap(); + partitioner.shuffle_write().unwrap(); + + let index_data = std::fs::read(&index_path).unwrap(); + assert_eq!(index_data.len(), (num_partitions + 1) * 8); + + let first_offset = i64::from_le_bytes(index_data[0..8].try_into().unwrap()); + assert_eq!(first_offset, 0); + + let data_file_size = std::fs::metadata(&data_path).unwrap().len(); + let last_offset = i64::from_le_bytes( + index_data[num_partitions * 8..(num_partitions + 1) * 8] + .try_into() + .unwrap(), + ); + assert_eq!(last_offset as u64, data_file_size); + assert!(data_file_size > 0); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // spill uses std::io::copy which triggers copy_file_range + async fn test_immediate_mode_spill() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let dir = tempfile::tempdir().unwrap(); + let data_path = dir.path().join("data").to_str().unwrap().to_string(); + let index_path = dir.path().join("index").to_str().unwrap().to_string(); + + let num_partitions = 2; + let metrics = ShufflePartitionerMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + + let runtime = Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(GreedyMemoryPool::new(256))) + .build() + .unwrap(), + ); + + let mut partitioner = ImmediateModePartitioner::try_new( + 0, + data_path.clone(), + index_path.clone(), + Arc::clone(&schema), + make_hash_partitioning("a", num_partitions), + metrics, + runtime, + 8192, + CompressionCodec::None, + ) + .unwrap(); + + for i in 0..10 { + let values: Vec = ((i * 10)..((i + 1) * 10)).collect(); + let batch = make_test_batch(&values); + partitioner.insert_batch(batch).await.unwrap(); + } + + partitioner.shuffle_write().unwrap(); + + let index_data = std::fs::read(&index_path).unwrap(); + assert_eq!(index_data.len(), (num_partitions + 1) * 8); + + let data_file_size = std::fs::metadata(&data_path).unwrap().len(); + let last_offset = i64::from_le_bytes( + index_data[num_partitions * 8..(num_partitions + 1) * 8] + .try_into() + .unwrap(), + ); + assert_eq!(last_offset as u64, data_file_size); + assert!(data_file_size > 0); + } + + #[tokio::test] + async fn test_ipc_stream_format_roundtrip() { + let batch = make_test_batch(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let schema = batch.schema(); + let dir = tempfile::tempdir().unwrap(); + let data_path = dir.path().join("data").to_str().unwrap().to_string(); + let index_path = dir.path().join("index").to_str().unwrap().to_string(); + + let num_partitions = 2; + let metrics = ShufflePartitionerMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let runtime = Arc::new(RuntimeEnvBuilder::new().build().unwrap()); + + // Small target to trigger flush during insert + let mut partitioner = ImmediateModePartitioner::try_new( + 0, + data_path.clone(), + index_path.clone(), + Arc::clone(&schema), + make_hash_partitioning("a", num_partitions), + metrics, + runtime, + 4, + CompressionCodec::Lz4Frame, + ) + .unwrap(); + + partitioner.insert_batch(batch).await.unwrap(); + partitioner.shuffle_write().unwrap(); + + let index_data = std::fs::read(&index_path).unwrap(); + let mut offsets = Vec::new(); + for i in 0..=num_partitions { + let offset = i64::from_le_bytes(index_data[i * 8..(i + 1) * 8].try_into().unwrap()); + offsets.push(offset as usize); + } + + let data = std::fs::read(&data_path).unwrap(); + let mut total_rows = 0; + for pid in 0..num_partitions { + let (start, end) = (offsets[pid], offsets[pid + 1]); + if start == end { + continue; + } + // Each partition's data is one or more complete IPC streams. + // Use StreamReader to decode them. + let partition_data = &data[start..end]; + let reader = StreamReader::try_new(partition_data, None).unwrap(); + for batch_result in reader { + let decoded = batch_result.unwrap(); + assert_eq!(decoded.num_columns(), 1); + assert!(decoded.num_rows() > 0); + let col = decoded + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..col.len() { + assert!((1..=10).contains(&col.value(i))); + } + total_rows += decoded.num_rows(); + } + } + assert_eq!(total_rows, 10); + } +} diff --git a/native/shuffle/src/partitioners/mod.rs b/native/shuffle/src/partitioners/mod.rs index 3eedef62c7..a47666b2a7 100644 --- a/native/shuffle/src/partitioners/mod.rs +++ b/native/shuffle/src/partitioners/mod.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +mod immediate_mode; mod multi_partition; mod partitioned_batch_iterator; mod single_partition; mod traits; +pub(crate) use immediate_mode::ImmediateModePartitioner; pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner; pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator; pub(crate) use single_partition::SinglePartitionShufflePartitioner; diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs index 7de9314f54..1801a52d3e 100644 --- a/native/shuffle/src/partitioners/multi_partition.rs +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -21,9 +21,10 @@ use crate::partitioners::partitioned_batch_iterator::{ }; use crate::partitioners::ShufflePartitioner; use crate::writers::{BufBatchWriter, PartitionWriter}; -use crate::{comet_partitioning, CometPartitioning, CompressionCodec, ShuffleBlockWriter}; +use crate::{comet_partitioning, CometPartitioning, CompressionCodec}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; +use arrow::ipc::writer::IpcWriteOptions; use datafusion::common::utils::proxy::VecAllocExt; use datafusion::common::DataFusionError; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -111,7 +112,10 @@ pub(crate) struct MultiPartitionShuffleRepartitioner { buffered_batches: Vec, partition_indices: Vec>, partition_writers: Vec, - shuffle_block_writer: ShuffleBlockWriter, + /// Schema of the input data + schema: SchemaRef, + /// IPC write options (includes compression settings) + write_options: IpcWriteOptions, /// Partitioning scheme to use partitioning: CometPartitioning, runtime: Arc, @@ -123,8 +127,6 @@ pub(crate) struct MultiPartitionShuffleRepartitioner { /// Reservation for repartitioning reservation: MemoryReservation, tracing_enabled: bool, - /// Size of the write buffer in bytes - write_buffer_size: usize, } impl MultiPartitionShuffleRepartitioner { @@ -140,7 +142,7 @@ impl MultiPartitionShuffleRepartitioner { batch_size: usize, codec: CompressionCodec, tracing_enabled: bool, - write_buffer_size: usize, + _write_buffer_size: usize, ) -> datafusion::common::Result { let num_output_partitions = partitioning.partition_count(); assert_ne!( @@ -165,10 +167,10 @@ impl MultiPartitionShuffleRepartitioner { partition_starts: vec![0; num_output_partitions + 1], }; - let shuffle_block_writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone())?; + let write_options = codec.ipc_write_options()?; let partition_writers = (0..num_output_partitions) - .map(|_| PartitionWriter::try_new(shuffle_block_writer.clone())) + .map(|_| PartitionWriter::try_new(Arc::clone(&schema), write_options.clone())) .collect::>>()?; let reservation = MemoryConsumer::new(format!("ShuffleRepartitioner[{partition}]")) @@ -181,7 +183,8 @@ impl MultiPartitionShuffleRepartitioner { buffered_batches: vec![], partition_indices: vec![vec![]; num_output_partitions], partition_writers, - shuffle_block_writer, + schema: Arc::clone(&schema), + write_options, partitioning, runtime, metrics, @@ -189,7 +192,6 @@ impl MultiPartitionShuffleRepartitioner { batch_size, reservation, tracing_enabled, - write_buffer_size, }) } @@ -436,24 +438,23 @@ impl MultiPartitionShuffleRepartitioner { fn shuffle_write_partition( partition_iter: &mut PartitionedBatchIterator, - shuffle_block_writer: &mut ShuffleBlockWriter, + schema: &SchemaRef, + write_options: &IpcWriteOptions, output_data: &mut BufWriter, encode_time: &Time, - write_time: &Time, - write_buffer_size: usize, batch_size: usize, ) -> datafusion::common::Result<()> { - let mut buf_batch_writer = BufBatchWriter::new( - shuffle_block_writer, + let mut buf_batch_writer = BufBatchWriter::try_new( output_data, - write_buffer_size, + Arc::clone(schema), + write_options.clone(), batch_size, - ); + )?; for batch in partition_iter { let batch = batch?; - buf_batch_writer.write(&batch, encode_time, write_time)?; + buf_batch_writer.write(&batch, encode_time)?; } - buf_batch_writer.flush(encode_time, write_time)?; + buf_batch_writer.flush(encode_time)?; Ok(()) } @@ -507,13 +508,7 @@ impl MultiPartitionShuffleRepartitioner { for partition_id in 0..num_output_partitions { let partition_writer = &mut self.partition_writers[partition_id]; let mut iter = partitioned_batches.produce(partition_id); - spilled_bytes += partition_writer.spill( - &mut iter, - &self.runtime, - &self.metrics, - self.write_buffer_size, - self.batch_size, - )?; + spilled_bytes += partition_writer.spill(&mut iter, &self.runtime, &self.metrics)?; } self.reservation.free(); @@ -594,11 +589,10 @@ impl ShufflePartitioner for MultiPartitionShuffleRepartitioner { let mut partition_iter = partitioned_batches.produce(i); Self::shuffle_write_partition( &mut partition_iter, - &mut self.shuffle_block_writer, + &self.schema, + &self.write_options, &mut output_data, &self.metrics.encode_time, - &self.metrics.write_time, - self.write_buffer_size, self.batch_size, )?; } diff --git a/native/shuffle/src/partitioners/single_partition.rs b/native/shuffle/src/partitioners/single_partition.rs index 5801ef613b..d487d9310b 100644 --- a/native/shuffle/src/partitioners/single_partition.rs +++ b/native/shuffle/src/partitioners/single_partition.rs @@ -18,7 +18,7 @@ use crate::metrics::ShufflePartitionerMetrics; use crate::partitioners::ShufflePartitioner; use crate::writers::BufBatchWriter; -use crate::{CompressionCodec, ShuffleBlockWriter}; +use crate::CompressionCodec; use arrow::array::RecordBatch; use arrow::datatypes::SchemaRef; use datafusion::common::DataFusionError; @@ -26,19 +26,15 @@ use std::fs::{File, OpenOptions}; use std::io::{BufWriter, Write}; use tokio::time::Instant; -/// A partitioner that writes all shuffle data to a single file and a single index file +/// A partitioner that writes all shuffle data to a single file and a single index file. +/// Uses a persistent Arrow IPC StreamWriter via BufBatchWriter, so the schema is written +/// once and batches are appended with built-in body compression. pub(crate) struct SinglePartitionShufflePartitioner { - // output_data_file: File, - output_data_writer: BufBatchWriter, + output_data_writer: BufBatchWriter, + output_data_path: String, output_index_path: String, - /// Batches that are smaller than the batch size and to be concatenated - buffered_batches: Vec, - /// Number of rows in the concatenating batches - num_buffered_rows: usize, /// Metrics for the repartitioner metrics: ShufflePartitionerMetrics, - /// The configured batch size - batch_size: usize, } impl SinglePartitionShufflePartitioner { @@ -49,63 +45,26 @@ impl SinglePartitionShufflePartitioner { metrics: ShufflePartitionerMetrics, batch_size: usize, codec: CompressionCodec, - write_buffer_size: usize, + _write_buffer_size: usize, ) -> datafusion::common::Result { - let shuffle_block_writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone())?; + let write_options = codec.ipc_write_options()?; let output_data_file = OpenOptions::new() .write(true) .create(true) .truncate(true) - .open(output_data_path)?; + .open(&output_data_path)?; - let output_data_writer = BufBatchWriter::new( - shuffle_block_writer, - output_data_file, - write_buffer_size, - batch_size, - ); + let output_data_writer = + BufBatchWriter::try_new(output_data_file, schema, write_options, batch_size)?; Ok(Self { output_data_writer, + output_data_path, output_index_path, - buffered_batches: vec![], - num_buffered_rows: 0, metrics, - batch_size, }) } - - /// Add a batch to the buffer of the partitioner, these buffered batches will be concatenated - /// and written to the output data file when the number of rows in the buffer reaches the batch size. - fn add_buffered_batch(&mut self, batch: RecordBatch) { - self.num_buffered_rows += batch.num_rows(); - self.buffered_batches.push(batch); - } - - /// Consumes buffered batches and return a concatenated batch if successful - fn concat_buffered_batches(&mut self) -> datafusion::common::Result> { - if self.buffered_batches.is_empty() { - Ok(None) - } else if self.buffered_batches.len() == 1 { - let batch = self.buffered_batches.remove(0); - self.num_buffered_rows = 0; - Ok(Some(batch)) - } else { - let schema = &self.buffered_batches[0].schema(); - match arrow::compute::concat_batches(schema, self.buffered_batches.iter()) { - Ok(concatenated) => { - self.buffered_batches.clear(); - self.num_buffered_rows = 0; - Ok(Some(concatenated)) - } - Err(e) => Err(DataFusionError::ArrowError( - Box::from(e), - Some(DataFusionError::get_back_trace()), - )), - } - } - } } #[async_trait::async_trait] @@ -118,32 +77,8 @@ impl ShufflePartitioner for SinglePartitionShufflePartitioner { self.metrics.data_size.add(batch.get_array_memory_size()); self.metrics.baseline.record_output(num_rows); - if num_rows >= self.batch_size || num_rows + self.num_buffered_rows > self.batch_size { - let concatenated_batch = self.concat_buffered_batches()?; - - // Write the concatenated buffered batch - if let Some(batch) = concatenated_batch { - self.output_data_writer.write( - &batch, - &self.metrics.encode_time, - &self.metrics.write_time, - )?; - } - - if num_rows >= self.batch_size { - // Write the new batch - self.output_data_writer.write( - &batch, - &self.metrics.encode_time, - &self.metrics.write_time, - )?; - } else { - // Add the new batch to the buffer - self.add_buffered_batch(batch); - } - } else { - self.add_buffered_batch(batch); - } + self.output_data_writer + .write(&batch, &self.metrics.encode_time)?; } self.metrics.input_batches.add(1); @@ -156,29 +91,23 @@ impl ShufflePartitioner for SinglePartitionShufflePartitioner { fn shuffle_write(&mut self) -> datafusion::common::Result<()> { let start_time = Instant::now(); - let concatenated_batch = self.concat_buffered_batches()?; - // Write the concatenated buffered batch - if let Some(batch) = concatenated_batch { - self.output_data_writer.write( - &batch, - &self.metrics.encode_time, - &self.metrics.write_time, - )?; - } - self.output_data_writer - .flush(&self.metrics.encode_time, &self.metrics.write_time)?; + self.output_data_writer.flush(&self.metrics.encode_time)?; + + // Get data file length via filesystem metadata + let data_file_length = std::fs::metadata(&self.output_data_path) + .map(|m| m.len()) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?; // Write index file. It should only contain 2 entries: 0 and the total number of bytes written let index_file = OpenOptions::new() .write(true) .create(true) .truncate(true) - .open(self.output_index_path.clone()) + .open(&self.output_index_path) .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?; let mut index_buf_writer = BufWriter::new(index_file); - let data_file_length = self.output_data_writer.writer_stream_position()?; - for offset in [0, data_file_length] { + for offset in [0u64, data_file_length] { index_buf_writer.write_all(&(offset as i64).to_le_bytes()[..])?; } index_buf_writer.flush()?; diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index e649aaac69..63458bfeb1 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -19,7 +19,8 @@ use crate::metrics::ShufflePartitionerMetrics; use crate::partitioners::{ - MultiPartitionShuffleRepartitioner, ShufflePartitioner, SinglePartitionShufflePartitioner, + ImmediateModePartitioner, MultiPartitionShuffleRepartitioner, ShufflePartitioner, + SinglePartitionShufflePartitioner, }; use crate::{CometPartitioning, CompressionCodec}; use async_trait::async_trait; @@ -68,6 +69,8 @@ pub struct ShuffleWriterExec { tracing_enabled: bool, /// Size of the write buffer in bytes write_buffer_size: usize, + /// When true, use ImmediateModePartitioner instead of MultiPartitionShuffleRepartitioner + immediate_mode: bool, } impl ShuffleWriterExec { @@ -81,6 +84,7 @@ impl ShuffleWriterExec { output_index_file: String, tracing_enabled: bool, write_buffer_size: usize, + immediate_mode: bool, ) -> Result { let cache = PlanProperties::new( EquivalenceProperties::new(Arc::clone(&input.schema())), @@ -99,6 +103,7 @@ impl ShuffleWriterExec { codec, tracing_enabled, write_buffer_size, + immediate_mode, }) } } @@ -163,6 +168,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.output_index_file.clone(), self.tracing_enabled, self.write_buffer_size, + self.immediate_mode, )?)), _ => panic!("ShuffleWriterExec wrong number of children"), } @@ -190,6 +196,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.codec.clone(), self.tracing_enabled, self.write_buffer_size, + self.immediate_mode, ) .map_err(|e| ArrowError::ExternalError(Box::new(e))), ) @@ -210,6 +217,7 @@ async fn external_shuffle( codec: CompressionCodec, tracing_enabled: bool, write_buffer_size: usize, + immediate_mode: bool, ) -> Result { with_trace_async("external_shuffle", tracing_enabled, || async { let schema = input.schema(); @@ -226,6 +234,17 @@ async fn external_shuffle( write_buffer_size, )?) } + _ if immediate_mode => Box::new(ImmediateModePartitioner::try_new( + partition, + output_data_file, + output_index_file, + Arc::clone(&schema), + partitioning, + metrics, + context.runtime_env(), + context.session_config().batch_size(), + codec, + )?), _ => Box::new(MultiPartitionShuffleRepartitioner::try_new( partition, output_data_file, @@ -265,9 +284,9 @@ async fn external_shuffle( #[cfg(test)] mod test { use super::*; - use crate::{read_ipc_compressed, ShuffleBlockWriter}; use arrow::array::{Array, StringArray, StringBuilder}; use arrow::datatypes::{DataType, Field, Schema}; + use arrow::ipc::reader::StreamReader; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use datafusion::datasource::memory::MemorySourceConfig; @@ -280,30 +299,36 @@ mod test { use datafusion::physical_plan::metrics::Time; use datafusion::prelude::SessionContext; use itertools::Itertools; - use std::io::Cursor; use tokio::runtime::Runtime; #[test] #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` fn roundtrip_ipc() { + use crate::writers::BufBatchWriter; + let batch = create_batch(8192); for codec in &[ CompressionCodec::None, CompressionCodec::Zstd(1), - CompressionCodec::Snappy, CompressionCodec::Lz4Frame, ] { - let mut output = vec![]; - let mut cursor = Cursor::new(&mut output); - let writer = - ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone()).unwrap(); - let length = writer - .write_batch(&batch, &mut cursor, &Time::default()) - .unwrap(); - assert_eq!(length, output.len()); + let write_options = codec.ipc_write_options().unwrap(); + let mut output = Vec::new(); + let encode_time = Time::default(); + + { + let mut writer = + BufBatchWriter::try_new(&mut output, batch.schema(), write_options, 8192) + .unwrap(); + writer.write(&batch, &encode_time).unwrap(); + writer.flush(&encode_time).unwrap(); + } - let ipc_without_length_prefix = &output[16..]; - let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); + assert!(!output.is_empty()); + + // Read back using standard Arrow StreamReader + let mut reader = StreamReader::try_new(&output[..], None).unwrap(); + let batch2 = reader.next().unwrap().unwrap(); assert_eq!(batch, batch2); } } @@ -466,6 +491,7 @@ mod test { "/tmp/index.out".to_string(), false, 1024 * 1024, // write_buffer_size: 1MB default + false, // immediate_mode ) .unwrap(); @@ -525,6 +551,7 @@ mod test { index_file.clone(), false, 1024 * 1024, + false, // immediate_mode ) .unwrap(); @@ -586,15 +613,17 @@ mod test { let _ = fs::remove_file("/tmp/rr_index_1.out"); } - /// Test that batch coalescing in BufBatchWriter reduces output size by - /// writing fewer, larger IPC blocks instead of many small ones. + /// Test that batch coalescing in BufBatchWriter produces correct output. + /// With the new persistent StreamWriter format, schema is written once per stream + /// regardless of coalescing, but coalescing still reduces the number of record batch + /// messages in the stream. #[test] #[cfg_attr(miri, ignore)] - fn test_batch_coalescing_reduces_size() { + fn test_batch_coalescing_correct_output() { use crate::writers::BufBatchWriter; use arrow::array::Int32Array; - // Create a wide schema to amplify per-block schema overhead + // Create a wide schema to amplify per-batch message overhead let fields: Vec = (0..20) .map(|i| Field::new(format!("col_{i}"), DataType::Int32, false)) .collect(); @@ -616,52 +645,44 @@ mod test { .collect(); let codec = CompressionCodec::Lz4Frame; + let write_options = codec.ipc_write_options().unwrap(); let encode_time = Time::default(); - let write_time = Time::default(); // Write with coalescing (batch_size=8192) let mut coalesced_output = Vec::new(); { - let mut writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); - let mut buf_writer = BufBatchWriter::new( - &mut writer, - Cursor::new(&mut coalesced_output), - 1024 * 1024, + let mut buf_writer = BufBatchWriter::try_new( + &mut coalesced_output, + Arc::clone(&schema), + write_options.clone(), 8192, - ); + ) + .unwrap(); for batch in &small_batches { - buf_writer.write(batch, &encode_time, &write_time).unwrap(); + buf_writer.write(batch, &encode_time).unwrap(); } - buf_writer.flush(&encode_time, &write_time).unwrap(); + buf_writer.flush(&encode_time).unwrap(); } // Write without coalescing (batch_size=1) let mut uncoalesced_output = Vec::new(); { - let mut writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); - let mut buf_writer = BufBatchWriter::new( - &mut writer, - Cursor::new(&mut uncoalesced_output), - 1024 * 1024, + let mut buf_writer = BufBatchWriter::try_new( + &mut uncoalesced_output, + Arc::clone(&schema), + write_options, 1, - ); + ) + .unwrap(); for batch in &small_batches { - buf_writer.write(batch, &encode_time, &write_time).unwrap(); + buf_writer.write(batch, &encode_time).unwrap(); } - buf_writer.flush(&encode_time, &write_time).unwrap(); + buf_writer.flush(&encode_time).unwrap(); } - // Coalesced output should be smaller due to fewer IPC schema blocks - assert!( - coalesced_output.len() < uncoalesced_output.len(), - "Coalesced output ({} bytes) should be smaller than uncoalesced ({} bytes)", - coalesced_output.len(), - uncoalesced_output.len() - ); - - // Verify both roundtrip correctly by reading all IPC blocks - let coalesced_rows = read_all_ipc_blocks(&coalesced_output); - let uncoalesced_rows = read_all_ipc_blocks(&uncoalesced_output); + // Verify both roundtrip correctly by reading all batches via StreamReader + let coalesced_rows = read_all_ipc_stream_rows(&coalesced_output); + let uncoalesced_rows = read_all_ipc_stream_rows(&uncoalesced_output); assert_eq!( coalesced_rows, 5000, "Coalesced should contain all 5000 rows" @@ -672,24 +693,12 @@ mod test { ); } - /// Read all IPC blocks from a byte buffer written by BufBatchWriter/ShuffleBlockWriter, - /// returning the total number of rows. - fn read_all_ipc_blocks(data: &[u8]) -> usize { - let mut offset = 0; + /// Read all record batches from an Arrow IPC stream, returning total row count. + fn read_all_ipc_stream_rows(data: &[u8]) -> usize { + let reader = StreamReader::try_new(data, None).unwrap(); let mut total_rows = 0; - while offset < data.len() { - // First 8 bytes are the IPC length (little-endian u64) - let ipc_length = - u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // Skip the 8-byte length prefix; the next 8 bytes are field_count + codec header - let block_start = offset + 8; - let block_end = block_start + ipc_length; - // read_ipc_compressed expects data starting after the 16-byte header - // (i.e., after length + field_count), at the codec tag - let ipc_data = &data[block_start + 8..block_end]; - let batch = read_ipc_compressed(ipc_data).unwrap(); - total_rows += batch.num_rows(); - offset = block_end; + for batch in reader { + total_rows += batch.unwrap().num_rows(); } total_rows } diff --git a/native/shuffle/src/spark_unsafe/row.rs b/native/shuffle/src/spark_unsafe/row.rs index 3c98677199..0accf61daf 100644 --- a/native/shuffle/src/spark_unsafe/row.rs +++ b/native/shuffle/src/spark_unsafe/row.rs @@ -23,7 +23,7 @@ use crate::spark_unsafe::{ map::{append_map_elements, get_map_key_value_fields}, }; use crate::writers::Checksum; -use crate::writers::ShuffleBlockWriter; +use crate::CompressionCodec; use arrow::array::{ builder::{ ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder, @@ -37,7 +37,6 @@ use arrow::array::{ use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use arrow::error::ArrowError; -use datafusion::physical_plan::metrics::Time; use datafusion_comet_jni_bridge::errors::CometError; use jni::sys::{jint, jlong}; use std::{ @@ -197,7 +196,6 @@ macro_rules! get_field_builder { } // Expose the macro for other modules. -use crate::CompressionCodec; pub(crate) use downcast_builder_ref; /// Appends field of row to the given struct builder. `dt` is the data type of the field. @@ -1313,8 +1311,6 @@ pub fn process_sorted_row_partition( ) -> Result<(i64, Option), CometError> { // The current row number we are reading let mut current_row = 0; - // Total number of bytes written - let mut written = 0; // The current checksum value. This is updated incrementally in the following loop. let mut current_checksum = if checksum_enabled { Some(Checksum::try_new(checksum_algo, initial_checksum)?) @@ -1337,9 +1333,14 @@ pub fn process_sorted_row_partition( .append(true) .open(&output_path)?; - // Reusable buffer for serialized batch data + // Buffer that accumulates all IPC bytes across the single stream let mut frozen: Vec = Vec::new(); + // Build a schema from the first batch's datatypes so we can create the StreamWriter + // up front. We need a placeholder schema; we'll create it from the first batch. + let mut stream_writer: Option>> = None; + let write_options = codec.ipc_write_options()?; + while current_row < row_num { let n = std::cmp::min(batch_size, row_num - current_row); @@ -1368,22 +1369,33 @@ pub fn process_sorted_row_partition( .collect(); let batch = make_batch(array_refs?, n)?; - frozen.clear(); - let mut cursor = Cursor::new(&mut frozen); - - // we do not collect metrics in Native_writeSortedFileNative - let ipc_time = Time::default(); - let block_writer = ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone())?; - written += block_writer.write_batch(&batch, &mut cursor, &ipc_time)?; - - if let Some(checksum) = &mut current_checksum { - checksum.update(&mut cursor)?; + // Create the StreamWriter on the first batch (we need the schema) + if stream_writer.is_none() { + stream_writer = Some(arrow::ipc::writer::StreamWriter::try_new_with_options( + &mut frozen, + &batch.schema(), + write_options.clone(), + )?); } - output_data.write_all(&frozen)?; + stream_writer.as_mut().unwrap().write(&batch)?; current_row += n; } + // Finish the IPC stream and flush remaining bytes + if let Some(mut writer) = stream_writer { + writer.finish()?; + } + + let written = frozen.len(); + + if let Some(checksum) = &mut current_checksum { + let mut cursor = Cursor::new(&mut frozen); + checksum.update(&mut cursor)?; + } + + output_data.write_all(&frozen)?; + Ok((written as i64, current_checksum.map(|c| c.finalize()))) } diff --git a/native/shuffle/src/writers/buf_batch_writer.rs b/native/shuffle/src/writers/buf_batch_writer.rs index cfddb46539..0ca1b9a1d7 100644 --- a/native/shuffle/src/writers/buf_batch_writer.rs +++ b/native/shuffle/src/writers/buf_batch_writer.rs @@ -15,128 +15,65 @@ // specific language governing permissions and limitations // under the License. -use super::ShuffleBlockWriter; use arrow::array::RecordBatch; use arrow::compute::kernels::coalesce::BatchCoalescer; +use arrow::datatypes::SchemaRef; +use arrow::ipc::writer::{IpcWriteOptions, StreamWriter}; use datafusion::physical_plan::metrics::Time; -use std::borrow::Borrow; -use std::io::{Cursor, Seek, SeekFrom, Write}; +use std::io::Write; -/// Write batches to writer while using a buffer to avoid frequent system calls. -/// The record batches were first written by ShuffleBlockWriter into an internal buffer. -/// Once the buffer exceeds the max size, the buffer will be flushed to the writer. -/// -/// Small batches are coalesced using Arrow's [`BatchCoalescer`] before serialization, -/// producing exactly `batch_size`-row output batches to reduce per-block IPC schema overhead. -/// The coalescer is lazily initialized on the first write. -pub(crate) struct BufBatchWriter, W: Write> { - shuffle_block_writer: S, - writer: W, - buffer: Vec, - buffer_max_size: usize, +/// Writes batches to a persistent Arrow IPC `StreamWriter`. The schema is written once +/// when the writer is created. Small batches are coalesced via [`BatchCoalescer`] before +/// serialization, producing `batch_size`-row output batches. +pub(crate) struct BufBatchWriter { + writer: StreamWriter, /// Coalesces small batches into target_batch_size before serialization. - /// Lazily initialized on first write to capture the schema. - coalescer: Option, - /// Target batch size for coalescing - batch_size: usize, + coalescer: BatchCoalescer, } -impl, W: Write> BufBatchWriter { - pub(crate) fn new( - shuffle_block_writer: S, - writer: W, - buffer_max_size: usize, +impl BufBatchWriter { + pub(crate) fn try_new( + target: W, + schema: SchemaRef, + write_options: IpcWriteOptions, batch_size: usize, - ) -> Self { - Self { - shuffle_block_writer, - writer, - buffer: vec![], - buffer_max_size, - coalescer: None, - batch_size, - } + ) -> datafusion::common::Result { + let writer = StreamWriter::try_new_with_options(target, &schema, write_options)?; + let coalescer = BatchCoalescer::new(schema, batch_size); + Ok(Self { writer, coalescer }) } pub(crate) fn write( &mut self, batch: &RecordBatch, encode_time: &Time, - write_time: &Time, - ) -> datafusion::common::Result { - let coalescer = self - .coalescer - .get_or_insert_with(|| BatchCoalescer::new(batch.schema(), self.batch_size)); - coalescer.push_batch(batch.clone())?; + ) -> datafusion::common::Result<()> { + self.coalescer.push_batch(batch.clone())?; - // Drain completed batches into a local vec so the coalescer borrow ends - // before we call write_batch_to_buffer (which borrows &mut self). let mut completed = Vec::new(); - while let Some(batch) = coalescer.next_completed_batch() { + while let Some(batch) = self.coalescer.next_completed_batch() { completed.push(batch); } - let mut bytes_written = 0; for batch in &completed { - bytes_written += self.write_batch_to_buffer(batch, encode_time, write_time)?; + let mut timer = encode_time.timer(); + self.writer.write(batch)?; + timer.stop(); } - Ok(bytes_written) - } - - /// Serialize a single batch into the byte buffer, flushing to the writer if needed. - fn write_batch_to_buffer( - &mut self, - batch: &RecordBatch, - encode_time: &Time, - write_time: &Time, - ) -> datafusion::common::Result { - let mut cursor = Cursor::new(&mut self.buffer); - cursor.seek(SeekFrom::End(0))?; - let bytes_written = - self.shuffle_block_writer - .borrow() - .write_batch(batch, &mut cursor, encode_time)?; - let pos = cursor.position(); - if pos >= self.buffer_max_size as u64 { - let mut write_timer = write_time.timer(); - self.writer.write_all(&self.buffer)?; - write_timer.stop(); - self.buffer.clear(); - } - Ok(bytes_written) + Ok(()) } - pub(crate) fn flush( - &mut self, - encode_time: &Time, - write_time: &Time, - ) -> datafusion::common::Result<()> { + pub(crate) fn flush(&mut self, encode_time: &Time) -> datafusion::common::Result<()> { // Finish any remaining buffered rows in the coalescer - let mut remaining = Vec::new(); - if let Some(coalescer) = &mut self.coalescer { - coalescer.finish_buffered_batch()?; - while let Some(batch) = coalescer.next_completed_batch() { - remaining.push(batch); - } - } - for batch in &remaining { - self.write_batch_to_buffer(batch, encode_time, write_time)?; + self.coalescer.finish_buffered_batch()?; + while let Some(batch) = self.coalescer.next_completed_batch() { + let mut timer = encode_time.timer(); + self.writer.write(&batch)?; + timer.stop(); } - // Flush the byte buffer to the underlying writer - let mut write_timer = write_time.timer(); - if !self.buffer.is_empty() { - self.writer.write_all(&self.buffer)?; - } - self.writer.flush()?; - write_timer.stop(); - self.buffer.clear(); + // Finish the IPC stream (writes the end-of-stream marker) + self.writer.finish()?; Ok(()) } } - -impl, W: Write + Seek> BufBatchWriter { - pub(crate) fn writer_stream_position(&mut self) -> datafusion::common::Result { - self.writer.stream_position().map_err(Into::into) - } -} diff --git a/native/shuffle/src/writers/codec.rs b/native/shuffle/src/writers/codec.rs new file mode 100644 index 0000000000..5e6dc88772 --- /dev/null +++ b/native/shuffle/src/writers/codec.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::ipc::writer::IpcWriteOptions; +use arrow::ipc::CompressionType; + +/// Compression algorithm applied to shuffle IPC streams and Parquet output. +#[derive(Debug, Clone)] +pub enum CompressionCodec { + None, + Lz4Frame, + Zstd(i32), + /// Snappy is only used for Parquet output, not for shuffle IPC. + Snappy, +} + +impl CompressionCodec { + pub fn ipc_write_options(&self) -> datafusion::error::Result { + let compression = match self { + CompressionCodec::None => None, + CompressionCodec::Lz4Frame => Some(CompressionType::LZ4_FRAME), + CompressionCodec::Zstd(_) => Some(CompressionType::ZSTD), + CompressionCodec::Snappy => { + return Err(datafusion::common::DataFusionError::Execution( + "Snappy is not supported for Arrow IPC compression".to_string(), + )); + } + }; + let options = IpcWriteOptions::try_new(8, false, arrow::ipc::MetadataVersion::V5) + .map_err(|e| datafusion::common::DataFusionError::ArrowError(Box::from(e), None))?; + options + .try_with_compression(compression) + .map_err(|e| datafusion::common::DataFusionError::ArrowError(Box::from(e), None)) + } +} diff --git a/native/shuffle/src/writers/mod.rs b/native/shuffle/src/writers/mod.rs index 75caf9f3a3..ed57562856 100644 --- a/native/shuffle/src/writers/mod.rs +++ b/native/shuffle/src/writers/mod.rs @@ -17,10 +17,10 @@ mod buf_batch_writer; mod checksum; -mod shuffle_block_writer; +mod codec; mod spill; pub(crate) use buf_batch_writer::BufBatchWriter; pub(crate) use checksum::Checksum; -pub use shuffle_block_writer::{CompressionCodec, ShuffleBlockWriter}; +pub use codec::CompressionCodec; pub(crate) use spill::PartitionWriter; diff --git a/native/shuffle/src/writers/shuffle_block_writer.rs b/native/shuffle/src/writers/shuffle_block_writer.rs deleted file mode 100644 index 5ed5330e3a..0000000000 --- a/native/shuffle/src/writers/shuffle_block_writer.rs +++ /dev/null @@ -1,146 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::RecordBatch; -use arrow::datatypes::Schema; -use arrow::ipc::writer::StreamWriter; -use datafusion::common::DataFusionError; -use datafusion::error::Result; -use datafusion::physical_plan::metrics::Time; -use std::io::{Cursor, Seek, SeekFrom, Write}; - -/// Compression algorithm applied to shuffle IPC blocks. -#[derive(Debug, Clone)] -pub enum CompressionCodec { - None, - Lz4Frame, - Zstd(i32), - Snappy, -} - -/// Writes a record batch as a length-prefixed, compressed Arrow IPC block. -#[derive(Clone)] -pub struct ShuffleBlockWriter { - codec: CompressionCodec, - header_bytes: Vec, -} - -impl ShuffleBlockWriter { - pub fn try_new(schema: &Schema, codec: CompressionCodec) -> Result { - let header_bytes = Vec::with_capacity(20); - let mut cursor = Cursor::new(header_bytes); - - // leave space for compressed message length - cursor.seek_relative(8)?; - - // write number of columns because JVM side needs to know how many addresses to allocate - let field_count = schema.fields().len(); - cursor.write_all(&field_count.to_le_bytes())?; - - // write compression codec to header - let codec_header = match &codec { - CompressionCodec::Snappy => b"SNAP", - CompressionCodec::Lz4Frame => b"LZ4_", - CompressionCodec::Zstd(_) => b"ZSTD", - CompressionCodec::None => b"NONE", - }; - cursor.write_all(codec_header)?; - - let header_bytes = cursor.into_inner(); - - Ok(Self { - codec, - header_bytes, - }) - } - - /// Writes given record batch as Arrow IPC bytes into given writer. - /// Returns number of bytes written. - pub fn write_batch( - &self, - batch: &RecordBatch, - output: &mut W, - ipc_time: &Time, - ) -> Result { - if batch.num_rows() == 0 { - return Ok(0); - } - - let mut timer = ipc_time.timer(); - let start_pos = output.stream_position()?; - - // write header - output.write_all(&self.header_bytes)?; - - let output = match &self.codec { - CompressionCodec::None => { - let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - arrow_writer.into_inner()? - } - CompressionCodec::Lz4Frame => { - let mut wtr = lz4_flex::frame::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - wtr.finish().map_err(|e| { - DataFusionError::Execution(format!("lz4 compression error: {e}")) - })? - } - - CompressionCodec::Zstd(level) => { - let encoder = zstd::Encoder::new(output, *level)?; - let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - let zstd_encoder = arrow_writer.into_inner()?; - zstd_encoder.finish()? - } - - CompressionCodec::Snappy => { - let mut wtr = snap::write::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - wtr.into_inner().map_err(|e| { - DataFusionError::Execution(format!("snappy compression error: {e}")) - })? - } - }; - - // fill ipc length - let end_pos = output.stream_position()?; - let ipc_length = end_pos - start_pos - 8; - let max_size = i32::MAX as u64; - if ipc_length > max_size { - return Err(DataFusionError::Execution(format!( - "Shuffle block size {ipc_length} exceeds maximum size of {max_size}. \ - Try reducing batch size or increasing compression level" - ))); - } - - // fill ipc length - output.seek(SeekFrom::Start(start_pos))?; - output.write_all(&ipc_length.to_le_bytes())?; - output.seek(SeekFrom::Start(end_pos))?; - - timer.stop(); - - Ok((end_pos - start_pos) as usize) - } -} diff --git a/native/shuffle/src/writers/spill.rs b/native/shuffle/src/writers/spill.rs index c16caddbf9..c6feb34764 100644 --- a/native/shuffle/src/writers/spill.rs +++ b/native/shuffle/src/writers/spill.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use super::ShuffleBlockWriter; use crate::metrics::ShufflePartitionerMetrics; use crate::partitioners::PartitionedBatchIterator; -use crate::writers::buf_batch_writer::BufBatchWriter; +use arrow::datatypes::SchemaRef; +use arrow::ipc::writer::{IpcWriteOptions, StreamWriter}; use datafusion::common::DataFusionError; use datafusion::execution::disk_manager::RefCountedTempFile; use datafusion::execution::runtime_env::RuntimeEnv; @@ -36,17 +36,21 @@ pub(crate) struct PartitionWriter { /// will append to this file and the contents will be copied to the shuffle file at /// the end of processing. spill_file: Option, - /// Writer that performs encoding and compression - shuffle_block_writer: ShuffleBlockWriter, + /// Schema used for creating IPC stream writers + schema: SchemaRef, + /// IPC write options (includes compression settings) + write_options: IpcWriteOptions, } impl PartitionWriter { pub(crate) fn try_new( - shuffle_block_writer: ShuffleBlockWriter, + schema: SchemaRef, + write_options: IpcWriteOptions, ) -> datafusion::common::Result { Ok(Self { spill_file: None, - shuffle_block_writer, + schema, + write_options, }) } @@ -80,34 +84,42 @@ impl PartitionWriter { iter: &mut PartitionedBatchIterator, runtime: &RuntimeEnv, metrics: &ShufflePartitionerMetrics, - write_buffer_size: usize, - batch_size: usize, ) -> datafusion::common::Result { if let Some(batch) = iter.next() { self.ensure_spill_file_created(runtime)?; - let total_bytes_written = { - let mut buf_batch_writer = BufBatchWriter::new( - &mut self.shuffle_block_writer, - &mut self.spill_file.as_mut().unwrap().file, - write_buffer_size, - batch_size, - ); - let mut bytes_written = - buf_batch_writer.write(&batch?, &metrics.encode_time, &metrics.write_time)?; - for batch in iter { - let batch = batch?; - bytes_written += buf_batch_writer.write( - &batch, - &metrics.encode_time, - &metrics.write_time, - )?; - } - buf_batch_writer.flush(&metrics.encode_time, &metrics.write_time)?; - bytes_written - }; + let file = &mut self.spill_file.as_mut().unwrap().file; + let start_pos = file.metadata().map(|m| m.len()).unwrap_or(0); - Ok(total_bytes_written) + let mut writer = + StreamWriter::try_new_with_options(file, &self.schema, self.write_options.clone())?; + + let batch = batch?; + let mut encode_timer = metrics.encode_time.timer(); + writer.write(&batch)?; + encode_timer.stop(); + + for batch in iter { + let batch = batch?; + let mut encode_timer = metrics.encode_time.timer(); + writer.write(&batch)?; + encode_timer.stop(); + } + + let mut write_timer = metrics.write_time.timer(); + writer.finish()?; + write_timer.stop(); + + let end_pos = self + .spill_file + .as_ref() + .unwrap() + .file + .metadata() + .map(|m| m.len()) + .unwrap_or(0); + + Ok((end_pos - start_pos) as usize) } else { Ok(0) } diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java deleted file mode 100644 index 9f72b20f51..0000000000 --- a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet; - -import java.io.Closeable; -import java.io.EOFException; -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.channels.Channels; -import java.nio.channels.ReadableByteChannel; - -/** - * Provides raw compressed shuffle blocks to native code via JNI. - * - *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads the - * compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() and - * getBuffer(). - * - *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. - * Native code must fully consume it (via read_ipc_compressed which allocates new memory for the - * decompressed data) before pulling the next block. - */ -public class CometShuffleBlockIterator implements Closeable { - - private static final int INITIAL_BUFFER_SIZE = 128 * 1024; - - private final ReadableByteChannel channel; - private final InputStream inputStream; - private final ByteBuffer headerBuf = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); - private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); - private boolean closed = false; - private int currentBlockLength = 0; - - public CometShuffleBlockIterator(InputStream in) { - this.inputStream = in; - this.channel = Channels.newChannel(in); - } - - /** - * Reads the next block header and loads the compressed body into the internal buffer. Called by - * native code via JNI. - * - *

Header format: 8-byte compressedLength (includes field count but not itself) + 8-byte - * fieldCount (discarded, schema comes from protobuf). - * - * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF - */ - public int hasNext() throws IOException { - if (closed) { - return -1; - } - - // Read 16-byte header: clear() resets position=0, limit=capacity, - // preparing the buffer for channel.read() to fill it - headerBuf.clear(); - while (headerBuf.hasRemaining()) { - int bytesRead = channel.read(headerBuf); - if (bytesRead < 0) { - if (headerBuf.position() == 0) { - close(); - return -1; - } - throw new EOFException("Data corrupt: unexpected EOF while reading batch header"); - } - } - headerBuf.flip(); - long compressedLength = headerBuf.getLong(); - // Field count discarded - schema determined by ShuffleScan protobuf fields - headerBuf.getLong(); - - // Subtract 8 because compressedLength includes the 8-byte field count we already read - long bytesToRead = compressedLength - 8; - if (bytesToRead > Integer.MAX_VALUE) { - throw new IllegalStateException( - "Native shuffle block size of " - + bytesToRead - + " exceeds maximum of " - + Integer.MAX_VALUE - + ". Try reducing spark.comet.columnar.shuffle.batch.size."); - } - - currentBlockLength = (int) bytesToRead; - - if (dataBuf.capacity() < currentBlockLength) { - int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); - dataBuf = ByteBuffer.allocateDirect(newCapacity); - } - - dataBuf.clear(); - dataBuf.limit(currentBlockLength); - while (dataBuf.hasRemaining()) { - int bytesRead = channel.read(dataBuf); - if (bytesRead < 0) { - throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch"); - } - } - // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, - // not the buffer's position/limit. No flip needed. - - return currentBlockLength; - } - - /** - * Returns the DirectByteBuffer containing the current block's compressed bytes (4-byte codec - * prefix + compressed IPC data). Called by native code via JNI. - */ - public ByteBuffer getBuffer() { - return dataBuf; - } - - /** Returns the length of the current block in bytes. Called by native code via JNI. */ - public int getCurrentBlockLength() { - return currentBlockLength; - } - - @Override - public void close() throws IOException { - if (!closed) { - closed = true; - inputStream.close(); - } - } -} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index e198ac99ff..e9e1968ccd 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -68,7 +68,7 @@ class CometExecIterator( partitionIndex: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, encryptedFilePaths: Seq[String] = Seq.empty, - shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) + shuffleInputStreams: Map[Int, java.io.InputStream] = Map.empty) extends Iterator[ColumnarBatch] with Logging { @@ -79,11 +79,11 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle - // scan indices, CometBatchIterator for regular scan indices. + // Build a mixed array of iterators: InputStream for shuffle scan indices, + // CometBatchIterator for regular scan indices. private val inputIterators: Array[Object] = inputs.zipWithIndex.map { - case (_, idx) if shuffleBlockIterators.contains(idx) => - shuffleBlockIterators(idx).asInstanceOf[Object] + case (_, idx) if shuffleInputStreams.contains(idx) => + shuffleInputStreams(idx).asInstanceOf[Object] case (iterator, _) => new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] }.toArray @@ -235,7 +235,7 @@ class CometExecIterator( currentBatch = null } nativeUtil.close() - shuffleBlockIterators.values.foreach(_.close()) + shuffleInputStreams.values.foreach(_.close()) nativeLib.releasePlan(plan) if (tracingEnabled) { diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index f6800626d6..abb321b683 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -19,8 +19,6 @@ package org.apache.comet -import java.nio.ByteBuffer - import org.apache.spark.CometTaskMemoryManager import org.apache.spark.sql.comet.CometMetricNode @@ -172,12 +170,16 @@ class Native extends NativeBase { * @param size * the size of the array. */ - @native def decodeShuffleBlock( - shuffleBlock: ByteBuffer, - length: Int, + @native def openShuffleStream(inputStream: java.io.InputStream): Long + + @native def nextShuffleStreamBatch( + handle: Long, arrayAddrs: Array[Long], - schemaAddrs: Array[Long], - tracingEnabled: Boolean): Long + schemaAddrs: Array[Long]): Long + + @native def shuffleStreamNumFields(handle: Long): Long + + @native def closeShuffleStream(handle: Long): Unit /** * Log the beginning of an event. diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index c5014818c4..963505dcf2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -111,11 +111,11 @@ private[spark] class CometExecRDD( serializedPlan } - // Create shuffle block iterators for inputs that are CometShuffledBatchRDD - val shuffleBlockIters = shuffleScanIndices.flatMap { idx => + // Create raw InputStreams for inputs that are CometShuffledBatchRDD + val shuffleStreams = shuffleScanIndices.flatMap { idx => inputRDDs(idx) match { case rdd: CometShuffledBatchRDD => - Some(idx -> rdd.computeAsShuffleBlockIterator(partition.inputPartitions(idx), context)) + Some(idx -> rdd.computeAsRawStream(partition.inputPartitions(idx), context)) case _ => None } }.toMap @@ -130,7 +130,7 @@ private[spark] class CometExecRDD( partition.index, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleBlockIters) + shuffleStreams) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 3fc222bd19..96c140300b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -192,6 +192,8 @@ class CometNativeShuffleWriter[K, V]( CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) shuffleWriterBuilder.setWriteBufferSize( CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt) + shuffleWriterBuilder.setImmediateMode( + CometConf.COMET_SHUFFLE_PARTITIONER_MODE.get() == "immediate") outputPartitioning match { case p if isSinglePartitioning(p) => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index 7604910b06..45677d93fb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsRe import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometShuffleBlockIterator - /** * Different from [[org.apache.spark.sql.execution.ShuffledRowRDD]], this RDD is specialized for * reading shuffled data through [[CometBlockStoreShuffleReader]]. The shuffled data is read in an @@ -149,14 +147,12 @@ class CometShuffledBatchRDD( } /** - * Creates a CometShuffleBlockIterator that provides raw compressed shuffle blocks for direct - * consumption by native code, bypassing Arrow FFI. + * Returns the raw InputStream of concatenated Arrow IPC streams for direct consumption by + * native code via ShuffleStreamReader. */ - def computeAsShuffleBlockIterator( - split: Partition, - context: TaskContext): CometShuffleBlockIterator = { + def computeAsRawStream(split: Partition, context: TaskContext): java.io.InputStream = { val reader = createReader(split, context) - new CometShuffleBlockIterator(reader.readAsRawStream()) + reader.readAsRawStream() } override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala index f96c8f16dd..22fc14df97 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.comet.execution.shuffle -import java.io.{EOFException, InputStream} -import java.nio.{ByteBuffer, ByteOrder} -import java.nio.channels.{Channels, ReadableByteChannel} +import java.io.InputStream import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch @@ -43,27 +41,32 @@ case class NativeBatchDecoderIterator( extends Iterator[ColumnarBatch] { private var isClosed = false - private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) private var currentBatch: ColumnarBatch = null - private var batch = fetchNext() - import NativeBatchDecoderIterator._ + // Open the native stream reader + private val handle: Long = if (in != null) { + nativeLib.openShuffleStream(in) + } else { + 0L + } - private val channel: ReadableByteChannel = if (in != null) { - Channels.newChannel(in) + // Get field count from the native reader (it parsed the schema on open) + private val numFields: Int = if (handle != 0L) { + nativeLib.shuffleStreamNumFields(handle).toInt } else { - null + 0 } + private var batch = fetchNext() + def hasNext(): Boolean = { - if (channel == null || isClosed) { + if (handle == 0L || isClosed) { return false } if (batch.isDefined) { return true } - // Release the previous batch. if (currentBatch != null) { currentBatch.close() currentBatch = null @@ -81,89 +84,24 @@ case class NativeBatchDecoderIterator( if (!hasNext) { throw new NoSuchElementException } - val nextBatch = batch.get - currentBatch = nextBatch batch = None currentBatch } private def fetchNext(): Option[ColumnarBatch] = { - if (channel == null || isClosed) { + if (handle == 0L || isClosed) { return None } - // read compressed batch size from header - try { - longBuf.clear() - while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {} - } catch { - case _: EOFException => - close() - return None - } - - // If we reach the end of the stream, we are done, or if we read partial length - // then the stream is corrupted. - if (longBuf.hasRemaining) { - if (longBuf.position() == 0) { - close() - return None - } - throw new EOFException("Data corrupt: unexpected EOF while reading compressed ipc lengths") - } - - // get compressed length (including headers) - longBuf.flip() - val compressedLength = longBuf.getLong - - // read field count from header - longBuf.clear() - while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {} - if (longBuf.hasRemaining) { - throw new EOFException("Data corrupt: unexpected EOF while reading field count") - } - longBuf.flip() - val fieldCount = longBuf.getLong.toInt - - // read body - val bytesToRead = compressedLength - 8 - if (bytesToRead > Integer.MAX_VALUE) { - // very unlikely that shuffle block will reach 2GB - throw new IllegalStateException( - s"Native shuffle block size of $bytesToRead exceeds " + - s"maximum of ${Integer.MAX_VALUE}. Try reducing shuffle batch size.") - } - var dataBuf = threadLocalDataBuf.get() - if (dataBuf.capacity() < bytesToRead) { - // it is unlikely that we would overflow here since it would - // require a 1GB compressed shuffle block but we check anyway - val newCapacity = (bytesToRead * 2L).min(Integer.MAX_VALUE).toInt - dataBuf = ByteBuffer.allocateDirect(newCapacity) - threadLocalDataBuf.set(dataBuf) - } - dataBuf.clear() - dataBuf.limit(bytesToRead.toInt) - while (dataBuf.hasRemaining && channel.read(dataBuf) >= 0) {} - if (dataBuf.hasRemaining) { - throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch") - } - - // make native call to decode batch val startTime = System.nanoTime() val batch = nativeUtil.getNextBatch( - fieldCount, + numFields, (arrayAddrs, schemaAddrs) => { - nativeLib.decodeShuffleBlock( - dataBuf, - bytesToRead.toInt, - arrayAddrs, - schemaAddrs, - tracingEnabled) + nativeLib.nextShuffleStreamBatch(handle, arrayAddrs, schemaAddrs) }) decodeTime.add(System.nanoTime() - startTime) - batch } @@ -174,25 +112,14 @@ case class NativeBatchDecoderIterator( currentBatch.close() currentBatch = null } - in.close() - resetDataBuf() + if (handle != 0L) { + nativeLib.closeShuffleStream(handle) + } + if (in != null) { + in.close() + } isClosed = true } } } } - -object NativeBatchDecoderIterator { - - private val INITIAL_BUFFER_SIZE = 128 * 1024 - - private val threadLocalDataBuf: ThreadLocal[ByteBuffer] = ThreadLocal.withInitial(() => { - ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE) - }) - - private def resetDataBuf(): Unit = { - if (threadLocalDataBuf.get().capacity() > INITIAL_BUFFER_SIZE) { - threadLocalDataBuf.set(ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE)) - } - } -}