diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 101d13c6b580c..fd71e22c555cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -379,36 +379,36 @@ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, partitionValues: Seq[InternalRow] = Seq.empty, - originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike { + originalPartitionValues: Seq[InternalRow] = Seq.empty, + isPartiallyClustered: Boolean = false) extends HashPartitioningLike { + // See SPARK-55848. We must check ClusteredDistribution BEFORE delegating to + // super.satisfies0(), because HashPartitioningLike.satisfies0() also matches + // ClusteredDistribution and returns true, which would short-circuit the + // isPartiallyClustered guard. override def satisfies0(required: Distribution): Boolean = { - super.satisfies0(required) || { - required match { - case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => - if (requireAllClusterKeys) { - // Checks whether this partitioning is partitioned on exactly same clustering keys of - // `ClusteredDistribution`. - c.areAllClusterKeysMatched(expressions) + required match { + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (isPartiallyClustered) { + false + } else if (requireAllClusterKeys) { + c.areAllClusterKeysMatched(expressions) + } else { + val attributes = expressions.flatMap(_.collectLeaves()) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) } else { - // We'll need to find leaf attributes from the partition expressions first. - val attributes = expressions.flatMap(_.collectLeaves()) - - if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that join keys (required clustering keys) - // overlap with partition keys (KeyGroupedPartitioning attributes) - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) - } else { - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) - } + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } + } - case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => - o.areAllClusterKeysMatched(expressions) + case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting => + o.areAllClusterKeysMatched(expressions) - case _ => - false - } + case _ => + super.satisfies0(required) } } @@ -420,7 +420,7 @@ case class KeyGroupedPartitioning( // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, - partitionValues, originalPartitionValues) + partitionValues, originalPartitionValues, isPartiallyClustered) result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) } else { result @@ -435,7 +435,7 @@ case class KeyGroupedPartitioning( } override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - copy(expressions = newChildren) + copy(expressions = newChildren, isPartiallyClustered = isPartiallyClustered) } object KeyGroupedPartitioning { @@ -443,7 +443,8 @@ object KeyGroupedPartitioning { expressions: Seq[Expression], projectionPositions: Seq[Int], partitionValues: Seq[InternalRow], - originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + originalPartitionValues: Seq[InternalRow], + isPartiallyClustered: Boolean): KeyGroupedPartitioning = { val projectedExpressions = projectionPositions.map(expressions(_)) val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) val projectedOriginalPartitionValues = @@ -455,7 +456,7 @@ object KeyGroupedPartitioning { .map(_.row) KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length, - finalPartitionValues, projectedOriginalPartitionValues) + finalPartitionValues, projectedOriginalPartitionValues, isPartiallyClustered) } def project( @@ -867,7 +868,10 @@ case class KeyGroupedShuffleSpec( // transform functions. // 4. the partition values from both sides are following the same order. case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => - distribution.clustering.length == otherDistribution.clustering.length && + // SPARK-55848: partially-clustered partitioning is not compatible for SPJ + !partitioning.isPartiallyClustered && + !otherPartitioning.isPartiallyClustered && + distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { case (left, right) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala index e9d8e8e6d0fb3..0a70021dc858c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala @@ -58,7 +58,8 @@ trait KeyGroupedPartitionedScan[T] { } } basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length, - partitionValues = newPartValues) + partitionValues = newPartValues, + isPartiallyClustered = spjParams.applyPartialClustering) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 088ece6554c96..0d180bd336221 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -307,7 +307,7 @@ case class EnsureRequirements( private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = { (plan.outputPartitioning, distribution) match { - case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _), + case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _, _), d @ OrderedDistribution(ordering)) if p.satisfies(d) => val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute]) val partitionOrdering: Ordering[InternalRow] = { @@ -340,12 +340,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index f052bd9068805..c59bb4d39b096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -370,7 +370,7 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _, _) => + case k @ KeyGroupedPartitioning(expressions, n, _, _, _) => val valueMap = k.uniquePartitionValues.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap @@ -401,7 +401,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _, _) => + case KeyGroupedPartitioning(expressions, _, _, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case s: ShufflePartitionIdPassThrough => // For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index 1a0efa7c4aafb..80f360ed216f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + case KeyGroupedPartitioning(clustering, numPartitions, partValues, + originalPartValues, isPartiallyClustered) => KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, - originalPartValues) + originalPartValues, isPartiallyClustered) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 8a65cb623f6e6..122c511bf8358 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.distributions.Distributions import org.apache.spark.sql.connector.expressions._ import org.apache.spark.sql.connector.expressions.Expressions._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -94,13 +94,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, - partitionValues, partitionValues)) + partitionValues, partitionValues, isPartiallyClustered = false)) } test("non-clustered distribution: no partition") { @@ -2892,4 +2892,150 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row("ccc", 30, 400.50))) } } + + test("SPARK-55848: dropDuplicates after SPJ with partial clustering should produce " + + "correct results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // dropDuplicates on the join key after a partially-clustered SPJ must still + // produce the correct number of distinct ids. Before the fix, the + // partially-clustered partitioning was incorrectly treated as satisfying + // ClusteredDistribution, so EnsureRequirements did not insert an Exchange + // before the dedup, leading to duplicate rows. + val df = sql( + s""" + |SELECT DISTINCT i.id + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id + |""".stripMargin) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join dedup with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: Window dedup after SPJ with partial clustering should produce " + + "correct results") { + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Use ROW_NUMBER() OVER to dedup joined rows per id after a partially-clustered + // SPJ. The WINDOW operator requires ClusteredDistribution on i.id; with partial + // clustering the plan must insert a shuffle so that the window + // produces exactly one row per id. + val df = sql( + s""" + |SELECT id, price FROM ( + | SELECT i.id, i.price, + | ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn + | FROM testcat.ns.$items i + | JOIN testcat.ns.$purchases p ON i.id = p.item_id + |) t WHERE rn = 1 + |""".stripMargin) + checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the post-join window with partial clustering") + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning") + } + } + + test("SPARK-55848: checkpointed partially-clustered join with dedup") { + withTempDir { dir => + spark.sparkContext.setCheckpointDir(dir.getPath) + val items_partitions = Array(identity("id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + "(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 50.0, cast('2020-01-02' as timestamp)), " + + "(2, 11.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) { + // Checkpoint the JOIN result (not the scan) so the plan behind the + // checkpoint carries partially-clustered KeyGroupedPartitioning. + // The dedup on top must still insert an Exchange because the + // isPartiallyClustered flag causes satisfies0()=false for + // ClusteredDistribution. + val joinedDf = spark.sql( + s"""SELECT i.id, i.name, i.price + |FROM testcat.ns.$items i + |JOIN testcat.ns.$purchases p ON i.id = p.item_id""".stripMargin) + val checkpointedDf = joinedDf.checkpoint() + val df = checkpointedDf.select("id").distinct() + + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + + val allShuffles = collectAllShuffles(df.queryExecution.executedPlan) + assert(allShuffles.nonEmpty, + "should contain a shuffle for the dedup after checkpointed " + + "partially-clustered join") + + val rddScans = collect(df.queryExecution.executedPlan) { + case r: RDDScanExec => r + } + assert(rddScans.exists(_.outputPartitioning match { + case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered + case _ => false + }), "checkpoint (RDDScanExec) should have " + + "partially-clustered KeyGroupedPartitioning") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1cc0d795d74f8..ee852da536768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1129,7 +1129,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil)