diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 046ccf0b1c..3e58707748 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -280,6 +280,15 @@ object CometConf extends ShimCometConf { createExecEnabledConfig("hashJoin", defaultValue = true) val COMET_EXEC_SORT_MERGE_JOIN_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("sortMergeJoin", defaultValue = true) + val COMET_EXEC_SMJ_USE_NATIVE: ConfigEntry[Boolean] = + conf("spark.comet.exec.sortMergeJoin.useNative") + .category(CATEGORY_EXEC) + .doc( + "When true, use Comet's native sort merge join implementation. " + + "When false, use DataFusion's SortMergeJoinExec. " + + "This is useful for benchmarking the two implementations.") + .booleanConf + .createWithDefault(true) val COMET_EXEC_AGGREGATE_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("aggregate", defaultValue = true) val COMET_EXEC_COLLECT_LIMIT_ENABLED: ConfigEntry[Boolean] = diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e0a395ebbf..0e02099aed 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -179,6 +179,8 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Spark configuration map passed from JVM + pub spark_config: HashMap, } /// Accept serialized query plan and return the address of the native query plan. @@ -327,6 +329,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native, memory_pool_config, tracing_enabled, + spark_config, }); Ok(Box::into_raw(exec_context) as i64) @@ -545,7 +548,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let start = Instant::now(); let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); + .with_exec_id(exec_context_id) + .with_spark_config(exec_context.spark_config.clone()); let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), diff --git a/native/core/src/execution/joins/buffered_batch.rs b/native/core/src/execution/joins/buffered_batch.rs new file mode 100644 index 0000000000..ebd725233a --- /dev/null +++ b/native/core/src/execution/joins/buffered_batch.rs @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Buffered batch management for the sort merge join operator. +//! +//! [`BufferedMatchGroup`] holds all rows from the buffered (right) side that +//! share the current join key. When memory is tight, individual batches are +//! spilled to Arrow IPC files on disk and reloaded on demand. + +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::record_batch::RecordBatch; +use datafusion::common::utils::memory::get_record_batch_memory_size; +use datafusion::common::{DataFusionError, Result}; +use datafusion::execution::disk_manager::RefCountedTempFile; +use datafusion::execution::memory_pool::MemoryReservation; +use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_plan::spill::SpillManager; + +use super::metrics::SortMergeJoinMetrics; + +/// State of a single buffered batch — either held in memory or spilled to disk. +#[derive(Debug)] +enum BatchState { + /// The batch is available in memory. + InMemory(RecordBatch), + /// The batch has been spilled to an Arrow IPC file. + Spilled(RefCountedTempFile), +} + +/// A single batch in a [`BufferedMatchGroup`]. +/// +/// Tracks the batch data (in-memory or spilled), pre-evaluated join key arrays, +/// row count, estimated memory size, and per-row match flags for outer joins. +#[derive(Debug)] +pub(super) struct BufferedBatch { + /// The batch data, either in memory or spilled to disk. + state: BatchState, + /// Pre-evaluated join key column arrays. `None` when the batch has been spilled. + #[allow(dead_code)] + join_arrays: Option>, + /// Number of rows in this batch (cached so we don't need the batch to know). + pub num_rows: usize, + /// Estimated memory footprint in bytes (batch + join arrays). + pub size_estimate: usize, + /// For full/right outer joins: tracks which rows have been matched. + matched: Option>, +} + +impl BufferedBatch { + /// Mark a buffered row as matched (for full outer join tracking). + pub fn mark_matched(&mut self, row_idx: usize) { + if let Some(ref mut matched) = self.matched { + matched[row_idx] = true; + } + } + + /// Iterate over unmatched row indices. + pub fn unmatched_indices(&self) -> impl Iterator + '_ { + self.matched.as_ref().into_iter().flat_map(|m| { + m.iter() + .enumerate() + .filter(|(_, &matched)| !matched) + .map(|(idx, _)| idx) + }) + } + + /// Create a new in-memory buffered batch. + /// + /// `full_outer` controls whether per-row match tracking is allocated. + fn new_in_memory(batch: RecordBatch, join_arrays: Vec, full_outer: bool) -> Self { + let num_rows = batch.num_rows(); + let mut size_estimate = get_record_batch_memory_size(&batch); + for arr in &join_arrays { + size_estimate += arr.get_array_memory_size(); + } + let matched = if full_outer { + Some(vec![false; num_rows]) + } else { + None + }; + Self { + state: BatchState::InMemory(batch), + join_arrays: Some(join_arrays), + num_rows, + size_estimate, + matched, + } + } + + /// Return the batch. If it was spilled, read it back from disk via the spill manager. + pub fn get_batch(&self, spill_manager: &SpillManager) -> Result { + match &self.state { + BatchState::InMemory(batch) => Ok(batch.clone()), + BatchState::Spilled(file) => { + let reader = spill_manager.read_spill_as_stream(file.clone(), None)?; + let batches = tokio::task::block_in_place(|| { + let rt = tokio::runtime::Handle::current(); + rt.block_on(async { + use futures::StreamExt; + let mut stream = reader; + let mut batches = Vec::new(); + while let Some(batch) = stream.next().await { + batches.push(batch?); + } + Ok::<_, DataFusionError>(batches) + }) + })?; + // A single batch was spilled per file, but concatenate just in case. + if batches.len() == 1 { + Ok(batches.into_iter().next().unwrap()) + } else { + arrow::compute::concat_batches(&Arc::clone(spill_manager.schema()), &batches) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + } + } + } + } + + /// Return join key arrays. If in memory, returns the cached arrays directly. + /// If spilled, deserializes the batch and re-evaluates the join expressions. + #[allow(dead_code)] + pub fn get_join_arrays( + &self, + spill_manager: &SpillManager, + join_exprs: &[PhysicalExprRef], + ) -> Result> { + if let Some(ref arrays) = self.join_arrays { + return Ok(arrays.clone()); + } + // Spilled — reload and re-evaluate + let batch = self.get_batch(spill_manager)?; + evaluate_join_keys(&batch, join_exprs) + } +} + +/// A group of buffered batches that share the same join key values. +/// +/// Batches may be held in memory or spilled to disk when the memory reservation +/// cannot accommodate them. When spilled, they can be loaded back on demand via +/// the spill manager. +#[derive(Debug)] +pub(super) struct BufferedMatchGroup { + /// All batches in this match group. + pub batches: Vec, + /// Total number of rows across all batches. + pub num_rows: usize, + /// Total estimated memory usage of in-memory batches. + pub memory_size: usize, +} + +impl BufferedMatchGroup { + /// Create a new empty match group. + pub fn new() -> Self { + Self { + batches: Vec::new(), + num_rows: 0, + memory_size: 0, + } + } + + /// Add a batch to this match group. + /// + /// First attempts to grow the memory reservation to hold the batch in memory. + /// If that fails, the batch is spilled to disk via the spill manager and the + /// spill metrics are updated accordingly. + pub fn add_batch( + &mut self, + batch: RecordBatch, + join_arrays: Vec, + full_outer: bool, + reservation: &mut MemoryReservation, + spill_manager: &SpillManager, + metrics: &SortMergeJoinMetrics, + ) -> Result<()> { + let buffered = BufferedBatch::new_in_memory(batch.clone(), join_arrays, full_outer); + let size = buffered.size_estimate; + let num_rows = buffered.num_rows; + + if reservation.try_grow(size).is_ok() { + // Fits in memory + self.memory_size += size; + self.num_rows += num_rows; + self.batches.push(buffered); + } else { + // Spill to disk + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "SortMergeJoin buffered batch")?; + match spill_file { + Some(file) => { + metrics.spill_count.add(1); + metrics.spilled_bytes.add( + std::fs::metadata(file.path()) + .map(|m| m.len() as usize) + .unwrap_or(0), + ); + metrics.spilled_rows.add(num_rows); + let matched = if full_outer { + Some(vec![false; num_rows]) + } else { + None + }; + self.num_rows += num_rows; + self.batches.push(BufferedBatch { + state: BatchState::Spilled(file), + join_arrays: None, + num_rows, + size_estimate: 0, // not consuming memory + matched, + }); + } + None => { + // Empty batch, nothing to do + } + } + } + Ok(()) + } + + /// Clear all batches and release the memory reservation. + pub fn clear(&mut self, reservation: &mut MemoryReservation) { + self.batches.clear(); + reservation.shrink(self.memory_size); + self.num_rows = 0; + self.memory_size = 0; + } + + /// Get a batch by index. If the batch was spilled, it is read back from disk. + pub fn get_batch(&self, batch_idx: usize, spill_manager: &SpillManager) -> Result { + self.batches[batch_idx].get_batch(spill_manager) + } + + /// Returns `true` if this group contains no batches. + pub fn is_empty(&self) -> bool { + self.batches.is_empty() + } +} + +/// Evaluate join key physical expressions against a record batch and return the +/// resulting column arrays. +pub(super) fn evaluate_join_keys( + batch: &RecordBatch, + join_exprs: &[PhysicalExprRef], +) -> Result> { + join_exprs + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|cv| cv.into_array(batch.num_rows())) + }) + .collect() +} diff --git a/native/core/src/execution/joins/filter.rs b/native/core/src/execution/joins/filter.rs new file mode 100644 index 0000000000..e6ef4b416f --- /dev/null +++ b/native/core/src/execution/joins/filter.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Join filter evaluation with corrected masks for outer, semi, and anti joins. +//! +//! In outer joins, if all candidate pairs for a streamed row fail the filter, +//! the streamed row must be null-joined (not dropped). Semi joins emit a +//! streamed row if ANY pair passes. Anti joins emit if NO pair passes. This +//! module groups filter results by streamed row to implement these semantics. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt32Array}; +use arrow::compute::take; +use datafusion::common::{internal_err, JoinSide, Result}; +use datafusion::logical_expr::JoinType; +use datafusion::physical_plan::joins::utils::JoinFilter; + +use super::output_builder::JoinIndex; + +/// Result of applying a join filter to a set of candidate pairs. +pub(super) struct FilteredOutput { + /// Pairs that passed the filter (or were selected for null-join in outer). + pub passed_indices: Vec, + /// Streamed row indices that had no passing pair and should be null-joined + /// (applies to outer and anti joins). + pub streamed_null_joins: Vec, + /// (batch_idx, buffered_idx) pairs that passed the filter, used for + /// tracking matched buffered rows in full outer joins. + pub buffered_matched: Vec<(usize, usize)>, +} + +/// Evaluate a join filter on candidate pairs and return corrected results +/// based on the join type. +/// +/// `pair_indices` contains the candidate pairs as `JoinIndex` values. +/// `candidate_batch` is the intermediate batch built for filter evaluation. +pub(super) fn apply_join_filter( + filter: &JoinFilter, + candidate_batch: &RecordBatch, + pair_indices: &[JoinIndex], + join_type: &JoinType, +) -> Result { + // Evaluate the filter expression on the candidate batch + let filter_result = filter + .expression() + .evaluate(candidate_batch)? + .into_array(candidate_batch.num_rows())?; + + let mask = filter_result + .as_any() + .downcast_ref::() + .expect("join filter expression must return BooleanArray"); + + match join_type { + JoinType::Inner => Ok(apply_inner_filter(mask, pair_indices)), + JoinType::Left | JoinType::Right => Ok(apply_outer_filter(mask, pair_indices)), + JoinType::Full => Ok(apply_full_outer_filter(mask, pair_indices)), + JoinType::LeftSemi | JoinType::RightSemi => Ok(apply_semi_filter(mask, pair_indices)), + JoinType::LeftAnti | JoinType::RightAnti => Ok(apply_anti_filter(mask, pair_indices)), + _ => Ok(apply_inner_filter(mask, pair_indices)), + } +} + +/// Build the intermediate batch used for filter evaluation. +/// +/// For each column in the filter's `column_indices`, we take the appropriate +/// rows from either the streamed or buffered batch using the provided index +/// arrays. +pub(super) fn build_filter_candidate_batch( + filter: &JoinFilter, + streamed_batch: &RecordBatch, + buffered_batch: &RecordBatch, + streamed_indices: &UInt32Array, + buffered_indices: &UInt32Array, +) -> Result { + let columns: Vec = filter + .column_indices() + .iter() + .map(|col_idx| { + let (batch, indices) = match col_idx.side { + JoinSide::Left => (streamed_batch, streamed_indices), + JoinSide::Right => (buffered_batch, buffered_indices), + JoinSide::None => { + return internal_err!("unexpected JoinSide::None in join filter column index"); + } + }; + let column = batch.column(col_idx.index); + Ok(take(column.as_ref(), indices, None)?) + }) + .collect::>>()?; + + Ok(RecordBatch::try_new(Arc::clone(filter.schema()), columns)?) +} + +/// Returns true if the mask value at `i` is true and not null. +#[inline] +fn mask_passed(mask: &BooleanArray, i: usize) -> bool { + mask.value(i) && !mask.is_null(i) +} + +/// Inner join: keep rows where mask is true (and not null). +fn apply_inner_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + let passed_indices: Vec = indices + .iter() + .enumerate() + .filter(|(i, _)| mask_passed(mask, *i)) + .map(|(_, idx)| *idx) + .collect(); + + let buffered_matched = passed_indices + .iter() + .map(|idx| (idx.batch_idx, idx.buffered_idx)) + .collect(); + + FilteredOutput { + passed_indices, + streamed_null_joins: Vec::new(), + buffered_matched, + } +} + +/// Outer join (Left/Right): group by streamed_idx. If any pair passes for a +/// streamed row, keep those passing pairs. If none pass, add streamed row to +/// null-joins. +fn apply_outer_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + let mut groups: HashMap> = HashMap::new(); + for (i, idx) in indices.iter().enumerate() { + groups.entry(idx.streamed_idx).or_default().push((i, *idx)); + } + + let mut passed_indices = Vec::new(); + let mut streamed_null_joins = Vec::new(); + let mut buffered_matched = Vec::new(); + + for (streamed_idx, pairs) in &groups { + let passing: Vec = pairs + .iter() + .filter(|(i, _)| mask_passed(mask, *i)) + .map(|(_, idx)| *idx) + .collect(); + + if passing.is_empty() { + streamed_null_joins.push(*streamed_idx); + } else { + for idx in &passing { + buffered_matched.push((idx.batch_idx, idx.buffered_idx)); + } + passed_indices.extend(passing); + } + } + + FilteredOutput { + passed_indices, + streamed_null_joins, + buffered_matched, + } +} + +/// Full outer join: same grouping logic as outer, but buffered tracking is +/// done via the matched bitvector on BufferedBatch (caller uses +/// `buffered_matched`). +fn apply_full_outer_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + // Same logic as outer — the caller handles buffered-side null joins + // via the BufferedBatch matched bitvector. + apply_outer_filter(mask, indices) +} + +/// Semi join: group by streamed_idx. If any pair passes for a streamed row, +/// emit one JoinIndex for that row (the first passing pair). +fn apply_semi_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + let mut groups: HashMap> = HashMap::new(); + for (i, idx) in indices.iter().enumerate() { + groups.entry(idx.streamed_idx).or_default().push((i, *idx)); + } + + let mut passed_indices = Vec::new(); + let mut buffered_matched = Vec::new(); + + for pairs in groups.values() { + if let Some((_, idx)) = pairs.iter().find(|(i, _)| mask_passed(mask, *i)) { + passed_indices.push(*idx); + buffered_matched.push((idx.batch_idx, idx.buffered_idx)); + } + } + + FilteredOutput { + passed_indices, + streamed_null_joins: Vec::new(), + buffered_matched, + } +} + +/// Anti join: group by streamed_idx. If no pair passes for a streamed row, +/// add it to streamed_null_joins. +fn apply_anti_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + let mut groups: HashMap> = HashMap::new(); + for (i, idx) in indices.iter().enumerate() { + groups.entry(idx.streamed_idx).or_default().push(i); + } + + let mut streamed_null_joins = Vec::new(); + + for (streamed_idx, mask_indices) in &groups { + let any_passed = mask_indices.iter().any(|i| mask_passed(mask, *i)); + + if !any_passed { + streamed_null_joins.push(*streamed_idx); + } + } + + FilteredOutput { + passed_indices: Vec::new(), + streamed_null_joins, + buffered_matched: Vec::new(), + } +} diff --git a/native/core/src/execution/joins/metrics.rs b/native/core/src/execution/joins/metrics.rs new file mode 100644 index 0000000000..3248a7c949 --- /dev/null +++ b/native/core/src/execution/joins/metrics.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::physical_plan::metrics::{ + Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, Time, +}; + +/// Metrics for CometSortMergeJoinExec, matching CometMetricNode.scala definitions. +#[derive(Debug, Clone)] +pub(super) struct SortMergeJoinMetrics { + pub input_rows: Count, + pub input_batches: Count, + pub output_rows: Count, + pub output_batches: Count, + pub join_time: Time, + pub peak_mem_used: Gauge, + pub spill_count: Count, + pub spilled_bytes: Count, + pub spilled_rows: Count, +} + +impl SortMergeJoinMetrics { + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + input_rows: MetricBuilder::new(metrics).counter("input_rows", partition), + input_batches: MetricBuilder::new(metrics).counter("input_batches", partition), + output_rows: MetricBuilder::new(metrics).output_rows(partition), + output_batches: MetricBuilder::new(metrics).counter("output_batches", partition), + join_time: MetricBuilder::new(metrics).subset_time("join_time", partition), + peak_mem_used: MetricBuilder::new(metrics).gauge("peak_mem_used", partition), + spill_count: MetricBuilder::new(metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), + spilled_rows: MetricBuilder::new(metrics).counter("spilled_rows", partition), + } + } + + pub fn update_peak_mem(&self, current_mem: usize) { + if current_mem > self.peak_mem_used.value() { + self.peak_mem_used.set(current_mem); + } + } +} diff --git a/native/core/src/execution/joins/mod.rs b/native/core/src/execution/joins/mod.rs new file mode 100644 index 0000000000..7dc4d73b67 --- /dev/null +++ b/native/core/src/execution/joins/mod.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod buffered_batch; +mod filter; +mod metrics; +mod output_builder; +mod sort_merge_join; +mod sort_merge_join_stream; + +pub(crate) use sort_merge_join::CometSortMergeJoinExec; + +#[cfg(test)] +mod tests; diff --git a/native/core/src/execution/joins/output_builder.rs b/native/core/src/execution/joins/output_builder.rs new file mode 100644 index 0000000000..7f04de438c --- /dev/null +++ b/native/core/src/execution/joins/output_builder.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Output batch builder for the sort merge join operator. +//! +//! The [`OutputBuilder`] accumulates matched and null-joined index pairs during +//! the join's Joining state and materializes them into Arrow [`RecordBatch`]es +//! during the OutputReady state. + +use std::sync::Arc; + +use arrow::array::{new_null_array, ArrayRef, RecordBatch, UInt32Array}; +use arrow::compute::kernels::concat::concat; +use arrow::compute::kernels::take::take; +use arrow::datatypes::SchemaRef; +use datafusion::common::{JoinType, Result}; +use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_plan::spill::SpillManager; + +use super::buffered_batch::BufferedMatchGroup; + +/// An index pair representing a matched row from the streamed and buffered sides. +#[derive(Debug, Clone, Copy)] +pub(super) struct JoinIndex { + pub streamed_idx: usize, + pub batch_idx: usize, + pub buffered_idx: usize, +} + +/// Accumulates join output indices and materializes them into Arrow record batches. +pub(super) struct OutputBuilder { + output_schema: SchemaRef, + buffered_schema: SchemaRef, + join_type: JoinType, + target_batch_size: usize, + indices: Vec, + streamed_null_joins: Vec, + buffered_null_joins: Vec<(usize, usize)>, +} + +impl OutputBuilder { + pub fn new( + output_schema: SchemaRef, + _streamed_schema: SchemaRef, + buffered_schema: SchemaRef, + join_type: JoinType, + target_batch_size: usize, + ) -> Self { + Self { + output_schema, + buffered_schema, + join_type, + target_batch_size, + indices: Vec::new(), + streamed_null_joins: Vec::new(), + buffered_null_joins: Vec::new(), + } + } + + pub fn add_match(&mut self, streamed_idx: usize, batch_idx: usize, buffered_idx: usize) { + self.indices.push(JoinIndex { + streamed_idx, + batch_idx, + buffered_idx, + }); + } + + pub fn add_streamed_null_join(&mut self, streamed_idx: usize) { + self.streamed_null_joins.push(streamed_idx); + } + + pub fn add_buffered_null_join(&mut self, batch_idx: usize, buffered_idx: usize) { + self.buffered_null_joins.push((batch_idx, buffered_idx)); + } + + pub fn pending_count(&self) -> usize { + self.indices.len() + self.streamed_null_joins.len() + self.buffered_null_joins.len() + } + + pub fn should_flush(&self) -> bool { + self.pending_count() >= self.target_batch_size + } + + pub fn has_pending(&self) -> bool { + self.pending_count() > 0 + } + + /// Materialize the accumulated indices into a [`RecordBatch`]. + /// + /// After building, all accumulated indices are cleared. + pub fn build( + &mut self, + streamed_batch: &RecordBatch, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + _buffered_join_exprs: &[PhysicalExprRef], + ) -> Result { + let result = match self.join_type { + JoinType::LeftSemi | JoinType::LeftAnti => self.build_semi_anti(streamed_batch), + _ => self.build_full(streamed_batch, match_group, spill_manager), + }; + + self.indices.clear(); + self.streamed_null_joins.clear(); + self.buffered_null_joins.clear(); + + result + } + + fn build_semi_anti(&self, streamed_batch: &RecordBatch) -> Result { + let indices: Vec = self + .indices + .iter() + .map(|idx| idx.streamed_idx as u32) + .chain(self.streamed_null_joins.iter().map(|&idx| idx as u32)) + .collect(); + + let indices_array = UInt32Array::from(indices); + + let columns: Vec = streamed_batch + .columns() + .iter() + .map(|col| take(col.as_ref(), &indices_array, None).map_err(Into::into)) + .collect::>()?; + + Ok(RecordBatch::try_new( + Arc::clone(&self.output_schema), + columns, + )?) + } + + fn build_full( + &self, + streamed_batch: &RecordBatch, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + ) -> Result { + let streamed_columns = self.build_streamed_columns(streamed_batch)?; + let buffered_columns = self.build_buffered_columns(match_group, spill_manager)?; + + let mut columns = streamed_columns; + columns.extend(buffered_columns); + + Ok(RecordBatch::try_new( + Arc::clone(&self.output_schema), + columns, + )?) + } + + fn build_streamed_columns(&self, streamed_batch: &RecordBatch) -> Result> { + let total_rows = self.pending_count(); + let num_buffered_nulls = self.buffered_null_joins.len(); + + let indices: Vec> = self + .indices + .iter() + .map(|idx| Some(idx.streamed_idx as u32)) + .chain(self.streamed_null_joins.iter().map(|&idx| Some(idx as u32))) + .chain(std::iter::repeat_n(None, num_buffered_nulls)) + .collect(); + + debug_assert_eq!(indices.len(), total_rows); + + let indices_array = UInt32Array::from(indices); + + streamed_batch + .columns() + .iter() + .map(|col| take(col.as_ref(), &indices_array, None).map_err(Into::into)) + .collect() + } + + /// Build all buffered columns at once, loading each batch only once across all columns. + fn build_buffered_columns( + &self, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + ) -> Result> { + let num_cols = self.buffered_schema.fields().len(); + let num_streamed_nulls = self.streamed_null_joins.len(); + + // Pre-compute which batches we need and their grouped row indices. + // This avoids loading the same spilled batch N times (once per column). + let matched_groups = group_by_batch(&self.indices); + let null_join_groups = group_by_batch_tuple(&self.buffered_null_joins); + + // Load each referenced batch once + let mut batch_cache: std::collections::HashMap = + std::collections::HashMap::new(); + for &(batch_idx, _) in matched_groups.iter().chain(null_join_groups.iter()) { + if let std::collections::hash_map::Entry::Vacant(e) = batch_cache.entry(batch_idx) { + e.insert(match_group.batches[batch_idx].get_batch(spill_manager)?); + } + } + + // Build all columns using the cached batches + let mut result: Vec = Vec::with_capacity(num_cols); + for col_idx in 0..num_cols { + let data_type = self.buffered_schema.field(col_idx).data_type(); + let mut parts: Vec = Vec::with_capacity(3); + + // Matched pairs + if !matched_groups.is_empty() { + parts.push(take_from_groups(col_idx, &matched_groups, &batch_cache)?); + } + + // Null arrays for streamed null joins + if num_streamed_nulls > 0 { + parts.push(new_null_array(data_type, num_streamed_nulls)); + } + + // Buffered null joins + if !null_join_groups.is_empty() { + parts.push(take_from_groups(col_idx, &null_join_groups, &batch_cache)?); + } + + result.push(concat_parts(parts, data_type)?); + } + + Ok(result) + } +} + +/// Group JoinIndex entries by batch_idx into (batch_idx, row_indices) pairs. +fn group_by_batch(indices: &[JoinIndex]) -> Vec<(usize, Vec)> { + let mut groups: Vec<(usize, Vec)> = Vec::new(); + for idx in indices { + if let Some(last) = groups.last_mut() { + if last.0 == idx.batch_idx { + last.1.push(idx.buffered_idx as u32); + continue; + } + } + groups.push((idx.batch_idx, vec![idx.buffered_idx as u32])); + } + groups +} + +/// Group (batch_idx, row_idx) tuples by batch_idx into (batch_idx, row_indices) pairs. +fn group_by_batch_tuple(indices: &[(usize, usize)]) -> Vec<(usize, Vec)> { + let mut groups: Vec<(usize, Vec)> = Vec::new(); + for &(batch_idx, row_idx) in indices { + if let Some(last) = groups.last_mut() { + if last.0 == batch_idx { + last.1.push(row_idx as u32); + continue; + } + } + groups.push((batch_idx, vec![row_idx as u32])); + } + groups +} + +/// Take a single column from pre-loaded batches using grouped indices. +fn take_from_groups( + col_idx: usize, + groups: &[(usize, Vec)], + batch_cache: &std::collections::HashMap, +) -> Result { + let mut parts: Vec = Vec::with_capacity(groups.len()); + for (batch_idx, row_indices) in groups { + let batch = &batch_cache[batch_idx]; + let col = batch.column(col_idx); + let index_array = UInt32Array::from(row_indices.clone()); + parts.push(take(col.as_ref(), &index_array, None)?); + } + concat_parts( + parts, + batch_cache + .values() + .next() + .unwrap() + .column(col_idx) + .data_type(), + ) +} + +/// Concat array parts, handling empty and single-element cases. +fn concat_parts(parts: Vec, data_type: &arrow::datatypes::DataType) -> Result { + if parts.is_empty() { + return Ok(new_null_array(data_type, 0)); + } + if parts.len() == 1 { + return Ok(parts.into_iter().next().expect("checked len == 1")); + } + let refs: Vec<&dyn arrow::array::Array> = parts.iter().map(|a| a.as_ref()).collect(); + Ok(concat(&refs)?) +} diff --git a/native/core/src/execution/joins/sort_merge_join.rs b/native/core/src/execution/joins/sort_merge_join.rs new file mode 100644 index 0000000000..6909f260b0 --- /dev/null +++ b/native/core/src/execution/joins/sort_merge_join.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; + +use arrow::compute::SortOptions; +use arrow::datatypes::SchemaRef; +use datafusion::common::{NullEquality, Result}; +use datafusion::execution::memory_pool::MemoryConsumer; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::JoinType; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::joins::utils::{build_join_schema, check_join_is_valid, JoinFilter}; +use datafusion::physical_plan::joins::JoinOn; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet, SpillMetrics}; +use datafusion::physical_plan::spill::SpillManager; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, +}; + +use super::metrics::SortMergeJoinMetrics; +use super::sort_merge_join_stream::SortMergeJoinStream; + +/// A Comet-specific sort merge join operator that replaces DataFusion's +/// `SortMergeJoinExec` with Spark-compatible semantics. +#[derive(Debug)] +pub(crate) struct CometSortMergeJoinExec { + left: Arc, + right: Arc, + join_on: JoinOn, + join_filter: Option, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, + schema: SchemaRef, + properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl CometSortMergeJoinExec { + /// Create a new `CometSortMergeJoinExec`. + pub fn try_new( + left: Arc, + right: Arc, + join_on: JoinOn, + join_filter: Option, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + check_join_is_valid(&left_schema, &right_schema, &join_on)?; + + let (schema, _column_indices) = build_join_schema(&left_schema, &right_schema, &join_type); + let schema = Arc::new(schema); + + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning( + left.properties().output_partitioning().partition_count(), + ), + EmissionType::Incremental, + Boundedness::Bounded, + ); + + Ok(Self { + left, + right, + join_on, + join_filter, + join_type, + sort_options, + null_equality, + schema, + properties, + metrics: ExecutionPlanMetricsSet::default(), + }) + } +} + +impl DisplayAs for CometSortMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CometSortMergeJoinExec: join_type={:?}", self.join_type) + } + DisplayFormatType::TreeRender => unimplemented!(), + } + } +} + +impl ExecutionPlan for CometSortMergeJoinExec { + fn name(&self) -> &str { + "CometSortMergeJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(CometSortMergeJoinExec::try_new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.join_on.clone(), + self.join_filter.clone(), + self.join_type, + self.sort_options.clone(), + self.null_equality, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + // Determine streamed/buffered assignment based on join type. + // RightOuter: right is streamed, left is buffered. + // All others: left is streamed, right is buffered. + let (streamed_child, buffered_child, streamed_join_exprs, buffered_join_exprs) = + match self.join_type { + JoinType::Right => ( + Arc::clone(&self.right), + Arc::clone(&self.left), + self.join_on + .iter() + .map(|(_, r)| Arc::clone(r)) + .collect::>(), + self.join_on + .iter() + .map(|(l, _)| Arc::clone(l)) + .collect::>(), + ), + _ => ( + Arc::clone(&self.left), + Arc::clone(&self.right), + self.join_on + .iter() + .map(|(l, _)| Arc::clone(l)) + .collect::>(), + self.join_on + .iter() + .map(|(_, r)| Arc::clone(r)) + .collect::>(), + ), + }; + + let streamed_schema = streamed_child.schema(); + let buffered_schema = buffered_child.schema(); + + let streamed_input = streamed_child.execute(partition, Arc::clone(&context))?; + let buffered_input = buffered_child.execute(partition, Arc::clone(&context))?; + + // Create memory reservation. + let reservation = MemoryConsumer::new("CometSortMergeJoin") + .with_can_spill(true) + .register(context.memory_pool()); + + // Create spill manager. + let spill_metrics = SpillMetrics::new(&self.metrics, partition); + let spill_manager = SpillManager::new( + context.runtime_env(), + spill_metrics, + Arc::clone(&buffered_schema), + ); + + let metrics = SortMergeJoinMetrics::new(&self.metrics, partition); + let target_batch_size = context.session_config().batch_size(); + + Ok(Box::pin(SortMergeJoinStream::try_new( + Arc::clone(&self.schema), + streamed_schema, + buffered_schema, + self.join_type, + self.null_equality, + self.join_filter.clone(), + self.sort_options.clone(), + streamed_input, + buffered_input, + streamed_join_exprs, + buffered_join_exprs, + reservation, + spill_manager, + metrics, + target_batch_size, + )?)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} diff --git a/native/core/src/execution/joins/sort_merge_join_stream.rs b/native/core/src/execution/joins/sort_merge_join_stream.rs new file mode 100644 index 0000000000..ace92651b0 --- /dev/null +++ b/native/core/src/execution/joins/sort_merge_join_stream.rs @@ -0,0 +1,1138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Streaming state machine for the sort merge join operator. +//! +//! The [`SortMergeJoinStream`] drives two sorted input streams (streamed and +//! buffered), compares join keys, collects matching buffered rows into a +//! [`BufferedMatchGroup`], and produces joined output batches via the +//! [`OutputBuilder`]. + +use std::cmp::Ordering; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::{ArrayRef, RecordBatch, UInt32Array}; +use arrow::compute::SortOptions; +use arrow::datatypes::SchemaRef; +use arrow::row::{OwnedRow, RowConverter, Rows, SortField}; +use datafusion::common::{NullEquality, Result}; +use datafusion::execution::memory_pool::MemoryReservation; +use datafusion::logical_expr::JoinType; +use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_plan::joins::utils::{compare_join_arrays, JoinFilter}; +use datafusion::physical_plan::spill::SpillManager; +use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; + +use futures::{Stream, StreamExt}; + +use super::buffered_batch::{evaluate_join_keys, BufferedMatchGroup}; +use super::filter::{apply_join_filter, build_filter_candidate_batch}; +use super::metrics::SortMergeJoinMetrics; +use super::output_builder::{JoinIndex, OutputBuilder}; + +/// States of the sort merge join state machine. +#[derive(Debug, PartialEq, Eq)] +enum JoinState { + /// Need to poll the next streamed row. + PollStreamed, + /// Need to poll the next buffered batch. + PollBuffered, + /// Initial state: decide what to poll next. + Init, + /// Compare the current streamed key with the current buffered key. + Comparing, + /// Collecting more buffered batches into the match group (key spans batches). + CollectingBuffered, + /// Produce join output for the current streamed row against the match group. + Joining, + /// Flush accumulated output. + OutputReady, + /// Drain unmatched rows after one side is exhausted. + DrainUnmatched, + /// Drain remaining buffered rows as null-joined (Full/Right outer). + DrainBuffered, + /// No more output. + Exhausted, +} + +/// A streaming sort merge join that merges two sorted inputs by join keys. +pub(super) struct SortMergeJoinStream { + /// The type of join (Inner, Left, Right, Full, LeftSemi, LeftAnti). + join_type: JoinType, + /// How nulls compare during key matching. + null_equality: NullEquality, + /// Optional post-join filter. + join_filter: Option, + /// Sort options for each join key column. + sort_options: Vec, + + /// The streamed (driving) input. + streamed_input: SendableRecordBatchStream, + /// The buffered (probe) input. + buffered_input: SendableRecordBatchStream, + /// Expressions to evaluate join keys on the streamed side. + streamed_join_exprs: Vec, + /// Expressions to evaluate join keys on the buffered side. + buffered_join_exprs: Vec, + + /// Current streamed batch. + streamed_batch: Option, + /// Pre-evaluated join key arrays for the current streamed batch. + streamed_join_arrays: Option>, + /// Current row index within the streamed batch. + streamed_idx: usize, + /// Whether the streamed input is exhausted. + streamed_exhausted: bool, + + /// Pending buffered batch (batch + join arrays) not yet consumed. + buffered_pending: Option<(RecordBatch, Vec)>, + /// Whether the buffered input is exhausted. + buffered_exhausted: bool, + /// The current match group of buffered rows sharing the same join key. + match_group: BufferedMatchGroup, + + /// Converts join keys to comparable row format for key-reuse optimization. + row_converter: RowConverter, + /// Row-format conversion of current streamed batch's join keys (computed once per batch). + streamed_rows: Option, + /// Cached key of the previous streamed row (for key-reuse detection). + cached_streamed_key: Option, + + /// Accumulates output indices and builds result batches. + output_builder: OutputBuilder, + /// Schema of the output record batches. + output_schema: SchemaRef, + + /// Memory reservation for buffered data. + reservation: MemoryReservation, + /// Manages spilling buffered batches to disk when memory is tight. + spill_manager: SpillManager, + /// Metrics for this join operator. + metrics: SortMergeJoinMetrics, + + /// Current state of the state machine. + state: JoinState, +} + +impl SortMergeJoinStream { + /// Create a new `SortMergeJoinStream`. + #[allow(clippy::too_many_arguments)] + pub fn try_new( + output_schema: SchemaRef, + streamed_schema: SchemaRef, + buffered_schema: SchemaRef, + join_type: JoinType, + null_equality: NullEquality, + join_filter: Option, + sort_options: Vec, + streamed_input: SendableRecordBatchStream, + buffered_input: SendableRecordBatchStream, + streamed_join_exprs: Vec, + buffered_join_exprs: Vec, + reservation: MemoryReservation, + spill_manager: SpillManager, + metrics: SortMergeJoinMetrics, + target_batch_size: usize, + ) -> Result { + // Build SortFields from the streamed join key data types. + let sort_fields: Vec = streamed_join_exprs + .iter() + .zip(sort_options.iter()) + .map(|(expr, opts)| { + let dt = expr.data_type(&streamed_schema)?; + Ok(SortField::new_with_options(dt, *opts)) + }) + .collect::>()?; + + let row_converter = RowConverter::new(sort_fields) + .map_err(|e| datafusion::common::DataFusionError::ArrowError(Box::new(e), None))?; + + let output_builder = OutputBuilder::new( + Arc::clone(&output_schema), + streamed_schema, + buffered_schema, + join_type, + target_batch_size, + ); + + Ok(Self { + join_type, + null_equality, + join_filter, + sort_options, + streamed_input, + buffered_input, + streamed_join_exprs, + buffered_join_exprs, + streamed_batch: None, + streamed_join_arrays: None, + streamed_idx: 0, + streamed_exhausted: false, + buffered_pending: None, + buffered_exhausted: false, + match_group: BufferedMatchGroup::new(), + row_converter, + streamed_rows: None, + cached_streamed_key: None, + output_builder, + output_schema, + reservation, + spill_manager, + metrics, + state: JoinState::Init, + }) + } + + /// Drive the state machine, returning `Poll::Ready(Some(batch))` when a + /// batch is available, `Poll::Ready(None)` when done, or `Poll::Pending` + /// if waiting on input. + fn poll_next_inner(&mut self, cx: &mut Context<'_>) -> Poll>> { + loop { + match self.state { + JoinState::Init => { + // Decide what to poll based on current state. + if self.streamed_batch.is_none() && !self.streamed_exhausted { + self.state = JoinState::PollStreamed; + } else if self.buffered_pending.is_none() && !self.buffered_exhausted { + self.state = JoinState::PollBuffered; + } else if self.streamed_exhausted { + self.state = JoinState::DrainUnmatched; + } else { + // Have streamed data; compare regardless of buffered state. + self.state = JoinState::Comparing; + } + } + + JoinState::PollStreamed => { + match self.streamed_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + // Skip empty batches. + continue; + } + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); + let join_arrays = + evaluate_join_keys(&batch, &self.streamed_join_exprs)?; + // Convert join keys to row format once per batch for key-reuse checks + let rows = + self.row_converter + .convert_columns(&join_arrays) + .map_err(|e| { + datafusion::common::DataFusionError::ArrowError( + Box::new(e), + None, + ) + })?; + self.streamed_rows = Some(rows); + self.streamed_batch = Some(batch); + self.streamed_join_arrays = Some(join_arrays); + self.streamed_idx = 0; + // Now ensure we have buffered data too. + if self.buffered_pending.is_none() && !self.buffered_exhausted { + self.state = JoinState::PollBuffered; + } else { + self.state = JoinState::Comparing; + } + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.streamed_exhausted = true; + self.state = JoinState::DrainUnmatched; + } + Poll::Pending => return Poll::Pending, + } + } + + JoinState::PollBuffered => match self.buffered_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + continue; + } + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); + let join_arrays = evaluate_join_keys(&batch, &self.buffered_join_exprs)?; + self.buffered_pending = Some((batch, join_arrays)); + if self.streamed_batch.is_some() { + self.state = JoinState::Comparing; + } else if !self.streamed_exhausted { + self.state = JoinState::PollStreamed; + } else { + self.state = JoinState::DrainUnmatched; + } + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.buffered_exhausted = true; + if self.streamed_batch.is_some() { + self.state = JoinState::Comparing; + } else { + self.state = JoinState::DrainUnmatched; + } + } + Poll::Pending => return Poll::Pending, + }, + + JoinState::Comparing => { + // We have a streamed row. Compare its key against the + // buffered key (first row of buffered_pending). + let streamed_idx = self.streamed_idx; + + // Check if the streamed key has nulls. + let streamed_has_null = self + .streamed_join_arrays + .as_ref() + .unwrap() + .iter() + .any(|a| a.is_null(streamed_idx)); + + // For inner/semi joins, skip null keys entirely. + if streamed_has_null && self.null_equality == NullEquality::NullEqualsNothing { + match self.join_type { + JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi => { + self.advance_streamed()?; + self.determine_next_state(); + continue; + } + JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftAnti + | JoinType::RightAnti => { + self.output_builder.add_streamed_null_join(streamed_idx); + self.advance_streamed()?; + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + continue; + } + _ => { + self.advance_streamed()?; + self.determine_next_state(); + continue; + } + } + } + + // Check if the streamed key matches the cached key (reuse + // the existing match group). + if self.try_reuse_match_group()? { + self.state = JoinState::Joining; + continue; + } + + // Before clearing the match group, flush any pending output + // that references the current match group's batches. + if self.output_builder.has_pending() { + self.state = JoinState::OutputReady; + continue; + } + + // Clear old match group. + self.match_group.clear(&mut self.reservation); + self.cached_streamed_key = None; + + if self.buffered_exhausted && self.buffered_pending.is_none() { + // No buffered data at all. Streamed row is unmatched. + self.emit_streamed_unmatched(streamed_idx); + self.advance_streamed()?; + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + continue; + } + + // Compare streamed key with buffered key. + let ordering = { + let streamed_arrays = self.streamed_join_arrays.as_ref().unwrap(); + let (_buffered_batch, buffered_arrays) = + self.buffered_pending.as_ref().unwrap(); + compare_join_arrays( + streamed_arrays, + streamed_idx, + buffered_arrays, + 0, + &self.sort_options, + self.null_equality, + )? + }; + + match ordering { + Ordering::Less => { + // Streamed key < buffered key: streamed row has no match. + self.emit_streamed_unmatched(streamed_idx); + self.advance_streamed()?; + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + } + Ordering::Greater => { + // Streamed key > buffered key: advance past the + // first buffered row. If the pending batch has + // more rows, slice it; otherwise discard and poll + // the next batch. + let (batch, arrays) = self.buffered_pending.take().unwrap(); + if batch.num_rows() > 1 { + let remaining = batch.slice(1, batch.num_rows() - 1); + let remaining_arrays: Vec = + arrays.iter().map(|a| a.slice(1, a.len() - 1)).collect(); + self.buffered_pending = Some((remaining, remaining_arrays)); + // Re-compare with the next buffered row. + self.state = JoinState::Comparing; + } else if self.buffered_exhausted { + self.emit_streamed_unmatched(streamed_idx); + self.advance_streamed()?; + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + } else { + self.state = JoinState::PollBuffered; + } + } + Ordering::Equal => { + // Keys match. Build the match group. + let needs_more = self.build_match_group()?; + self.cache_streamed_key()?; + if needs_more { + self.state = JoinState::CollectingBuffered; + } else { + self.state = JoinState::Joining; + } + } + } + } + + JoinState::CollectingBuffered => { + // We consumed an entire buffered batch into the match group + // and need to check if more buffered rows have the same key. + if self.buffered_exhausted { + // No more buffered data. Match group is complete. + self.state = JoinState::Joining; + continue; + } + match self.buffered_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + continue; + } + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); + let join_arrays = + evaluate_join_keys(&batch, &self.buffered_join_exprs)?; + self.buffered_pending = Some((batch, join_arrays)); + let needs_more = self.build_match_group()?; + if needs_more { + // Still consuming; keep collecting. + continue; + } + self.state = JoinState::Joining; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.buffered_exhausted = true; + self.state = JoinState::Joining; + } + Poll::Pending => return Poll::Pending, + } + } + + JoinState::Joining => { + // Produce join pairs for the current streamed row against + // all rows in the match group. + let streamed_idx = self.streamed_idx; + self.produce_join_pairs(streamed_idx)?; + + self.advance_streamed()?; + + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + } + + JoinState::OutputReady => { + let batch = self.flush_output()?; + if batch.num_rows() > 0 { + self.metrics.output_rows.add(batch.num_rows()); + self.metrics.output_batches.add(1); + // After flushing, figure out what to do next. + self.determine_next_state(); + return Poll::Ready(Some(Ok(batch))); + } + self.determine_next_state(); + } + + JoinState::DrainUnmatched => { + // Drain remaining streamed rows as null-joined (for outer/anti). + self.drain_remaining()?; + + if self.output_builder.has_pending() { + let batch = self.flush_output()?; + // For Full/Right outer, we may still need to drain buffered rows. + if matches!(self.join_type, JoinType::Full | JoinType::Right) { + self.state = JoinState::DrainBuffered; + } else { + self.state = JoinState::Exhausted; + } + if batch.num_rows() > 0 { + self.metrics.output_rows.add(batch.num_rows()); + self.metrics.output_batches.add(1); + return Poll::Ready(Some(Ok(batch))); + } + } + if matches!(self.join_type, JoinType::Full | JoinType::Right) { + self.state = JoinState::DrainBuffered; + } else { + self.state = JoinState::Exhausted; + } + } + + JoinState::DrainBuffered => { + // For Full/Right outer: emit remaining buffered rows as null-joined. + // First, clear the match group so we can reuse it for pending rows. + self.match_group.clear(&mut self.reservation); + + // Add buffered_pending rows to the match group. + if let Some((batch, arrays)) = self.buffered_pending.take() { + let num_rows = batch.num_rows(); + self.match_group.add_batch( + batch, + arrays, + true, // track matched status + &mut self.reservation, + &self.spill_manager, + &self.metrics, + )?; + // All these rows are unmatched. + for row_idx in 0..num_rows { + self.output_builder.add_buffered_null_join(0, row_idx); + } + } + + // Poll remaining buffered batches. + if !self.buffered_exhausted { + match self.buffered_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() > 0 { + let num_rows = batch.num_rows(); + let join_arrays = + evaluate_join_keys(&batch, &self.buffered_join_exprs)?; + let batch_idx = self.match_group.batches.len(); + self.match_group.add_batch( + batch, + join_arrays, + true, + &mut self.reservation, + &self.spill_manager, + &self.metrics, + )?; + for row_idx in 0..num_rows { + self.output_builder + .add_buffered_null_join(batch_idx, row_idx); + } + } + // Stay in DrainBuffered to poll more. + continue; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.buffered_exhausted = true; + } + Poll::Pending => return Poll::Pending, + } + } + + // All buffered rows drained. Flush and finish. + if self.output_builder.has_pending() { + let batch = self.flush_output()?; + self.state = JoinState::Exhausted; + if batch.num_rows() > 0 { + self.metrics.output_rows.add(batch.num_rows()); + self.metrics.output_batches.add(1); + return Poll::Ready(Some(Ok(batch))); + } + } + self.state = JoinState::Exhausted; + } + + JoinState::Exhausted => { + self.metrics.update_peak_mem(self.reservation.size()); + return Poll::Ready(None); + } + } + } + } + + /// Determine the next state after processing a row. + /// + /// When the current streamed batch is exhausted, any pending output must be + /// flushed first (the output builder holds row indices into the batch). + /// After the flush the batch is cleared and we move on to poll new data. + fn determine_next_state(&mut self) { + // Check if the current streamed batch is exhausted. + if let Some(batch) = &self.streamed_batch { + if self.streamed_idx < batch.num_rows() { + // More rows in current streamed batch. + self.state = JoinState::Comparing; + return; + } + // Batch exhausted. If there are pending output rows that reference + // indices in this batch we must flush before clearing. + if self.output_builder.has_pending() { + self.state = JoinState::OutputReady; + return; + } + // Safe to clear — no pending references. + self.streamed_batch = None; + self.streamed_join_arrays = None; + self.streamed_rows = None; + self.streamed_idx = 0; + } + + if !self.streamed_exhausted { + self.state = JoinState::Init; + } else { + self.state = JoinState::DrainUnmatched; + } + } + + /// Advance the streamed side to the next row. + /// + /// Note: we do NOT clear the streamed batch here even when all rows have + /// been consumed, because the output builder may still hold index references + /// into the batch that need to be materialized during the next flush. + /// The batch is cleared lazily in `determine_next_state` when we transition + /// to polling a new batch. + fn advance_streamed(&mut self) -> Result<()> { + self.streamed_idx += 1; + Ok(()) + } + + /// Try to reuse the existing match group if the current streamed key + /// matches the cached key. Uses pre-computed row-format keys (once per batch). + fn try_reuse_match_group(&self) -> Result { + if self.cached_streamed_key.is_none() || self.match_group.is_empty() { + return Ok(false); + } + + let rows = self.streamed_rows.as_ref().unwrap(); + let current_key = rows.row(self.streamed_idx); + + if let Some(ref cached) = self.cached_streamed_key { + if current_key == cached.row() { + return Ok(true); + } + } + + Ok(false) + } + + /// Cache the current streamed key as an OwnedRow. + fn cache_streamed_key(&mut self) -> Result<()> { + let rows = self.streamed_rows.as_ref().unwrap(); + self.cached_streamed_key = Some(rows.row(self.streamed_idx).owned()); + Ok(()) + } + + /// Build a match group by collecting all buffered rows with the same key + /// as the current streamed row. + /// + /// Returns `true` if the entire buffered batch was consumed and more data + /// may need to be polled to complete the match group. + fn build_match_group(&mut self) -> Result { + let streamed_arrays = self.streamed_join_arrays.as_ref().unwrap(); + let streamed_idx = self.streamed_idx; + let full_outer = matches!(self.join_type, JoinType::Full); + + // Take the pending buffered batch. + let (batch, arrays) = self.buffered_pending.take().unwrap(); + + // Find how many rows from this batch have the same key. + let boundary = find_key_boundary( + streamed_arrays, + streamed_idx, + &arrays, + &self.sort_options, + self.null_equality, + )?; + + let needs_more = boundary == batch.num_rows(); + + if needs_more { + // Entire batch matches. Add it to the group. + self.match_group.add_batch( + batch, + arrays, + full_outer, + &mut self.reservation, + &self.spill_manager, + &self.metrics, + )?; + // buffered_pending remains None; caller should poll more. + } else { + // Split the batch: rows [0..boundary) match, [boundary..) don't. + if boundary > 0 { + let matching = batch.slice(0, boundary); + let matching_arrays: Vec = + arrays.iter().map(|a| a.slice(0, boundary)).collect(); + self.match_group.add_batch( + matching, + matching_arrays, + full_outer, + &mut self.reservation, + &self.spill_manager, + &self.metrics, + )?; + } + + // Keep the remaining rows as the new pending batch. + let remaining = batch.slice(boundary, batch.num_rows() - boundary); + let remaining_arrays: Vec = arrays + .iter() + .map(|a| a.slice(boundary, a.len() - boundary)) + .collect(); + self.buffered_pending = Some((remaining, remaining_arrays)); + } + + self.metrics.update_peak_mem(self.reservation.size()); + Ok(needs_more) + } + + /// Emit a streamed row as unmatched (null-joined) for outer/anti joins, + /// or skip it for inner/semi joins. + fn emit_streamed_unmatched(&mut self, streamed_idx: usize) { + match self.join_type { + JoinType::Left | JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => { + self.output_builder.add_streamed_null_join(streamed_idx); + } + // For Right outer: the streamed side is actually the right side, + // so unmatched streamed rows need null-joining. + JoinType::Right => { + self.output_builder.add_streamed_null_join(streamed_idx); + } + // Inner and semi joins: unmatched rows are dropped. + _ => {} + } + } + + /// Produce join pairs for the current streamed row against all rows in + /// the match group, applying the join filter if present. + fn produce_join_pairs(&mut self, streamed_idx: usize) -> Result<()> { + if self.match_group.is_empty() { + return Ok(()); + } + + match self.join_type { + JoinType::LeftSemi | JoinType::RightSemi => { + self.produce_semi_pairs(streamed_idx)?; + } + JoinType::LeftAnti | JoinType::RightAnti => { + self.produce_anti_pairs(streamed_idx)?; + } + _ => { + self.produce_standard_pairs(streamed_idx)?; + } + } + Ok(()) + } + + /// Produce pairs for inner/outer joins. + fn produce_standard_pairs(&mut self, streamed_idx: usize) -> Result<()> { + if let Some(ref filter) = self.join_filter { + // Build candidate pairs and apply filter. + let mut pair_indices = Vec::new(); + for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { + for row_idx in 0..buffered_batch.num_rows { + pair_indices.push(JoinIndex { + streamed_idx, + batch_idx, + buffered_idx: row_idx, + }); + } + } + + if pair_indices.is_empty() { + self.emit_streamed_unmatched(streamed_idx); + return Ok(()); + } + + // Build candidate batch for filter evaluation. + let streamed_batch = self.streamed_batch.as_ref().unwrap(); + let candidate_batch = self.build_filter_batch(filter, streamed_batch, &pair_indices)?; + + let filtered = + apply_join_filter(filter, &candidate_batch, &pair_indices, &self.join_type)?; + + // Apply filtered results. + for idx in &filtered.passed_indices { + self.output_builder + .add_match(idx.streamed_idx, idx.batch_idx, idx.buffered_idx); + } + for &si in &filtered.streamed_null_joins { + self.output_builder.add_streamed_null_join(si); + } + for &(batch_idx, buffered_idx) in &filtered.buffered_matched { + self.match_group.batches[batch_idx].mark_matched(buffered_idx); + } + } else { + // No filter: all pairs match. + for (batch_idx, buffered_batch) in self.match_group.batches.iter_mut().enumerate() { + for row_idx in 0..buffered_batch.num_rows { + self.output_builder + .add_match(streamed_idx, batch_idx, row_idx); + buffered_batch.mark_matched(row_idx); + } + } + } + Ok(()) + } + + /// Produce pairs for semi joins: emit the streamed row if any match passes. + fn produce_semi_pairs(&mut self, streamed_idx: usize) -> Result<()> { + if let Some(ref filter) = self.join_filter { + let mut pair_indices = Vec::new(); + for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { + for row_idx in 0..buffered_batch.num_rows { + pair_indices.push(JoinIndex { + streamed_idx, + batch_idx, + buffered_idx: row_idx, + }); + } + } + + if pair_indices.is_empty() { + return Ok(()); + } + + let streamed_batch = self.streamed_batch.as_ref().unwrap(); + let candidate_batch = self.build_filter_batch(filter, streamed_batch, &pair_indices)?; + + let filtered = + apply_join_filter(filter, &candidate_batch, &pair_indices, &self.join_type)?; + + // Semi: emit the streamed row if any pair passed. + if !filtered.passed_indices.is_empty() { + let idx = &filtered.passed_indices[0]; + self.output_builder + .add_match(idx.streamed_idx, idx.batch_idx, idx.buffered_idx); + } + } else { + // No filter: key match is sufficient for semi join. + if !self.match_group.is_empty() { + self.output_builder.add_match(streamed_idx, 0, 0); + } + } + Ok(()) + } + + /// Produce pairs for anti joins: emit the streamed row if no match passes. + fn produce_anti_pairs(&mut self, streamed_idx: usize) -> Result<()> { + if let Some(ref filter) = self.join_filter { + let mut pair_indices = Vec::new(); + for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { + for row_idx in 0..buffered_batch.num_rows { + pair_indices.push(JoinIndex { + streamed_idx, + batch_idx, + buffered_idx: row_idx, + }); + } + } + + if pair_indices.is_empty() { + // No buffered matches at all => emit for anti. + self.output_builder.add_streamed_null_join(streamed_idx); + return Ok(()); + } + + let streamed_batch = self.streamed_batch.as_ref().unwrap(); + let candidate_batch = self.build_filter_batch(filter, streamed_batch, &pair_indices)?; + + let filtered = + apply_join_filter(filter, &candidate_batch, &pair_indices, &self.join_type)?; + + // Anti: emit streamed rows that had no passing pair. + for &si in &filtered.streamed_null_joins { + self.output_builder.add_streamed_null_join(si); + } + } else { + // No filter: key match means the streamed row is NOT emitted (anti). + // Do nothing. + } + Ok(()) + } + + /// Build a filter candidate batch for the given pairs. + fn build_filter_batch( + &self, + filter: &JoinFilter, + streamed_batch: &RecordBatch, + pair_indices: &[JoinIndex], + ) -> Result { + // We need to combine rows from potentially multiple buffered batches + // into a single batch for filter evaluation. + // First, build streamed and buffered index arrays. + let streamed_indices: Vec = pair_indices + .iter() + .map(|idx| idx.streamed_idx as u32) + .collect(); + let streamed_idx_array = UInt32Array::from(streamed_indices); + + // For the buffered side, we need to build a single batch containing + // all referenced rows. + let buffered_batch = self.collect_buffered_rows(pair_indices)?; + let buffered_indices: Vec = (0..pair_indices.len() as u32).collect(); + let buffered_idx_array = UInt32Array::from(buffered_indices); + + build_filter_candidate_batch( + filter, + streamed_batch, + &buffered_batch, + &streamed_idx_array, + &buffered_idx_array, + ) + } + + /// Collect all referenced buffered rows into a single batch. + fn collect_buffered_rows(&self, pair_indices: &[JoinIndex]) -> Result { + if pair_indices.is_empty() { + // Return an empty batch with the correct schema. + let schema = Arc::clone(self.spill_manager.schema()); + return Ok(RecordBatch::new_empty(schema)); + } + + // Group indices by batch_idx, then take rows and concatenate. + let schema = Arc::clone(self.spill_manager.schema()); + let num_cols = schema.fields().len(); + let mut result_columns: Vec> = vec![Vec::new(); num_cols]; + + let mut current_batch_idx: Option = None; + let mut current_row_indices: Vec = Vec::new(); + + let flush = |batch_idx: usize, + row_indices: &[u32], + result_columns: &mut Vec>, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager| + -> Result<()> { + let batch = match_group.get_batch(batch_idx, spill_manager)?; + let idx_array = UInt32Array::from(row_indices.to_vec()); + for (col_idx, col_parts) in result_columns.iter_mut().enumerate() { + let col = batch.column(col_idx); + let taken = arrow::compute::take(col.as_ref(), &idx_array, None)?; + col_parts.push(taken); + } + Ok(()) + }; + + for idx in pair_indices { + if current_batch_idx == Some(idx.batch_idx) { + current_row_indices.push(idx.buffered_idx as u32); + } else { + if let Some(bi) = current_batch_idx { + flush( + bi, + ¤t_row_indices, + &mut result_columns, + &self.match_group, + &self.spill_manager, + )?; + current_row_indices.clear(); + } + current_batch_idx = Some(idx.batch_idx); + current_row_indices.push(idx.buffered_idx as u32); + } + } + if let Some(bi) = current_batch_idx { + flush( + bi, + ¤t_row_indices, + &mut result_columns, + &self.match_group, + &self.spill_manager, + )?; + } + + // Concatenate column parts. + let columns: Vec = result_columns + .into_iter() + .map(|parts| { + if parts.len() == 1 { + Ok(parts.into_iter().next().unwrap()) + } else { + let refs: Vec<&dyn arrow::array::Array> = + parts.iter().map(|a| a.as_ref()).collect(); + Ok(arrow::compute::concat(&refs)?) + } + }) + .collect::>()?; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + /// Flush accumulated output indices into a RecordBatch. + fn flush_output(&mut self) -> Result { + let streamed_batch = match &self.streamed_batch { + Some(b) => b.clone(), + None => { + // If the streamed batch has been consumed, we might still + // have pending output from DrainUnmatched (buffered null joins). + // Create an empty streamed batch. + let schema = self.output_builder_streamed_schema(); + RecordBatch::new_empty(schema) + } + }; + + self.output_builder.build( + &streamed_batch, + &self.match_group, + &self.spill_manager, + &self.buffered_join_exprs, + ) + } + + /// Get the streamed schema from the output builder (needed for empty batch creation). + fn output_builder_streamed_schema(&self) -> SchemaRef { + // We can derive this from the output schema and join type, + // but for simplicity use the streamed input's schema. + self.streamed_input.schema() + } + + /// Drain remaining rows after one side is exhausted. + fn drain_remaining(&mut self) -> Result<()> { + match self.join_type { + JoinType::Left | JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => { + // Drain remaining streamed rows as null-joined. + if let Some(batch) = &self.streamed_batch { + let num_rows = batch.num_rows(); + for idx in self.streamed_idx..num_rows { + self.output_builder.add_streamed_null_join(idx); + } + } + } + JoinType::Right => { + // Right outer: streamed side is right, so drain remaining. + if let Some(batch) = &self.streamed_batch { + let num_rows = batch.num_rows(); + for idx in self.streamed_idx..num_rows { + self.output_builder.add_streamed_null_join(idx); + } + } + } + _ => {} + } + + // For full outer: drain unmatched buffered rows. + if matches!(self.join_type, JoinType::Full) { + for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { + for row_idx in buffered_batch.unmatched_indices() { + self.output_builder + .add_buffered_null_join(batch_idx, row_idx); + } + } + } + + // Clear streamed state. + self.streamed_batch = None; + self.streamed_join_arrays = None; + self.streamed_idx = 0; + + Ok(()) + } +} + +/// Find the boundary index in a buffered batch where the key changes relative +/// to the streamed key. Returns the number of rows from the start that have +/// the same key as the streamed row at `streamed_idx`. +/// +/// Uses binary search since the buffered side is sorted. +fn find_key_boundary( + streamed_arrays: &[ArrayRef], + streamed_idx: usize, + buffered_arrays: &[ArrayRef], + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let num_rows = buffered_arrays[0].len(); + if num_rows == 0 { + return Ok(0); + } + + // Quick check: if the last row also matches, the entire batch is in the group. + let last_cmp = compare_join_arrays( + streamed_arrays, + streamed_idx, + buffered_arrays, + num_rows - 1, + sort_options, + null_equality, + )?; + if last_cmp == Ordering::Equal { + return Ok(num_rows); + } + + // Binary search for the boundary. + let mut lo = 0usize; + let mut hi = num_rows; + while lo < hi { + let mid = lo + (hi - lo) / 2; + let cmp = compare_join_arrays( + streamed_arrays, + streamed_idx, + buffered_arrays, + mid, + sort_options, + null_equality, + )?; + if cmp == Ordering::Equal { + lo = mid + 1; + } else { + hi = mid; + } + } + Ok(lo) +} + +impl Stream for SortMergeJoinStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let join_time = self.metrics.join_time.clone(); + let timer = join_time.timer(); + let result = self.poll_next_inner(cx); + timer.done(); + result + } +} + +impl RecordBatchStream for SortMergeJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.output_schema) + } +} diff --git a/native/core/src/execution/joins/tests.rs b/native/core/src/execution/joins/tests.rs new file mode 100644 index 0000000000..d916f67747 --- /dev/null +++ b/native/core/src/execution/joins/tests.rs @@ -0,0 +1,370 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, StringArray}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion::common::{NullEquality, Result}; +use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::datasource::source::DataSourceExec; +use datafusion::logical_expr::JoinType; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use futures::StreamExt; + +use super::CometSortMergeJoinExec; + +fn left_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("l_key", DataType::Int32, true), + Field::new("l_val", DataType::Utf8, true), + ])) +} + +fn right_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("r_key", DataType::Int32, true), + Field::new("r_val", DataType::Utf8, true), + ])) +} + +fn make_sorted_batches( + schema: SchemaRef, + keys: Vec>, + vals: Vec>, +) -> Vec { + vec![RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(keys)), + Arc::new(StringArray::from(vals)), + ], + ) + .unwrap()] +} + +async fn execute_join( + join_type: JoinType, + left_batches: Vec, + right_batches: Vec, +) -> Result> { + let l_schema = left_batches[0].schema(); + let r_schema = right_batches[0].schema(); + + let left = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + l_schema, + None, + )?))); + let right = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + r_schema, + None, + )?))); + + let on = vec![( + Arc::new(datafusion::physical_expr::expressions::Column::new( + "l_key", 0, + )) as Arc, + Arc::new(datafusion::physical_expr::expressions::Column::new( + "r_key", 0, + )) as Arc, + )]; + + let join = CometSortMergeJoinExec::try_new( + left, + right, + on, + None, + join_type, + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + )?; + + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + let stream = join.execute(0, task_ctx)?; + + let mut results = Vec::new(); + let mut stream = stream; + while let Some(batch) = stream.next().await { + results.push(batch?); + } + Ok(results) +} + +fn total_row_count(batches: &[RecordBatch]) -> usize { + batches.iter().map(|b| b.num_rows()).sum() +} + +#[tokio::test] +async fn test_inner_join_basic() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(3), Some(4)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::Inner, left, right).await?; + assert_eq!(total_row_count(&result), 2); + Ok(()) +} + +#[tokio::test] +async fn test_inner_join_with_duplicates() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(1), Some(2)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(1), Some(1), Some(3)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::Inner, left, right).await?; + assert_eq!(total_row_count(&result), 4); + Ok(()) +} + +#[tokio::test] +async fn test_inner_join_null_keys_skipped() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![None, Some(1), Some(2)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![None, Some(1), Some(2)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::Inner, left, right).await?; + assert_eq!(total_row_count(&result), 2); + Ok(()) +} + +#[tokio::test] +async fn test_inner_join_empty_result() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2)], + vec![Some("a"), Some("b")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(3), Some(4)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::Inner, left, right).await?; + assert_eq!(total_row_count(&result), 0); + Ok(()) +} + +// --- Outer join tests --- + +#[tokio::test] +async fn test_left_outer_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(4)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::Left, left, right).await?; + assert_eq!(total_row_count(&result), 3); + Ok(()) +} + +#[tokio::test] +async fn test_left_outer_null_keys() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![None, Some(1)], + vec![Some("a"), Some("b")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(1), Some(2)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::Left, left, right).await?; + assert_eq!(total_row_count(&result), 2); + Ok(()) +} + +#[tokio::test] +async fn test_right_outer_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(3)], + vec![Some("a"), Some("b")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::Right, left, right).await?; + assert_eq!(total_row_count(&result), 3); + Ok(()) +} + +#[tokio::test] +async fn test_full_outer_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2)], + vec![Some("a"), Some("b")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(3)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::Full, left, right).await?; + assert_eq!(total_row_count(&result), 3); + Ok(()) +} + +// --- Semi/Anti join tests --- + +#[tokio::test] +async fn test_left_semi_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(3), Some(4)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::LeftSemi, left, right).await?; + assert_eq!(total_row_count(&result), 2); + // Semi join should only output left columns + assert_eq!(result[0].num_columns(), 2); + Ok(()) +} + +#[tokio::test] +async fn test_left_anti_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(4)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::LeftAnti, left, right).await?; + assert_eq!(total_row_count(&result), 2); + Ok(()) +} + +// --- Spill test --- + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_inner_join_with_spill() -> Result<()> { + use datafusion::execution::runtime_env::RuntimeEnvBuilder; + + let l_schema = left_schema(); + let r_schema = right_schema(); + + let left_batches = make_sorted_batches( + Arc::clone(&l_schema), + vec![Some(1), Some(1), Some(1), Some(2), Some(2)], + vec![Some("a"), Some("b"), Some("c"), Some("d"), Some("e")], + ); + let right_batches = make_sorted_batches( + Arc::clone(&r_schema), + vec![Some(1), Some(1), Some(1), Some(2)], + vec![Some("w"), Some("x"), Some("y"), Some("z")], + ); + + let left_exec = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + l_schema, + None, + )?))); + let right_exec = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + r_schema, + None, + )?))); + + let on = vec![( + Arc::new(datafusion::physical_expr::expressions::Column::new( + "l_key", 0, + )) as Arc, + Arc::new(datafusion::physical_expr::expressions::Column::new( + "r_key", 0, + )) as Arc, + )]; + + let join = CometSortMergeJoinExec::try_new( + left_exec, + right_exec, + on, + None, + JoinType::Inner, + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + )?; + + let config = datafusion::prelude::SessionConfig::new().with_batch_size(2); + let runtime = Arc::new( + RuntimeEnvBuilder::new() + .with_memory_limit(1024, 1.0) + .build()?, + ); + let ctx = SessionContext::new_with_config_rt(config, runtime); + let task_ctx = ctx.task_ctx(); + let mut stream = join.execute(0, task_ctx)?; + + let mut results = Vec::new(); + while let Some(batch) = stream.next().await { + results.push(batch?); + } + // 3*3 for key=1 + 2*1 for key=2 = 11 + assert_eq!(total_row_count(&results), 11); + Ok(()) +} diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index f556fce41c..01aedfc62b 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -20,6 +20,7 @@ pub mod columnar_to_row; pub mod expressions; pub mod jni_api; pub(crate) mod metrics; +pub(crate) mod joins; pub mod operators; pub(crate) mod planner; pub mod serde; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0f96c829e7..74524506d1 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -21,8 +21,10 @@ pub mod expression_registry; pub mod macros; pub mod operator_registry; +use crate::execution::joins::CometSortMergeJoinExec; use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::IcebergScanExec; +use crate::execution::spark_config::{SparkConfig, COMET_EXEC_SMJ_USE_NATIVE}; use crate::execution::{ expressions::subquery::Subquery, operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, @@ -163,6 +165,7 @@ pub struct PhysicalPlanner { partition: i32, session_ctx: Arc, query_context_registry: Arc, + spark_config: HashMap, } impl Default for PhysicalPlanner { @@ -178,6 +181,7 @@ impl PhysicalPlanner { session_ctx, partition, query_context_registry: datafusion_comet_spark_expr::create_query_context_map(), + spark_config: HashMap::new(), } } @@ -187,6 +191,14 @@ impl PhysicalPlanner { partition: self.partition, session_ctx: Arc::clone(&self.session_ctx), query_context_registry: Arc::clone(&self.query_context_registry), + spark_config: self.spark_config, + } + } + + pub fn with_spark_config(self, spark_config: HashMap) -> Self { + Self { + spark_config, + ..self } } @@ -1625,43 +1637,19 @@ impl PhysicalPlanner { let left = Arc::clone(&join_params.left.native_plan); let right = Arc::clone(&join_params.right.native_plan); - let join = Arc::new(SortMergeJoinExec::try_new( - Arc::clone(&left), - Arc::clone(&right), - join_params.join_on, - join_params.join_filter, - join_params.join_type, - sort_options, - // null doesn't equal to null in Spark join key. If the join key is - // `EqualNullSafe`, Spark will rewrite it during planning. - NullEquality::NullEqualsNothing, - )?); - - if join.filter.is_some() { - // SMJ with join filter produces lots of tiny batches - let coalesce_batches: Arc = - Arc::new(CoalesceBatchesExec::new( - Arc::::clone(&join), - self.session_ctx - .state() - .config_options() - .execution - .batch_size, - )); - Ok(( - scans, - shuffle_scans, - Arc::new(SparkPlan::new_with_additional( - spark_plan.plan_id, - coalesce_batches, - vec![ - Arc::clone(&join_params.left), - Arc::clone(&join_params.right), - ], - vec![join], - )), - )) - } else { + let use_native_smj = self.spark_config.get_bool(COMET_EXEC_SMJ_USE_NATIVE); + + if use_native_smj { + let join: Arc = + Arc::new(CometSortMergeJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + join_params.join_on, + join_params.join_filter, + join_params.join_type, + sort_options, + NullEquality::NullEqualsNothing, + )?); Ok(( scans, shuffle_scans, @@ -1674,6 +1662,57 @@ impl PhysicalPlanner { ], )), )) + } else { + let join = Arc::new(SortMergeJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + join_params.join_on, + join_params.join_filter, + join_params.join_type, + sort_options, + // null doesn't equal to null in Spark join key. If the join key is + // `EqualNullSafe`, Spark will rewrite it during planning. + NullEquality::NullEqualsNothing, + )?); + + if join.filter.is_some() { + // SMJ with join filter produces lots of tiny batches + let coalesce_batches: Arc = + Arc::new(CoalesceBatchesExec::new( + Arc::::clone(&join), + self.session_ctx + .state() + .config_options() + .execution + .batch_size, + )); + Ok(( + scans, + shuffle_scans, + Arc::new(SparkPlan::new_with_additional( + spark_plan.plan_id, + coalesce_batches, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + vec![join], + )), + )) + } else { + Ok(( + scans, + shuffle_scans, + Arc::new(SparkPlan::new( + spark_plan.plan_id, + join, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + )), + )) + } } } OpStruct::HashJoin(join) => { diff --git a/native/core/src/execution/spark_config.rs b/native/core/src/execution/spark_config.rs index 277c0eb43b..36f2fd417f 100644 --- a/native/core/src/execution/spark_config.rs +++ b/native/core/src/execution/spark_config.rs @@ -22,6 +22,7 @@ pub(crate) const COMET_DEBUG_ENABLED: &str = "spark.comet.debug.enabled"; pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.native.enabled"; pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize"; pub(crate) const COMET_DEBUG_MEMORY: &str = "spark.comet.debug.memory"; +pub(crate) const COMET_EXEC_SMJ_USE_NATIVE: &str = "spark.comet.exec.sortMergeJoin.useNative"; pub(crate) const SPARK_EXECUTOR_CORES: &str = "spark.executor.cores"; pub(crate) trait SparkConfig {