From baaf95892732b5fb1475f314b67f8d36aab45bdc Mon Sep 17 00:00:00 2001 From: Naveen Kumar Puppala Date: Tue, 10 Mar 2026 20:19:19 -0700 Subject: [PATCH 1/2] [SPARK-55848][SQL][4.1] Fix incorrect dedup results with SPJ partial clustering When SPJ partial clustering splits a partition across multiple tasks, post-join dedup operators (dropDuplicates, Window row_number) produce incorrect results because KeyGroupedPartitioning.satisfies0() incorrectly reports satisfaction of ClusteredDistribution via super.satisfies0() short-circuiting the isPartiallyClustered guard. This fix adds an isPartiallyClustered flag to KeyGroupedPartitioning and restructures satisfies0() to check ClusteredDistribution first, returning false when partially clustered. EnsureRequirements then inserts the necessary Exchange. Plain SPJ joins without dedup are unaffected. Closes #54378 --- .../plans/physical/partitioning.scala | 62 +++---- .../execution/KeyGroupedPartitionedScan.scala | 3 +- .../exchange/EnsureRequirements.scala | 6 +- .../exchange/ShuffleExchangeExec.scala | 4 +- .../DistributionAndOrderingSuiteBase.scala | 5 +- .../KeyGroupedPartitioningSuite.scala | 152 +++++++++++++++++- .../exchange/EnsureRequirementsSuite.scala | 2 +- 7 files changed, 193 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 101d13c6b580c..fd71e22c555cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -379,36 +379,36 @@ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, partitionValues: Seq[InternalRow] = Seq.empty, - originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike { + originalPartitionValues: Seq[InternalRow] = Seq.empty, + isPartiallyClustered: Boolean = false) extends HashPartitioningLike { + // See SPARK-55848. We must check ClusteredDistribution BEFORE delegating to + // super.satisfies0(), because HashPartitioningLike.satisfies0() also matches + // ClusteredDistribution and returns true, which would short-circuit the + // isPartiallyClustered guard. override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (isPartiallyClustered) { + false + } else if (requireAllClusterKeys) { + c.areAllClusterKeysMatched(expressions) + } else { + val attributes = expressions.flatMap(_.collectLeaves()) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - - if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that join keys (required clustering keys) - // overlap with partition keys (KeyGroupedPartitioning attributes) - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) - } else { - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) - } + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } + } - case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => - o.areAllClusterKeysMatched(expressions) + case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => + o.areAllClusterKeysMatched(expressions) - case _ => - false - } + case _ => + super.satisfies0(required) } } @@ -420,7 +420,7 @@ case class KeyGroupedPartitioning( // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, - partitionValues, originalPartitionValues) + partitionValues, originalPartitionValues, isPartiallyClustered) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result @@ -435,7 +435,7 @@ case class KeyGroupedPartitioning( } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - copy(expressions = newChildren) + copy(expressions = newChildren, isPartiallyClustered = isPartiallyClustered) } object KeyGroupedPartitioning { @@ -443,7 +443,8 @@ object KeyGroupedPartitioning { expressions: Seq[Expression], projectionPositions: Seq[Int], partitionValues: Seq[InternalRow], - originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + originalPartitionValues: Seq[InternalRow], + isPartiallyClustered: Boolean): KeyGroupedPartitioning = { val projectedExpressions = projectionPositions.map(expressions(_)) val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) val projectedOriginalPartitionValues = @@ -455,7 +456,7 @@ object KeyGroupedPartitioning { .map(_.row) KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, - finalPartitionValues, projectedOriginalPartitionValues) + finalPartitionValues, projectedOriginalPartitionValues, isPartiallyClustered) } def project( @@ -867,7 +868,10 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => - distribution.clustering.length == otherDistribution.clustering.length && + // SPARK-55848: partially-clustered partitioning is not compatible for SPJ + !partitioning.isPartiallyClustered && + !otherPartitioning.isPartiallyClustered && + distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { case (left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index e9d8e8e6d0fb3..0a70021dc858c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -58,7 +58,8 @@ trait KeyGroupedPartitionedScan[T] { } } basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, - partitionValues = newPartValues) + partitionValues = newPartValues, + isPartiallyClustered = spjParams.applyPartialClustering) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 088ece6554c96..0d180bd336221 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -307,7 +307,7 @@ case class EnsureRequirements( private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = { (plan.outputPartitioning, distribution) match { - case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _), + case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _, _), d @ OrderedDistribution(ordering)) if p.satisfies(d) => val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) val partitionOrdering: Ordering[InternalRow] = { @@ -340,12 +340,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index f052bd9068805..c59bb4d39b096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,7 +370,7 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _, _) => + case k @ KeyGroupedPartitioning(expressions, n, _, _, _) => val valueMap = k.uniquePartitionValues.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap @@ -401,7 +401,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _, _) => + case KeyGroupedPartitioning(expressions, _, _, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case s: ShufflePartitionIdPassThrough => // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 1a0efa7c4aafb..80f360ed216f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + case KeyGroupedPartitioning(clustering, numPartitions, partValues, + originalPartValues, isPartiallyClustered) => KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, - originalPartValues) + originalPartValues, isPartiallyClustered) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 8a65cb623f6e6..122c511bf8358 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.Expressions._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -94,13 +94,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) } test("non-clustered distribution: no partition") { @@ -2892,4 +2892,150 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row("ccc", 30, 400.50))) } } + + test("SPARK-55848: dropDuplicates after SPJ with partial clustering should produce " + + "correct results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // dropDuplicates on the join key after a partially-clustered SPJ must still + // produce the correct number of distinct ids. Before the fix, the + // partially-clustered partitioning was incorrectly treated as satisfying + // ClusteredDistribution, so EnsureRequirements did not insert an Exchange + // before the dedup, leading to duplicate rows. + val df = sql( + s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join dedup with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: Window dedup after SPJ with partial clustering should produce " + + "correct results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Use ROW_NUMBER() OVER to dedup joined rows per id after a partially-clustered + // SPJ. The WINDOW operator requires ClusteredDistribution on i.id; with partial + // clustering the plan must insert a shuffle so that the window + // produces exactly one row per id. + val df = sql( + s""" + |SELECT id, price FROM ( + | SELECT i.id, i.price, + | ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn + | FROM testcat.ns.$items i + | JOIN testcat.ns.$purchases p ON i.id = p.item_id + |) t WHERE rn = 1 + |""".stripMargin) + checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join window with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: checkpointed partially-clustered join with dedup") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Checkpoint the JOIN result (not the scan) so the plan behind the + // checkpoint carries partially-clustered KeyGroupedPartitioning. + // The dedup on top must still insert an Exchange because the + // isPartiallyClustered flag causes satisfies0()=false for + // ClusteredDistribution. + val joinedDf = spark.sql( + s"""SELECT i.id, i.name, i.price + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id""".stripMargin) + val checkpointedDf = joinedDf.checkpoint() + val df = checkpointedDf.select("id").distinct() + + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the dedup after checkpointed " + + "partially-clustered join") + + val rddScans = collect(df.queryExecution.executedPlan) { + case r: RDDScanExec => r + } + assert(rddScans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "checkpoint (RDDScanExec) should have " + + "partially-clustered KeyGroupedPartitioning") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1cc0d795d74f8..ee852da536768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1129,7 +1129,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) From 3d5888c6d8c4100e3b09986cfc0ab61e28c253ae Mon Sep 17 00:00:00 2001 From: Naveen Kumar Puppala Date: Mon, 16 Mar 2026 05:49:11 -0700 Subject: [PATCH 2/2] retrigger CI