From 40a90bc8d0266417b6e2b5c4cc912c6006aa95eb Mon Sep 17 00:00:00 2001 From: ilicmarkodb Date: Fri, 26 Dec 2025 13:56:00 +0100 Subject: [PATCH] temp: --- .../catalyst/expressions/CollationKey.scala | 61 +++++++++++++- .../joins/BroadcastHashJoinExec.scala | 28 ++++++- .../spark/sql/execution/joins/HashJoin.scala | 16 ++++ .../joins/ShuffledHashJoinExec.scala | 26 +++++- .../spark/sql/collation/CollationSuite.scala | 81 +++++++++++++++++++ .../execution/joins/BroadcastJoinSuite.scala | 4 +- 6 files changed, 210 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala index 5d2fd14eee298..9a0aaea75f810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.catalyst.util.{CollationFactory, UnsafeRowUtils} import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ArrayImplicits.SparkArrayOps case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = @@ -46,3 +47,61 @@ case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsIn override def child: Expression = expr } + +object CollationKey { + /** + * Recursively process the expression in order to recursively replace non-binary collated strings + * with their associated collation key. + */ + def injectCollationKey(expr: Expression): Expression = { + injectCollationKey(expr, expr.dataType) + } + + private def injectCollationKey(expr: Expression, dt: DataType): Expression = { + dt match { + // For binary stable expressions, no special handling is needed. + case _ if UnsafeRowUtils.isBinaryStable(dt) => + expr + + // Inject CollationKey for non-binary collated strings. + case _: StringType => + CollationKey(expr) + + // Recursively process struct fields for non-binary structs. + case StructType(fields) => + val transformed = fields.zipWithIndex.map { case (f, i) => + val originalField = GetStructField(expr, i, Some(f.name)) + val injected = injectCollationKey(originalField, f.dataType) + (f, injected, injected.fastEquals(originalField)) + } + val anyChanged = transformed.exists { case (_, _, same) => !same } + if (!anyChanged) { + expr + } else { + val struct = CreateNamedStruct( + transformed.flatMap { case (f, injected, _) => + Seq(Literal(f.name), injected) + }.toImmutableArraySeq) + if (expr.nullable) { + If(IsNull(expr), Literal(null, struct.dataType), struct) + } else { + struct + } + } + + // Recursively process array elements for non-binary arrays. + case ArrayType(et, containsNull) => + val param: NamedExpression = NamedLambdaVariable("a", et, containsNull) + val funcBody: Expression = injectCollationKey(param, et) + if (!funcBody.fastEquals(param)) { + ArrayTransform(expr, LambdaFunction(funcBody, Seq(param))) + } else { + expr + } + + // Joins are not supported on maps, so there's no special handling for MapType. + case _ => + expr + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index b62d8f0798b6a..944ee3b059092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioningLike, Partitioning, PartitioningCollection, UnspecifiedDistribution} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed * relation is not shuffled. */ -case class BroadcastHashJoinExec( +case class BroadcastHashJoinExec private( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, @@ -245,3 +245,27 @@ case class BroadcastHashJoinExec( newLeft: SparkPlan, newRight: SparkPlan): BroadcastHashJoinExec = copy(left = newLeft, right = newRight) } + +object BroadcastHashJoinExec extends JoinSelectionHelper { + def apply( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false): BroadcastHashJoinExec = { + val (normalizedLeftKeys, normalizedRightKeys) = HashJoin.normalizeJoinKeys(leftKeys, rightKeys) + + new BroadcastHashJoinExec( + normalizedLeftKeys, + normalizedRightKeys, + joinType, + buildSide, + condition, + left, + right, + isNullAwareAntiJoin) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index a1abb64e262df..fab14dba444dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator} import org.apache.spark.sql.execution.metric.SQLMetric @@ -41,6 +42,9 @@ private[joins] case class HashedRelationInfo( isEmpty: Boolean) trait HashJoin extends JoinCodegenSupport { + assert(leftKeys.forall(key => UnsafeRowUtils.isBinaryStable(key.dataType))) + assert(rightKeys.forall(key => UnsafeRowUtils.isBinaryStable(key.dataType))) + def buildSide: BuildSide override def simpleStringWithNodeId(): String = { @@ -724,6 +728,18 @@ trait HashJoin extends JoinCodegenSupport { object HashJoin extends CastSupport with SQLConfHelper { + /** + * Normalize join keys by injecting `CollationKey` when the keys are collated. + */ + def normalizeJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + ( + leftKeys.map(CollationKey.injectCollationKey), + rightKeys.map(CollationKey.injectCollationKey) + ) + } + private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = { // TODO: support BooleanType, DateType and TimestampType keys.forall(_.dataType.isInstanceOf[IntegralType]) && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 97ca74aee30c0..0f90f443ad41d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet} /** * Performs a hash join of two child relations by first shuffling the data using the join keys. */ -case class ShuffledHashJoinExec( +case class ShuffledHashJoinExec private ( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, @@ -659,3 +659,27 @@ case class ShuffledHashJoinExec( newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec = copy(left = newLeft, right = newRight) } + +object ShuffledHashJoinExec { + def apply( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isSkewJoin: Boolean = false): ShuffledHashJoinExec = { + val (normalizedLeftKeys, normalizedRightKeys) = HashJoin.normalizeJoinKeys(leftKeys, rightKeys) + + new ShuffledHashJoinExec( + normalizedLeftKeys, + normalizedRightKeys, + joinType, + buildSide, + condition, + left, + right, + isSkewJoin) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala index 6cdf681d65ca3..c84647066f25d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala @@ -2114,4 +2114,85 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"CREATE TABLE t (c STRING COLLATE system.builtin.UTF8_LCASE)") } } + + test("null aware anti join from NOT IN with collated columns") { + val expectedAnswer = Seq() + val (tableName1, tableName2) = ("t1", "t2") + withTable(tableName1, tableName2) { + sql(s"CREATE TABLE $tableName1 (C1 STRING COLLATE UTF8_LCASE_RTRIM)") + sql(s"CREATE TABLE $tableName2 (C1 STRING COLLATE UTF8_LCASE_RTRIM)") + sql(s"INSERT INTO $tableName1 VALUES ('a')") + sql(s"INSERT INTO $tableName2 VALUES ('A ')") + + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + expectedAnswer) + + sql(s"INSERT INTO $tableName1 VALUES (NULL)") + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + expectedAnswer) + + sql(s"INSERT INTO $tableName1 VALUES ('b')") + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + expectedAnswer ++ Seq(Row("b"))) + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)" + + s" AND C1 = 'B '"), Row("b")) + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)" + + s" AND C1 > 'b'"), Seq()) + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)" + + s" AND C1 = 'c'"), Seq()) + + // This case results in empty output due to NULL in the t2. + sql(s"INSERT INTO $tableName2 VALUES (NULL)") + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + Seq()) + } + } + + test("null aware anti join from NOT IN with collated columns in array type") { + val expectedAnswer = Seq() + val (tableName1, tableName2) = ("t1", "t2") + withTable(tableName1, tableName2) { + sql(s"CREATE TABLE $tableName1 (C1 ARRAY)") + sql(s"CREATE TABLE $tableName2 (C1 ARRAY)") + sql(s"INSERT INTO $tableName1 VALUES (ARRAY('a ', 'Aa '))") + sql(s"INSERT INTO $tableName2 VALUES (ARRAY('A', 'aa'))") + + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + expectedAnswer) + + sql(s"INSERT INTO $tableName1 VALUES (NULL)") + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + expectedAnswer) + + // This case results in empty output due to NULL in the t2. + sql(s"INSERT INTO $tableName2 VALUES (NULL)") + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + Seq()) + } + } + + test("null aware anti join from NOT IN with collated columns in struct type") { + val expectedAnswer = Seq() + val (tableName1, tableName2) = ("t1", "t2") + withTable(tableName1, tableName2) { + sql(s"CREATE TABLE $tableName1 (C1 STRUCT)") + sql(s"CREATE TABLE $tableName2 (C1 STRUCT)") + sql(s"INSERT INTO $tableName1 VALUES (named_struct('x', 'a ', 'y', 'Aa '))") + sql(s"INSERT INTO $tableName2 VALUES (named_struct('x', 'A', 'y', 'aa'))") + + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + expectedAnswer) + + sql(s"INSERT INTO $tableName1 VALUES (NULL)") + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + expectedAnswer) + + // This case results in empty output due to NULL in the t2. + sql(s"INSERT INTO $tableName2 VALUES (NULL)") + checkAnswer(sql(s"SELECT * FROM $tableName1 WHERE C1 NOT IN (SELECT * FROM $tableName2)"), + Seq()) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 69dd04e07d551..9bd858608cb9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -397,8 +397,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils } } - private val bh = BroadcastHashJoinExec.toString - private val bl = BroadcastNestedLoopJoinExec.toString + private val bh = classOf[BroadcastHashJoinExec].getSimpleName + private val bl = classOf[BroadcastNestedLoopJoinExec].getSimpleName private def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { val executedPlan = stripAQEPlan(sql(sqlStr).queryExecution.executedPlan)