Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ import org.apache.comet.serde.Metric
case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometMetricNode])
extends Logging {

/**
* Returns the leaf node (deepest single-child descendant). For a native scan plan like
* FilterExec -> DataSourceExec, this returns the DataSourceExec node which has the
* bytes_scanned and output_rows metrics from the Parquet reader.
*/
def leafNode: CometMetricNode = {
if (children.isEmpty) this
else children.head.leafNode
}

/**
* Gets a child node. Called from native.
*/
Expand Down Expand Up @@ -79,6 +89,7 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM
}
}

// Called via JNI from `comet_metric_node.rs`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that the only place this will ever be called from? Otherwise I'm not sure the comment is necessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDE highlights the method as unused because it is called via JNI only, can be accidentally cleaned up. Added comments to clarify

def set_all_from_bytes(bytes: Array[Byte]): Unit = {
val metricNode = Metric.NativeMetricNode.parseFrom(bytes)
set_all(metricNode)
Expand Down
30 changes: 28 additions & 2 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -558,7 +559,7 @@ abstract class CometNativeExec extends CometExec {

// Unified RDD creation - CometExecRDD handles all cases
val subqueries = collectSubqueries(this)
CometExecRDD(
new CometExecRDD(
sparkContext,
inputs.toSeq,
commonByKey,
Expand All @@ -570,7 +571,32 @@ abstract class CometNativeExec extends CometExec {
subqueries,
broadcastedHadoopConfForEncryption,
encryptedFilePaths,
shuffleScanIndices)
shuffleScanIndices) {
override def compute(
split: Partition,
context: TaskContext): Iterator[ColumnarBatch] = {
val res = super.compute(split, context)

// Report scan input metrics only when the native plan contains a scan.
if (sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])) {
Option(context).foreach { ctx =>
ctx.addTaskCompletionListener[Unit] { _ =>
val leaf = nativeMetrics.leafNode
leaf.metrics.get("bytes_scanned").foreach { bs =>
ctx.taskMetrics().inputMetrics.setBytesRead(bs.value)
val outputRows =
leaf.metrics.get("output_rows").map(_.value).getOrElse(0L)
val prunedRows =
leaf.metrics.get("pushdown_rows_pruned").map(_.value).getOrElse(0L)
ctx.taskMetrics().inputMetrics.setRecordsRead(outputRows + prunedRows)
}
}
}
}

res
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,25 @@ package org.apache.spark.sql.comet

import scala.collection.mutable

import org.apache.spark.SparkConf
import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.SparkListener
import org.apache.spark.scheduler.SparkListenerTaskEnd
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper

import org.apache.comet.CometConf

class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {

override protected def sparkConf: SparkConf = {
super.sparkConf.set("spark.ui.enabled", "true")
}

import testImplicits._

test("per-task native shuffle metrics") {
Expand Down Expand Up @@ -91,4 +99,80 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("native_datafusion scan reports task-level input metrics matching Spark") {
val totalRows = 10000
withTempPath { dir =>
spark
.createDataFrame((0 until totalRows).map(i => (i, s"elem_$i")))
.repartition(5)
.write
.parquet(dir.getAbsolutePath)
spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("tbl")
// Collect baseline input metrics from vanilla Spark (Comet disabled)
val (sparkBytes, sparkRecords, _) =
collectInputMetrics(CometConf.COMET_ENABLED.key -> "false")

// Collect input metrics from Comet native_datafusion scan.
val (cometBytes, cometRecords, cometPlan) = collectInputMetrics(
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)

// Verify the plan actually used CometNativeScanExec
assert(
find(cometPlan)(_.isInstanceOf[CometNativeScanExec]).isDefined,
s"Expected CometNativeScanExec in plan:\n${cometPlan.treeString}")

assert(sparkRecords > 0, s"Spark outputRecords should be > 0, got $sparkRecords")
assert(cometRecords > 0, s"Comet outputRecords should be > 0, got $cometRecords")

assert(
cometRecords == sparkRecords,
s"recordsRead mismatch: comet=$cometRecords, sparkRecords=$sparkRecords")

// Bytes should be in the same ballpark -- both read the same Parquet file(s),
// but the exact byte count can differ due to reader implementation details
// (e.g. footer reads, page headers, buffering granularity).
assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
val ratio = cometBytes.toDouble / sparkBytes.toDouble
assert(
ratio >= 0.8 && ratio <= 1.2,
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
}
}

/**
* Runs `SELECT * FROM tbl WHERE _1 > 2000` with the given SQL config overrides and returns the
* aggregated (bytesRead, recordsRead) across all tasks, along with the executed plan.
*
* Uses AppStatusStore (same source as Spark UI) to read task-level input metrics.
* AppStatusStore stores immutable snapshots of metric values, unlike SparkListener's
* InputMetrics which are backed by mutable accumulators that can be reset.
*/
private def collectInputMetrics(confs: (String, String)*): (Long, Long, SparkPlan) = {
val store = spark.sparkContext.statusStore

// Record existing stage IDs so we only look at stages from our query
val stagesBefore = store.stageList(null).map(_.stageId).toSet

var plan: SparkPlan = null
withSQLConf(confs: _*) {
val df = sql("SELECT * FROM tbl where _1 > 2000")
df.collect()
plan = stripAQEPlan(df.queryExecution.executedPlan)
}

// Wait for listener bus to flush all events into the status store
spark.sparkContext.listenerBus.waitUntilEmpty()

// Sum input metrics from stages created by our query
val newStages = store.stageList(null).filterNot(s => stagesBefore.contains(s.stageId))
assert(newStages.nonEmpty, s"No new stages found for confs=$confs")

val totalBytes = newStages.map(_.inputBytes).sum
val totalRecords = newStages.map(_.inputRecords).sum

(totalBytes, totalRecords, plan)
}
}
Loading