|
15 | 15 | package org.apache.spark.sql.clickhouse |
16 | 16 |
|
17 | 17 | import org.apache.spark.sql.AnalysisException |
18 | | -import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException |
| 18 | +import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, NoSuchFunctionException, TypeCoercion} |
| 19 | +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, ListQuery, Literal} |
| 20 | +import org.apache.spark.sql.catalyst.expressions.{TimeZoneAwareExpression, TransformExpression, V2ExpressionUtils} |
| 21 | +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} |
| 22 | +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} |
| 23 | +import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION} |
19 | 24 | import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper} |
20 | | -import org.apache.spark.sql.catalyst.expressions.BoundReference |
21 | | -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, TransformExpression, V2ExpressionUtils} |
22 | 25 | import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM |
23 | 26 | import org.apache.spark.sql.connector.catalog.Identifier |
24 | 27 | import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} |
25 | 28 | import org.apache.spark.sql.connector.expressions.Expressions._ |
26 | | -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder} |
27 | | -import org.apache.spark.sql.connector.expressions._ |
| 29 | +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder, _} |
28 | 30 | import org.apache.spark.sql.types.{StructField, StructType} |
29 | 31 | import xenon.clickhouse.exception.CHClientException |
30 | 32 | import xenon.clickhouse.expr._ |
@@ -94,15 +96,50 @@ object ExprUtils extends SQLConfHelper with Serializable { |
94 | 96 | def resolveTransformCatalyst( |
95 | 97 | catalystExpr: Expression, |
96 | 98 | timeZoneId: Option[String] = None |
97 | | - ): Expression = catalystExpr match { |
98 | | - case TransformExpression(function: ScalarFunction[_], args, _) => |
99 | | - val resolvedArgs: Seq[Expression] = args.map(resolveTransformCatalyst(_, timeZoneId)) |
100 | | - val castedArgs: Seq[Expression] = resolvedArgs.zip(function.inputTypes()).map { |
101 | | - case (arg, expectedType) if !arg.dataType.sameType(expectedType) => Cast(arg, expectedType, timeZoneId) |
102 | | - case (arg, _) => arg |
103 | | - } |
104 | | - V2ExpressionUtils.resolveScalarFunction(function, castedArgs) |
105 | | - case other => other |
| 99 | + ): Expression = |
| 100 | + new TypeCoercionExecutor(timeZoneId) |
| 101 | + .execute(DummyLeafNode(resolveTransformExpression(catalystExpr))) |
| 102 | + .asInstanceOf[DummyLeafNode].expr |
| 103 | + |
| 104 | + private case class DummyLeafNode(expr: Expression) extends LeafNode { |
| 105 | + override def output: Seq[Attribute] = Nil |
| 106 | + } |
| 107 | + |
| 108 | + private class CustomResolveTimeZone(timeZoneId: Option[String]) extends Rule[LogicalPlan] { |
| 109 | + private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = { |
| 110 | + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => |
| 111 | + e.withTimeZone(timeZoneId.getOrElse(conf.sessionLocalTimeZone)) |
| 112 | + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing |
| 113 | + // the types between the value expression and list query expression of IN expression. |
| 114 | + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone |
| 115 | + // information for time zone aware expressions. |
| 116 | + case e: ListQuery => e.withNewPlan(apply(e.plan)) |
| 117 | + } |
| 118 | + |
| 119 | + override def apply(plan: LogicalPlan): LogicalPlan = |
| 120 | + plan.resolveExpressionsWithPruning( |
| 121 | + _.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION), |
| 122 | + ruleId |
| 123 | + )(transformTimeZoneExprs) |
| 124 | + } |
| 125 | + |
| 126 | + private class TypeCoercionExecutor(timeZoneId: Option[String]) extends RuleExecutor[LogicalPlan] { |
| 127 | + override val batches = |
| 128 | + Batch("Resolve TypeCoercion", FixedPoint(1), typeCoercionRules: _*) :: |
| 129 | + Batch("Resolve TimeZone", FixedPoint(1), new CustomResolveTimeZone(timeZoneId)) :: Nil |
| 130 | + } |
| 131 | + |
| 132 | + private def resolveTransformExpression(expr: Expression): Expression = expr.transform { |
| 133 | + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => |
| 134 | + V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments) |
| 135 | + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => |
| 136 | + V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments) |
| 137 | + } |
| 138 | + |
| 139 | + private def typeCoercionRules: List[Rule[LogicalPlan]] = if (conf.ansiEnabled) { |
| 140 | + AnsiTypeCoercion.typeCoercionRules |
| 141 | + } else { |
| 142 | + TypeCoercion.typeCoercionRules |
106 | 143 | } |
107 | 144 |
|
108 | 145 | def toCatalyst( |
@@ -142,7 +179,7 @@ object ExprUtils extends SQLConfHelper with Serializable { |
142 | 179 | def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression = |
143 | 180 | expr match { |
144 | 181 | case FieldRef(col) => identity(col) |
145 | | - case StringLiteral(value) => literal(value) |
| 182 | + case StringLiteral(value) => literal(value) // TODO LiteralTransform |
146 | 183 | case FuncExpr("rand", Nil) => apply("rand") |
147 | 184 | case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col) |
148 | 185 | case FuncExpr(funName, args) if functionRegistry.clickHouseToSparkFunc.contains(funName) => |
|
0 commit comments