diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index b0fa4f889cda1..4a4c72298944b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -379,12 +379,27 @@ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, partitionValues: Seq[InternalRow] = Seq.empty, - originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike { + originalPartitionValues: Seq[InternalRow] = Seq.empty, + isPartiallyClustered: Boolean = false) extends HashPartitioningLike { override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + // When partial clustering is active, the same partition key is intentionally spread + // across multiple tasks (partitionValues contains duplicates). In that case, this + // partitioning does NOT satisfy ClusteredDistribution, because ClusteredDistribution + // requires all rows with the same key to be co-located in a single task. Without this + // guard, downstream operators such as dropDuplicates or Window functions would skip + // their required shuffle and produce incorrect results. + // See SPARK-54378 / SPARK-55848. + // + // We must check isPartiallyClustered BEFORE delegating to super.satisfies0, because + // HashPartitioningLike.satisfies0 also matches ClusteredDistribution and would return + // true based on expression matching alone, short-circuiting the partial-clustering check. + val isFullyClustered = !isPartiallyClustered && (partitionValues.isEmpty || + uniquePartitionValues.length == partitionValues.length) + + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + isFullyClustered && { if (requireAllClusterKeys) { // Checks whether this partitioning is partitioned on exactly same clustering keys of // `ClusteredDistribution`. @@ -402,13 +417,19 @@ case class KeyGroupedPartitioning( attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } } + } - case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => - o.areAllClusterKeysMatched(expressions) - - case _ => - false - } + case _ => + // For non-ClusteredDistribution cases, delegate to super (handles + // StatefulOpClusteredDistribution, OrderedDistribution, etc.) + super.satisfies0(required) || { + required match { + case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => + o.areAllClusterKeysMatched(expressions) + case _ => + false + } + } } } @@ -420,7 +441,7 @@ case class KeyGroupedPartitioning( // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, - partitionValues, originalPartitionValues) + partitionValues, originalPartitionValues, isPartiallyClustered = isPartiallyClustered) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result @@ -438,7 +459,7 @@ case class KeyGroupedPartitioning( } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - copy(expressions = newChildren) + copy(expressions = newChildren, isPartiallyClustered = isPartiallyClustered) } object KeyGroupedPartitioning { @@ -446,7 +467,8 @@ object KeyGroupedPartitioning { expressions: Seq[Expression], projectionPositions: Seq[Int], partitionValues: Seq[InternalRow], - originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + originalPartitionValues: Seq[InternalRow], + isPartiallyClustered: Boolean): KeyGroupedPartitioning = { val projectedExpressions = projectionPositions.map(expressions(_)) val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) val projectedOriginalPartitionValues = @@ -461,7 +483,8 @@ object KeyGroupedPartitioning { .map(_.row) KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, - finalPartitionValues, projectedOriginalPartitionValues) + finalPartitionValues, projectedOriginalPartitionValues, + isPartiallyClustered = isPartiallyClustered) } def project( @@ -873,15 +896,21 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => - lazy val internalRowComparableFactory = - InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( - partitioning.expressions.map(_.dataType)) - distribution.clustering.length == otherDistribution.clustering.length && - numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && - partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { - case (left, right) => - internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) - } + // SPARK-55848: If either side is partially clustered, they are not compatible + // for shuffle purposes because rows with the same key may be spread across tasks. + if (partitioning.isPartiallyClustered || otherPartitioning.isPartiallyClustered) { + false + } else { + lazy val internalRowComparableFactory = + InternalRowComparableWrapper.getInternalRowComparableWrapperFactory( + partitioning.expressions.map(_.dataType)) + distribution.clustering.length == otherDistribution.clustering.length && + numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && + partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { + case (left, right) => + internalRowComparableFactory(left).equals(internalRowComparableFactory(right)) + } + } case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) case _ => false @@ -961,7 +990,8 @@ case class KeyGroupedShuffleSpec( } KeyGroupedPartitioning(newExpressions, partitioning.numPartitions, - partitioning.partitionValues) + partitioning.partitionValues, + isPartiallyClustered = partitioning.isPartiallyClustered) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index cac4a9bc852f6..3647373b23f7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -62,7 +62,8 @@ trait KeyGroupedPartitionedScan[T] { } } basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, - partitionValues = newPartValues) + partitionValues = newPartValues, + isPartiallyClustered = spjParams.applyPartialClustering) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e239174e40ad4..05bfc1d5c0bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -307,7 +307,7 @@ case class EnsureRequirements( private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = { (plan.outputPartitioning, distribution) match { - case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _), + case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _, _), d @ OrderedDistribution(ordering)) if p.satisfies(d) => val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) val partitionOrdering: Ordering[InternalRow] = { @@ -340,12 +340,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 95120039a6f94..e7233535ec677 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,7 +370,7 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _, _) => + case k @ KeyGroupedPartitioning(expressions, n, _, _, _) => val valueMap = k.uniquePartitionValues.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap @@ -401,7 +401,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _, _) => + case KeyGroupedPartitioning(expressions, _, _, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case s: ShufflePartitionIdPassThrough => // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 1a0efa7c4aafb..da5bb8ff28f79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + case KeyGroupedPartitioning(clustering, numPartitions, partValues, + originalPartValues, partiallyClustered) => KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, - originalPartValues) + originalPartValues, partiallyClustered) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 56bd028464e54..71a5380912739 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -94,13 +94,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) } test("non-clustered distribution: no partition") { @@ -1130,6 +1130,131 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("[SPARK-54378] dropDuplicates after SPJ with partial clustering should give correct " + + "results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + // Insert two copies of id=1 so that the left side has duplicate rows for id=1 after the join, + // and three distinct id values in total. + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString) { + // dropDuplicates on the join key: must produce exactly 3 distinct id values regardless + // of whether partial clustering is active. + val df = sql( + s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + } + } + } + + test("[SPARK-54378] Window dedup after SPJ with partial clustering should give correct " + + "results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + // Two rows with id=1 so that a naive per-task row_number() without a shuffle would + // keep both when partial clustering splits them across tasks. + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString) { + // row_number() OVER (PARTITION BY id) should produce rn=1 for exactly one row per id. + val df = sql( + s""" + |SELECT id, price + |FROM ( + | SELECT i.id, i.price, + | ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn + | FROM testcat.ns.$items i + | JOIN testcat.ns.$purchases p ON i.id = p.item_id + |) t + |WHERE rn = 1 + |""".stripMargin) + // For id=1 only the row with the highest price (41.0) should survive. + checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f))) + } + } + } + + test("SPARK-55848: dropDuplicates after SPJ with partial clustering should produce " + + "correct results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + // Two rows for id=1 so partial clustering may split them across tasks + sql( + s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql( + s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // dropDuplicates on the join key after a partially-clustered SPJ must still + // produce the correct number of distinct ids. Before SPARK-55848, the + // isPartiallyClustered flag was missing, so EnsureRequirements did not insert + // an Exchange before the dedup, leading to duplicate rows. + val df = sql(s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + // Also verify the plan inserts a shuffle for the dedup when partial clustering is active. + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert( + allShuffles.nonEmpty, + "should contain a shuffle for the post-join dedup with partial clustering") + } + } + test("data source partitioning + dynamic partition filtering") { withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1cc0d795d74f8..ee852da536768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1129,7 +1129,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil)