Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ case class FlussLakeAppendScan(
tableInfo: TableInfo,
requiredSchema: Option[StructType],
pushedPredicate: Option[FlussPredicate],
override val partitionPredicate: Option[FlussPredicate],
override val pushedSparkPredicates: Seq[Predicate],
options: CaseInsensitiveStringMap,
flussConfig: Configuration)
Expand All @@ -119,6 +120,7 @@ case class FlussLakeAppendScan(
tableInfo,
readSchema,
pushedPredicate,
partitionPredicate,
options,
flussConfig)
}
Expand Down Expand Up @@ -167,6 +169,7 @@ case class FlussLakeUpsertScan(
tableInfo: TableInfo,
requiredSchema: Option[StructType],
pushedPredicate: Option[FlussPredicate],
override val partitionPredicate: Option[FlussPredicate],
override val pushedSparkPredicates: Seq[Predicate],
options: CaseInsensitiveStringMap,
flussConfig: Configuration)
Expand All @@ -180,6 +183,7 @@ case class FlussLakeUpsertScan(
tableInfo,
readSchema,
pushedPredicate,
partitionPredicate,
options,
flussConfig)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ trait FlussSupportsPushDownPartitionFilters
protected var partitionPredicate: Option[FlussPredicate] = None

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
partitionPredicate = SparkPartitionPredicate.extract(tableInfo, predicates.toSeq)
predicates
val (nonPartitionPred, partitionPred) =
SparkPartitionPredicate.extract(tableInfo, predicates.toSeq)
partitionPredicate = partitionPred
nonPartitionPred.toArray
}

override def pushedPredicates(): Array[Predicate] = Array.empty
Expand All @@ -73,12 +75,12 @@ trait FlussSupportsPushDownV2Filters extends FlussSupportsPushDownPartitionFilte
}

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
super.pushPredicates(predicates)
val nonPartitionPredicates = super.pushPredicates(predicates)
// Server-side batch filter only supports ARROW; other log formats reject it.
if (tableInfo.getTableConfig.getLogFormat == LogFormat.ARROW) {
convertAndStorePredicates(predicates)
convertAndStorePredicates(nonPartitionPredicates)
}
predicates
nonPartitionPredicates
}

override def pushedPredicates(): Array[Predicate] = acceptedPredicates
Expand All @@ -96,8 +98,11 @@ trait FlussLakeSupportsPushDownV2Filters extends FlussSupportsPushDownV2Filters
def flussConfig: FlussConfiguration

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val nonPartitionPredicates = super.pushPredicates(predicates)
val pairs =
SparkPredicateConverter.convertPerPredicate(tableInfo.getRowType, predicates.toSeq)
SparkPredicateConverter.convertPerPredicate(
tableInfo.getRowType,
nonPartitionPredicates.toSeq)
val (acceptedSpark, acceptedFluss) = if (pairs.isEmpty) {
(Seq.empty[Predicate], Seq.empty[FlussPredicate])
} else {
Expand All @@ -112,7 +117,7 @@ trait FlussLakeSupportsPushDownV2Filters extends FlussSupportsPushDownV2Filters
}
pushedPredicate = SparkPredicateConverter.combineAnd(acceptedFluss)
acceptedPredicates = acceptedSpark.toArray
predicates
nonPartitionPredicates
}
}

