From 376337058c411eead1dbeaee334ed30086360fd7 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 21 May 2026 16:44:09 -0400 Subject: [PATCH 01/12] enable CometLocalTableScanExec by default --- spark/src/main/scala/org/apache/comet/CometConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index fdd1ae2073..faee23a8eb 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -273,7 +273,7 @@ object CometConf extends ShimCometConf { val COMET_EXEC_TAKE_ORDERED_AND_PROJECT_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("takeOrderedAndProject", defaultValue = true) val COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED: ConfigEntry[Boolean] = - createExecEnabledConfig("localTableScan", defaultValue = false) + createExecEnabledConfig("localTableScan", defaultValue = true) val COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.columnarToRow.native.enabled") From 810e5d5c38d106fae4a3bff6563137e3a5fcfd01 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 07:46:49 -0400 Subject: [PATCH 02/12] add NullType to toArrowType --- .../main/scala/org/apache/spark/sql/comet/util/Utils.scala | 1 + .../test/scala/org/apache/comet/exec/CometExecSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 783367c054..4605e641f1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -148,6 +148,7 @@ object Utils extends CometTypeShim with Logging { } case TimestampNTZType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case NullType => ArrowType.Null.INSTANCE case dt if isTimeType(dt) => new ArrowType.Time(TimeUnit.NANOSECOND, 64) case _ => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 16601d056b..8bf00de20c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3925,6 +3925,13 @@ class CometExecSuite extends CometTestBase { } } + test("CometLocalTableScanExec handles NullType column") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val df = spark.sql("SELECT * FROM VALUES ('a', null), ('b', null) AS t(x, y)") + checkSparkAnswer(df) + } + } + test("Native_datafusion reports correct files and bytes scanned") { val inputFiles = 2 From 174c939540a0be46d184f8e8b9d57a52ceae722a Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 08:02:09 -0400 Subject: [PATCH 03/12] add NullType to shuffles --- native/shuffle/src/spark_unsafe/row.rs | 15 ++++++++++-- .../shuffle/CometShuffleExchangeExec.scala | 6 ++--- .../exec/CometColumnarShuffleSuite.scala | 23 ++++--------------- .../comet/exec/CometNativeShuffleSuite.scala | 6 +++++ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/native/shuffle/src/spark_unsafe/row.rs b/native/shuffle/src/spark_unsafe/row.rs index ec0903bc56..6ffe9d0b6e 100644 --- a/native/shuffle/src/spark_unsafe/row.rs +++ b/native/shuffle/src/spark_unsafe/row.rs @@ -28,8 +28,8 @@ use arrow::array::{ builder::{ ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, - Int64Builder, Int8Builder, ListBuilder, MapBuilder, StringBuilder, StringDictionaryBuilder, - StructBuilder, TimestampMicrosecondBuilder, + Int64Builder, Int8Builder, ListBuilder, MapBuilder, NullBuilder, StringBuilder, + StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder, }, types::Int32Type, Array, ArrayRef, RecordBatch, RecordBatchOptions, @@ -267,6 +267,10 @@ pub(super) fn append_field( append_field_to_builder!(Date32Builder, |builder: &mut Date32Builder| builder .append_value(row.get_date(idx))); } + DataType::Null => { + let field_builder = get_field_builder!(struct_builder, NullBuilder, idx); + field_builder.append_null(); + } DataType::Timestamp(TimeUnit::Microsecond, _) => { append_field_to_builder!( TimestampMicrosecondBuilder, @@ -1148,6 +1152,12 @@ fn append_columns( .append_value(row.get_date(idx)) ); } + DataType::Null => { + let null_builder = downcast_builder_ref!(NullBuilder, builder); + for _ in row_start..row_end { + null_builder.append_null(); + } + } DataType::Timestamp(TimeUnit::Microsecond, _) => { append_column_to_builder!( TimestampMicrosecondBuilder, @@ -1252,6 +1262,7 @@ fn make_builders( } } DataType::Date32 => Box::new(Date32Builder::with_capacity(row_num)), + DataType::Null => Box::new(NullBuilder::new()), DataType::Timestamp(TimeUnit::Microsecond, _) => { Box::new(TimestampMicrosecondBuilder::with_capacity(row_num).with_data_type(dt.clone())) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 493c20f8b7..16e7a8b774 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} @@ -364,7 +364,7 @@ object CometShuffleExchangeExec def supportedSerializableDataType(dt: DataType): Boolean = dt match { case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | - _: TimestampNTZType | _: DecimalType | _: DateType => + _: TimestampNTZType | _: DecimalType | _: DateType | _: NullType => true case StructType(fields) => fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) @@ -487,7 +487,7 @@ object CometShuffleExchangeExec def supportedSerializableDataType(dt: DataType): Boolean = dt match { case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | - _: TimestampNTZType | _: DecimalType | _: DateType => + _: TimestampNTZType | _: DecimalType | _: DateType | _: NullType => true case StructType(fields) => fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) && diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 86c6a6aa4b..70d427972a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -22,14 +22,13 @@ package org.apache.comet.exec import java.nio.file.{Files, Paths} import scala.reflect.runtime.universe._ -import scala.util.Random import org.scalactic.source.Position import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} -import org.apache.spark.sql.{CometTestBase, DataFrame, RandomDataGenerator, Row} +import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec @@ -94,22 +93,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar """.stripMargin)) } - test("Fallback to Spark for unsupported input besides ordering") { - val dataGenerator = RandomDataGenerator - .forType( - dataType = NullType, - nullable = true, - new Random(System.nanoTime()), - validJulianDatetime = false) - .get - - val schema = new StructType() - .add("index", IntegerType, nullable = false) - .add("col", NullType, nullable = true) - val rdd = - spark.sparkContext.parallelize((1 to 20).map(i => Row(i, dataGenerator()))) - val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - checkSparkAnswer(df) + test("columnar shuffle with NullType passthrough column") { + val df = sql("SELECT x, y FROM VALUES ('a', null), ('b', null), ('c', null) AS t(x, y)") + val shuffled = df.repartition(2, $"x") + checkShuffleAnswer(shuffled, 1) } test("columnar shuffle on nested struct including nulls") { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index e0ef1df1f4..60637102f0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -218,6 +218,12 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } + test("native shuffle with NullType passthrough column") { + val df = spark.sql("SELECT x, y FROM VALUES ('a', null), ('b', null), ('c', null) AS t(x, y)") + val shuffled = df.repartition(2, $"x") + checkShuffleAnswer(shuffled, 1) + } + test("fix: Comet native shuffle with binary data") { withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { val df = sql("SELECT cast(cast(_1 as STRING) as BINARY) as binary, _2 FROM tbl") From 3790c1022b71cbe47cbde8e4553304373cfde1f4 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 09:12:23 -0400 Subject: [PATCH 04/12] fix windowexec test and nulltype. fix timetype issues --- .../sql/comet/CometLocalTableScanExec.scala | 32 +++++++++++++++++-- .../apache/comet/exec/CometExecSuite.scala | 18 +++++++++++ .../comet/exec/CometWindowExecSuite.scala | 10 +----- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 622168bcc9..0a836bc389 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.comet +import scala.collection.mutable.ListBuffer + import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -27,11 +29,13 @@ import org.apache.spark.sql.comet.CometLocalTableScanExec.createMetricsIterator import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types.{DataType, NullType} import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.base.Objects -import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.{CometConf, ConfigEntry, DataTypeSupport} +import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.operator.CometSink @@ -104,7 +108,7 @@ case class CometLocalTableScanExec( override def hashCode(): Int = Objects.hashCode(originalPlan, originalPlan.schema, output) } -object CometLocalTableScanExec extends CometSink[LocalTableScanExec] { +object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTypeSupport { // uses CometArrowConverters, which re-uses arrays override def isFfiSafe: Boolean = false @@ -112,6 +116,30 @@ object CometLocalTableScanExec extends CometSink[LocalTableScanExec] { override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED) + // CometArrowConverters / ArrowWriter support NullType (via Utils.toArrowType + + // NullWriter). Other types not on DataTypeSupport's allow list (e.g. TimeType, + // intervals) lack ArrowWriter coverage and must fall back to Spark. + override def isTypeSupported( + dt: DataType, + name: String, + fallbackReasons: ListBuffer[String]): Boolean = dt match { + case _: NullType => true + case _ => super.isTypeSupported(dt, name, fallbackReasons) + } + + override def convert( + op: LocalTableScanExec, + builder: Operator.Builder, + childOp: Operator*): Option[Operator] = { + val fallbackReasons = new ListBuffer[String]() + if (!isSchemaSupported(op.schema, fallbackReasons)) { + withInfo(op, fallbackReasons.mkString("; ")) + None + } else { + super.convert(op, builder, childOp: _*) + } + } + override def createExec(nativeOp: Operator, op: LocalTableScanExec): CometNativeExec = { CometScanWrapper(nativeOp, CometLocalTableScanExec(op, op.rows, op.output)) } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 8bf00de20c..c3e903f883 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3932,6 +3932,24 @@ class CometExecSuite extends CometTestBase { } } + test("CometLocalTableScanExec handles NullType nested in struct/array/map") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + checkSparkAnswer( + spark.sql("SELECT named_struct('a', 1, 'b', null) AS s, array(null, null) AS a, " + + "map('k', null) AS m")) + } + } + + test("CometLocalTableScanExec falls back when schema contains TimeType") { + assume( + org.apache.comet.CometSparkSessionExtensions.isSpark41Plus, + "TimeType requires Spark 4.1+") + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val df = spark.sql("SELECT TIME '12:34:56' AS t, 1 AS id") + checkSparkAnswer(df) + } + } + test("Native_datafusion reports correct files and bytes scanned") { val inputFiles = 2 diff --git a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala index 544cd91bd2..a9fdc96231 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala @@ -108,15 +108,7 @@ class CometWindowExecSuite extends CometTestBase { val cometShuffles = collect(df2.queryExecution.executedPlan) { case _: CometShuffleExchangeExec => true } - if (shuffleMode == "jvm" || shuffleMode == "auto") { - assert(cometShuffles.length == 1) - } else { - // we fall back to Spark for shuffle because we do not support - // native shuffle with a LocalTableScan input, and we do not fall - // back to Comet columnar shuffle due to - // https://github.com/apache/datafusion-comet/issues/1248 - assert(cometShuffles.isEmpty) - } + assert(cometShuffles.length == 1) } } } From 18cd14b0940eeb9a95e1a22e681b5a1c84f5b650 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 10:03:50 -0400 Subject: [PATCH 05/12] Fix TimeType test. --- .../test/scala/org/apache/comet/exec/CometExecSuite.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index c3e903f883..71a30adecd 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3945,8 +3945,12 @@ class CometExecSuite extends CometTestBase { org.apache.comet.CometSparkSessionExtensions.isSpark41Plus, "TimeType requires Spark 4.1+") withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { - val df = spark.sql("SELECT TIME '12:34:56' AS t, 1 AS id") - checkSparkAnswer(df) + // Spark 4.1's row encoder cannot serialize TIME columns to the JVM, so we cannot + // collect rows. count() exercises the LocalRelation -> scan path without materializing + // the TIME value, which is sufficient to verify the fallback (without the fallback the + // CometLocalTableScanExec ArrowWriter would crash on TimeType). + val cnt = spark.sql("SELECT TIME '12:34:56' AS t, 1 AS id").count() + assert(cnt == 1) } } From fc40d59a81abfc70c91ab1fb5d2453e00fa11e8e Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 16:59:33 -0400 Subject: [PATCH 06/12] fix null value type in map in native shuffle --- native/shuffle/src/spark_unsafe/list.rs | 8 +++++++- .../comet/exec/CometColumnarShuffleSuite.scala | 6 ++++++ .../org/apache/comet/exec/CometExecSuite.scala | 16 +++++++++------- .../comet/exec/CometNativeShuffleSuite.scala | 6 ++++++ 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/native/shuffle/src/spark_unsafe/list.rs b/native/shuffle/src/spark_unsafe/list.rs index 3fea3fadeb..14f9feb843 100644 --- a/native/shuffle/src/spark_unsafe/list.rs +++ b/native/shuffle/src/spark_unsafe/list.rs @@ -24,7 +24,7 @@ use arrow::array::{ builder::{ ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, - ListBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder, + ListBuilder, NullBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder, }, MapBuilder, }; @@ -393,6 +393,12 @@ pub fn append_to_builder( let builder = downcast_builder_ref!(Date32Builder, builder); array.append_dates_to_builder::(builder); } + DataType::Null => { + let builder = downcast_builder_ref!(NullBuilder, builder); + for _ in 0..array.get_num_elements() { + builder.append_null(); + } + } DataType::Binary => { add_values!( BinaryBuilder, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 70d427972a..b0be2b90ac 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -99,6 +99,12 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar checkShuffleAnswer(shuffled, 1) } + test("columnar shuffle with Map[_, NullType] column") { + val df = sql("SELECT id, map(id, null) AS m FROM VALUES (1), (2), (3) AS t(id)") + val shuffled = df.repartition(2, $"id") + checkShuffleAnswer(shuffled, 1) + } + test("columnar shuffle on nested struct including nulls") { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 71a30adecd..8c8c19bb9c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3944,13 +3944,15 @@ class CometExecSuite extends CometTestBase { assume( org.apache.comet.CometSparkSessionExtensions.isSpark41Plus, "TimeType requires Spark 4.1+") - withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { - // Spark 4.1's row encoder cannot serialize TIME columns to the JVM, so we cannot - // collect rows. count() exercises the LocalRelation -> scan path without materializing - // the TIME value, which is sufficient to verify the fallback (without the fallback the - // CometLocalTableScanExec ArrowWriter would crash on TimeType). - val cnt = spark.sql("SELECT TIME '12:34:56' AS t, 1 AS id").count() - assert(cnt == 1) + // spark.sql.timeType.enabled defaults to Utils.isTesting; enable explicitly so the + // row encoder accepts TIME (matches Spark's own TimeFunctionsSuiteBase setup). + withSQLConf( + "spark.sql.timeType.enabled" -> "true", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + // VALUES folds to a LocalRelation, exercising the CometLocalTableScanExec convert + // path; the TimeType column should drive the schema-level fallback. + val df = spark.sql("SELECT * FROM VALUES (TIME '12:34:56'), (TIME '01:02:03') AS t(c)") + checkSparkAnswer(df) } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 60637102f0..b34e75d137 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -224,6 +224,12 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper checkShuffleAnswer(shuffled, 1) } + test("native shuffle with Map[_, NullType] column") { + val df = spark.sql("SELECT id, map(id, null) AS m FROM VALUES (1), (2), (3) AS t(id)") + val shuffled = df.repartition(2, $"id") + checkShuffleAnswer(shuffled, 1) + } + test("fix: Comet native shuffle with binary data") { withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { val df = sql("SELECT cast(cast(_1 as STRING) as BINARY) as binary, _2 FROM tbl") From 8c088a732cb9301e3d1832926c0c676094a65ba8 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 26 May 2026 08:29:31 -0400 Subject: [PATCH 07/12] avoid reuse in LocalTableScanExec --- .../sql/comet/CometLocalTableScanExec.scala | 41 +++++++------------ .../arrow/CometArrowConverters.scala | 33 +++++++++++++++ 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 0a836bc389..6ada259e82 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -25,7 +25,6 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.comet.CometLocalTableScanExec.createMetricsIterator import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -68,19 +67,24 @@ case class CometLocalTableScanExec( } override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numInputRows = longMetric("numOutputRows") + val numOutputRows = longMetric("numOutputRows") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - // Use UTC to match native side expectations. See CometSparkToColumnarExec. - val timeZoneId = "UTC" - rdd.mapPartitionsInternal { sparkBatches => + val schema = originalPlan.schema + // Native side asserts Timestamp(Microsecond, Some("UTC")). See COMET-2720. + rdd.mapPartitionsInternal { rowIter => val context = TaskContext.get() - val batches = CometArrowConverters.rowToArrowBatchIter( - sparkBatches, - originalPlan.schema, + // Non-Comet JVM consumers (e.g. Iceberg writers) may retain batches across next() + // calls, so each batch must own independent Arrow buffers. + val batches = CometArrowConverters.rowToArrowBatchIterNoReuse( + rowIter, + schema, maxRecordsPerBatch, - timeZoneId, + "UTC", context) - createMetricsIterator(batches, numInputRows) + batches.map { batch => + numOutputRows.add(batch.numRows()) + batch + } } } @@ -110,9 +114,6 @@ case class CometLocalTableScanExec( object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTypeSupport { - // uses CometArrowConverters, which re-uses arrays - override def isFfiSafe: Boolean = false - override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED) @@ -143,18 +144,4 @@ object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTy override def createExec(nativeOp: Operator, op: LocalTableScanExec): CometNativeExec = { CometScanWrapper(nativeOp, CometLocalTableScanExec(op, op.rows, op.output)) } - - private def createMetricsIterator( - it: Iterator[ColumnarBatch], - numInputRows: SQLMetric): Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { - override def hasNext: Boolean = it.hasNext - - override def next(): ColumnarBatch = { - val batch = it.next() - numInputRows.add(batch.numRows()) - batch - } - } - } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala index 6d52078181..0e7d0d9a8d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala @@ -140,6 +140,39 @@ object CometArrowConverters extends Logging { new RowToArrowBatchIter(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) } + /** + * Use when the downstream consumer may retain batches across `next()` calls (e.g. non-Comet JVM + * columnar sinks like Iceberg writers). Each batch owns independent Arrow buffers. + */ + def rowToArrowBatchIterNoReuse( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + context: TaskContext): Iterator[ColumnarBatch] = { + val arrowSchema = Utils.toArrowSchema(schema, timeZoneId) + val allocator = + CometArrowAllocator.newChildAllocator("rowToArrowBatchIterNoReuse", 0, Long.MaxValue) + Option(context).foreach(_.addTaskCompletionListener[Unit](_ => allocator.close())) + + new Iterator[ColumnarBatch] { + override def hasNext: Boolean = rowIter.hasNext + + override def next(): ColumnarBatch = { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val writer = ArrowWriter.create(root) + var rowCount = 0L + while (rowIter.hasNext && + (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + writer.write(rowIter.next()) + rowCount += 1 + } + writer.finish() + NativeUtil.rootAsBatch(root) + } + } + } + private[sql] class ColumnBatchToArrowBatchIter( colBatch: ColumnarBatch, schema: StructType, From bd04fb4d3a0c9039d3e020493f5dc450ef1de839 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 12:18:56 -0400 Subject: [PATCH 08/12] Replace Comet's bespoke CometBatchIterator JNI input path with the canonical Arrow C Stream Interface (JVM Data.exportArrayStream <-> native ArrowArrayStreamReader), eliminating the per-batch FFI deep copy and the arrow_ffi_safe flag. --- native/core/src/execution/operators/scan.rs | 185 ++++------------- native/core/src/execution/planner.rs | 51 +++-- native/core/src/execution/utils.rs | 35 +--- ...atch_iterator.rs => arrow_array_stream.rs} | 39 ++-- native/jni-bridge/src/lib.rs | 11 +- native/proto/src/proto/operator.proto | 2 - .../org/apache/comet/CometBatchIterator.java | 93 --------- .../org/apache/comet/CometExecIterator.scala | 15 +- .../operator/CometDataWritingCommand.scala | 1 - .../comet/serde/operator/CometSink.scala | 4 - .../apache/spark/sql/comet/CometExecRDD.scala | 65 +++--- .../spark/sql/comet/CometExecUtils.scala | 15 +- .../sql/comet/CometLocalTableScanExec.scala | 66 +++++-- .../sql/comet/CometNativeWriteExec.scala | 10 +- .../sql/comet/CometSparkToColumnarExec.scala | 138 +++++++------ .../CometTakeOrderedAndProjectExec.scala | 23 ++- .../arrow/ColumnarBatchArrowReader.scala | 84 ++++++++ .../arrow/CometArrowConverters.scala | 186 ++++-------------- .../arrow/CometNativeArrowSource.scala | 184 +++++++++++++++++ .../execution/arrow/RowArrowReader.scala | 69 +++++++ .../arrow/SparkColumnarArrowReader.scala | 97 +++++++++ .../shuffle/CometNativeShuffleWriter.scala | 10 +- .../apache/spark/sql/comet/operators.scala | 114 ++++++++--- .../apache/spark/sql/comet/util/Utils.scala | 10 + .../org/apache/comet/CometNativeSuite.scala | 18 +- .../apache/comet/exec/CometExecSuite.scala | 23 +++ .../exec/CometNativeColumnarToRowSuite.scala | 9 +- 27 files changed, 924 insertions(+), 633 deletions(-) rename native/jni-bridge/src/{batch_iterator.rs => arrow_array_stream.rs} (57%) delete mode 100644 spark/src/main/java/org/apache/comet/CometBatchIterator.java create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index e318d9e66b..bad349e4d7 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -15,19 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::operators::{copy_array, copy_or_unpack_array, CopyMode}; -use crate::{ - errors::CometError, - execution::{ - operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, utils::SparkArrowConvert, - }, - jvm_bridge::JVMClasses, -}; -use arrow::array::{make_array, ArrayData, ArrayRef, RecordBatch, RecordBatchOptions}; +use crate::execution::operators::{copy_or_unpack_array, CopyMode}; +use crate::{errors::CometError, execution::planner::TEST_EXEC_CONTEXT_ID}; +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::ffi::FFI_ArrowArray; -use arrow::ffi::FFI_ArrowSchema; +use arrow::ffi_stream::ArrowArrayStreamReader; use datafusion::common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::metrics::{ @@ -40,8 +33,6 @@ use datafusion::{ }; use futures::Stream; use itertools::Itertools; -use jni::objects::{Global, JObject, JValue}; -use std::rc::Rc; use std::{ any::Any, pin::Pin, @@ -49,43 +40,34 @@ use std::{ task::{Context, Poll}, }; -/// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file -/// scan or the result of reading a broadcast or shuffle exchange. ScanExec isn't invoked -/// until the data is already available in the JVM. When CometExecIterator invokes -/// Native.executePlan, it passes in the memory addresses of the input batches. +/// `ScanExec` reads batches of data from Spark over the Arrow C Stream Interface. The +/// `input_source` is moved out of the JVM-exported `ArrowArrayStream` at plan-construction time; +/// dropping the reader (when this exec drops) fires the stream's release callback, which closes +/// the JVM-side `ArrowReader` and its `VectorSchemaRoot`. #[derive(Debug, Clone)] pub struct ScanExec { - /// The ID of the execution context that owns this subquery. We use this ID to retrieve the JVM - /// environment `JNIEnv` from the execution context. + /// JVM execution-context id used to look up the `JNIEnv` for callbacks. pub exec_context_id: i64, - /// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object. - pub input_source: Option>>>, - /// A description of the input source for informational purposes + /// The C Stream Interface reader. `None` only in unit tests that seed input via + /// `set_input_batch`. + pub input_source: Option>>, pub input_source_description: String, - /// The data types of columns of the input batch. Converted from Spark schema. pub data_types: Vec, - /// Schema of first batch pub schema: SchemaRef, - /// The input batch of input data. Used to determine the schema of the input data. - /// It is also used in unit test to mock the input data from JVM. + /// Used in unit tests to mock the input batch; otherwise written by `pull_next` on each + /// poll. pub batch: Arc>>, - /// Cache of expensive-to-compute plan properties cache: Arc, - /// Metrics collector metrics: ExecutionPlanMetricsSet, - /// Baseline metrics baseline_metrics: BaselineMetrics, - /// Whether native code can assume ownership of batches that it receives - arrow_ffi_safe: bool, } impl ScanExec { pub fn new( exec_context_id: i64, - input_source: Option>>>, + input_source: Option>>, input_source_description: &str, data_types: Vec, - arrow_ffi_safe: bool, ) -> Result { let metrics_set = ExecutionPlanMetricsSet::default(); let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); @@ -112,7 +94,6 @@ impl ScanExec { metrics: metrics_set, baseline_metrics, schema, - arrow_ffi_safe, }) } @@ -131,22 +112,18 @@ impl ScanExec { *self.batch.try_lock().unwrap() = Some(input); } - /// Pull next input batch from JVM. + /// Pull next input batch from the upstream `ArrowArrayStreamReader`. pub fn get_next_batch(&mut self) -> Result<(), CometError> { if self.input_source.is_none() { - // This is a unit test. We don't need to call JNI. + // This is a unit test. Input batches are seeded via `set_input_batch`. return Ok(()); } let mut timer = self.baseline_metrics.elapsed_compute().timer(); let mut current_batch = self.batch.try_lock().unwrap(); if current_batch.is_none() { - let next_batch = ScanExec::get_next( - self.exec_context_id, - self.input_source.as_ref().unwrap().as_obj(), - self.data_types.len(), - self.arrow_ffi_safe, - )?; + let next_batch = + ScanExec::pull_next(self.exec_context_id, self.input_source.as_ref().unwrap())?; *current_batch = Some(next_batch); } @@ -155,119 +132,35 @@ impl ScanExec { Ok(()) } - /// Invokes JNI call to get next batch. - fn get_next( + /// Pull the next `RecordBatch` from the stream and convert it to an `InputBatch`. Dictionary + /// columns are unpacked because Comet's downstream operators do not handle them. + fn pull_next( exec_context_id: i64, - iter: &JObject, - num_cols: usize, - arrow_ffi_safe: bool, + reader: &Arc>, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { - // This is a unit test. We don't need to call JNI. + // Unit test path; input batches are seeded directly. return Ok(InputBatch::EOF); } - if iter.is_null() { - return Err(CometError::from(ExecutionError::GeneralError(format!( - "Null batch iterator object. Plan id: {exec_context_id}" - )))); - } - - JVMClasses::with_env(|env| { - let num_rows: i32 = unsafe { - jni_call!(env, - comet_batch_iterator(iter).has_next() -> i32)? - }; - - if num_rows == -1 { - return Ok(InputBatch::EOF); - } - - // fetch batch data from JVM via FFI - let (num_rows, array_addrs, schema_addrs) = - Self::allocate_and_fetch_batch(env, iter, num_cols)?; - - let mut inputs: Vec = Vec::with_capacity(num_cols); - - // Process each column - for i in 0..num_cols { - let array_ptr = array_addrs[i]; - let schema_ptr = schema_addrs[i]; - let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; - - // TODO: validate array input data - // array_data.validate_full()?; - - let array = make_array(array_data); - - let array = if arrow_ffi_safe { - // ownership of this array has been transferred to native - // but we still need to unpack dictionary arrays - copy_or_unpack_array(&array, &CopyMode::UnpackOrClone)? - } else { - // it is necessary to copy the array because the contents may be - // overwritten on the JVM side in the future - copy_array(&array) - }; - - inputs.push(array); - - // Drop the Arcs to avoid memory leak - unsafe { - Rc::from_raw(array_ptr as *const FFI_ArrowArray); - Rc::from_raw(schema_ptr as *const FFI_ArrowSchema); + let mut reader = reader + .try_lock() + .map_err(|_| CometError::Internal("ArrowArrayStreamReader contended".to_string()))?; + + let next = reader.next(); + match next { + None => Ok(InputBatch::EOF), + Some(Err(e)) => Err(CometError::from(e)), + Some(Ok(record_batch)) => { + let num_rows = record_batch.num_rows(); + let columns = record_batch.columns(); + let mut inputs: Vec = Vec::with_capacity(columns.len()); + for col in columns { + inputs.push(copy_or_unpack_array(col, &CopyMode::UnpackOrClone)?); } + Ok(InputBatch::new(inputs, Some(num_rows))) } - - Ok(InputBatch::new(inputs, Some(num_rows as usize))) - }) - } - - /// Allocates Arrow FFI structures and calls JNI to get the next batch data. - /// Returns the number of rows and the allocated array/schema addresses. - fn allocate_and_fetch_batch( - env: &mut jni::Env, - iter: &JObject, - num_cols: usize, - ) -> Result<(i32, Vec, Vec), CometError> { - let mut array_addrs = Vec::with_capacity(num_cols); - let mut schema_addrs = Vec::with_capacity(num_cols); - - for _ in 0..num_cols { - let arrow_array = Rc::new(FFI_ArrowArray::empty()); - let arrow_schema = Rc::new(FFI_ArrowSchema::empty()); - let (array_ptr, schema_ptr) = ( - Rc::into_raw(arrow_array) as i64, - Rc::into_raw(arrow_schema) as i64, - ); - - array_addrs.push(array_ptr); - schema_addrs.push(schema_ptr); } - - // Prepare the java array parameters - let long_array_addrs = env.new_long_array(num_cols)?; - let long_schema_addrs = env.new_long_array(num_cols)?; - - long_array_addrs.set_region(env, 0, &array_addrs)?; - long_schema_addrs.set_region(env, 0, &schema_addrs)?; - - let array_obj = JObject::from(long_array_addrs); - let schema_obj = JObject::from(long_schema_addrs); - - let array_obj = JValue::Object(array_obj.as_ref()); - let schema_obj = JValue::Object(schema_obj.as_ref()); - - let num_rows: i32 = unsafe { - jni_call!(env, - comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)? - }; - - // we already checked for end of results on call to has_next() so should always - // have a valid row count when calling next() - assert!(num_rows != -1); - - Ok((num_rows, array_addrs, schema_addrs)) } } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 542c3d9536..77b174ea8d 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -21,6 +21,7 @@ pub mod expression_registry; pub mod macros; pub mod operator_registry; +use crate::errors::CometError; use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::IcebergScanExec; use crate::execution::{ @@ -32,6 +33,7 @@ use crate::execution::{ serde::to_arrow_datatype, shuffle::ShuffleWriterExec, }; +use crate::jvm_bridge::{jni_call, JVMClasses}; use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field, FieldRef, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; @@ -1447,23 +1449,36 @@ impl PhysicalPlanner { return Err(GeneralError("No input for scan".to_string())); } - // Consumes the first input source for the scan - let input_source = - if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - // For unit test, we will set input batch to scan directly by `set_input_batch`. - None - } else { - Some(inputs.remove(0)) - }; + // Consumes the first input source for the scan. The Java side passes an + // `org.apache.arrow.c.ArrowArrayStream` whose `memoryAddress` points at the C + // struct; native takes ownership via `ArrowArrayStreamReader::from_raw`. + let input_source = if self.exec_context_id == TEST_EXEC_CONTEXT_ID + && inputs.is_empty() + { + // For unit test, we will set input batch to scan directly by `set_input_batch`. + None + } else { + let java_stream = inputs.remove(0); + let address: i64 = JVMClasses::with_env(|env| -> Result { + let addr = unsafe { + jni_call!(env, arrow_array_stream(java_stream.as_obj()).memory_address() -> i64)? + }; + Ok(addr) + })?; + let reader = unsafe { + arrow::ffi_stream::ArrowArrayStreamReader::from_raw( + address as *mut arrow::ffi_stream::FFI_ArrowArrayStream, + ) + } + .map_err(|e| { + GeneralError(format!("Failed to import ArrowArrayStream from JVM: {e}")) + })?; + Some(Arc::new(std::sync::Mutex::new(reader))) + }; // The `ScanExec` operator will take actual arrays from Spark during execution - let scan = ScanExec::new( - self.exec_context_id, - input_source, - &scan.source, - data_types, - scan.arrow_ffi_safe, - )?; + let scan = + ScanExec::new(self.exec_context_id, input_source, &scan.source, data_types)?; Ok(( vec![scan.clone()], @@ -3980,7 +3995,6 @@ mod tests { type_info: None, }], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4046,7 +4060,6 @@ mod tests { type_info: None, }], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4256,7 +4269,6 @@ mod tests { op_struct: Some(OpStruct::Scan(spark_operator::Scan { fields: vec![create_proto_datatype()], source: "".to_string(), - arrow_ffi_safe: false, })), } } @@ -4299,7 +4311,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4422,7 +4433,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4905,7 +4915,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 2fe6f8758f..6195e3f0ae 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -19,48 +19,15 @@ use crate::execution::operators::ExecutionError; use arrow::{ array::ArrayData, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + ffi::{FFI_ArrowArray, FFI_ArrowSchema}, }; pub trait SparkArrowConvert { - /// Build Arrow Arrays from C data interface passed from Spark. - /// It accepts a tuple (ArrowArray address, ArrowSchema address). - fn from_spark(addresses: (i64, i64)) -> Result - where - Self: Sized; - /// Move Arrow Arrays to C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError>; } impl SparkArrowConvert for ArrayData { - fn from_spark(addresses: (i64, i64)) -> Result { - let (array_ptr, schema_ptr) = addresses; - - let array_ptr = array_ptr as *mut FFI_ArrowArray; - let schema_ptr = schema_ptr as *mut FFI_ArrowSchema; - - if array_ptr.is_null() || schema_ptr.is_null() { - return Err(ExecutionError::ArrowError( - "At least one of passed pointers is null".to_string(), - )); - }; - - // `ArrowArray` will convert raw pointers back to `Arc`. No worries - // about memory leak. - let mut ffi_array = unsafe { - let array_data = std::ptr::replace(array_ptr, FFI_ArrowArray::empty()); - let schema_data = std::ptr::replace(schema_ptr, FFI_ArrowSchema::empty()); - - from_ffi(array_data, &schema_data)? - }; - - // Align imported buffers from Java. - ffi_array.align_buffers(); - - Ok(ffi_array) - } - /// Move this ArrowData to pointers of Arrow C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { let array_ptr = array as *mut FFI_ArrowArray; diff --git a/native/jni-bridge/src/batch_iterator.rs b/native/jni-bridge/src/arrow_array_stream.rs similarity index 57% rename from native/jni-bridge/src/batch_iterator.rs rename to native/jni-bridge/src/arrow_array_stream.rs index addda133fa..0b285607ff 100644 --- a/native/jni-bridge/src/batch_iterator.rs +++ b/native/jni-bridge/src/arrow_array_stream.rs @@ -15,45 +15,38 @@ // specific language governing permissions and limitations // under the License. -use jni::signature::Primitive; use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, - signature::ReturnType, + signature::{Primitive, ReturnType}, strings::JNIString, Env, }; -/// A struct that holds all the JNI methods and fields for JVM `CometBatchIterator` class. +/// A struct that holds all the JNI methods and fields for JVM `org.apache.arrow.c.ArrowArrayStream` +/// class. `memoryAddress()` is read once per partition so native can take ownership of the +/// underlying C struct via `ArrowArrayStreamReader::from_raw`. #[allow(dead_code)] // we need to keep references to Java items to prevent GC -pub struct CometBatchIterator<'a> { +pub struct ArrowArrayStream<'a> { pub class: JClass<'a>, - pub method_has_next: JMethodID, - pub method_has_next_ret: ReturnType, - pub method_next: JMethodID, - pub method_next_ret: ReturnType, + pub method_memory_address: JMethodID, + pub method_memory_address_ret: ReturnType, } -impl<'a> CometBatchIterator<'a> { - pub const JVM_CLASS: &'static str = "org/apache/comet/CometBatchIterator"; +impl<'a> ArrowArrayStream<'a> { + pub const JVM_CLASS: &'static str = "org/apache/arrow/c/ArrowArrayStream"; - pub fn new(env: &mut Env<'a>) -> JniResult> { + pub fn new(env: &mut Env<'a>) -> JniResult> { let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; - Ok(CometBatchIterator { - class, - method_has_next: env.get_method_id( - JNIString::new(Self::JVM_CLASS), - jni::jni_str!("hasNext"), - jni::jni_sig!("()I"), - )?, - method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_next: env.get_method_id( + Ok(ArrowArrayStream { + method_memory_address: env.get_method_id( JNIString::new(Self::JVM_CLASS), - jni::jni_str!("next"), - jni::jni_sig!("([J[J)I"), + jni::jni_str!("memoryAddress"), + jni::jni_sig!("()J"), )?, - method_next_ret: ReturnType::Primitive(Primitive::Int), + method_memory_address_ret: ReturnType::Primitive(Primitive::Long), + class, }) } } diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index d72323c961..c8bb7cd02d 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -189,13 +189,13 @@ impl<'a> TryFrom> for BinaryWrapper<'a> { mod comet_exec; pub use comet_exec::*; -mod batch_iterator; +mod arrow_array_stream; mod comet_metric_node; mod comet_task_memory_manager; mod comet_udf_bridge; mod shuffle_block_iterator; -use batch_iterator::CometBatchIterator; +use arrow_array_stream::ArrowArrayStream; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; use comet_udf_bridge::CometUdfBridge; @@ -223,8 +223,9 @@ pub struct JVMClasses<'a> { pub comet_metric_node: CometMetricNode<'a>, /// The static CometExec class. Used for getting the subquery result. pub comet_exec: CometExec<'a>, - /// The CometBatchIterator class. Used for iterating over the batches. - pub comet_batch_iterator: CometBatchIterator<'a>, + /// The org.apache.arrow.c.ArrowArrayStream class. Used to get the C struct memory address + /// when importing a JVM-exported batch stream into native code. + pub arrow_array_stream: ArrowArrayStream<'a>, /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to @@ -300,7 +301,7 @@ impl JVMClasses<'_> { throwable_get_cause_method, comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), - comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + arrow_array_stream: ArrowArrayStream::new(env).unwrap(), comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), comet_udf_bridge: { diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index ed1684b240..b65a215c78 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -85,8 +85,6 @@ message Scan { // is purely for informational purposes when viewing native query plans in // debug mode. string source = 2; - // Whether native code can assume ownership of batches that it receives - bool arrow_ffi_safe = 3; } message ShuffleScan { diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java deleted file mode 100644 index 9b48a47c57..0000000000 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet; - -import scala.collection.Iterator; - -import org.apache.spark.sql.vectorized.ColumnarBatch; - -import org.apache.comet.vector.NativeUtil; - -/** - * Iterator for fetching batches from JVM to native code. Usually called via JNI from native - * ScanExec. - * - *

Batches are owned by the JVM. Native code can safely access the batch after calling `next` but - * the native code must not retain references to the batch because the next call to `hasNext` will - * signal to the JVM that the batch can be closed. - */ -public class CometBatchIterator { - private final Iterator input; - private final NativeUtil nativeUtil; - private ColumnarBatch previousBatch = null; - private ColumnarBatch currentBatch = null; - - CometBatchIterator(Iterator input, NativeUtil nativeUtil) { - this.input = input; - this.nativeUtil = nativeUtil; - } - - /** - * Fetch the next input batch and allow the previous batch to be closed (this may not happen - * immediately). - * - * @return Number of rows in next batch or -1 if no batches left. - */ - public int hasNext() { - - // release reference to previous batch - previousBatch = null; - - if (currentBatch == null) { - if (input.hasNext()) { - currentBatch = input.next(); - } - } - if (currentBatch == null) { - return -1; - } else { - return currentBatch.numRows(); - } - } - - /** - * Get the next batch of Arrow arrays. - * - * @param arrayAddrs The addresses of the ArrowArray structures. - * @param schemaAddrs The addresses of the ArrowSchema structures. - * @return the number of rows of the current batch. -1 if there is no more batch. - */ - public int next(long[] arrayAddrs, long[] schemaAddrs) { - if (currentBatch == null) { - return -1; - } - - // export the batch using the Arrow C Data Interface - int numRows = nativeUtil.exportBatch(arrayAddrs, schemaAddrs, currentBatch); - - // keep a reference to the exported batch so that it doesn't get garbage collected - // while the native code is still processing it - previousBatch = currentBatch; - - currentBatch = null; - - return numRows; - } -} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 6140eca553..d17735a560 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -60,7 +60,7 @@ import org.apache.comet.vector.NativeUtil */ class CometExecIterator( val id: Long, - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode, @@ -79,14 +79,11 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle - // scan indices, CometBatchIterator for regular scan indices. - private val inputIterators: Array[Object] = inputs.zipWithIndex.map { - case (_, idx) if shuffleBlockIterators.contains(idx) => - shuffleBlockIterators(idx).asInstanceOf[Object] - case (iterator, _) => - new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] - }.toArray + // Each input slot is either an org.apache.arrow.c.ArrowArrayStream (consumed natively via + // ArrowArrayStreamReader::from_raw against its memoryAddress) or a CometShuffleBlockIterator + // (consumed via the existing JNI block-iteration protocol). The slot index matches the scan + // input index in the serialized native plan. + private val inputIterators: Array[Object] = inputObjects private val plan = { val conf = SparkEnv.get.conf diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 69b9bd5f85..4a8ae4d2ac 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -96,7 +96,6 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec val scanOp = OperatorOuterClass.Scan .newBuilder() .setSource(cmd.query.nodeName) - .setArrowFfiSafe(false) // Add fields from the query output schema val scanTypes = cmd.query.output.flatMap { attr => diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index 845803d133..b7caeb43c2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -40,9 +40,6 @@ import org.apache.comet.serde.QueryPlanSerde.{serializeDataType, supportedDataTy */ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { - /** Whether the data produced by the Comet operator is FFI safe */ - def isFfiSafe: Boolean = true - override def enabledConfig: Option[ConfigEntry[Boolean]] = None override def convert( @@ -65,7 +62,6 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { } else { scanBuilder.setSource(source) } - scanBuilder.setArrowFfiSafe(isFfiSafe) val scanTypes = op.output.flatten { attr => serializeDataType(attr.dataType) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 47eda98a11..4b411d87f7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.comet +import org.apache.arrow.c.ArrowArrayStream import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -27,7 +28,7 @@ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometExecIterator +import org.apache.comet.{CometExecIterator, CometRuntimeException, CometShuffleBlockIterator} import org.apache.comet.serde.OperatorOuterClass /** @@ -40,23 +41,14 @@ private[spark] class CometExecPartition( extends Partition /** - * Unified RDD for Comet native execution. - * - * Solves the closure capture problem: instead of capturing all partitions' data in the closure - * (which gets serialized to every task), each Partition object carries only its own data. - * - * Handles three cases: - * - With inputs + per-partition data: injects planning data into operator tree - * - With inputs + no per-partition data: just zips inputs (no injection overhead) - * - No inputs: uses numPartitions to create partitions - * - * NOTE: This RDD does not handle DPP (InSubqueryExec), which is resolved in - * CometIcebergNativeScanExec.serializedPartitionData before this RDD is created. It also handles - * ScalarSubquery expressions by registering them with CometScalarSubquery before execution. + * Unified RDD for Comet native execution. Non-shuffle input slots are `RDD[ArrowArrayStream]` + * (consumed natively via the C Stream Interface); shuffle input slots are `CometShuffledBatchRDD` + * (consumed via `CometShuffleBlockIterator`). Slot order matches the scan-input order in the + * serialized native plan. */ private[spark] class CometExecRDD( sc: SparkContext, - var inputRDDs: Seq[RDD[ColumnarBatch]], + var inputRDDs: Seq[RDD[_]], commonByKey: Map[String, Array[Byte]], @transient perPartitionByKey: Map[String, Array[Array[Byte]]], serializedPlan: Array[Byte], @@ -97,9 +89,31 @@ private[spark] class CometExecRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometExecPartition] - val inputs = inputRDDs.zip(partition.inputPartitions).map { case (rdd, part) => - rdd.iterator(part, context) - } + val shuffleBlockIters = scala.collection.mutable.Map.empty[Int, CometShuffleBlockIterator] + val inputObjects: Array[Object] = inputRDDs + .zip(partition.inputPartitions) + .zipWithIndex + .map { case ((rdd, part), idx) => + if (shuffleScanIndices.contains(idx)) { + rdd match { + case shuffleRDD: CometShuffledBatchRDD => + val it = shuffleRDD.computeAsShuffleBlockIterator(part, context) + shuffleBlockIters(idx) = it + it.asInstanceOf[Object] + case other => + throw new CometRuntimeException( + s"Slot $idx is marked as a shuffle scan but the input RDD is " + + s"${other.getClass.getName}, expected CometShuffledBatchRDD") + } + } else { + val streams = rdd.iterator(part, context).asInstanceOf[Iterator[ArrowArrayStream]] + if (!streams.hasNext) { + throw new CometRuntimeException(s"Empty ArrowArrayStream RDD partition for slot $idx") + } + streams.next().asInstanceOf[Object] + } + } + .toArray // Only inject if we have per-partition planning data val actualPlan = if (commonByKey.nonEmpty) { @@ -111,18 +125,9 @@ private[spark] class CometExecRDD( serializedPlan } - // Create shuffle block iterators for inputs that are CometShuffledBatchRDD - val shuffleBlockIters = shuffleScanIndices.flatMap { idx => - inputRDDs(idx) match { - case rdd: CometShuffledBatchRDD => - Some(idx -> rdd.computeAsShuffleBlockIterator(partition.inputPartitions(idx), context)) - case _ => None - } - }.toMap - val it = new CometExecIterator( CometExec.newIterId, - inputs, + inputObjects, numOutputCols, actualPlan, nativeMetrics, @@ -130,7 +135,7 @@ private[spark] class CometExecRDD( partition.index, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleBlockIters) + shuffleBlockIters.toMap) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) @@ -169,7 +174,7 @@ object CometExecRDD { // scalastyle:off def apply( sc: SparkContext, - inputRDDs: Seq[RDD[ColumnarBatch]], + inputRDDs: Seq[RDD[_]], commonByKey: Map[String, Array[Byte]], perPartitionByKey: Map[String, Array[Array[Byte]]], serializedPlan: Array[Byte], diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index a2af60142b..e632190f0a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -25,6 +25,8 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.vectorized.ColumnarBatch @@ -56,8 +58,19 @@ object CometExecUtils { // Serialize the plan once before mapping to avoid repeated serialization per partition val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get val serializedPlan = CometExec.serializeNativePlan(limitOp) + val inputSchema = Utils.fromAttributes(outputAttribute) childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => - CometExec.getCometIterator(Seq(iter), outputAttribute.length, serializedPlan, numParts, idx) + val stream = CometArrowStream.fromColumnarBatchIter( + iter, + inputSchema, + CometArrowStream.NATIVE_TIMEZONE, + "CometExecUtils-getNativeLimit") + CometExec.getCometIterator( + Array(stream.asInstanceOf[Object]), + outputAttribute.length, + serializedPlan, + numParts, + idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 6ada259e82..32b8933872 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -21,11 +21,12 @@ package org.apache.spark.sql.comet import scala.collection.mutable.ListBuffer -import org.apache.spark.TaskContext +import org.apache.arrow.c.ArrowArrayStream import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource, RowArrowReader} +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.{DataType, NullType} @@ -43,7 +44,8 @@ case class CometLocalTableScanExec( @transient rows: Seq[InternalRow], override val output: Seq[Attribute]) extends CometExec - with LeafExecNode { + with LeafExecNode + with CometNativeArrowSource { override lazy val metrics: Map[String, SQLMetric] = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -66,25 +68,47 @@ case class CometLocalTableScanExec( } } + private def countingRows( + iter: Iterator[InternalRow], + numOutputRows: SQLMetric): Iterator[InternalRow] = new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext + override def next(): InternalRow = { + val row = iter.next() + numOutputRows.add(1) + row + } + } + override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - val schema = originalPlan.schema - // Native side asserts Timestamp(Microsecond, Some("UTC")). See COMET-2720. + val sparkSchema = originalPlan.schema + rdd.mapPartitionsInternal { rowIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.readerBatchIter( + "CometLocalTableScan", + new RowArrowReader( + _, + arrowSchema, + countingRows(rowIter, numOutputRows), + maxRecordsPerBatch)) + } + } + + override def doExecuteAsArrowStream(): RDD[ArrowArrayStream] = { + val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) + val sparkSchema = originalPlan.schema + val numOutputRows = longMetric("numOutputRows") rdd.mapPartitionsInternal { rowIter => - val context = TaskContext.get() - // Non-Comet JVM consumers (e.g. Iceberg writers) may retain batches across next() - // calls, so each batch must own independent Arrow buffers. - val batches = CometArrowConverters.rowToArrowBatchIterNoReuse( - rowIter, - schema, - maxRecordsPerBatch, - "UTC", - context) - batches.map { batch => - numOutputRows.add(batch.numRows()) - batch - } + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.stream( + "CometLocalTableScan", + allocator => + new RowArrowReader( + allocator, + arrowSchema, + countingRows(rowIter, numOutputRows), + maxRecordsPerBatch)) } } @@ -117,9 +141,9 @@ object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTy override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED) - // CometArrowConverters / ArrowWriter support NullType (via Utils.toArrowType + - // NullWriter). Other types not on DataTypeSupport's allow list (e.g. TimeType, - // intervals) lack ArrowWriter coverage and must fall back to Spark. + // ArrowWriter (used by RowArrowReader) handles NullType via Utils.toArrowType + NullWriter; + // other types off DataTypeSupport's allow list (TimeType, intervals, ...) have no ArrowWriter + // coverage and must fall back to Spark. override def isTypeSupported( dt: DataType, name: String, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index 4fb8af39e8..7ba281a666 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -28,6 +28,8 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -210,9 +212,15 @@ case class CometNativeWriteExec( modifiedNativeOp.writeTo(codedOutput) codedOutput.checkNoSpaceLeft() + val arrowStream = CometArrowStream.fromColumnarBatchIter( + iter, + CometUtils.fromAttributes(child.output), + CometArrowStream.NATIVE_TIMEZONE, + "CometNativeWriteExec") + val execIterator = new CometExecIterator( CometExec.newIterId, - Seq(iter), + Array(arrowStream.asInstanceOf[Object]), numOutputCols, planBytes, nativeMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index efe6a97d40..00e13bcbde 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -21,13 +21,14 @@ package org.apache.spark.sql.comet import scala.collection.mutable.ListBuffer -import org.apache.spark.TaskContext +import org.apache.arrow.c.ArrowArrayStream import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource, RowArrowReader, SparkColumnarArrowReader} +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{RowToColumnarTransition, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types._ @@ -39,7 +40,8 @@ import org.apache.comet.serde.operator.CometSink case class CometSparkToColumnarExec(child: SparkPlan) extends RowToColumnarTransition - with CometPlan { + with CometPlan + with CometNativeArrowSource { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning @@ -69,72 +71,99 @@ case class CometSparkToColumnarExec(child: SparkPlan) sparkContext, "time converting Spark batches to Arrow batches")) - // The conversion happens in next(), so wrap the call to measure time spent. - private def createTimingIter( + private def countingBatches( iter: Iterator[ColumnarBatch], numInputRows: SQLMetric, - numOutputBatches: SQLMetric, - conversionTime: SQLMetric): Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { + numOutputBatches: SQLMetric): Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = iter.hasNext + override def next(): ColumnarBatch = { + val batch = iter.next() + numInputRows += batch.numRows() + numOutputBatches += 1 + batch + } + } - override def hasNext: Boolean = { - iter.hasNext - } + private def countingRows( + iter: Iterator[InternalRow], + numInputRows: SQLMetric): Iterator[InternalRow] = new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext + override def next(): InternalRow = { + val row = iter.next() + numInputRows += 1 + row + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numInputRows = longMetric("numInputRows") + val numOutputBatches = longMetric("numOutputBatches") + val conversionTime = longMetric("conversionTime") + val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) + val sparkSchema = child.schema - override def next(): ColumnarBatch = { - val startNs = System.nanoTime() - val batch = iter.next() - conversionTime += System.nanoTime() - startNs - numInputRows += batch.numRows() - numOutputBatches += 1 - batch + if (child.supportsColumnar) { + val maxBatchInt = maxRecordsPerBatch.toInt + child.executeColumnar().mapPartitionsInternal { sparkBatches => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.readerBatchIter( + "CometSparkColumnarToColumnar", + new SparkColumnarArrowReader( + _, + arrowSchema, + countingBatches(sparkBatches, numInputRows, numOutputBatches), + maxBatchInt, + ns => conversionTime += ns)) + } + } else { + child.execute().mapPartitionsInternal { rowIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.readerBatchIter( + "CometSparkRowToColumnar", + new RowArrowReader( + _, + arrowSchema, + countingRows(rowIter, numInputRows), + maxRecordsPerBatch, + ns => conversionTime += ns)) } } } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { + override def doExecuteAsArrowStream(): RDD[ArrowArrayStream] = { val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val conversionTime = longMetric("conversionTime") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - // Use UTC for Arrow schema timezone to match the native side, which always - // deserializes Timestamp as Timestamp(Microsecond, Some("UTC")). Spark's internal - // timestamp representation is always UTC microseconds, so the timezone here is - // purely schema metadata. Using session timezone would cause Arrow RowConverter - // schema mismatch errors in non-UTC sessions. See COMET-2720. - val timeZoneId = "UTC" - val schema = child.schema + val sparkSchema = child.schema if (child.supportsColumnar) { - child - .executeColumnar() - .mapPartitionsInternal { sparkBatches => - val arrowBatches = - sparkBatches.flatMap { sparkBatch => - val context = TaskContext.get() - CometArrowConverters.columnarBatchToArrowBatchIter( - sparkBatch, - schema, - maxRecordsPerBatch, - timeZoneId, - context) - } - createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime) - } + val maxBatchInt = maxRecordsPerBatch.toInt + child.executeColumnar().mapPartitionsInternal { sparkBatches => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.stream( + "CometSparkColumnarToColumnar", + allocator => + new SparkColumnarArrowReader( + allocator, + arrowSchema, + countingBatches(sparkBatches, numInputRows, numOutputBatches), + maxBatchInt, + ns => conversionTime += ns)) + } } else { - child - .execute() - .mapPartitionsInternal { sparkBatches => - val context = TaskContext.get() - val arrowBatches = - CometArrowConverters.rowToArrowBatchIter( - sparkBatches, - schema, + child.execute().mapPartitionsInternal { rowIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.stream( + "CometSparkRowToColumnar", + allocator => + new RowArrowReader( + allocator, + arrowSchema, + countingRows(rowIter, numInputRows), maxRecordsPerBatch, - timeZoneId, - context) - createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime) - } + ns => conversionTime += ns)) + } } } @@ -145,9 +174,6 @@ case class CometSparkToColumnarExec(child: SparkPlan) object CometSparkToColumnarExec extends CometSink[SparkPlan] with DataTypeSupport { - // uses CometArrowConverters, which re-uses arrays - override def isFfiSafe: Boolean = false - override def createExec( nativeOp: OperatorOuterClass.Operator, op: SparkPlan): CometNativeExec = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index a66d1b58d6..e9b178dc6b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -24,7 +24,9 @@ import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.{SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -140,8 +142,19 @@ case class CometTakeOrderedAndProjectExec( .get val serializedTopK = CometExec.serializeNativePlan(topK) val numOutputCols = child.output.length + val inputSchema = CometUtils.fromAttributes(child.output) childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => - CometExec.getCometIterator(Seq(iter), numOutputCols, serializedTopK, numParts, idx) + val stream = CometArrowStream.fromColumnarBatchIter( + iter, + inputSchema, + CometArrowStream.NATIVE_TIMEZONE, + "CometTakeOrderedAndProject-topK") + CometExec.getCometIterator( + Array(stream.asInstanceOf[Object]), + numOutputCols, + serializedTopK, + numParts, + idx) } } @@ -163,9 +176,15 @@ case class CometTakeOrderedAndProjectExec( .get val serializedTopKAndProjection = CometExec.serializeNativePlan(topKAndProjection) val finalOutputLength = output.length + val finalInputSchema = CometUtils.fromAttributes(child.output) singlePartitionRDD.mapPartitionsInternal { iter => + val stream = CometArrowStream.fromColumnarBatchIter( + iter, + finalInputSchema, + CometArrowStream.NATIVE_TIMEZONE, + "CometTakeOrderedAndProject-final") val it = CometExec.getCometIterator( - Seq(iter), + Array(stream.asInstanceOf[Object]), finalOutputLength, serializedTopKAndProjection, 1, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala new file mode 100644 index 0000000000..eaacc3968b --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import java.util.{ArrayList => JArrayList} + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.vector.CometVector + +/** + * `ArrowReader` over an iterator of Arrow-backed `ColumnarBatch`es. Each `loadNextBatch` unloads + * the source's `FieldVector`s into a transient `ArrowRecordBatch` (retains buffers), loads it + * into this reader's stable VSR via `loadFieldBuffers` (release-and-replace), then closes the + * source batch. The unload/load step decouples this reader's VSR ownership from whatever the + * upstream does with its own buffers. + */ +private[comet] class ColumnarBatchArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + source: Iterator[ColumnarBatch]) + extends ArrowReader(allocator) { + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!source.hasNext) { + return false + } + + val src = source.next() + try { + val sourceVectors = new JArrayList[FieldVector](src.numCols()) + var i = 0 + while (i < src.numCols()) { + sourceVectors.add( + src.column(i).asInstanceOf[CometVector].getValueVector.asInstanceOf[FieldVector]) + i += 1 + } + val transient = new VectorSchemaRoot(sourceVectors) + transient.setRowCount(src.numRows()) + + val unloader = new VectorUnloader(transient) + val rb = unloader.getRecordBatch + try { + loadRecordBatch(rb) + } finally { + rb.close() + } + // Note: do not close `transient`. It shares FieldVectors with `src`; closing `src` below + // releases the producer-side refs. Closing `transient` would double-release. + } finally { + src.close() + } + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala index 0e7d0d9a8d..32441029bb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala @@ -22,138 +22,41 @@ package org.apache.spark.sql.comet.execution.arrow import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch} -import org.apache.comet.CometArrowAllocator import org.apache.comet.vector.NativeUtil +/** + * Convert Spark `InternalRow` / `ColumnarBatch` streams to a stream of independently-owned Arrow + * `ColumnarBatch`es. Each emitted batch owns a fresh `VectorSchemaRoot` with newly allocated + * buffers; the consumer is responsible for closing the batch. + * + * Buffers are allocated from the caller-provided `BufferAllocator`. The caller owns the + * allocator's lifecycle (typically a child allocator closed at task completion). When emitted + * batches reach `ColumnarBatchArrowReader.loadNextBatch`, ownership of their buffers is + * transferred (via `VectorUnloader` / `loadFieldBuffers`) to the reader's allocator, after which + * the source batch is closed and the producer's allocator returns to zero outstanding bytes. + */ object CometArrowConverters extends Logging { - // This is similar how Spark converts internal row to Arrow format except that it is transforming - // the result batch to Comet's ColumnarBatch instead of serialized bytes. - // There's another big difference that Comet may consume the ColumnarBatch by exporting it to - // the native side. Hence, we need to: - // 1. reset the Arrow writer after the ColumnarBatch is consumed - // 2. close the allocator when the task is finished but not when the iterator is all consumed - // The reason for the second point is that when ColumnarBatch is exported to the native side, the - // exported process increases the reference count of the Arrow vectors. The reference count is - // only decreased when the native plan is done with the vectors, which is usually longer than - // all the ColumnarBatches are consumed. - - abstract private[sql] class ArrowBatchIterBase( - schema: StructType, - timeZoneId: String, - context: TaskContext) - extends Iterator[ColumnarBatch] - with AutoCloseable { - - protected val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) - // Reuse the same root allocator here. - protected val allocator: BufferAllocator = - CometArrowAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) - protected val root: VectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, allocator) - protected val arrowWriter: ArrowWriter = ArrowWriter.create(root) - - protected var currentBatch: ColumnarBatch = null - protected var closed: Boolean = false - - Option(context).foreach { - _.addTaskCompletionListener[Unit] { _ => - close(true) - } - } - - override def close(): Unit = { - close(false) - } - - protected def close(closeAllocator: Boolean): Unit = { - try { - if (!closed) { - if (currentBatch != null) { - arrowWriter.reset() - currentBatch.close() - currentBatch = null - } - root.close() - closed = true - } - } finally { - // the allocator shall be closed when the task is finished - if (closeAllocator) { - allocator.close() - } - } - } - - override def next(): ColumnarBatch = { - currentBatch = nextBatch() - currentBatch - } - - protected def nextBatch(): ColumnarBatch - - } - - private[sql] class RowToArrowBatchIter( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - timeZoneId: String, - context: TaskContext) - extends ArrowBatchIterBase(schema, timeZoneId, context) - with AutoCloseable { - - override def hasNext: Boolean = rowIter.hasNext || { - close(false) - false - } - - override protected def nextBatch(): ColumnarBatch = { - if (rowIter.hasNext) { - // the arrow writer shall be reset before writing the next batch - arrowWriter.reset() - var rowCount = 0L - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { - val row = rowIter.next() - arrowWriter.write(row) - rowCount += 1 - } - arrowWriter.finish() - NativeUtil.rootAsBatch(root) - } else { - null - } - } - } - - def rowToArrowBatchIter( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - timeZoneId: String, - context: TaskContext): Iterator[ColumnarBatch] = { - new RowToArrowBatchIter(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) - } /** - * Use when the downstream consumer may retain batches across `next()` calls (e.g. non-Comet JVM - * columnar sinks like Iceberg writers). Each batch owns independent Arrow buffers. + * Convert an iterator of Spark `InternalRow`s into an iterator of Arrow `ColumnarBatch`es. + * + * Each call to `next()` allocates a fresh `VectorSchemaRoot`, writes up to `maxRecordsPerBatch` + * rows into it, and emits a `ColumnarBatch` wrapping that root. The consumer must close every + * emitted batch. */ - def rowToArrowBatchIterNoReuse( + def rowToArrowBatchIter( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, timeZoneId: String, - context: TaskContext): Iterator[ColumnarBatch] = { - val arrowSchema = Utils.toArrowSchema(schema, timeZoneId) - val allocator = - CometArrowAllocator.newChildAllocator("rowToArrowBatchIterNoReuse", 0, Long.MaxValue) - Option(context).foreach(_.addTaskCompletionListener[Unit](_ => allocator.close())) + allocator: BufferAllocator): Iterator[ColumnarBatch] = { + val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) new Iterator[ColumnarBatch] { override def hasNext: Boolean = rowIter.hasNext @@ -173,57 +76,46 @@ object CometArrowConverters extends Logging { } } - private[sql] class ColumnBatchToArrowBatchIter( + /** + * Slice a single Spark `ColumnarBatch` into one or more Arrow `ColumnarBatch`es of at most + * `maxRecordsPerBatch` rows each. Each emitted batch owns a fresh `VectorSchemaRoot`. + */ + def columnarBatchToArrowBatchIter( colBatch: ColumnarBatch, schema: StructType, maxRecordsPerBatch: Int, timeZoneId: String, - context: TaskContext) - extends ArrowBatchIterBase(schema, timeZoneId, context) - with AutoCloseable { + allocator: BufferAllocator): Iterator[ColumnarBatch] = { + val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) + val totalRows = colBatch.numRows() - private var rowsProduced: Int = 0 + new Iterator[ColumnarBatch] { + private var rowsProduced: Int = 0 - override def hasNext: Boolean = rowsProduced < colBatch.numRows() || { - close(false) - false - } + override def hasNext: Boolean = rowsProduced < totalRows - override protected def nextBatch(): ColumnarBatch = { - val rowsInBatch = colBatch.numRows() - if (rowsProduced < rowsInBatch) { - // the arrow writer shall be reset before writing the next batch - arrowWriter.reset() + override def next(): ColumnarBatch = { val rowsToProduce = - if (maxRecordsPerBatch <= 0) rowsInBatch - rowsProduced - else Math.min(maxRecordsPerBatch, rowsInBatch - rowsProduced) + if (maxRecordsPerBatch <= 0) totalRows - rowsProduced + else math.min(maxRecordsPerBatch, totalRows - rowsProduced) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val writer = ArrowWriter.create(root) for (columnIndex <- 0 until colBatch.numCols()) { val column = colBatch.column(columnIndex) val columnArray = new ColumnarArray(column, rowsProduced, rowsToProduce) if (column.hasNull) { - arrowWriter.writeCol(columnArray, columnIndex) + writer.writeCol(columnArray, columnIndex) } else { - arrowWriter.writeColNoNull(columnArray, columnIndex) + writer.writeColNoNull(columnArray, columnIndex) } } rowsProduced += rowsToProduce - - arrowWriter.finish() + writer.finish() NativeUtil.rootAsBatch(root) - } else { - null } } } - - def columnarBatchToArrowBatchIter( - colBatch: ColumnarBatch, - schema: StructType, - maxRecordsPerBatch: Int, - timeZoneId: String, - context: TaskContext): Iterator[ColumnarBatch] = { - new ColumnBatchToArrowBatchIter(colBatch, schema, maxRecordsPerBatch, timeZoneId, context) - } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala new file mode 100644 index 0000000000..14a7a9ed0c --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.c.{ArrowArrayStream, Data} +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometArrowAllocator +import org.apache.comet.vector.NativeUtil + +/** + * Marker for Comet operators that can produce Arrow data destined for a Comet native executor + * directly as the C Stream Interface, skipping the intermediate `RDD[ColumnarBatch]` layer. + */ +trait CometNativeArrowSource extends SparkPlan { + def doExecuteAsArrowStream(): RDD[ArrowArrayStream] +} + +object CometArrowStream { + + /** + * Native side asserts `Timestamp(Microsecond, Some("UTC"))` regardless of session timezone; + * Spark's internal timestamp representation is always UTC microseconds anyway, and a non-UTC + * timezone here would only show up as schema metadata that breaks Arrow RowConverter + * validation. See COMET-2720. + */ + val NATIVE_TIMEZONE: String = "UTC" + + /** + * Wrap an `RDD[ColumnarBatch]` whose batches are Arrow-backed into an `RDD[ArrowArrayStream]`. + */ + def wrapColumnarBatchRDD( + rdd: RDD[ColumnarBatch], + sparkSchema: StructType, + timeZoneId: String, + name: String): RDD[ArrowArrayStream] = { + // Arrow `Schema` is not Serializable; only Spark's `StructType` is. Build the Arrow schema + // inside the per-task body so the closure cleaner doesn't try to ship a Schema across. + rdd.mapPartitionsInternal { batchIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, timeZoneId) + stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, batchIter)) + } + } + + /** + * Wrap a single per-partition `Iterator[ColumnarBatch]` (Arrow-backed) and return the exported + * `ArrowArrayStream`. For callers outside `CometExecRDD` that hand a JNI input slot directly to + * a `CometExecIterator`. + */ + def fromColumnarBatchIter( + iter: Iterator[ColumnarBatch], + sparkSchema: StructType, + timeZoneId: String, + name: String): ArrowArrayStream = { + val arrowSchema = Utils.toArrowSchema(sparkSchema, timeZoneId) + stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, iter)).next() + } + + /** + * Allocate a child allocator, build a reader, export it as an `ArrowArrayStream`, and register + * task-completion cleanup. Returns a single-element iterator so this composes with + * `RDD.mapPartitionsInternal`. + * + * Close ordering: when native drops its `ArrowArrayStreamReader`, the C release callback fires + * synchronously into `ExportedArrayStreamPrivateData.close` -> `reader.close` -> the VSR's + * buffers are released. The task-completion listener registered here runs strictly after that + * (Spark fires listeners in reverse registration order, and the listener that drops the native + * plan is registered later by `CometExecIterator`), so `allocator.close` finds zero outstanding + * bytes. + */ + def stream( + name: String, + readerFactory: BufferAllocator => ArrowReader): Iterator[ArrowArrayStream] = { + val context = TaskContext.get() + val allocator = CometArrowAllocator.newChildAllocator(name, 0, Long.MaxValue) + var reader: ArrowReader = null + var arrowStream: ArrowArrayStream = null + try { + reader = readerFactory(allocator) + arrowStream = ArrowArrayStream.allocateNew(allocator) + Data.exportArrayStream(allocator, reader, arrowStream) + } catch { + case t: Throwable => + // Roll back partial setup before rethrowing -- nothing has been registered with + // TaskContext yet, so without this the allocator (and possibly the reader/stream) leaks. + if (arrowStream != null) { + try arrowStream.close() + catch { case _: Throwable => () } + } + if (reader != null) { + try reader.close() + catch { case _: Throwable => () } + } + try allocator.close() + catch { case _: Throwable => () } + throw t + } + if (context != null) { + val streamRef = arrowStream + context.addTaskCompletionListener[Unit] { _ => + streamRef.close() + allocator.close() + } + } + Iterator.single(arrowStream) + } + + /** + * Drive an `ArrowReader` from a per-task body and emit `ColumnarBatch`es wrapping the reader's + * stable VSR. Lifecycle: the supplied factory builds the reader against a fresh child + * allocator; both close at task completion. This is the non-native consumer path + * (`doExecuteColumnar`) -- the native consumer path uses [[stream]] to export instead. + */ + def readerBatchIter( + name: String, + readerFactory: BufferAllocator => ArrowReader): Iterator[ColumnarBatch] = { + val context = TaskContext.get() + val allocator = CometArrowAllocator.newChildAllocator(name, 0, Long.MaxValue) + val reader = + try readerFactory(allocator) + catch { + case t: Throwable => + try allocator.close() + catch { case _: Throwable => () } + throw t + } + if (context != null) { + context.addTaskCompletionListener[Unit] { _ => + reader.close() + allocator.close() + } + } + new Iterator[ColumnarBatch] { + // Lazily prefetch one batch so `hasNext` can answer without consuming. + private var loaded: Boolean = false + private var hasMore: Boolean = false + + private def ensureLoaded(): Unit = { + if (!loaded) { + hasMore = reader.loadNextBatch() + loaded = true + } + } + + override def hasNext: Boolean = { + ensureLoaded() + hasMore + } + + override def next(): ColumnarBatch = { + ensureLoaded() + if (!hasMore) { + throw new NoSuchElementException("No more batches") + } + loaded = false + NativeUtil.rootAsBatch(reader.getVectorSchemaRoot) + } + } + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala new file mode 100644 index 0000000000..e1829eb5c5 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.catalyst.InternalRow + +/** + * `ArrowReader` over an iterator of Spark `InternalRow`s, writing up to `maxRecordsPerBatch` rows + * per call into the reader's stable VSR via `ArrowWriter`. + * + * `ArrowWriter.create(root)` calls `vector.allocateNew()`, which releases any prior buffers and + * allocates fresh ones. This is required for FFI safety: previously-exported batches retain their + * buffers via the C release callback, so reusing those buffers in place would corrupt native + * consumers still holding the prior batch. + */ +private[comet] class RowArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + rowIter: Iterator[InternalRow], + maxRecordsPerBatch: Long, + onConversionNs: Long => Unit = _ => ()) + extends ArrowReader(allocator) { + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!rowIter.hasNext) { + return false + } + + val startNs = System.nanoTime() + val writer = ArrowWriter.create(getVectorSchemaRoot) + var rowCount = 0L + while (rowIter.hasNext && + (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + writer.write(rowIter.next()) + rowCount += 1 + } + writer.finish() + onConversionNs(System.nanoTime() - startNs) + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala new file mode 100644 index 0000000000..0af940fab3 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch} + +/** + * `ArrowReader` over an iterator of Spark-side `ColumnarBatch`es (not Arrow-backed). Slices up to + * `maxRecordsPerBatch` rows per `loadNextBatch` from the current Spark batch into the reader's + * stable VSR via `ArrowWriter.writeCol`. Spark's `ColumnVector` implementations aren't Arrow + * buffers, so this reader necessarily copies element values into Arrow format. + */ +private[comet] class SparkColumnarArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + source: Iterator[ColumnarBatch], + maxRecordsPerBatch: Int, + onConversionNs: Long => Unit = _ => ()) + extends ArrowReader(allocator) { + + private var current: ColumnarBatch = _ + private var rowsConsumedInCurrent: Int = 0 + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + private def advanceToNonEmptyBatch(): Boolean = { + while (current == null || rowsConsumedInCurrent >= current.numRows()) { + if (current != null) { + // We don't own Spark ColumnarBatches; just drop the reference. + current = null + rowsConsumedInCurrent = 0 + } + if (!source.hasNext) { + return false + } + current = source.next() + rowsConsumedInCurrent = 0 + } + true + } + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!advanceToNonEmptyBatch()) { + return false + } + + val startNs = System.nanoTime() + val rowsRemaining = current.numRows() - rowsConsumedInCurrent + val rowsToProduce = + if (maxRecordsPerBatch <= 0) rowsRemaining + else math.min(maxRecordsPerBatch, rowsRemaining) + + val writer = ArrowWriter.create(getVectorSchemaRoot) + var col = 0 + while (col < current.numCols()) { + val column = current.column(col) + val columnArray = new ColumnarArray(column, rowsConsumedInCurrent, rowsToProduce) + if (column.hasNull) { + writer.writeCol(columnArray, col) + } else { + writer.writeColNoNull(columnArray, col) + } + col += 1 + } + rowsConsumedInCurrent += rowsToProduce + + writer.finish() + onConversionNs(System.nanoTime() - startNs) + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index f27d021ac4..8a9fd9019c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.comet.{CometExec, CometMetricNode} +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch @@ -96,8 +98,14 @@ class CometNativeShuffleWriter[K, V]( // Getting rid of the fake partitionId val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) + val arrowStream = CometArrowStream.fromColumnarBatchIter( + newInputs.asInstanceOf[Iterator[ColumnarBatch]], + CometUtils.fromAttributes(outputAttributes), + CometArrowStream.NATIVE_TIMEZONE, + "CometNativeShuffleWriter") + val cometIter = CometExec.getCometIterator( - Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), + Array(arrowStream.asInstanceOf[Object]), outputAttributes.length, nativePlan, nativeMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 7d5398ae62..aa4ffad19f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ @@ -311,13 +312,13 @@ object CometExec { } def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, nativePlan: Operator, numParts: Int, partitionIdx: Int): CometExecIterator = { getCometIterator( - inputs, + inputObjects, numOutputCols, nativePlan, CometMetricNode(Map.empty), @@ -332,14 +333,14 @@ object CometExec { * executing the same plan across multiple partitions to avoid serializing the plan repeatedly. */ def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, serializedPlan: Array[Byte], numParts: Int, partitionIdx: Int): CometExecIterator = { new CometExecIterator( newIterId, - inputs, + inputObjects, numOutputCols, serializedPlan, CometMetricNode(Map.empty), @@ -350,7 +351,7 @@ object CometExec { } def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, nativePlan: Operator, nativeMetrics: CometMetricNode, @@ -361,7 +362,7 @@ object CometExec { val bytes = serializeNativePlan(nativePlan) new CometExecIterator( newIterId, - inputs, + inputObjects, numOutputCols, bytes, nativeMetrics, @@ -473,10 +474,11 @@ abstract class CometNativeExec extends CometExec { // Find planning data within this stage (stops at shuffle boundaries). val (commonByKey, perPartitionByKey) = findAllPlanData(this) - // Collect the input ColumnarBatches from the child operators and create a CometExecIterator - // to execute the native plan. + // Collect the input batches from the child operators. Non-shuffle inputs become + // RDD[ArrowArrayStream] (one stream per partition, exported via the C Stream Interface + // for native consumption); shuffle inputs stay as CometShuffledBatchRDD. val sparkPlans = ArrayBuffer.empty[SparkPlan] - val inputs = ArrayBuffer.empty[RDD[ColumnarBatch]] + val inputs = ArrayBuffer.empty[RDD[_]] foreachUntilCometInput(this)(sparkPlans += _) @@ -503,15 +505,85 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"Cannot find the first non broadcast plan: $this") } + def isShuffleScanInput(plan: SparkPlan): Boolean = plan match { + case _: CometShuffleExchangeExec | _: ShuffleQueryStageExec | _: AQEShuffleReadExec => + true + case ReusedExchangeExec(_, _: CometShuffleExchangeExec) => true + case _ => false + } + + // The protobuf is the source of truth for whether a slot is a ShuffleScan or a regular + // Scan: `CometExchangeSink.shouldUseShuffleScan` only fires for AQE wrappers + // (`ShuffleQueryStageExec`), so a bare non-AQE `CometShuffleExchangeExec` always serializes + // as a regular Scan regardless of `COMET_SHUFFLE_DIRECT_READ_ENABLED`. Driving the JVM + // dispatch from `shuffleScanIndices` instead of the conf keeps the two aligned. + val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) + + def isBroadcastInput(plan: SparkPlan): Boolean = plan match { + case _: CometBroadcastExchangeExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true + case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true + case ReusedExchangeExec(_, _: CometBroadcastExchangeExec) => true + case _ => false + } + + // Unwrap any number of AQE / reuse wrappers to find a CometBroadcastExchangeExec, if + // present. Returns the unwrapped exchange for input wiring -- broadcast partition counts + // are coerced to match firstNonBroadcastPlanNumPartitions, so we always read from the + // underlying exchange directly. + def asBroadcastExchange(plan: SparkPlan): Option[CometBroadcastExchangeExec] = + plan match { + case c: CometBroadcastExchangeExec => Some(c) + case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => Some(c) + case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => Some(c) + case BroadcastQueryStageExec( + _, + ReusedExchangeExec(_, c: CometBroadcastExchangeExec), + _) => + Some(c) + case _ => None + } + + def asArrowStreamRDD(plan: SparkPlan, partitionCount: Int, scanSlot: Int): RDD[_] = + plan match { + case s: CometNativeArrowSource => + s.doExecuteAsArrowStream() + case _ if asBroadcastExchange(plan).isDefined => + val c = asBroadcastExchange(plan).get + CometArrowStream.wrapColumnarBatchRDD( + c.executeColumnar(partitionCount), + c.schema, + CometArrowStream.NATIVE_TIMEZONE, + c.nodeName) + case _ if isShuffleScanInput(plan) && shuffleScanIndices.contains(scanSlot) => + // Direct-read shuffle: `CometShuffledBatchRDD` reaches native via + // CometShuffleBlockIterator. Other shuffle slots fall through and get wrapped. + plan.executeColumnar() + case _ => + CometArrowStream.wrapColumnarBatchRDD( + plan.executeColumnar(), + plan.schema, + CometArrowStream.NATIVE_TIMEZONE, + plan.nodeName) + } + // If the first non broadcast plan is found, we need to adjust the partition number of // the broadcast plans to make sure they have the same partition number as the first non // broadcast plan. + // Walk-order: count how many non-CometNativeExec plans come before the firstNonBroadcast + // plan in `sparkPlans`. That's the slot index it will occupy in `inputs`, and therefore + // the protobuf scan-slot index whose Scan vs ShuffleScan classification governs whether + // it should be wrapped or direct-read. + val firstNonBroadcastSlot = sparkPlans + .take(firstNonBroadcastPlan.get._2) + .count(p => !p.isInstanceOf[CometNativeExec]) + val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) = firstNonBroadcastPlan.get._1 match { case plan: CometNativeExec => (null, plan.outputPartitioning.numPartitions) case plan => - val rdd = plan.executeColumnar() + val rdd = asArrowStreamRDD(plan, 0, firstNonBroadcastSlot) (rdd, rdd.getNumPartitions) } @@ -520,24 +592,17 @@ abstract class CometNativeExec extends CometExec { // partition number of Broadcast RDDs to make sure they have the same partition number. sparkPlans.zipWithIndex.foreach { case (plan, idx) => plan match { - case c: CometBroadcastExchangeExec => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case BroadcastQueryStageExec( - _, - ReusedExchangeExec(_, c: CometBroadcastExchangeExec), - _) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case _: CometNativeExec => // no-op case _ if idx == firstNonBroadcastPlan.get._2 => inputs += firstNonBroadcastPlanRDD case _ => - val rdd = plan.executeColumnar() - if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { + // Each plan we add to `inputs` corresponds to the next protobuf scan slot, in + // walk order. `inputs.size` is the slot index this plan will occupy. + val scanSlot = inputs.size + val rdd = asArrowStreamRDD(plan, firstNonBroadcastPlanNumPartitions, scanSlot) + if (!isBroadcastInput(plan) && + rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { throw new CometRuntimeException( s"Partition number mismatch: ${rdd.getNumPartitions} != " + s"$firstNonBroadcastPlanNumPartitions") @@ -551,9 +616,6 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } - // Detect ShuffleScan indices for direct read in CometExecRDD - val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) - // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) val hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec]) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 4605e641f1..a645c7b17a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -205,6 +205,16 @@ object Utils extends CometTypeShim with Logging { }.asJava) } + /** + * Build a `StructType` from a sequence of Spark `Attribute`s. Avoids + * `StructType.fromAttributes` (removed in Spark 4) and `DataTypeUtils.fromAttributes` (only on + * 4) so the same call works across supported Spark versions. + */ + def fromAttributes( + attributes: Seq[org.apache.spark.sql.catalyst.expressions.Attribute]): StructType = + StructType(attributes.map(a => + org.apache.spark.sql.types.StructField(a.name, a.dataType, a.nullable, a.metadata))) + /** * Serializes a list of `ColumnarBatch` into an output stream. This method must be in `spark` * package because `ChunkedByteBufferOutputStream` is spark private class. As it uses Arrow diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala index 9c34b3a3ce..e30a1cf6b3 100644 --- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.PrettyAttribute import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometExec, CometExecUtils} -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch class CometNativeSuite extends CometTestBase { @@ -31,15 +32,16 @@ class CometNativeSuite extends CometTestBase { val rdd = spark.range(0, 1).rdd.map { value => val limitOp = CometExecUtils.getLimitNativePlan(Seq(PrettyAttribute("test", LongType)), 100).get - val cometIter = CometExec.getCometIterator( - Seq(new Iterator[ColumnarBatch] { + val arrowStream = CometArrowStream.fromColumnarBatchIter( + new Iterator[ColumnarBatch] { override def hasNext: Boolean = true override def next(): ColumnarBatch = throw new NullPointerException() - }), - 1, - limitOp, - 1, - 0) + }, + StructType(Seq(StructField("test", LongType, nullable = false))), + CometArrowStream.NATIVE_TIMEZONE, + "test-npe") + val cometIter = + CometExec.getCometIterator(Array(arrowStream.asInstanceOf[Object]), 1, limitOp, 1, 0) try { cometIter.next() } finally { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 8c8c19bb9c..d6075dac55 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3956,6 +3956,29 @@ class CometExecSuite extends CometTestBase { } } + test("CometLocalTableScanExec does not leak Arrow buffers (project consumer)") { + // Forces a CometNativeExec consumer over an ArrowArrayStream input. The producer must not + // leak the Arrow buffers it allocates per batch; if it does, the BaseAllocator + // leak detector fires inside the task completion listener. + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val session = spark + import session.implicits._ + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkSparkAnswer(df.select($"a" + 1)) + } + } + + test("CometLocalTableScanExec does not leak Arrow buffers (collect_list)") { + // Mirrors DataFrameAggregateSuite "collect functions" which is the test that + // surfaced the leak in CI. + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val session = spark + import session.implicits._ + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkSparkAnswer(df.select(collect_list($"a"), collect_list($"b"))) + } + } + test("Native_datafusion reports correct files and bytes scanned") { val inputFiles = 2 diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala index b858fe5c83..a2aac7e6c7 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala @@ -492,9 +492,13 @@ class CometNativeColumnarToRowSuite extends CometTestBase with AdaptiveSparkPlan InternalRow(i, UTF8String.fromString(s"value_$i")) } - // Create batches using rowToArrowBatchIter which handles shading internally + // Each emitted batch needs independent Arrow buffers so the test can hold rows from + // earlier batches while later batches are consumed. CometArrowConverters allocates a + // fresh VSR per batch from the supplied allocator. + val allocator = + org.apache.comet.CometArrowAllocator.newChildAllocator("c2r-test", 0, Long.MaxValue) val batchIter = CometArrowConverters - .rowToArrowBatchIter(rows.iterator, schema, rowsPerBatch, "UTC", null) + .rowToArrowBatchIter(rows.iterator, schema, rowsPerBatch, "UTC", allocator) val converter = new NativeColumnarToRowConverter(schema, rowsPerBatch) try { @@ -529,6 +533,7 @@ class CometNativeColumnarToRowSuite extends CometTestBase with AdaptiveSparkPlan "reused UnsafeRow object.") } finally { converter.close() + allocator.close() } } From b6db9969a5e732e309bab0c17a595d534a32dd25 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 14:02:44 -0400 Subject: [PATCH 09/12] Unpack dictionaries. --- .../arrow/ColumnarBatchArrowReader.scala | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala index eaacc3968b..2cb8746107 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala @@ -23,11 +23,12 @@ import java.util.{ArrayList => JArrayList} import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.dictionary.DictionaryEncoder import org.apache.arrow.vector.ipc.ArrowReader import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.vector.CometVector +import org.apache.comet.vector.{CometDictionaryVector, CometVector} /** * `ArrowReader` over an iterator of Arrow-backed `ColumnarBatch`es. Each `loadNextBatch` unloads @@ -56,12 +57,29 @@ private[comet] class ColumnarBatchArrowReader( } val src = source.next() + var materialized: JArrayList[FieldVector] = null try { val sourceVectors = new JArrayList[FieldVector](src.numCols()) var i = 0 while (i < src.numCols()) { - sourceVectors.add( - src.column(i).asInstanceOf[CometVector].getValueVector.asInstanceOf[FieldVector]) + val col = src.column(i).asInstanceOf[CometVector] + val fv = col match { + case d: CometDictionaryVector => + // Stable VSR was built from the logical (non-dict) schema, so a dict-encoded + // source's indices layout would mismatch the dest buffer count on load. Native + // unpacks downstream anyway via copy_or_unpack_array. + val indices = d.getValueVector + val dictionary = d.provider.lookup(indices.getField.getDictionary.getId) + val plain = DictionaryEncoder + .decode(indices, dictionary, allocator) + .asInstanceOf[FieldVector] + if (materialized == null) materialized = new JArrayList[FieldVector]() + materialized.add(plain) + plain + case _ => + col.getValueVector.asInstanceOf[FieldVector] + } + sourceVectors.add(fv) i += 1 } val transient = new VectorSchemaRoot(sourceVectors) @@ -74,9 +92,17 @@ private[comet] class ColumnarBatchArrowReader( } finally { rb.close() } - // Note: do not close `transient`. It shares FieldVectors with `src`; closing `src` below + // Do not close `transient`. It shares FieldVectors with `src`; closing `src` below // releases the producer-side refs. Closing `transient` would double-release. } finally { + if (materialized != null) { + var j = 0 + while (j < materialized.size()) { + try materialized.get(j).close() + catch { case _: Throwable => () } + j += 1 + } + } src.close() } true From cf7bb6ed003f5e2a8a617a9aab4ec45e249b5dc6 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 14:45:22 -0400 Subject: [PATCH 10/12] Fix shading issue. --- spark/pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/pom.xml b/spark/pom.xml index 6d97ea831f..25c6c34e45 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -476,6 +476,7 @@ under the License. org/apache/arrow/c/jni/PrivateData org/apache/arrow/c/jni/CDataJniException + org/apache/arrow/c/ArrayStreamExporter org/apache/arrow/c/ArrayStreamExporter$ExportedArrayStreamPrivateData From 82c9a1b357da4cc91d65b66f3707b56db104ee74 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 15:25:45 -0400 Subject: [PATCH 11/12] Try again to fix shading issue. --- spark/pom.xml | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/spark/pom.xml b/spark/pom.xml index 25c6c34e45..ef1b07abf1 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -469,15 +469,13 @@ under the License. org.apache.arrow ${comet.shade.packageName}.arrow - - org/apache/arrow/c/jni/JniWrapper - org/apache/arrow/c/jni/PrivateData - org/apache/arrow/c/jni/CDataJniException - - org/apache/arrow/c/ArrayStreamExporter - org/apache/arrow/c/ArrayStreamExporter$ExportedArrayStreamPrivateData + + org/apache/arrow/c/** From 6adf124c51dd1a7b7c1d2da82ee97489a4cfb32a Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 16:36:44 -0400 Subject: [PATCH 12/12] Fix alignment issue for FFI Decimal128 with ArrowArrayStreamReader --- .../operators/aligned_stream_reader.rs | 110 ++++++++++++++++++ native/core/src/execution/operators/mod.rs | 2 + native/core/src/execution/operators/scan.rs | 11 +- native/core/src/execution/planner.rs | 5 +- 4 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 native/core/src/execution/operators/aligned_stream_reader.rs diff --git a/native/core/src/execution/operators/aligned_stream_reader.rs b/native/core/src/execution/operators/aligned_stream_reader.rs new file mode 100644 index 0000000000..c1d615a79f --- /dev/null +++ b/native/core/src/execution/operators/aligned_stream_reader.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{RecordBatch, RecordBatchOptions, StructArray}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::error::ArrowError; +use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::ffi_stream::FFI_ArrowArrayStream; +use std::ffi::CStr; +use std::sync::Arc; + +/// C Stream Interface reader that calls [`arrow::array::ArrayData::align_buffers`] on every +/// imported batch before constructing typed arrays. Stock `ArrowArrayStreamReader` panics +/// when a JVM producer hands us a `Decimal128` buffer at an offset that is 8-byte but not +/// 16-byte aligned, which Java's allocator does not guarantee. Track upstream: +/// . +#[derive(Debug)] +pub struct AlignedArrowStreamReader { + stream: FFI_ArrowArrayStream, + schema: SchemaRef, +} + +impl AlignedArrowStreamReader { + /// # Safety + /// `raw` must point at a valid `FFI_ArrowArrayStream` whose ownership is being transferred + /// to this reader. The stream's release callback fires when the reader is dropped. + pub unsafe fn from_raw(raw: *mut FFI_ArrowArrayStream) -> Result { + let mut stream = FFI_ArrowArrayStream::from_raw(raw); + if stream.release.is_none() { + return Err(ArrowError::CDataInterface( + "input stream is already released".to_string(), + )); + } + let schema = read_schema(&mut stream)?; + Ok(Self { stream, schema }) + } + + pub fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn last_error(&mut self) -> Option { + let get = self.stream.get_last_error?; + let ptr = unsafe { get(&mut self.stream) }; + if ptr.is_null() { + return None; + } + Some( + unsafe { CStr::from_ptr(ptr) } + .to_string_lossy() + .into_owned(), + ) + } +} + +impl Iterator for AlignedArrowStreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + let mut array = FFI_ArrowArray::empty(); + let ret = unsafe { self.stream.get_next.unwrap()(&mut self.stream, &mut array) }; + if ret != 0 { + let msg = self + .last_error() + .unwrap_or_else(|| format!("get_next returned {ret}")); + return Some(Err(ArrowError::CDataInterface(msg))); + } + if array.is_released() { + return None; + } + + let dt = DataType::Struct(self.schema.fields().clone()); + Some( + unsafe { from_ffi_and_data_type(array, dt) }.and_then(|mut data| { + data.align_buffers(); + let len = data.len(); + RecordBatch::try_new_with_options( + Arc::clone(&self.schema), + StructArray::from(data).into_parts().1, + &RecordBatchOptions::new().with_row_count(Some(len)), + ) + }), + ) + } +} + +fn read_schema(stream: &mut FFI_ArrowArrayStream) -> Result { + let mut schema = FFI_ArrowSchema::empty(); + let ret = unsafe { stream.get_schema.unwrap()(stream, &mut schema) }; + if ret != 0 { + return Err(ArrowError::CDataInterface(format!( + "Cannot get schema from input stream. Error code: {ret}" + ))); + } + Ok(Arc::new(Schema::try_from(&schema)?)) +} diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 4b2c06575d..d68252bd9b 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -19,10 +19,12 @@ pub use crate::errors::ExecutionError; +pub use aligned_stream_reader::*; pub use copy::*; pub use iceberg_scan::*; pub use scan::*; +mod aligned_stream_reader; mod copy; mod expand; pub use expand::ExpandExec; diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index bad349e4d7..2ef32f6a13 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::operators::{copy_or_unpack_array, CopyMode}; +use crate::execution::operators::{copy_or_unpack_array, AlignedArrowStreamReader, CopyMode}; use crate::{errors::CometError, execution::planner::TEST_EXEC_CONTEXT_ID}; use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::ffi_stream::ArrowArrayStreamReader; use datafusion::common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::metrics::{ @@ -50,7 +49,7 @@ pub struct ScanExec { pub exec_context_id: i64, /// The C Stream Interface reader. `None` only in unit tests that seed input via /// `set_input_batch`. - pub input_source: Option>>, + pub input_source: Option>>, pub input_source_description: String, pub data_types: Vec, pub schema: SchemaRef, @@ -65,7 +64,7 @@ pub struct ScanExec { impl ScanExec { pub fn new( exec_context_id: i64, - input_source: Option>>, + input_source: Option>>, input_source_description: &str, data_types: Vec, ) -> Result { @@ -136,7 +135,7 @@ impl ScanExec { /// columns are unpacked because Comet's downstream operators do not handle them. fn pull_next( exec_context_id: i64, - reader: &Arc>, + reader: &Arc>, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { // Unit test path; input batches are seeded directly. @@ -145,7 +144,7 @@ impl ScanExec { let mut reader = reader .try_lock() - .map_err(|_| CometError::Internal("ArrowArrayStreamReader contended".to_string()))?; + .map_err(|_| CometError::Internal("AlignedArrowStreamReader contended".to_string()))?; let next = reader.next(); match next { diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 77b174ea8d..6213ce6b11 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -23,6 +23,7 @@ pub mod operator_registry; use crate::errors::CometError; use crate::execution::operators::init_csv_datasource_exec; +use crate::execution::operators::AlignedArrowStreamReader; use crate::execution::operators::IcebergScanExec; use crate::execution::{ expressions::list_positions::ListPositionsExpr, @@ -1451,7 +1452,7 @@ impl PhysicalPlanner { // Consumes the first input source for the scan. The Java side passes an // `org.apache.arrow.c.ArrowArrayStream` whose `memoryAddress` points at the C - // struct; native takes ownership via `ArrowArrayStreamReader::from_raw`. + // struct; native takes ownership via `AlignedArrowStreamReader::from_raw`. let input_source = if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { @@ -1466,7 +1467,7 @@ impl PhysicalPlanner { Ok(addr) })?; let reader = unsafe { - arrow::ffi_stream::ArrowArrayStreamReader::from_raw( + AlignedArrowStreamReader::from_raw( address as *mut arrow::ffi_stream::FFI_ArrowArrayStream, ) }