Skip to content

Commit fb3bf50

Browse files
committed
Spark 3.4: Cast type when calling projection, support recursive resolve
(cherry picked from commit 936a18a)
1 parent 2c08b48 commit fb3bf50

File tree

3 files changed

+45
-28
lines changed

3 files changed

+45
-28
lines changed

spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterShardByTransformSuite.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,21 @@ class ClusterShardByTransformSuite extends SparkClickHouseClusterTest {
9494
}
9595

9696
Seq(
97+
// wait for SPARK-44180 to be fixed, then add implicit cast test cases
9798
("toYear", Array("create_date")),
99+
// ("toYear", Array("create_time")),
98100
("toYYYYMM", Array("create_date")),
101+
// ("toYYYYMM", Array("create_time")),
99102
("toYYYYMMDD", Array("create_date")),
103+
// ("toYYYYMMDD", Array("create_time")),
100104
("toHour", Array("create_time")),
101105
("xxHash64", Array("value")),
102106
("murmurHash2_64", Array("value")),
103107
("murmurHash2_32", Array("value")),
104108
("murmurHash3_64", Array("value")),
105109
("murmurHash3_32", Array("value")),
106-
("cityHash64", Array("value"))
110+
("cityHash64", Array("value")),
111+
("positiveModulo", Array("toYYYYMM(create_date)", "10"))
107112
).foreach {
108113
case (func_name: String, func_args: Array[String]) =>
109114
test(s"shard by $func_name(${func_args.mkString(",")})")(runTest(func_name, func_args))

spark-3.4/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@
1515
package org.apache.spark.sql.clickhouse
1616

1717
import org.apache.spark.sql.AnalysisException
18-
import org.apache.spark.sql.catalyst.SQLConfHelper
1918
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
20-
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, TransformExpression}
19+
import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper}
20+
import org.apache.spark.sql.catalyst.expressions.{
21+
BoundReference,
22+
Cast,
23+
Expression,
24+
TransformExpression,
25+
V2ExpressionUtils
26+
}
2127
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM
2228
import org.apache.spark.sql.connector.catalog.Identifier
2329
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
@@ -89,6 +95,20 @@ object ExprUtils extends SQLConfHelper with Serializable {
8995
}
9096
}
9197

98+
def resolveTransformCatalyst(
99+
catalystExpr: Expression,
100+
timeZoneId: Option[String] = None
101+
): Expression = catalystExpr match {
102+
case TransformExpression(function: ScalarFunction[_], args, _) =>
103+
val resolvedArgs: Seq[Expression] = args.map(resolveTransformCatalyst(_, timeZoneId))
104+
val castedArgs: Seq[Expression] = resolvedArgs.zip(function.inputTypes()).map {
105+
case (arg, expectedType) if !arg.dataType.sameType(expectedType) => Cast(arg, expectedType, timeZoneId)
106+
case (arg, _) => arg
107+
}
108+
V2ExpressionUtils.resolveScalarFunction(function, castedArgs)
109+
case other => other
110+
}
111+
92112
def toCatalyst(
93113
v2Expr: V2Expression,
94114
fields: Array[StructField],
@@ -108,6 +128,7 @@ object ExprUtils extends SQLConfHelper with Serializable {
108128
.map(bound => TransformExpression(bound, catalystArgs)).getOrElse {
109129
throw CHClientException(s"Unsupported expression: $v2Expr")
110130
}
131+
case literal: LiteralValue[Any] => expressions.Literal(literal.value)
111132
case _ => throw CHClientException(
112133
s"Unsupported expression: $v2Expr"
113134
)

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/ClickHouseWriter.scala

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,9 @@ package xenon.clickhouse.write
1717
import com.clickhouse.client.ClickHouseProtocol
1818
import com.clickhouse.data.ClickHouseCompression
1919
import org.apache.commons.io.IOUtils
20-
import org.apache.spark.sql.catalyst.expressions.{
21-
BoundReference,
22-
Expression,
23-
SafeProjection,
24-
TransformExpression,
25-
V2ExpressionUtils
26-
}
20+
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, SafeProjection, TransformExpression}
2721
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
2822
import org.apache.spark.sql.clickhouse.ExprUtils
29-
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
3023
import org.apache.spark.sql.connector.metric.CustomTaskMetric
3124
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
3225
import org.apache.spark.sql.types._
@@ -86,23 +79,21 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription)
8679

8780
protected lazy val shardProjection: Option[expressions.Projection] = shardExpr
8881
.filter(_ => writeJob.writeOptions.convertDistributedToLocal)
89-
.flatMap(expr =>
90-
expr match {
91-
case BoundReference(_, _, _) =>
92-
Some(SafeProjection.create(Seq(expr)))
93-
case TransformExpression(function, args, _) =>
94-
val retType = function.resultType() match {
95-
case ByteType => classOf[Byte]
96-
case ShortType => classOf[Short]
97-
case IntegerType => classOf[Int]
98-
case LongType => classOf[Long]
99-
case _ => throw CHClientException(s"Invalid return data type for function ${function.name()}," +
100-
s"sharding field: ${function.resultType()}")
101-
}
102-
val expr = V2ExpressionUtils.resolveScalarFunction(function.asInstanceOf[ScalarFunction[retType.type]], args)
103-
Some(SafeProjection.create(Seq(expr)))
104-
}
105-
)
82+
.flatMap {
83+
case expr: BoundReference =>
84+
Some(SafeProjection.create(Seq(expr)))
85+
case expr @ TransformExpression(function, _, _) =>
86+
// result type must be integer class
87+
function.resultType() match {
88+
case ByteType => classOf[Byte]
89+
case ShortType => classOf[Short]
90+
case IntegerType => classOf[Int]
91+
case LongType => classOf[Long]
92+
case _ => throw CHClientException(s"Invalid return data type for function ${function.name()}," +
93+
s"sharding field: ${function.resultType()}")
94+
}
95+
Some(SafeProjection.create(Seq(ExprUtils.resolveTransformCatalyst(expr, Some(writeJob.tz.getId)))))
96+
}
10697

10798
// put the node select strategy in executor side because we need to calculate shard and don't know the records
10899
// util DataWriter#write(InternalRow) invoked.

0 commit comments

Comments
 (0)