Expand Down Expand Up @@ -151,6 +156,7 @@ class FlussLakeAppendScanBuilder(
tableInfo,
requiredSchema,
pushedPredicate,
partitionPredicate,
acceptedPredicates.toSeq,
options,
flussConfig)
Expand Down Expand Up @@ -184,6 +190,7 @@ class FlussLakeUpsertScanBuilder(
tableInfo,
requiredSchema,
pushedPredicate,
partitionPredicate,
acceptedPredicates.toSeq,
options,
flussConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.fluss.lake.source.{LakeSource, LakeSplit}
import org.apache.fluss.metadata.{LogFormat, ResolvedPartitionSpec, TableBucket, TableInfo, TablePath}
import org.apache.fluss.predicate.{Predicate => FlussPredicate}
import org.apache.fluss.spark.read._
import org.apache.fluss.spark.utils.SparkPartitionPredicate
import org.apache.fluss.utils.ExceptionUtils

import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
Expand All @@ -40,6 +41,7 @@ class FlussLakeAppendBatch(
tableInfo: TableInfo,
readSchema: StructType,
pushedPredicate: Option[FlussPredicate],
partitionPredicate: Option[FlussPredicate],
options: CaseInsensitiveStringMap,
flussConfig: Configuration)
extends FlussLakeBatch(tablePath, tableInfo, readSchema, options, flussConfig) {
Expand Down Expand Up @@ -146,16 +148,22 @@ class FlussLakeAppendBatch(
bucketOffsetsRetriever: BucketOffsetsRetrieverImpl): Array[InputPartition] = {
val tableId = tableInfo.getTableId

// Filter Fluss-known partitions using the partition predicate to skip non-matching ones
val filteredPartitionInfos = SparkPartitionPredicate.filterPartitions(
tableInfo,
partitionInfos.asScala.toSeq,
partitionPredicate)

val flussPartitionIdByName = mutable.LinkedHashMap.empty[String, Long]
partitionInfos.asScala.foreach {
filteredPartitionInfos.foreach {
pi => flussPartitionIdByName(pi.getPartitionName) = pi.getPartitionId
}

val lakeSplitsByPartition = groupLakeSplitsByPartition(lakeSplits)
var lakeSplitPartitionId = -1L

val lakeAndLogPartitions = lakeSplitsByPartition.flatMap {
case (partitionName, splits) =>
case (partitionName, (partitionValues, splits)) =>
flussPartitionIdByName.remove(partitionName) match {
case Some(partitionId) =>
// Partition in both lake and Fluss — lake splits + log tail
Expand All @@ -176,10 +184,18 @@ class FlussLakeAppendBatch(
lakePartitions ++ logPartitions

case None =>
// Partition only in lake (expired in Fluss) — lake splits only
val pid = lakeSplitPartitionId
lakeSplitPartitionId -= 1
createLakePartitions(splits.toSeq, tableId, Some(pid))
// Partition only in lake (expired in Fluss). Apply the partition predicate directly
// on the resolved partition values to avoid round-tripping through partition names.
if (
SparkPartitionPredicate
.matchesPartition(tableInfo, partitionValues, partitionPredicate)
) {
val pid = lakeSplitPartitionId
lakeSplitPartitionId -= 1
createLakePartitions(splits.toSeq, tableId, Some(pid))
} else {
Seq.empty
}
}
}.toSeq

Expand Down Expand Up @@ -210,17 +226,20 @@ class FlussLakeAppendBatch(
(lakeAndLogPartitions ++ flussOnlyPartitions).toArray
}

private def groupLakeSplitsByPartition(
lakeSplits: Seq[LakeSplit]): mutable.LinkedHashMap[String, mutable.ArrayBuffer[LakeSplit]] = {
val grouped = mutable.LinkedHashMap.empty[String, mutable.ArrayBuffer[LakeSplit]]
private def groupLakeSplitsByPartition(lakeSplits: Seq[LakeSplit])
: mutable.LinkedHashMap[String, (Seq[String], mutable.ArrayBuffer[LakeSplit])] = {
val grouped =
mutable.LinkedHashMap.empty[String, (Seq[String], mutable.ArrayBuffer[LakeSplit])]
lakeSplits.foreach {
split =>
val partitionName = if (split.partition() == null || split.partition().isEmpty) {
""
} else {
split.partition().asScala.mkString(ResolvedPartitionSpec.PARTITION_SPEC_SEPARATOR)
}
grouped.getOrElseUpdate(partitionName, mutable.ArrayBuffer.empty) += split
val partitionValues =
if (split.partition() == null) Seq.empty[String] else split.partition().asScala.toSeq
val partitionName =
if (partitionValues.isEmpty) ""
else partitionValues.mkString(ResolvedPartitionSpec.PARTITION_SPEC_SEPARATOR)
val (_, buf) =
grouped.getOrElseUpdate(partitionName, (partitionValues, mutable.ArrayBuffer.empty))
buf += split
}
grouped
}
Expand Down Expand Up @@ -286,9 +305,11 @@ class FlussLakeAppendBatch(
}

if (tableInfo.isPartitioned) {
partitionInfos.asScala.flatMap {
pi => createPartitions(Some(pi.getPartitionId), pi.getPartitionName)
}.toArray
val matching = SparkPartitionPredicate.filterPartitions(
tableInfo,
partitionInfos.asScala.toSeq,
partitionPredicate)
matching.flatMap(pi => createPartitions(Some(pi.getPartitionId), pi.getPartitionName)).toArray
} else {
createPartitions(None, null)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.fluss.lake.source.LakeSplit
import org.apache.fluss.metadata.{ResolvedPartitionSpec, TableBucket, TableInfo, TablePath}
import org.apache.fluss.predicate.{Predicate => FlussPredicate}
import org.apache.fluss.spark.read._
import org.apache.fluss.spark.utils.SparkPartitionPredicate
import org.apache.fluss.utils.ExceptionUtils

import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
Expand All @@ -43,6 +44,7 @@ class FlussLakeUpsertBatch(
tableInfo: TableInfo,
readSchema: StructType,
pushedPredicate: Option[FlussPredicate],
partitionPredicate: Option[FlussPredicate],
options: CaseInsensitiveStringMap,
flussConfig: Configuration)
extends FlussLakeBatch(tablePath, tableInfo, readSchema, options, flussConfig) {
Expand Down Expand Up @@ -141,18 +143,24 @@ class FlussLakeUpsertBatch(
val tableId = tableInfo.getTableId
val buckets = (0 until tableInfo.getNumBuckets).toSeq

// Filter Fluss-known partitions using the partition predicate to skip non-matching ones
val filteredPartitionInfos = SparkPartitionPredicate.filterPartitions(
tableInfo,
partitionInfos.asScala.toSeq,
partitionPredicate)

val flussPartitionIdByName = mutable.LinkedHashMap.empty[String, Long]
partitionInfos.asScala.foreach {
filteredPartitionInfos.foreach {
pi => flussPartitionIdByName(pi.getPartitionName) = pi.getPartitionId
}

val lakeSplitsByPartition = groupLakeSplitsByPartition(lakeSplits)

val lakePartitions = lakeSplitsByPartition.flatMap {
case (partitionName, splitsByBucket) =>
case (partitionName, (partitionValues, splitsByBucket)) =>
flussPartitionIdByName.remove(partitionName) match {
case Some(partitionId) =>
// Partition in both lake and Fluss
// Partition in both lake and Fluss (already passed the predicate filter above)
val stoppingOffsets = getBucketOffsets(
stoppingOffsetsInitializer,
partitionName,
Expand All @@ -173,13 +181,21 @@ class FlussLakeUpsertBatch(
}

case None =>
// Partition only in lake (expired in Fluss)
buckets.flatMap {
bucketId =>
val tableBucket = new TableBucket(tableId, -1, bucketId)
splitsByBucket.getOrElse(bucketId, Seq.empty).map {
lakeSplit => FlussLakeInputPartition(tableBucket, lakeSplit)
}
// Partition only in lake (expired in Fluss). Apply the partition predicate directly
// on the resolved partition values to avoid round-tripping through partition names.
if (
SparkPartitionPredicate
.matchesPartition(tableInfo, partitionValues, partitionPredicate)
) {
buckets.flatMap {
bucketId =>
val tableBucket = new TableBucket(tableId, -1, bucketId)
splitsByBucket.getOrElse(bucketId, Seq.empty).map {
lakeSplit => FlussLakeInputPartition(tableBucket, lakeSplit)
}
}
} else {
Seq.empty
}
}
}
Expand Down Expand Up @@ -216,18 +232,25 @@ class FlussLakeUpsertBatch(
(lakePartitions ++ flussOnlyPartitions).toArray
}

/**
* Group lake splits by partition. Each entry stores the resolved partition values along with
* splits keyed by bucket id, so callers can both look up Fluss partitions by name and evaluate
* the partition predicate directly on the values without re-parsing the joined name.
*/
private def groupLakeSplitsByPartition(
lakeSplits: Seq[LakeSplit]): Map[String, mutable.Map[Int, Seq[LakeSplit]]] = {
val grouped = mutable.LinkedHashMap.empty[String, mutable.Map[Int, Seq[LakeSplit]]]
lakeSplits: Seq[LakeSplit]): Map[String, (Seq[String], mutable.Map[Int, Seq[LakeSplit]])] = {
val grouped =
mutable.LinkedHashMap.empty[String, (Seq[String], mutable.Map[Int, Seq[LakeSplit]])]
lakeSplits.foreach {
split =>
val partitionName = if (split.partition() == null || split.partition().isEmpty) {
""
} else {
split.partition().asScala.mkString(ResolvedPartitionSpec.PARTITION_SPEC_SEPARATOR)
}
val partitionValues =
if (split.partition() == null) Seq.empty[String] else split.partition().asScala.toSeq
val partitionName =
if (partitionValues.isEmpty) ""
else partitionValues.mkString(ResolvedPartitionSpec.PARTITION_SPEC_SEPARATOR)
val (_, bucketMap) =
grouped.getOrElseUpdate(partitionName, (partitionValues, mutable.Map.empty))
val bucketId = split.bucket()
val bucketMap = grouped.getOrElseUpdate(partitionName, mutable.Map.empty)
val splits = bucketMap.getOrElse(bucketId, Seq.empty)
bucketMap(bucketId) = splits :+ split
}
Expand Down Expand Up @@ -271,7 +294,13 @@ class FlussLakeUpsertBatch(
val bucketOffsetsRetriever = new BucketOffsetsRetrieverImpl(admin, tablePath)

if (tableInfo.isPartitioned) {
partitionInfos.asScala.flatMap {
// Filter partitions using partition predicate early to skip non-matching partitions
val filteredPartitionInfos = SparkPartitionPredicate.filterPartitions(
tableInfo,
partitionInfos.asScala.toSeq,
partitionPredicate)

filteredPartitionInfos.flatMap {
pi =>
val partitionName = pi.getPartitionName
val kvSnapshots = admin.getLatestKvSnapshots(tablePath, partitionName).get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,40 @@ import scala.jdk.CollectionConverters._
/** Extracts a partition-key predicate and prunes the partition list at planning time. */
object SparkPartitionPredicate {

def extract(tableInfo: TableInfo, predicates: Seq[Predicate]): Option[FlussPredicate] = {
def extract(
tableInfo: TableInfo,
predicates: Seq[Predicate]): (Seq[Predicate], Option[FlussPredicate]) = {
val partitionKeys = tableInfo.getPartitionKeys
if (partitionKeys.isEmpty) return None
if (partitionKeys.isEmpty) {
return (predicates, None)
}

val rowType = PartitionUtils.partitionRowType(tableInfo)
val onlyPartitionKeys = new PartitionPredicateVisitor(partitionKeys)

val converted = predicates.flatMap {
// Separate predicates: those that can be converted and only touch partition keys
val (partitionPredicates, nonPartitionPredicates) = predicates.partition {
sparkPredicate =>
SparkPredicateConverter
.convert(rowType, sparkPredicate)
.exists(_.visit(onlyPartitionKeys))
}

// Convert partition predicates to FlussPredicate
val converted = partitionPredicates.flatMap {
sparkPredicate =>
SparkPredicateConverter
.convert(rowType, sparkPredicate)
.filter(_.visit(onlyPartitionKeys))
}

converted match {
val partitionPredicate = converted match {
case Seq() => None
case Seq(single) => Some(single)
case many => Some(PredicateBuilder.and(many.asJava))
}

(nonPartitionPredicates, partitionPredicate)
}

def filterPartitions(
Expand All @@ -63,4 +78,19 @@ object SparkPartitionPredicate {
PartitionUtils.toPartitionRow(p.getResolvedPartitionSpec.getPartitionValues, rowType))
}
}

/**
* Tests whether a partition (described by its ordered partition values) matches the given
* predicate. Returns true when no predicate is provided.
*/
def matchesPartition(
tableInfo: TableInfo,
partitionValues: Seq[String],
partitionPredicate: Option[FlussPredicate]): Boolean =
partitionPredicate match {
case None => true
case Some(predicate) =>
val rowType = PartitionUtils.partitionRowType(tableInfo)
predicate.test(PartitionUtils.toPartitionRow(partitionValues.asJava, rowType))
}
}
Loading