Skip to content
13 changes: 13 additions & 0 deletions native/shuffle/src/partitioners/multi_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,19 @@ impl MultiPartitionShuffleRepartitioner {
return Ok(());
}

// For zero-column schemas (e.g. COUNT queries), assign all rows to partition 0.
// The actual index values don't matter — the consumer only uses indices.len()
// as the row count for zero-column batches.
if input.num_columns() == 0 {
let num_rows = input.num_rows();
self.metrics.baseline.record_output(num_rows);
let batch_idx = self.buffered_batches.len() as u32;
self.buffered_batches.push(input);
let indices = &mut self.partition_indices[0];
indices.resize(indices.len() + num_rows, (batch_idx, 0));
return Ok(());
}

if input.num_rows() > self.batch_size {
return Err(DataFusionError::Internal(
"Input batch size exceeds configured batch size. Call `insert_batch` instead."
Expand Down
27 changes: 16 additions & 11 deletions native/shuffle/src/partitioners/partitioned_batch_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::RecordBatch;
use arrow::array::{RecordBatch, RecordBatchOptions};
use arrow::compute::interleave_record_batch;
use datafusion::common::DataFusionError;

Expand Down Expand Up @@ -97,15 +97,20 @@ impl Iterator for PartitionedBatchIterator<'_> {

let indices_end = std::cmp::min(self.pos + self.batch_size, self.indices.len());
let indices = &self.indices[self.pos..indices_end];
match interleave_record_batch(&self.record_batches, indices) {
Ok(batch) => {
self.pos = indices_end;
Some(Ok(batch))
}
Err(e) => Some(Err(DataFusionError::ArrowError(
Box::from(e),
Some(DataFusionError::get_back_trace()),
))),
}

// interleave_record_batch requires at least one column or an explicit row count.
// For zero-column batches (e.g. COUNT queries), create the batch directly.
let schema = self.record_batches[0].schema();
let result = if schema.fields().is_empty() {
let options = RecordBatchOptions::new().with_row_count(Some(indices.len()));
RecordBatch::try_new_with_options(schema, vec![], &options)
} else {
interleave_record_batch(&self.record_batches, indices)
};

self.pos = indices_end;
Some(result.map_err(|e| {
DataFusionError::ArrowError(Box::from(e), Some(DataFusionError::get_back_trace()))
}))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -474,4 +474,34 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper
}
}
}

test("native datafusion scan - repartition count") {
withTempPath { dir =>
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
spark
.range(1000)
.selectExpr("id", "concat('name_', id) as name")
.repartition(100)
.write
.parquet(dir.toString)
}
withSQLConf(
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION,
CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_ENABLED.key -> "true") {
Comment on lines +488 to +490
Copy link
Copy Markdown
Member

@andygrove andygrove Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the issue specific to this combination of scan and shuffle?

interleave_record_batch is used in other parts of the shuffle codebase so those may also need updating?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like native_datafusion is used here just to easily force native shuffle.

I am confused by the comment For zero-column batches (e.g. COUNT queries) when the test isn't using a count.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to reproduce the crash with both native_datafusion and native_iceberg_compat in combination with native shuffle. the sample query for repro and test case is

spark.read.parquet("hdfs://location").repartition(50).count()

perhaps test can be slightly improved, if it confuses

val testDF = spark.read.parquet(dir.toString).repartition(10)
// Verify CometShuffleExchangeExec is in the plan
assert(
find(testDF.queryExecution.executedPlan) {
case _: CometShuffleExchangeExec => true
case _ => false
}.isDefined,
"Expected CometShuffleExchangeExec in the plan")
// Actual validation, no crash
val count = testDF.count()
assert(count == 1000)
// Ensure test df evaluated by Comet
checkSparkAnswerAndOperator(testDF)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no usage of count() here. Is this intentional ?
Another way could be something like:

val testDF = spark.read.parquet(dir.toString).repartition(10)
val countDF = testDF.selectExpr("count(*) as cnt")
val count = countDF.collect().head.getLong(0)
assert(count == 1000)
checkSparkAnswerAndOperator(countDF)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is intentional, yes. Count returns just Long, I can't really inject in the middle to check native plan, so do it I check that at least everything before count is native which works for this case

}
}
}
}
Loading