Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,27 @@ case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty,
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike {
originalPartitionValues: Seq[InternalRow] = Seq.empty,
isPartiallyClustered: Boolean = false) extends HashPartitioningLike {

override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
// When partial clustering is active, the same partition key is intentionally spread
// across multiple tasks (partitionValues contains duplicates). In that case, this
// partitioning does NOT satisfy ClusteredDistribution, because ClusteredDistribution
// requires all rows with the same key to be co-located in a single task. Without this
// guard, downstream operators such as dropDuplicates or Window functions would skip
// their required shuffle and produce incorrect results.
// See SPARK-54378 / SPARK-55848.
Copy link
Copy Markdown
Contributor

@peter-toth peter-toth Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is SPARK-54378 related to this issue?

//
// We must check isPartiallyClustered BEFORE delegating to super.satisfies0, because
// HashPartitioningLike.satisfies0 also matches ClusteredDistribution and would return
// true based on expression matching alone, short-circuiting the partial-clustering check.
val isFullyClustered = !isPartiallyClustered && (partitionValues.isEmpty ||
uniquePartitionValues.length == partitionValues.length)

required match {
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
isFullyClustered && {
if (requireAllClusterKeys) {
// Checks whether this partitioning is partitioned on exactly same clustering keys of
// `ClusteredDistribution`.
Expand All @@ -402,13 +417,19 @@ case class KeyGroupedPartitioning(
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
}
}

case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
o.areAllClusterKeysMatched(expressions)

case _ =>
false
}
case _ =>
// For non-ClusteredDistribution cases, delegate to super (handles
// StatefulOpClusteredDistribution, OrderedDistribution, etc.)
super.satisfies0(required) || {
required match {
case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
o.areAllClusterKeysMatched(expressions)
case _ =>
false
}
}
}
}

Expand All @@ -420,7 +441,7 @@ case class KeyGroupedPartitioning(
// the returned shuffle spec.
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
partitionValues, originalPartitionValues)
partitionValues, originalPartitionValues, isPartiallyClustered = isPartiallyClustered)
result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
} else {
result
Expand All @@ -438,15 +459,16 @@ case class KeyGroupedPartitioning(
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(expressions = newChildren)
copy(expressions = newChildren, isPartiallyClustered = isPartiallyClustered)
}

object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
projectionPositions: Seq[Int],
partitionValues: Seq[InternalRow],
originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
originalPartitionValues: Seq[InternalRow],
isPartiallyClustered: Boolean): KeyGroupedPartitioning = {
val projectedExpressions = projectionPositions.map(expressions(_))
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
val projectedOriginalPartitionValues =
Expand All @@ -461,7 +483,8 @@ object KeyGroupedPartitioning {
.map(_.row)

KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
finalPartitionValues, projectedOriginalPartitionValues)
finalPartitionValues, projectedOriginalPartitionValues,
isPartiallyClustered = isPartiallyClustered)
}

def project(
Expand Down Expand Up @@ -873,15 +896,21 @@ case class KeyGroupedShuffleSpec(
// transform functions.
// 4. the partition values from both sides are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
lazy val internalRowComparableFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
partitioning.expressions.map(_.dataType))
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
case (left, right) =>
internalRowComparableFactory(left).equals(internalRowComparableFactory(right))
}
// SPARK-55848: If either side is partially clustered, they are not compatible
// for shuffle purposes because rows with the same key may be spread across tasks.
if (partitioning.isPartiallyClustered || otherPartitioning.isPartiallyClustered) {
false
} else {
lazy val internalRowComparableFactory =
InternalRowComparableWrapper.getInternalRowComparableWrapperFactory(
partitioning.expressions.map(_.dataType))
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
case (left, right) =>
internalRowComparableFactory(left).equals(internalRowComparableFactory(right))
}
}
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ => false
Expand Down Expand Up @@ -961,7 +990,8 @@ case class KeyGroupedShuffleSpec(
}
KeyGroupedPartitioning(newExpressions,
partitioning.numPartitions,
partitioning.partitionValues)
partitioning.partitionValues,
isPartiallyClustered = partitioning.isPartiallyClustered)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ trait KeyGroupedPartitionedScan[T] {
}
}
basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
partitionValues = newPartValues,
isPartiallyClustered = spjParams.applyPartialClustering)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ case class EnsureRequirements(

private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = {
(plan.outputPartitioning, distribution) match {
case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _),
case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _, _),
d @ OrderedDistribution(ordering)) if p.satisfies(d) =>
val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute])
val partitionOrdering: Ordering[InternalRow] = {
Expand Down Expand Up @@ -340,12 +340,12 @@ case class EnsureRequirements(
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) =>
case (Some(KeyGroupedPartitioning(clustering, _, _, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) =>
case (_, Some(KeyGroupedPartitioning(clustering, _, _, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ object ShuffleExchangeExec {
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case k @ KeyGroupedPartitioning(expressions, n, _, _) =>
case k @ KeyGroupedPartitioning(expressions, n, _, _, _) =>
val valueMap = k.uniquePartitionValues.zipWithIndex.map {
case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index)
}.toMap
Expand Down Expand Up @@ -401,7 +401,7 @@ object ShuffleExchangeExec {
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _, _) =>
case KeyGroupedPartitioning(expressions, _, _, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case s: ShufflePartitionIdPassThrough =>
// For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase
plan: QueryPlan[T]): Partitioning = partitioning match {
case HashPartitioning(exprs, numPartitions) =>
HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions)
case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) =>
case KeyGroupedPartitioning(clustering, numPartitions, partValues,
originalPartValues, partiallyClustered) =>
KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues,
originalPartValues)
originalPartValues, partiallyClustered)
case PartitioningCollection(partitionings) =>
PartitioningCollection(partitionings.map(resolvePartitioning(_, plan)))
case RangePartitioning(ordering, numPartitions) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {

checkQueryPlan(df, catalystDistribution,
physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions,
partitionValues, partitionValues))
partitionValues, partitionValues, isPartiallyClustered = false))

// multiple group keys should work too as long as partition keys are subset of them
df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts")
checkQueryPlan(df, catalystDistribution,
physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions,
partitionValues, partitionValues))
partitionValues, partitionValues, isPartiallyClustered = false))
}

