Skip to content

Commit 4ee49bf

Browse files
committed
refactor
1 parent 0847757 commit 4ee49bf

File tree

2 files changed

+67
-32
lines changed

2 files changed

+67
-32
lines changed

spark-3.4/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@
1515
package org.apache.spark.sql.clickhouse
1616

1717
import org.apache.spark.sql.AnalysisException
18-
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
18+
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, NoSuchFunctionException, TypeCoercion}
19+
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, ListQuery, Literal}
20+
import org.apache.spark.sql.catalyst.expressions.{TimeZoneAwareExpression, TransformExpression, V2ExpressionUtils}
21+
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
22+
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
23+
import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION}
1924
import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper}
20-
import org.apache.spark.sql.catalyst.expressions.BoundReference
21-
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, TransformExpression, V2ExpressionUtils}
2225
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM
2326
import org.apache.spark.sql.connector.catalog.Identifier
2427
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
2528
import org.apache.spark.sql.connector.expressions.Expressions._
26-
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder}
27-
import org.apache.spark.sql.connector.expressions._
29+
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder, _}
2830
import org.apache.spark.sql.types.{StructField, StructType}
2931
import xenon.clickhouse.exception.CHClientException
3032
import xenon.clickhouse.expr._
@@ -94,15 +96,50 @@ object ExprUtils extends SQLConfHelper with Serializable {
9496
def resolveTransformCatalyst(
9597
catalystExpr: Expression,
9698
timeZoneId: Option[String] = None
97-
): Expression = catalystExpr match {
98-
case TransformExpression(function: ScalarFunction[_], args, _) =>
99-
val resolvedArgs: Seq[Expression] = args.map(resolveTransformCatalyst(_, timeZoneId))
100-
val castedArgs: Seq[Expression] = resolvedArgs.zip(function.inputTypes()).map {
101-
case (arg, expectedType) if !arg.dataType.sameType(expectedType) => Cast(arg, expectedType, timeZoneId)
102-
case (arg, _) => arg
103-
}
104-
V2ExpressionUtils.resolveScalarFunction(function, castedArgs)
105-
case other => other
99+
): Expression =
100+
new TypeCoercionExecutor(timeZoneId)
101+
.execute(DummyLeafNode(resolveTransformExpression(catalystExpr)))
102+
.asInstanceOf[DummyLeafNode].expr
103+
104+
private case class DummyLeafNode(expr: Expression) extends LeafNode {
105+
override def output: Seq[Attribute] = Nil
106+
}
107+
108+
private class CustomResolveTimeZone(timeZoneId: Option[String]) extends Rule[LogicalPlan] {
109+
private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = {
110+
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
111+
e.withTimeZone(timeZoneId.getOrElse(conf.sessionLocalTimeZone))
112+
// Casts could be added in the subquery plan through the rule TypeCoercion while coercing
113+
// the types between the value expression and list query expression of IN expression.
114+
// We need to subject the subquery plan through ResolveTimeZone again to setup timezone
115+
// information for time zone aware expressions.
116+
case e: ListQuery => e.withNewPlan(apply(e.plan))
117+
}
118+
119+
override def apply(plan: LogicalPlan): LogicalPlan =
120+
plan.resolveExpressionsWithPruning(
121+
_.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION),
122+
ruleId
123+
)(transformTimeZoneExprs)
124+
}
125+
126+
private class TypeCoercionExecutor(timeZoneId: Option[String]) extends RuleExecutor[LogicalPlan] {
127+
override val batches =
128+
Batch("Resolve TypeCoercion", FixedPoint(1), typeCoercionRules: _*) ::
129+
Batch("Resolve TimeZone", FixedPoint(1), new CustomResolveTimeZone(timeZoneId)) :: Nil
130+
}
131+
132+
private def resolveTransformExpression(expr: Expression): Expression = expr.transform {
133+
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
134+
V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments)
135+
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
136+
V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)
137+
}
138+
139+
private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) {
140+
AnsiTypeCoercion.typeCoercionRules
141+
} else {
142+
TypeCoercion.typeCoercionRules
106143
}
107144

108145
def toCatalyst(
@@ -142,7 +179,7 @@ object ExprUtils extends SQLConfHelper with Serializable {
142179
def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression =
143180
expr match {
144181
case FieldRef(col) => identity(col)
145-
case StringLiteral(value) => literal(value)
182+
case StringLiteral(value) => literal(value) // TODO LiteralTransform
146183
case FuncExpr("rand", Nil) => apply("rand")
147184
case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col)
148185
case FuncExpr(funName, args) if functionRegistry.clickHouseToSparkFunc.contains(funName) =>

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/ClickHouseWriter.scala

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ package xenon.clickhouse.write
1717
import com.clickhouse.client.ClickHouseProtocol
1818
import com.clickhouse.data.ClickHouseCompression
1919
import org.apache.commons.io.IOUtils
20-
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, SafeProjection, TransformExpression}
20+
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, TransformExpression}
21+
import org.apache.spark.sql.catalyst.expressions.{Projection, SafeProjection}
2122
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
2223
import org.apache.spark.sql.clickhouse.ExprUtils
2324
import org.apache.spark.sql.connector.metric.CustomTaskMetric
@@ -118,26 +119,23 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription)
118119

119120
def calcShard(record: InternalRow): Option[Int] = (shardExpr, shardProjection) match {
120121
case (Some(BoundReference(_, dataType, _)), Some(projection)) =>
121-
val shardValue = dataType match {
122-
case ByteType => Some(projection(record).getByte(0).toLong)
123-
case ShortType => Some(projection(record).getShort(0).toLong)
124-
case IntegerType => Some(projection(record).getInt(0).toLong)
125-
case LongType => Some(projection(record).getLong(0))
126-
case _ => None
127-
}
128-
shardValue.map(value => ShardUtils.calcShard(writeJob.cluster.get, value).num)
122+
doCalcShard(record, dataType, projection)
129123
case (Some(TransformExpression(function, _, _)), Some(projection)) =>
130-
val shardValue = function.resultType match {
131-
case ByteType => Some(projection(record).getByte(0).toLong)
132-
case ShortType => Some(projection(record).getShort(0).toLong)
133-
case IntegerType => Some(projection(record).getInt(0).toLong)
134-
case LongType => Some(projection(record).getLong(0))
135-
case _ => None
136-
}
137-
shardValue.map(value => ShardUtils.calcShard(writeJob.cluster.get, value).num)
124+
doCalcShard(record, function.resultType, projection)
138125
case _ => None
139126
}
140127

128+
private def doCalcShard(record: InternalRow, dataType: DataType, projection: Projection): Option[Int] = {
129+
val shardValue = dataType match {
130+
case ByteType => Some(projection(record).getByte(0).toLong)
131+
case ShortType => Some(projection(record).getShort(0).toLong)
132+
case IntegerType => Some(projection(record).getInt(0).toLong)
133+
case LongType => Some(projection(record).getLong(0))
134+
case _ => None
135+
}
136+
shardValue.map(value => ShardUtils.calcShard(writeJob.cluster.get, value).num)
137+
}
138+
141139
val _currentBufferedRows = new LongAdder
142140
def currentBufferedRows: Long = _currentBufferedRows.longValue
143141
val _totalRecordsWritten = new LongAdder

0 commit comments

Comments
 (0)