diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 8c75df1d45..c60202a96f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -41,6 +41,16 @@ import org.apache.comet.serde.Metric case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometMetricNode]) extends Logging { + /** + * Returns the leaf node (deepest single-child descendant). For a native scan plan like + * FilterExec -> DataSourceExec, this returns the DataSourceExec node which has the + * bytes_scanned and output_rows metrics from the Parquet reader. + */ + def leafNode: CometMetricNode = { + if (children.isEmpty) this + else children.head.leafNode + } + /** * Gets a child node. Called from native. */ @@ -79,6 +89,7 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM } } + // Called via JNI from `comet_metric_node.rs` def set_all_from_bytes(bytes: Array[Byte]): Unit = { val metricNode = Metric.NativeMetricNode.parseFrom(bytes) set_all(metricNode) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 2965e46988..4ab101fea8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -558,7 +559,7 @@ abstract class CometNativeExec extends CometExec { // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) - CometExecRDD( + new CometExecRDD( sparkContext, inputs.toSeq, commonByKey, @@ -570,7 +571,32 @@ abstract class CometNativeExec extends CometExec { subqueries, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleScanIndices) + shuffleScanIndices) { + override def compute( + split: Partition, + context: TaskContext): Iterator[ColumnarBatch] = { + val res = super.compute(split, context) + + // Report scan input metrics only when the native plan contains a scan. + if (sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])) { + Option(context).foreach { ctx => + ctx.addTaskCompletionListener[Unit] { _ => + val leaf = nativeMetrics.leafNode + leaf.metrics.get("bytes_scanned").foreach { bs => + ctx.taskMetrics().inputMetrics.setBytesRead(bs.value) + val outputRows = + leaf.metrics.get("output_rows").map(_.value).getOrElse(0L) + val prunedRows = + leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L) + ctx.taskMetrics().inputMetrics.setRecordsRead(outputRows + prunedRows) + } + } + } + } + + res + } + } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 3946aab184..5b6225b720 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet import scala.collection.mutable +import org.apache.spark.SparkConf import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.SparkListener @@ -28,10 +29,17 @@ import org.apache.spark.scheduler.SparkListenerTaskEnd import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.comet.CometConf + class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { + override protected def sparkConf: SparkConf = { + super.sparkConf.set("spark.ui.enabled", "true") + } + import testImplicits._ test("per-task native shuffle metrics") { @@ -91,4 +99,80 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("native_datafusion scan reports task-level input metrics matching Spark") { + val totalRows = 10000 + withTempPath { dir => + spark + .createDataFrame((0 until totalRows).map(i => (i, s"elem_$i"))) + .repartition(5) + .write + .parquet(dir.getAbsolutePath) + spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("tbl") + // Collect baseline input metrics from vanilla Spark (Comet disabled) + val (sparkBytes, sparkRecords, _) = + collectInputMetrics(CometConf.COMET_ENABLED.key -> "false") + + // Collect input metrics from Comet native_datafusion scan. + val (cometBytes, cometRecords, cometPlan) = collectInputMetrics( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + // Verify the plan actually used CometNativeScanExec + assert( + find(cometPlan)(_.isInstanceOf[CometNativeScanExec]).isDefined, + s"Expected CometNativeScanExec in plan:\n${cometPlan.treeString}") + + assert(sparkRecords > 0, s"Spark outputRecords should be > 0, got $sparkRecords") + assert(cometRecords > 0, s"Comet outputRecords should be > 0, got $cometRecords") + + assert( + cometRecords == sparkRecords, + s"recordsRead mismatch: comet=$cometRecords, sparkRecords=$sparkRecords") + + // Bytes should be in the same ballpark -- both read the same Parquet file(s), + // but the exact byte count can differ due to reader implementation details + // (e.g. footer reads, page headers, buffering granularity). + assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes") + assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") + val ratio = cometBytes.toDouble / sparkBytes.toDouble + assert( + ratio >= 0.8 && ratio <= 1.2, + s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") + } + } + + /** + * Runs `SELECT * FROM tbl WHERE _1 > 2000` with the given SQL config overrides and returns the + * aggregated (bytesRead, recordsRead) across all tasks, along with the executed plan. + * + * Uses AppStatusStore (same source as Spark UI) to read task-level input metrics. + * AppStatusStore stores immutable snapshots of metric values, unlike SparkListener's + * InputMetrics which are backed by mutable accumulators that can be reset. + */ + private def collectInputMetrics(confs: (String, String)*): (Long, Long, SparkPlan) = { + val store = spark.sparkContext.statusStore + + // Record existing stage IDs so we only look at stages from our query + val stagesBefore = store.stageList(null).map(_.stageId).toSet + + var plan: SparkPlan = null + withSQLConf(confs: _*) { + val df = sql("SELECT * FROM tbl where _1 > 2000") + df.collect() + plan = stripAQEPlan(df.queryExecution.executedPlan) + } + + // Wait for listener bus to flush all events into the status store + spark.sparkContext.listenerBus.waitUntilEmpty() + + // Sum input metrics from stages created by our query + val newStages = store.stageList(null).filterNot(s => stagesBefore.contains(s.stageId)) + assert(newStages.nonEmpty, s"No new stages found for confs=$confs") + + val totalBytes = newStages.map(_.inputBytes).sum + val totalRecords = newStages.map(_.inputRecords).sum + + (totalBytes, totalRecords, plan) + } }