test("non-clustered distribution: no partition") {
Expand Down Expand Up @@ -1130,6 +1130,131 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}

test("[SPARK-54378] dropDuplicates after SPJ with partial clustering should give correct " +
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please start the new test names with SPARK-55848:

"results") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
// Insert two copies of id=1 so that the left side has duplicate rows for id=1 after the join,
// and three distinct id values in total.
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { partiallyClustered =>
withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to turn REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION off for this test?

V2_BUCKETING_PUSH_PART_VALUES_ENABLED is enabled by default so we can omit it.

SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
partiallyClustered.toString) {
// dropDuplicates on the join key: must produce exactly 3 distinct id values regardless
// of whether partial clustering is active.
val df = sql(
s"""
|SELECT DISTINCT i.id
|FROM testcat.ns.$items i
|JOIN testcat.ns.$purchases p ON i.id = p.item_id
|""".stripMargin)
checkAnswer(df, Seq(Row(1), Row(2), Row(3)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please check the presence of shuffles and the number of partitons of scans are the expected?

}
}
}

test("[SPARK-54378] Window dedup after SPJ with partial clustering should give correct " +
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

"results") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
// Two rows with id=1 so that a naive per-task row_number() without a shuffle would
// keep both when partial clustering splits them across tasks.
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

Seq(true, false).foreach { partiallyClustered =>
withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key ->
partiallyClustered.toString) {
// row_number() OVER (PARTITION BY id) should produce rn=1 for exactly one row per id.
val df = sql(
s"""
|SELECT id, price
|FROM (
| SELECT i.id, i.price,
| ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn
| FROM testcat.ns.$items i
| JOIN testcat.ns.$purchases p ON i.id = p.item_id
|) t
|WHERE rn = 1
|""".stripMargin)
// For id=1 only the row with the highest price (41.0) should survive.
checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f)))
}
}
}

test("SPARK-55848: dropDuplicates after SPJ with partial clustering should produce " +
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test different to the first one?

"correct results") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
// Two rows for id=1 so partial clustering may split them across tasks
sql(
s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(
s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 50.0, cast('2020-01-02' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString,
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> true.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) {
// dropDuplicates on the join key after a partially-clustered SPJ must still
// produce the correct number of distinct ids. Before SPARK-55848, the
// isPartiallyClustered flag was missing, so EnsureRequirements did not insert
// an Exchange before the dedup, leading to duplicate rows.
val df = sql(s"""
|SELECT DISTINCT i.id
|FROM testcat.ns.$items i
|JOIN testcat.ns.$purchases p ON i.id = p.item_id
|""".stripMargin)
checkAnswer(df, Seq(Row(1), Row(2), Row(3)))

// Also verify the plan inserts a shuffle for the dedup when partial clustering is active.
val allShuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(
allShuffles.nonEmpty,
"should contain a shuffle for the post-join dedup with partial clustering")
}
}

test("data source partitioning + dynamic partition filtering") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
EnsureRequirements.apply(smjExec) match {
case ShuffledHashJoinExec(_, _, _, _, _,
DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _),
ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _),
ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _, _),
DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) =>
assert(left.expressions == a1 :: Nil)
assert(attrs == a1 :: Nil)
Expand Down