From c576288c4d358d11e6fc167e5ee2d40b58c4a88f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BE=8A=E5=B7=9D?= Date: Thu, 28 May 2026 23:36:38 +0800 Subject: [PATCH 1/3] fix --- .../fluss/spark/read/FlussScanBuilder.scala | 19 ++++++++------ .../spark/utils/SparkPartitionPredicate.scala | 23 ++++++++++++++--- .../fluss/spark/SparkLogTableReadTest.scala | 1 + .../spark/SparkPrimaryKeyTableReadTest.scala | 25 +++++++++++++++++++ 4 files changed, 57 insertions(+), 11 deletions(-) diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScanBuilder.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScanBuilder.scala index 361542d769..875c959269 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScanBuilder.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScanBuilder.scala @@ -52,8 +52,10 @@ trait FlussSupportsPushDownPartitionFilters protected var partitionPredicate: Option[FlussPredicate] = None override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { - partitionPredicate = SparkPartitionPredicate.extract(tableInfo, predicates.toSeq) - predicates + val (nonPartitionPred, partitionPred) = + SparkPartitionPredicate.extract(tableInfo, predicates.toSeq) + partitionPredicate = partitionPred + nonPartitionPred.toArray } override def pushedPredicates(): Array[Predicate] = Array.empty @@ -73,12 +75,12 @@ trait FlussSupportsPushDownV2Filters extends FlussSupportsPushDownPartitionFilte } override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { - super.pushPredicates(predicates) + val nonPartitionPredicates = super.pushPredicates(predicates) // Server-side batch filter only supports ARROW; other log formats reject it. if (tableInfo.getTableConfig.getLogFormat == LogFormat.ARROW) { - convertAndStorePredicates(predicates) + convertAndStorePredicates(nonPartitionPredicates) } - predicates + nonPartitionPredicates } override def pushedPredicates(): Array[Predicate] = acceptedPredicates @@ -96,8 +98,11 @@ trait FlussLakeSupportsPushDownV2Filters extends FlussSupportsPushDownV2Filters def flussConfig: FlussConfiguration override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + val nonPartitionPredicates = super.pushPredicates(predicates) val pairs = - SparkPredicateConverter.convertPerPredicate(tableInfo.getRowType, predicates.toSeq) + SparkPredicateConverter.convertPerPredicate( + tableInfo.getRowType, + nonPartitionPredicates.toSeq) val (acceptedSpark, acceptedFluss) = if (pairs.isEmpty) { (Seq.empty[Predicate], Seq.empty[FlussPredicate]) } else { @@ -112,7 +117,7 @@ trait FlussLakeSupportsPushDownV2Filters extends FlussSupportsPushDownV2Filters } pushedPredicate = SparkPredicateConverter.combineAnd(acceptedFluss) acceptedPredicates = acceptedSpark.toArray - predicates + nonPartitionPredicates } } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala index 353334c6e4..ca134dd350 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala @@ -28,25 +28,40 @@ import scala.jdk.CollectionConverters._ /** Extracts a partition-key predicate and prunes the partition list at planning time. */ object SparkPartitionPredicate { - def extract(tableInfo: TableInfo, predicates: Seq[Predicate]): Option[FlussPredicate] = { + def extract( + tableInfo: TableInfo, + predicates: Seq[Predicate]): (Seq[Predicate], Option[FlussPredicate]) = { val partitionKeys = tableInfo.getPartitionKeys - if (partitionKeys.isEmpty) return None + if (partitionKeys.isEmpty) { + return (predicates, None) + } val rowType = PartitionUtils.partitionRowType(tableInfo) val onlyPartitionKeys = new PartitionPredicateVisitor(partitionKeys) - val converted = predicates.flatMap { + // Separate predicates: those that can be converted and only touch partition keys + val (partitionPredicates, nonPartitionPredicates) = predicates.partition { + sparkPredicate => + SparkPredicateConverter + .convert(rowType, sparkPredicate) + .exists(_.visit(onlyPartitionKeys)) + } + + // Convert partition predicates to FlussPredicate + val converted = partitionPredicates.flatMap { sparkPredicate => SparkPredicateConverter .convert(rowType, sparkPredicate) .filter(_.visit(onlyPartitionKeys)) } - converted match { + val partitionPredicate = converted match { case Seq() => None case Seq(single) => Some(single) case many => Some(PredicateBuilder.and(many.asJava)) } + + (nonPartitionPredicates, partitionPredicate) } def filterPartitions( diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala index 3276decb77..7c7fb8a5c6 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkLogTableReadTest.scala @@ -487,6 +487,7 @@ class SparkLogTableReadTest extends FlussSparkTestBase { |WHERE dt = '2026-01-02' AND amount > 603 ORDER BY orderId""".stripMargin) checkAnswer(query, Row(900L) :: Nil) assert(partitionPredicate(query).isDefined) + assert(pushedPredicates(query).nonEmpty) } } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala index cddca21722..7ad0f41fc4 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala @@ -23,6 +23,8 @@ import org.apache.fluss.metadata.{TableBucket, TablePath} import org.apache.fluss.spark.read.{FlussMetrics, FlussScan, FlussUpsertInputPartition, FlussUpsertScan} import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.assertj.core.api.Assertions.assertThat @@ -432,6 +434,29 @@ class SparkPrimaryKeyTableReadTest extends FlussSparkTestBase { } } + test("Spark Read: mixed partition and non-partition filter (PK table)") { + withPkPartitionedTable { + val query = sql( + s"SELECT * FROM $DEFAULT_DATABASE.t WHERE dt = '2026-01-01' AND amount > 601") + checkAnswer(query, Row(700L, 22L, 602, "addr2", "2026-01-01") :: Nil) + // Partition predicate extracted for partition pruning + assert(partitionPredicate(query).isDefined) + // Non-partition predicate (amount > 601) remains as a Filter node in the plan + val executedPlan = query.queryExecution.executedPlan match { + case aqe: AdaptiveSparkPlanExec => aqe.executedPlan + case e: SparkPlan => e + } + assert( + executedPlan.exists(_.isInstanceOf[FilterExec]), + s"Expected Filter node in plan for non-partition predicate, got: $executedPlan") + + val numRowsRead = executedPlan + .collectFirst { case b: BatchScanExec => b.metrics(FlussMetrics.NUM_ROWS_READ).value } + .getOrElse(0L) + assert(numRowsRead == 2L, s"Expected 2 rows read for single partition, got $numRowsRead") + } + } + private def withPkPartitionedTable(body: => Unit): Unit = withTable("t") { sql(s""" |CREATE TABLE $DEFAULT_DATABASE.t ( From c28a45dd762115734097d3de97d70932a1fa464f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BE=8A=E5=B7=9D?= Date: Fri, 29 May 2026 09:43:45 +0800 Subject: [PATCH 2/3] fix style --- .../org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala index 7ad0f41fc4..cc6feeab08 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/SparkPrimaryKeyTableReadTest.scala @@ -436,8 +436,7 @@ class SparkPrimaryKeyTableReadTest extends FlussSparkTestBase { test("Spark Read: mixed partition and non-partition filter (PK table)") { withPkPartitionedTable { - val query = sql( - s"SELECT * FROM $DEFAULT_DATABASE.t WHERE dt = '2026-01-01' AND amount > 601") + val query = sql(s"SELECT * FROM $DEFAULT_DATABASE.t WHERE dt = '2026-01-01' AND amount > 601") checkAnswer(query, Row(700L, 22L, 602, "addr2", "2026-01-01") :: Nil) // Partition predicate extracted for partition pruning assert(partitionPredicate(query).isDefined) From 840022d1f12a3ac7e9760ba8159e8344fe4e3335 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BE=8A=E5=B7=9D?= Date: Sat, 30 May 2026 16:44:17 +0800 Subject: [PATCH 3/3] fix lake batch partition filter pushdown --- .../apache/fluss/spark/read/FlussScan.scala | 4 + .../fluss/spark/read/FlussScanBuilder.scala | 2 + .../read/lake/FlussLakeAppendBatch.scala | 57 ++++++++---- .../read/lake/FlussLakeUpsertBatch.scala | 67 +++++++++---- .../spark/utils/SparkPartitionPredicate.scala | 15 +++ .../lake/SparkLakeLogTableReadTest.scala | 93 +++++++++++++++++++ ...SparkLakePrimaryKeyTableReadTestBase.scala | 5 +- .../lake/SparkLakeTableReadTestBase.scala | 12 ++- 8 files changed, 212 insertions(+), 43 deletions(-) diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScan.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScan.scala index c5379cfd8c..f9d0be1bba 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScan.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScan.scala @@ -106,6 +106,7 @@ case class FlussLakeAppendScan( tableInfo: TableInfo, requiredSchema: Option[StructType], pushedPredicate: Option[FlussPredicate], + override val partitionPredicate: Option[FlussPredicate], override val pushedSparkPredicates: Seq[Predicate], options: CaseInsensitiveStringMap, flussConfig: Configuration) @@ -119,6 +120,7 @@ case class FlussLakeAppendScan( tableInfo, readSchema, pushedPredicate, + partitionPredicate, options, flussConfig) } @@ -167,6 +169,7 @@ case class FlussLakeUpsertScan( tableInfo: TableInfo, requiredSchema: Option[StructType], pushedPredicate: Option[FlussPredicate], + override val partitionPredicate: Option[FlussPredicate], override val pushedSparkPredicates: Seq[Predicate], options: CaseInsensitiveStringMap, flussConfig: Configuration) @@ -180,6 +183,7 @@ case class FlussLakeUpsertScan( tableInfo, readSchema, pushedPredicate, + partitionPredicate, options, flussConfig) } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScanBuilder.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScanBuilder.scala index 875c959269..f7f8c7aea6 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScanBuilder.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/FlussScanBuilder.scala @@ -156,6 +156,7 @@ class FlussLakeAppendScanBuilder( tableInfo, requiredSchema, pushedPredicate, + partitionPredicate, acceptedPredicates.toSeq, options, flussConfig) @@ -189,6 +190,7 @@ class FlussLakeUpsertScanBuilder( tableInfo, requiredSchema, pushedPredicate, + partitionPredicate, acceptedPredicates.toSeq, options, flussConfig) diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeAppendBatch.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeAppendBatch.scala index 0eb701555f..f5b747e306 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeAppendBatch.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeAppendBatch.scala @@ -25,6 +25,7 @@ import org.apache.fluss.lake.source.{LakeSource, LakeSplit} import org.apache.fluss.metadata.{LogFormat, ResolvedPartitionSpec, TableBucket, TableInfo, TablePath} import org.apache.fluss.predicate.{Predicate => FlussPredicate} import org.apache.fluss.spark.read._ +import org.apache.fluss.spark.utils.SparkPartitionPredicate import org.apache.fluss.utils.ExceptionUtils import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} @@ -40,6 +41,7 @@ class FlussLakeAppendBatch( tableInfo: TableInfo, readSchema: StructType, pushedPredicate: Option[FlussPredicate], + partitionPredicate: Option[FlussPredicate], options: CaseInsensitiveStringMap, flussConfig: Configuration) extends FlussLakeBatch(tablePath, tableInfo, readSchema, options, flussConfig) { @@ -146,8 +148,14 @@ class FlussLakeAppendBatch( bucketOffsetsRetriever: BucketOffsetsRetrieverImpl): Array[InputPartition] = { val tableId = tableInfo.getTableId + // Filter Fluss-known partitions using the partition predicate to skip non-matching ones + val filteredPartitionInfos = SparkPartitionPredicate.filterPartitions( + tableInfo, + partitionInfos.asScala.toSeq, + partitionPredicate) + val flussPartitionIdByName = mutable.LinkedHashMap.empty[String, Long] - partitionInfos.asScala.foreach { + filteredPartitionInfos.foreach { pi => flussPartitionIdByName(pi.getPartitionName) = pi.getPartitionId } @@ -155,7 +163,7 @@ class FlussLakeAppendBatch( var lakeSplitPartitionId = -1L val lakeAndLogPartitions = lakeSplitsByPartition.flatMap { - case (partitionName, splits) => + case (partitionName, (partitionValues, splits)) => flussPartitionIdByName.remove(partitionName) match { case Some(partitionId) => // Partition in both lake and Fluss — lake splits + log tail @@ -176,10 +184,18 @@ class FlussLakeAppendBatch( lakePartitions ++ logPartitions case None => - // Partition only in lake (expired in Fluss) — lake splits only - val pid = lakeSplitPartitionId - lakeSplitPartitionId -= 1 - createLakePartitions(splits.toSeq, tableId, Some(pid)) + // Partition only in lake (expired in Fluss). Apply the partition predicate directly + // on the resolved partition values to avoid round-tripping through partition names. + if ( + SparkPartitionPredicate + .matchesPartition(tableInfo, partitionValues, partitionPredicate) + ) { + val pid = lakeSplitPartitionId + lakeSplitPartitionId -= 1 + createLakePartitions(splits.toSeq, tableId, Some(pid)) + } else { + Seq.empty + } } }.toSeq @@ -210,17 +226,20 @@ class FlussLakeAppendBatch( (lakeAndLogPartitions ++ flussOnlyPartitions).toArray } - private def groupLakeSplitsByPartition( - lakeSplits: Seq[LakeSplit]): mutable.LinkedHashMap[String, mutable.ArrayBuffer[LakeSplit]] = { - val grouped = mutable.LinkedHashMap.empty[String, mutable.ArrayBuffer[LakeSplit]] + private def groupLakeSplitsByPartition(lakeSplits: Seq[LakeSplit]) + : mutable.LinkedHashMap[String, (Seq[String], mutable.ArrayBuffer[LakeSplit])] = { + val grouped = + mutable.LinkedHashMap.empty[String, (Seq[String], mutable.ArrayBuffer[LakeSplit])] lakeSplits.foreach { split => - val partitionName = if (split.partition() == null || split.partition().isEmpty) { - "" - } else { - split.partition().asScala.mkString(ResolvedPartitionSpec.PARTITION_SPEC_SEPARATOR) - } - grouped.getOrElseUpdate(partitionName, mutable.ArrayBuffer.empty) += split + val partitionValues = + if (split.partition() == null) Seq.empty[String] else split.partition().asScala.toSeq + val partitionName = + if (partitionValues.isEmpty) "" + else partitionValues.mkString(ResolvedPartitionSpec.PARTITION_SPEC_SEPARATOR) + val (_, buf) = + grouped.getOrElseUpdate(partitionName, (partitionValues, mutable.ArrayBuffer.empty)) + buf += split } grouped } @@ -286,9 +305,11 @@ class FlussLakeAppendBatch( } if (tableInfo.isPartitioned) { - partitionInfos.asScala.flatMap { - pi => createPartitions(Some(pi.getPartitionId), pi.getPartitionName) - }.toArray + val matching = SparkPartitionPredicate.filterPartitions( + tableInfo, + partitionInfos.asScala.toSeq, + partitionPredicate) + matching.flatMap(pi => createPartitions(Some(pi.getPartitionId), pi.getPartitionName)).toArray } else { createPartitions(None, null) } diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeUpsertBatch.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeUpsertBatch.scala index 1b095751ef..882d941436 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeUpsertBatch.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/read/lake/FlussLakeUpsertBatch.scala @@ -25,6 +25,7 @@ import org.apache.fluss.lake.source.LakeSplit import org.apache.fluss.metadata.{ResolvedPartitionSpec, TableBucket, TableInfo, TablePath} import org.apache.fluss.predicate.{Predicate => FlussPredicate} import org.apache.fluss.spark.read._ +import org.apache.fluss.spark.utils.SparkPartitionPredicate import org.apache.fluss.utils.ExceptionUtils import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} @@ -43,6 +44,7 @@ class FlussLakeUpsertBatch( tableInfo: TableInfo, readSchema: StructType, pushedPredicate: Option[FlussPredicate], + partitionPredicate: Option[FlussPredicate], options: CaseInsensitiveStringMap, flussConfig: Configuration) extends FlussLakeBatch(tablePath, tableInfo, readSchema, options, flussConfig) { @@ -141,18 +143,24 @@ class FlussLakeUpsertBatch( val tableId = tableInfo.getTableId val buckets = (0 until tableInfo.getNumBuckets).toSeq + // Filter Fluss-known partitions using the partition predicate to skip non-matching ones + val filteredPartitionInfos = SparkPartitionPredicate.filterPartitions( + tableInfo, + partitionInfos.asScala.toSeq, + partitionPredicate) + val flussPartitionIdByName = mutable.LinkedHashMap.empty[String, Long] - partitionInfos.asScala.foreach { + filteredPartitionInfos.foreach { pi => flussPartitionIdByName(pi.getPartitionName) = pi.getPartitionId } val lakeSplitsByPartition = groupLakeSplitsByPartition(lakeSplits) val lakePartitions = lakeSplitsByPartition.flatMap { - case (partitionName, splitsByBucket) => + case (partitionName, (partitionValues, splitsByBucket)) => flussPartitionIdByName.remove(partitionName) match { case Some(partitionId) => - // Partition in both lake and Fluss + // Partition in both lake and Fluss (already passed the predicate filter above) val stoppingOffsets = getBucketOffsets( stoppingOffsetsInitializer, partitionName, @@ -173,13 +181,21 @@ class FlussLakeUpsertBatch( } case None => - // Partition only in lake (expired in Fluss) - buckets.flatMap { - bucketId => - val tableBucket = new TableBucket(tableId, -1, bucketId) - splitsByBucket.getOrElse(bucketId, Seq.empty).map { - lakeSplit => FlussLakeInputPartition(tableBucket, lakeSplit) - } + // Partition only in lake (expired in Fluss). Apply the partition predicate directly + // on the resolved partition values to avoid round-tripping through partition names. + if ( + SparkPartitionPredicate + .matchesPartition(tableInfo, partitionValues, partitionPredicate) + ) { + buckets.flatMap { + bucketId => + val tableBucket = new TableBucket(tableId, -1, bucketId) + splitsByBucket.getOrElse(bucketId, Seq.empty).map { + lakeSplit => FlussLakeInputPartition(tableBucket, lakeSplit) + } + } + } else { + Seq.empty } } } @@ -216,18 +232,25 @@ class FlussLakeUpsertBatch( (lakePartitions ++ flussOnlyPartitions).toArray } + /** + * Group lake splits by partition. Each entry stores the resolved partition values along with + * splits keyed by bucket id, so callers can both look up Fluss partitions by name and evaluate + * the partition predicate directly on the values without re-parsing the joined name. + */ private def groupLakeSplitsByPartition( - lakeSplits: Seq[LakeSplit]): Map[String, mutable.Map[Int, Seq[LakeSplit]]] = { - val grouped = mutable.LinkedHashMap.empty[String, mutable.Map[Int, Seq[LakeSplit]]] + lakeSplits: Seq[LakeSplit]): Map[String, (Seq[String], mutable.Map[Int, Seq[LakeSplit]])] = { + val grouped = + mutable.LinkedHashMap.empty[String, (Seq[String], mutable.Map[Int, Seq[LakeSplit]])] lakeSplits.foreach { split => - val partitionName = if (split.partition() == null || split.partition().isEmpty) { - "" - } else { - split.partition().asScala.mkString(ResolvedPartitionSpec.PARTITION_SPEC_SEPARATOR) - } + val partitionValues = + if (split.partition() == null) Seq.empty[String] else split.partition().asScala.toSeq + val partitionName = + if (partitionValues.isEmpty) "" + else partitionValues.mkString(ResolvedPartitionSpec.PARTITION_SPEC_SEPARATOR) + val (_, bucketMap) = + grouped.getOrElseUpdate(partitionName, (partitionValues, mutable.Map.empty)) val bucketId = split.bucket() - val bucketMap = grouped.getOrElseUpdate(partitionName, mutable.Map.empty) val splits = bucketMap.getOrElse(bucketId, Seq.empty) bucketMap(bucketId) = splits :+ split } @@ -271,7 +294,13 @@ class FlussLakeUpsertBatch( val bucketOffsetsRetriever = new BucketOffsetsRetrieverImpl(admin, tablePath) if (tableInfo.isPartitioned) { - partitionInfos.asScala.flatMap { + // Filter partitions using partition predicate early to skip non-matching partitions + val filteredPartitionInfos = SparkPartitionPredicate.filterPartitions( + tableInfo, + partitionInfos.asScala.toSeq, + partitionPredicate) + + filteredPartitionInfos.flatMap { pi => val partitionName = pi.getPartitionName val kvSnapshots = admin.getLatestKvSnapshots(tablePath, partitionName).get() diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala index ca134dd350..6536c435ef 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/utils/SparkPartitionPredicate.scala @@ -78,4 +78,19 @@ object SparkPartitionPredicate { PartitionUtils.toPartitionRow(p.getResolvedPartitionSpec.getPartitionValues, rowType)) } } + + /** + * Tests whether a partition (described by its ordered partition values) matches the given + * predicate. Returns true when no predicate is provided. + */ + def matchesPartition( + tableInfo: TableInfo, + partitionValues: Seq[String], + partitionPredicate: Option[FlussPredicate]): Boolean = + partitionPredicate match { + case None => true + case Some(predicate) => + val rowType = PartitionUtils.partitionRowType(tableInfo) + predicate.test(PartitionUtils.toPartitionRow(partitionValues.asJava, rowType)) + } } diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeLogTableReadTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeLogTableReadTest.scala index 210b23e127..b47bdfeeb6 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeLogTableReadTest.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeLogTableReadTest.scala @@ -473,6 +473,99 @@ abstract class SparkLakeLogTableReadTest extends SparkLakeTableReadTestBase { } } + test("Spark Lake Read: log table partition filter pushdown prunes partitions") { + withTable("t_pd_part_lake") { + sql(s""" + |CREATE TABLE $DEFAULT_DATABASE.t_pd_part_lake + | (id INT, name STRING, dt STRING) + | PARTITIONED BY (dt) + | TBLPROPERTIES ( + | '${ConfigOptions.TABLE_DATALAKE_ENABLED.key()}' = true, + | '${ConfigOptions.TABLE_DATALAKE_FRESHNESS.key()}' = '1s', + | '${BUCKET_NUMBER.key()}' = 1) + |""".stripMargin) + + // Lake-only partitions: tier all data, then filter by partition key. + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_pd_part_lake VALUES + |(1, 'alpha', '2026-01-01'), + |(2, 'beta', '2026-01-01'), + |(3, 'gamma', '2026-01-02'), + |(4, 'delta', '2026-01-03') + |""".stripMargin) + tierToLake("t_pd_part_lake") + + val lakeOnlyDf = sql(s""" + |SELECT * FROM $DEFAULT_DATABASE.t_pd_part_lake + |WHERE dt = '2026-01-02' ORDER BY id""".stripMargin) + checkAnswer(lakeOnlyDf, Row(3, "gamma", "2026-01-02") :: Nil) + assert( + lakeInputPartitions(lakeOnlyDf).length == 1, + s"Expected 1 input partition after partition pruning" + ) + + // Append more data after tiering so the planner mixes lake splits and Fluss log tail. + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_pd_part_lake VALUES + |(5, 'epsilon', '2026-01-01'), + |(6, 'zeta', '2026-01-04') + |""".stripMargin) + + val unionDf = sql(s""" + |SELECT * FROM $DEFAULT_DATABASE.t_pd_part_lake + |WHERE dt = '2026-01-01' ORDER BY id""".stripMargin) + checkAnswer( + unionDf, + Row(1, "alpha", "2026-01-01") :: + Row(2, "beta", "2026-01-01") :: + Row(5, "epsilon", "2026-01-01") :: Nil + ) + // Only one partition should be planned: lake split + log tail for dt='2026-01-01'. + val unionParts = lakeInputPartitions(unionDf) + assert( + unionParts.length == 2, + s"Expected 2 input partitions (one lake split + one log tail) after pruning, " + + s"got ${unionParts.length}") + + // Check the description carries the partition filter for visibility in EXPLAIN output. + assert( + unionDf.queryExecution.executedPlan.toString.contains("PartitionFilter"), + s"Plan should contain PartitionFilter:\n${unionDf.queryExecution.executedPlan}" + ) + } + } + + test("Spark Lake Read: log table partition filter pushdown in fallback (no lake snapshot)") { + withTable("t_pd_part_fb") { + sql(s""" + |CREATE TABLE $DEFAULT_DATABASE.t_pd_part_fb + | (id INT, name STRING, dt STRING) + | PARTITIONED BY (dt) + | TBLPROPERTIES ( + | '${ConfigOptions.TABLE_DATALAKE_ENABLED.key()}' = true, + | '${ConfigOptions.TABLE_DATALAKE_FRESHNESS.key()}' = '1s', + | '${BUCKET_NUMBER.key()}' = 1) + |""".stripMargin) + + sql(s""" + |INSERT INTO $DEFAULT_DATABASE.t_pd_part_fb VALUES + |(1, 'alpha', '2026-01-01'), + |(2, 'beta', '2026-01-02'), + |(3, 'gamma', '2026-01-03') + |""".stripMargin) + + // No tiering performed -> falls back to reading directly from Fluss. + val df = sql(s""" + |SELECT * FROM $DEFAULT_DATABASE.t_pd_part_fb + |WHERE dt = '2026-01-02' ORDER BY id""".stripMargin) + checkAnswer(df, Row(2, "beta", "2026-01-02") :: Nil) + assert( + lakeInputPartitions(df).length == 1, + s"Expected fallback to plan 1 input partition after pruning" + ) + } + } + test("Spark Lake Read: filter pushdown — partitioned lake table") { withTable("t_pd_partitioned") { sql(s""" diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakePrimaryKeyTableReadTestBase.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakePrimaryKeyTableReadTestBase.scala index 0de25f100e..f331576658 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakePrimaryKeyTableReadTestBase.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakePrimaryKeyTableReadTestBase.scala @@ -20,6 +20,7 @@ package org.apache.fluss.spark.lake import org.apache.fluss.config.{ConfigOptions, Configuration} import org.apache.fluss.metadata.DataLakeFormat import org.apache.fluss.spark.SparkConnectorOptions.{BUCKET_NUMBER, PRIMARY_KEY} +import org.apache.fluss.spark.read.FlussUpsertInputPartition import org.apache.spark.sql.Row @@ -117,7 +118,7 @@ abstract class SparkLakePrimaryKeyTableReadTestBase extends SparkLakeTableReadTe |""".stripMargin) val df = sql(s"SELECT * FROM $DEFAULT_DATABASE.t_fb_hybrid ORDER BY id") - val partitions = lakeUpsertInputPartitions(df) + val partitions = lakeInputPartitions(df).map(_.asInstanceOf[FlussUpsertInputPartition]) assert( partitions.exists(_.snapshotId >= 0), s"Expected at least one hybrid partition with snapshotId >= 0, got: ${partitions.mkString(", ")}") @@ -157,7 +158,7 @@ abstract class SparkLakePrimaryKeyTableReadTestBase extends SparkLakeTableReadTe |""".stripMargin) val df = sql(s"SELECT * FROM $DEFAULT_DATABASE.t_fb_hybrid_partitioned ORDER BY id") - val partitions = lakeUpsertInputPartitions(df) + val partitions = lakeInputPartitions(df).map(_.asInstanceOf[FlussUpsertInputPartition]) assert( partitions.exists(_.snapshotId >= 0), s"Expected at least one hybrid partition with snapshotId >= 0, got: ${partitions.mkString(", ")}") diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeTableReadTestBase.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeTableReadTestBase.scala index 526e898899..14108636ec 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeTableReadTestBase.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/lake/SparkLakeTableReadTestBase.scala @@ -22,12 +22,13 @@ import org.apache.fluss.flink.tiering.LakeTieringJobBuilder import org.apache.fluss.flink.tiering.source.TieringSourceOptions import org.apache.fluss.metadata.{DataLakeFormat, TableBucket} import org.apache.fluss.spark.FlussSparkTestBase -import org.apache.fluss.spark.read.{FlussLakeUpsertScan, FlussScan, FlussUpsertInputPartition} +import org.apache.fluss.spark.read.{FlussLakeAppendScan, FlussLakeUpsertScan, FlussScan} import org.apache.flink.api.common.RuntimeExecutionMode import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import java.time.Duration @@ -152,7 +153,7 @@ abstract class SparkLakeTableReadTestBase extends FlussSparkTestBase { s"Expected any of $expected in pushed predicates, got $pushed") } - protected def lakeUpsertInputPartitions(df: DataFrame): Array[FlussUpsertInputPartition] = { + protected def lakeInputPartitions(df: DataFrame): Array[InputPartition] = { val scans = df.queryExecution.executedPlan.collect { case b: BatchScanExec => b.scan @@ -160,8 +161,11 @@ abstract class SparkLakeTableReadTestBase extends FlussSparkTestBase { case DataSourceV2ScanRelation(_, scan, _, _, _) => scan } scans - .collect { case s: FlussLakeUpsertScan => s } - .flatMap(_.toBatch.planInputPartitions().collect { case p: FlussUpsertInputPartition => p }) + .collect { + case upsert: FlussLakeUpsertScan => upsert + case append: FlussLakeAppendScan => append + } + .flatMap(_.toBatch.planInputPartitions()) .toArray } }