diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs index 655bee3511..c07978337b 100644 --- a/native/shuffle/src/partitioners/multi_partition.rs +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -203,6 +203,19 @@ impl MultiPartitionShuffleRepartitioner { return Ok(()); } + // For zero-column schemas (e.g. COUNT queries), assign all rows to partition 0. + // The actual index values don't matter — the consumer only uses indices.len() + // as the row count for zero-column batches. + if input.num_columns() == 0 { + let num_rows = input.num_rows(); + self.metrics.baseline.record_output(num_rows); + let batch_idx = self.buffered_batches.len() as u32; + self.buffered_batches.push(input); + let indices = &mut self.partition_indices[0]; + indices.resize(indices.len() + num_rows, (batch_idx, 0)); + return Ok(()); + } + if input.num_rows() > self.batch_size { return Err(DataFusionError::Internal( "Input batch size exceeds configured batch size. Call `insert_batch` instead." diff --git a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs index 8309a8ed4a..b97b6f6923 100644 --- a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs +++ b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::RecordBatch; +use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::compute::interleave_record_batch; use datafusion::common::DataFusionError; @@ -97,15 +97,20 @@ impl Iterator for PartitionedBatchIterator<'_> { let indices_end = std::cmp::min(self.pos + self.batch_size, self.indices.len()); let indices = &self.indices[self.pos..indices_end]; - match interleave_record_batch(&self.record_batches, indices) { - Ok(batch) => { - self.pos = indices_end; - Some(Ok(batch)) - } - Err(e) => Some(Err(DataFusionError::ArrowError( - Box::from(e), - Some(DataFusionError::get_back_trace()), - ))), - } + + // interleave_record_batch requires at least one column or an explicit row count. + // For zero-column batches (e.g. COUNT queries), create the batch directly. + let schema = self.record_batches[0].schema(); + let result = if schema.fields().is_empty() { + let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); + RecordBatch::try_new_with_options(schema, vec![], &options) + } else { + interleave_record_batch(&self.record_batches, indices) + }; + + self.pos = indices_end; + Some(result.map_err(|e| { + DataFusionError::ArrowError(Box::from(e), Some(DataFusionError::get_back_trace())) + })) } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index efb5fbca8a..02bf03fa2c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -474,4 +474,34 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } } + + test("native datafusion scan - repartition count") { + withTempPath { dir => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark + .range(1000) + .selectExpr("id", "concat('name_', id) as name") + .repartition(100) + .write + .parquet(dir.toString) + } + withSQLConf( + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION, + CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_ENABLED.key -> "true") { + val testDF = spark.read.parquet(dir.toString).repartition(10) + // Verify CometShuffleExchangeExec is in the plan + assert( + find(testDF.queryExecution.executedPlan) { + case _: CometShuffleExchangeExec => true + case _ => false + }.isDefined, + "Expected CometShuffleExchangeExec in the plan") + // Actual validation, no crash + val count = testDF.count() + assert(count == 1000) + // Ensure test df evaluated by Comet + checkSparkAnswerAndOperator(testDF) + } + } + } }