From 5de557fadf5c00724751ba36fa1ba2cabccbd8d1 Mon Sep 17 00:00:00 2001 From: Naveen Kumar Puppala Date: Sat, 28 Feb 2026 15:18:54 -0800 Subject: [PATCH 1/2] [SPARK-54378] Fix incorrect dedup results with SPJ partiallyClusteredDistribution --- .../plans/physical/partitioning.scala | 39 +++++---- .../KeyGroupedPartitioningSuite.scala | 81 +++++++++++++++++++ 2 files changed, 106 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index b0fa4f889cda1..cade49190e472 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -385,21 +385,32 @@ case class KeyGroupedPartitioning( super.satisfies0(required) || { required match { case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) - } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - - if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that join keys (required clustering keys) - // overlap with partition keys (KeyGroupedPartitioning attributes) - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) + // When partial clustering is active, the same partition key is intentionally spread + // across multiple tasks (partitionValues contains duplicates). In that case, this + // partitioning does NOT satisfy ClusteredDistribution, because ClusteredDistribution + // requires all rows with the same key to be co-located in a single task. Without this + // guard, downstream operators such as dropDuplicates or Window functions would skip + // their required shuffle and produce incorrect results. + // See SPARK-54378. + val isFullyClustered = partitionValues.isEmpty || + uniquePartitionValues.length == partitionValues.length + isFullyClustered && { + if (requireAllClusterKeys) { + // Checks whether this partitioning is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) } else { - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + // We'll need to find leaf attributes from the partition expressions first. + val attributes = expressions.flatMap(_.collectLeaves()) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that join keys (required clustering keys) + // overlap with partition keys (KeyGroupedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) + } else { + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 56bd028464e54..506365f0865b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1130,6 +1130,87 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("[SPARK-54378] dropDuplicates after SPJ with partial clustering should give correct " + + "results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + // Insert two copies of id=1 so that the left side has duplicate rows for id=1 after the join, + // and three distinct id values in total. + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString) { + // dropDuplicates on the join key: must produce exactly 3 distinct id values regardless + // of whether partial clustering is active. + val df = sql( + s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + } + } + } + + test("[SPARK-54378] Window dedup after SPJ with partial clustering should give correct " + + "results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + // Two rows with id=1 so that a naive per-task row_number() without a shuffle would + // keep both when partial clustering splits them across tasks. + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString) { + // row_number() OVER (PARTITION BY id) should produce rn=1 for exactly one row per id. + val df = sql( + s""" + |SELECT id, price + |FROM ( + | SELECT i.id, i.price, + | ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn + | FROM testcat.ns.$items i + | JOIN testcat.ns.$purchases p ON i.id = p.item_id + |) t + |WHERE rn = 1 + |""".stripMargin) + // For id=1 only the row with the highest price (41.0) should survive. + checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f))) + } + } + } + test("data source partitioning + dynamic partition filtering") { withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", From 4f9752ef4d9db3ccda7ca936b66e4f7c5109a62a Mon Sep 17 00:00:00 2001 From: Naveen Kumar Puppala Date: Sat, 7 Mar 2026 22:39:44 -0800 Subject: [PATCH 2/2] [SPARK-55848][SQL] Fix incorrect dedup results with SPJ partial clustering When SPJ partial clustering splits a partition across multiple tasks, post-join dedup operators (dropDuplicates, Window row_number) produce incorrect results because KeyGroupedPartitioning.satisfies0() incorrectly reports satisfaction of ClusteredDistribution via super.satisfies0() short-circuiting the isPartiallyClustered guard. This fix adds an isPartiallyClustered flag to KeyGroupedPartitioning and restructures satisfies0() to check ClusteredDistribution first, returning false when partially clustered. EnsureRequirements then inserts the necessary Exchange. Plain SPJ joins without dedup are unaffected. Closes #54378 --- .../plans/physical/partitioning.scala | 117 ++++++++++-------- .../execution/KeyGroupedPartitionedScan.scala | 3 +- .../exchange/EnsureRequirements.scala | 6 +- .../exchange/ShuffleExchangeExec.scala | 4 +- .../DistributionAndOrderingSuiteBase.scala | 5 +- .../KeyGroupedPartitioningSuite.scala | 48 ++++++- .../exchange/EnsureRequirementsSuite.scala | 2 +- 7 files changed, 125 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cade49190e472..4a4c72298944b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -379,47 +379,57 @@ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, partitionValues: Seq[InternalRow] = Seq.empty, - originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike { + originalPartitionValues: Seq[InternalRow] = Seq.empty, + isPartiallyClustered: Boolean = false) extends HashPartitioningLike { override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - // When partial clustering is active, the same partition key is intentionally spread - // across multiple tasks (partitionValues contains duplicates). In that case, this - // partitioning does NOT satisfy ClusteredDistribution, because ClusteredDistribution - // requires all rows with the same key to be co-located in a single task. Without this - // guard, downstream operators such as dropDuplicates or Window functions would skip - // their required shuffle and produce incorrect results. - // See SPARK-54378. - val isFullyClustered = partitionValues.isEmpty || - uniquePartitionValues.length == partitionValues.length - isFullyClustered && { - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) + // When partial clustering is active, the same partition key is intentionally spread + // across multiple tasks (partitionValues contains duplicates). In that case, this + // partitioning does NOT satisfy ClusteredDistribution, because ClusteredDistribution + // requires all rows with the same key to be co-located in a single task. Without this + // guard, downstream operators such as dropDuplicates or Window functions would skip + // their required shuffle and produce incorrect results. + // See SPARK-54378 / SPARK-55848. + // + // We must check isPartiallyClustered BEFORE delegating to super.satisfies0, because + // HashPartitioningLike.satisfies0 also matches ClusteredDistribution and would return + // true based on expression matching alone, short-circuiting the partial-clustering check. + val isFullyClustered = !isPartiallyClustered && (partitionValues.isEmpty || + uniquePartitionValues.length == partitionValues.length) + + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + isFullyClustered && { + if (requireAllClusterKeys) { + // Checks whether this partitioning is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + // We'll need to find leaf attributes from the partition expressions first. + val attributes = expressions.flatMap(_.collectLeaves()) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that join keys (required clustering keys) + // overlap with partition keys (KeyGroupedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - - if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that join keys (required clustering keys) - // overlap with partition keys (KeyGroupedPartitioning attributes) - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) - } else { - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) - } + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } } + } - case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => - o.areAllClusterKeysMatched(expressions) - - case _ => - false - } + case _ => + // For non-ClusteredDistribution cases, delegate to super (handles + // StatefulOpClusteredDistribution, OrderedDistribution, etc.) + super.satisfies0(required) || { + required match { + case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => + o.areAllClusterKeysMatched(expressions) + case _ => + false + } + } } } @@ -431,7 +441,7 @@ case class KeyGroupedPartitioning( // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, - partitionValues, originalPartitionValues) + partitionValues, originalPartitionValues, isPartiallyClustered = isPartiallyClustered) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result @@ -449,7 +459,7 @@ case class KeyGroupedPartitioning( } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - copy(expressions = newChildren) + copy(expressions = newChildren, isPartiallyClustered = isPartiallyClustered) } object KeyGroupedPartitioning { @@ -457,7 +467,8 @@ object KeyGroupedPartitioning { expressions: Seq[Expression], projectionPositions: Seq[Int], partitionValues: Seq[InternalRow], - originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + originalPartitionValues: Seq[InternalRow], + isPartiallyClustered: Boolean): KeyGroupedPartitioning = { val projectedExpressions = projectionPositions.map(expressions(_)) val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) val projectedOriginalPartitionValues = @@ -472,7 +483,8 @@ object KeyGroupedPartitioning { .map(_.row) KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, - finalPartitionValues, projectedOriginalPartitionValues) + finalPartitionValues, projectedOriginalPartitionValues, + isPartiallyClustered = isPartiallyClustered) } def project( @@ -884,15 +896,21 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => - lazy val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitioning.expressions.map(_.dataType)) - distribution.clustering.length == otherDistribution.clustering.length && - numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && - partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { - case (left, right) => - internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) - } + // SPARK-55848: If either side is partially clustered, they are not compatible + // for shuffle purposes because rows with the same key may be spread across tasks. + if (partitioning.isPartiallyClustered || otherPartitioning.isPartiallyClustered) { + false + } else { + lazy val internalRowComparableFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitioning.expressions.map(_.dataType)) + distribution.clustering.length == otherDistribution.clustering.length && + numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && + partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { + case (left, right) => + internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) + } + } case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) case _ => false @@ -972,7 +990,8 @@ case class KeyGroupedShuffleSpec( } KeyGroupedPartitioning(newExpressions, partitioning.numPartitions, - partitioning.partitionValues) + partitioning.partitionValues, + isPartiallyClustered = partitioning.isPartiallyClustered) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index cac4a9bc852f6..3647373b23f7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -62,7 +62,8 @@ trait KeyGroupedPartitionedScan[T] { } } basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, - partitionValues = newPartValues) + partitionValues = newPartValues, + isPartiallyClustered = spjParams.applyPartialClustering) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e239174e40ad4..05bfc1d5c0bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -307,7 +307,7 @@ case class EnsureRequirements( private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = { (plan.outputPartitioning, distribution) match { - case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _), + case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _, _), d @ OrderedDistribution(ordering)) if p.satisfies(d) => val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) val partitionOrdering: Ordering[InternalRow] = { @@ -340,12 +340,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 95120039a6f94..e7233535ec677 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,7 +370,7 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _, _) => + case k @ KeyGroupedPartitioning(expressions, n, _, _, _) => val valueMap = k.uniquePartitionValues.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap @@ -401,7 +401,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _, _) => + case KeyGroupedPartitioning(expressions, _, _, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case s: ShufflePartitionIdPassThrough => // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 1a0efa7c4aafb..da5bb8ff28f79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + case KeyGroupedPartitioning(clustering, numPartitions, partValues, + originalPartValues, partiallyClustered) => KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, - originalPartValues) + originalPartValues, partiallyClustered) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 506365f0865b4..71a5380912739 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -94,13 +94,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) } test("non-clustered distribution: no partition") { @@ -1211,6 +1211,50 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-55848: dropDuplicates after SPJ with partial clustering should produce " + + "correct results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + // Two rows for id=1 so partial clustering may split them across tasks + sql( + s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql( + s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // dropDuplicates on the join key after a partially-clustered SPJ must still + // produce the correct number of distinct ids. Before SPARK-55848, the + // isPartiallyClustered flag was missing, so EnsureRequirements did not insert + // an Exchange before the dedup, leading to duplicate rows. + val df = sql(s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + // Also verify the plan inserts a shuffle for the dedup when partial clustering is active. + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert( + allShuffles.nonEmpty, + "should contain a shuffle for the post-join dedup with partial clustering") + } + } + test("data source partitioning + dynamic partition filtering") { withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1cc0d795d74f8..ee852da536768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1129,7 +1129,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil)