From ef119582a9dab2dbb3f5eeaaaa8019477bf32752 Mon Sep 17 00:00:00 2001 From: Naveen Kumar Puppala Date: Mon, 16 Mar 2026 22:33:28 -0700 Subject: [PATCH 1/3] [SPARK-55848][SQL][4.0] Fix incorrect dedup results with SPJ partial clustering When SPJ partial clustering splits a partition across multiple tasks, post-join dedup operators (dropDuplicates, Window row_number) produce incorrect results because KeyGroupedPartitioning.satisfies0() incorrectly reports satisfaction of ClusteredDistribution via super.satisfies0() short-circuiting the isPartiallyClustered guard. This fix adds an isPartiallyClustered flag to KeyGroupedPartitioning and restructures satisfies0() to check ClusteredDistribution first, returning false when partially clustered. EnsureRequirements then inserts the necessary Exchange. Plain SPJ joins without dedup are unaffected. Closes #54378 --- .../plans/physical/partitioning.scala | 67 ++++---- .../datasources/v2/BatchScanExec.scala | 3 +- .../exchange/EnsureRequirements.scala | 6 +- .../exchange/ShuffleExchangeExec.scala | 4 +- .../DistributionAndOrderingSuiteBase.scala | 5 +- .../KeyGroupedPartitioningSuite.scala | 150 +++++++++++++++++- .../exchange/EnsureRequirementsSuite.scala | 2 +- 7 files changed, 196 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index dc66b6f30e521..9777c31426c0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -379,36 +379,41 @@ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, partitionValues: Seq[InternalRow] = Seq.empty, - originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike { + originalPartitionValues: Seq[InternalRow] = Seq.empty, + isPartiallyClustered: Boolean = false) extends HashPartitioningLike { + // See SPARK-55848. We must check ClusteredDistribution BEFORE delegating to + // super.satisfies0(), because HashPartitioningLike.satisfies0() also matches + // ClusteredDistribution and returns true, which would short-circuit the + // isPartiallyClustered guard. override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (isPartiallyClustered) { + false + } else if (requireAllClusterKeys) { + // Checks whether this partitioning is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + // We'll need to find leaf attributes from the partition expressions first. + val attributes = expressions.flatMap(_.collectLeaves()) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that join keys (required clustering keys) + // overlap with partition keys (KeyGroupedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - - if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that join keys (required clustering keys) - // overlap with partition keys (KeyGroupedPartitioning attributes) - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) - } else { - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) - } + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } + } - case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => - o.areAllClusterKeysMatched(expressions) + case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => + o.areAllClusterKeysMatched(expressions) - case _ => - false - } + case _ => + super.satisfies0(required) } } @@ -420,7 +425,7 @@ case class KeyGroupedPartitioning( // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, - partitionValues, originalPartitionValues) + partitionValues, originalPartitionValues, isPartiallyClustered) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result @@ -435,7 +440,7 @@ case class KeyGroupedPartitioning( } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - copy(expressions = newChildren) + copy(expressions = newChildren, isPartiallyClustered = isPartiallyClustered) } object KeyGroupedPartitioning { @@ -443,7 +448,8 @@ object KeyGroupedPartitioning { expressions: Seq[Expression], projectionPositions: Seq[Int], partitionValues: Seq[InternalRow], - originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + originalPartitionValues: Seq[InternalRow], + isPartiallyClustered: Boolean): KeyGroupedPartitioning = { val projectedExpressions = projectionPositions.map(expressions(_)) val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) val projectedOriginalPartitionValues = @@ -455,7 +461,7 @@ object KeyGroupedPartitioning { .map(_.row) KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, - finalPartitionValues, projectedOriginalPartitionValues) + finalPartitionValues, projectedOriginalPartitionValues, isPartiallyClustered) } def project( @@ -823,7 +829,10 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => - distribution.clustering.length == otherDistribution.clustering.length && + // SPARK-55848: partially-clustered partitioning is not compatible for SPJ + !partitioning.isPartiallyClustered && + !otherPartitioning.isPartiallyClustered && + distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { case (left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 82f28bdfbd492..99ae973896c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -142,7 +142,8 @@ case class BatchScanExec( } } k.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, - partitionValues = newPartValues) + partitionValues = newPartValues, + isPartiallyClustered = spjParams.applyPartialClustering) case p => p } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 503dca02490a5..41c8c5868e627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -292,7 +292,7 @@ case class EnsureRequirements( private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = { (plan.outputPartitioning, distribution) match { - case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _), + case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _, _), d @ OrderedDistribution(ordering)) if p.satisfies(d) => val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) val partitionOrdering: Ordering[InternalRow] = { @@ -325,12 +325,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 31a3f53eb7191..83697b715dfba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -366,7 +366,7 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _, _) => + case k @ KeyGroupedPartitioning(expressions, n, _, _, _) => val valueMap = k.uniquePartitionValues.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap @@ -397,7 +397,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _, _) => + case KeyGroupedPartitioning(expressions, _, _, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 1a0efa7c4aafb..80f360ed216f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + case KeyGroupedPartitioning(clustering, numPartitions, partValues, + originalPartValues, isPartiallyClustered) => KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, - originalPartValues) + originalPartValues, isPartiallyClustered) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index b61371285debf..01507610b6db3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.Expressions._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -93,13 +93,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) } test("non-clustered distribution: no partition") { @@ -2747,4 +2747,148 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row("ccc", 30, 400.50))) } } + + test("SPARK-55848: dropDuplicates after SPJ with partial clustering") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // dropDuplicates on the join key after a partially-clustered SPJ must still + // produce the correct number of distinct ids. Before the fix, the + // partially-clustered partitioning was incorrectly treated as satisfying + // ClusteredDistribution, so EnsureRequirements did not insert an Exchange + // before the dedup, leading to duplicate rows. + val df = sql( + s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join dedup with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: Window dedup after SPJ with partial clustering") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Use ROW_NUMBER() OVER to dedup joined rows per id after a partially-clustered + // SPJ. The WINDOW operator requires ClusteredDistribution on i.id; with partial + // clustering the plan must insert a shuffle so that the window + // produces exactly one row per id. + val df = sql( + s""" + |SELECT id, price FROM ( + | SELECT i.id, i.price, + | ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn + | FROM testcat.ns.$items i + | JOIN testcat.ns.$purchases p ON i.id = p.item_id + |) t WHERE rn = 1 + |""".stripMargin) + checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join window with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: checkpointed partially-clustered join with dedup") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Checkpoint the JOIN result (not the scan) so the plan behind the + // checkpoint carries partially-clustered KeyGroupedPartitioning. + // The dedup on top must still insert an Exchange because the + // isPartiallyClustered flag causes satisfies0()=false for + // ClusteredDistribution. + val joinedDf = spark.sql( + s"""SELECT i.id, i.name, i.price + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id""".stripMargin) + val checkpointedDf = joinedDf.checkpoint() + val df = checkpointedDf.select("id").distinct() + + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the dedup after checkpointed " + + "partially-clustered join") + + val rddScans = collect(df.queryExecution.executedPlan) { + case r: RDDScanExec => r + } + assert(rddScans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "checkpoint (RDDScanExec) should have " + + "partially-clustered KeyGroupedPartitioning") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 3b0bb088a1076..259564fcd3a82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1127,7 +1127,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) From 95762aa0d54508825c7750c0fb2062f4cb7f6724 Mon Sep 17 00:00:00 2001 From: Naveen Kumar Puppala Date: Wed, 18 Mar 2026 10:45:30 -0700 Subject: [PATCH 2/3] retrigger CI From bf95feac99c1eee0d7abcc119e501639b93dd535 Mon Sep 17 00:00:00 2001 From: Naveen Kumar Puppala Date: Thu, 19 Mar 2026 21:07:29 -0700 Subject: [PATCH 3/3] retrigger CI