Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,69 @@ package org.apache.spark.sql.execution

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.plans.{AliasAwareOutputExpression, AliasAwareQueryOutputOrdering}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}

/**
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` that
* satisfies distribution requirements.
*/
trait PartitioningPreservingUnaryExecNode extends UnaryExecNode
with AliasAwareOutputExpression {
/**
* Builds an ExprId -> Attribute map for direct remapping of KeyedPartitioning expressions
* through column aliases (e.g. `id AS new_id`). This is needed because KeyedPartitioning
* carries @transient partitionKeys that must be preserved while only the expressions are
* remapped. See SPARK-46367.
*/
private def buildExprIdAliasMap: Map[Long, Attribute] =
outputExpressions.flatMap {
case a @ Alias(child: Attribute, _) => Some(child.exprId.id -> a.toAttribute)
case _ => None
}.toMap

/**
* Remaps the partition expressions in a KeyedPartitioning through the alias map built from
* `outputExpressions`, replacing old AttributeReferences with their aliased counterparts
* by ExprId. Non-aliased attributes that are absent from the current output are pruned,
* causing the partitioning to be dropped (returns None). See SPARK-46367.
*/
private def remapKeyedPartitioning(
kp: KeyedPartitioning,
exprIdAliasMap: Map[Long, Attribute]): Option[KeyedPartitioning] = {
val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
def remap(expr: Expression): Option[Expression] = expr match {
case attr: Attribute =>
exprIdAliasMap.get(attr.exprId.id)
.orElse(if (outputSet.contains(attr)) Some(attr) else None)
case other =>
// For transform expressions (e.g. years(ts)), remap children recursively.
val newChildren = other.children.map(remap)
if (newChildren.forall(_.isDefined)) {
Some(other.withNewChildren(newChildren.map(_.get)))
} else {
None
}
}
val newExpressions = kp.expressions.map(remap)
if (newExpressions.forall(_.isDefined)) {
Some(kp.copy(expressions = newExpressions.map(_.get)))
} else {
None
}
}

final override def outputPartitioning: Partitioning = {
val partitionings: Seq[Partitioning] = if (hasAlias) {
val exprIdAliasMap = buildExprIdAliasMap
flattenPartitioning(child.outputPartitioning).iterator.flatMap {
case k: KeyedPartitioning =>
// SPARK-46367: KeyedPartitioning must be remapped via direct ExprId substitution
// rather than the generic projectExpression/multiTransformDown path, because
// multiTransformDown may not correctly propagate remappings through KeyedPartitioning's
// @transient partitionKeys field.
remapKeyedPartitioning(k, exprIdAliasMap).toSeq
case e: Expression =>
// We need unique partitionings but if the input partitioning is
// `HashPartitioning(Seq(id + id))` and we have `id -> a` and `id -> b` aliases then after
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3913,4 +3913,51 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
}
}
}

test("SPARK-46367: KeyedPartitioning should be remapped through column aliases in ProjectExec") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
// ProjectExec renames i.id -> new_id, introducing a new ExprId.
// The downstream GROUP BY on new_id requires ClusteredDistribution on new_id.
// If KeyedPartitioning is not remapped through the alias, EnsureRequirements
// will insert a spurious Exchange shuffle before the aggregation (SPARK-46367).
val df = sql(
s"""
|SELECT new_id, SUM(price)
|FROM (
| SELECT i.id AS new_id, i.price
| FROM testcat.ns.$items i
| JOIN testcat.ns.$purchases p ON i.id = p.item_id
|) t
|GROUP BY new_id
|""".stripMargin)

val executedPlan = df.queryExecution.executedPlan
val shuffles = collectAllShuffles(executedPlan)
assert(shuffles.isEmpty,
"SPARK-46367: KeyedPartitioning was not remapped through the alias in ProjectExec, " +
"causing EnsureRequirements to insert an unnecessary Exchange before GROUP BY. " +
s"Found ${shuffles.size} shuffle(s):\n" +
executedPlan)

// SUM(i.price) per item: items prices are 40.0f, 10.0f, 15.5f (float -> double for SUM)
checkAnswer(df.sort("new_id"),
Seq(Row(1L, 40.0), Row(2L, 10.0), Row(3L, 15.5)))
}
}
}