|
15 | 15 | package org.apache.spark.sql.clickhouse |
16 | 16 |
|
17 | 17 | import org.apache.spark.sql.AnalysisException |
18 | | -import org.apache.spark.sql.catalyst.SQLConfHelper |
19 | | -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression} |
| 18 | +import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException |
| 19 | +import org.apache.spark.sql.catalyst.{expressions, SQLConfHelper} |
| 20 | +import org.apache.spark.sql.catalyst.expressions.BoundReference |
| 21 | +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, TransformExpression, V2ExpressionUtils} |
20 | 22 | import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.IGNORE_UNSUPPORTED_TRANSFORM |
| 23 | +import org.apache.spark.sql.connector.catalog.Identifier |
| 24 | +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} |
21 | 25 | import org.apache.spark.sql.connector.expressions.Expressions._ |
22 | | -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, _} |
23 | | -import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} |
| 26 | +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, SortOrder => V2SortOrder} |
| 27 | +import org.apache.spark.sql.connector.expressions._ |
| 28 | +import org.apache.spark.sql.types.{StructField, StructType} |
24 | 29 | import xenon.clickhouse.exception.CHClientException |
25 | 30 | import xenon.clickhouse.expr._ |
| 31 | +import xenon.clickhouse.func.FunctionRegistry |
| 32 | +import xenon.clickhouse.spec.ClusterSpec |
26 | 33 |
|
27 | | -import scala.annotation.tailrec |
28 | 34 | import scala.util.{Failure, Success, Try} |
29 | 35 |
|
30 | | -object ExprUtils extends SQLConfHelper { |
| 36 | +object ExprUtils extends SQLConfHelper with Serializable { |
31 | 37 |
|
32 | | - def toSparkPartitions(partitionKey: Option[List[Expr]]): Array[Transform] = |
33 | | - partitionKey.seq.flatten.flatten(toSparkTransformOpt).toArray |
| 38 | + def toSparkPartitions( |
| 39 | + partitionKey: Option[List[Expr]], |
| 40 | + functionRegistry: FunctionRegistry |
| 41 | + ): Array[Transform] = |
| 42 | + partitionKey.seq.flatten.flatten(toSparkTransformOpt(_, functionRegistry)).toArray |
34 | 43 |
|
35 | | - def toSparkSplits(shardingKey: Option[Expr], partitionKey: Option[List[Expr]]): Array[Transform] = |
36 | | - (shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt).toArray |
| 44 | + def toSparkSplits( |
| 45 | + shardingKey: Option[Expr], |
| 46 | + partitionKey: Option[List[Expr]], |
| 47 | + functionRegistry: FunctionRegistry |
| 48 | + ): Array[Transform] = |
| 49 | + (shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_, functionRegistry)).toArray |
37 | 50 |
|
38 | 51 | def toSparkSortOrders( |
39 | 52 | shardingKeyIgnoreRand: Option[Expr], |
40 | 53 | partitionKey: Option[List[Expr]], |
41 | | - sortingKey: Option[List[OrderExpr]] |
42 | | - ): Array[SortOrder] = |
43 | | - toSparkSplits(shardingKeyIgnoreRand, partitionKey).map(Expressions.sort(_, SortDirection.ASCENDING)) ++: |
| 54 | + sortingKey: Option[List[OrderExpr]], |
| 55 | + cluster: Option[ClusterSpec], |
| 56 | + functionRegistry: FunctionRegistry |
| 57 | + ): Array[V2SortOrder] = |
| 58 | + toSparkSplits( |
| 59 | + shardingKeyIgnoreRand, |
| 60 | + partitionKey, |
| 61 | + functionRegistry |
| 62 | + ).map(Expressions.sort(_, SortDirection.ASCENDING)) ++: |
44 | 63 | sortingKey.seq.flatten.flatten { case OrderExpr(expr, asc, nullFirst) => |
45 | 64 | val direction = if (asc) SortDirection.ASCENDING else SortDirection.DESCENDING |
46 | 65 | val nullOrder = if (nullFirst) NullOrdering.NULLS_FIRST else NullOrdering.NULLS_LAST |
47 | | - toSparkTransformOpt(expr).map(trans => Expressions.sort(trans, direction, nullOrder)) |
| 66 | + toSparkTransformOpt(expr, functionRegistry).map(trans => |
| 67 | + Expressions.sort(trans, direction, nullOrder) |
| 68 | + ) |
48 | 69 | }.toArray |
49 | 70 |
|
50 | | - @tailrec |
51 | | - def toCatalyst(v2Expr: V2Expression, fields: Array[StructField]): Expression = |
| 71 | + private def loadV2FunctionOpt( |
| 72 | + name: String, |
| 73 | + args: Seq[Expression], |
| 74 | + functionRegistry: FunctionRegistry |
| 75 | + ): Option[BoundFunction] = { |
| 76 | + def loadFunction(ident: Identifier): UnboundFunction = |
| 77 | + functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident)) |
| 78 | + val inputType = StructType(args.zipWithIndex.map { |
| 79 | + case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable) |
| 80 | + }) |
| 81 | + try { |
| 82 | + val unbound = loadFunction(Identifier.of(Array.empty, name)) |
| 83 | + Some(unbound.bind(inputType)) |
| 84 | + } catch { |
| 85 | + case e: NoSuchFunctionException => |
| 86 | + throw e |
| 87 | + case _: UnsupportedOperationException if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => |
| 88 | + None |
| 89 | + case e: UnsupportedOperationException => |
| 90 | + throw new AnalysisException(e.getMessage, cause = Some(e)) |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + def resolveTransformCatalyst( |
| 95 | + catalystExpr: Expression, |
| 96 | + timeZoneId: Option[String] = None |
| 97 | + ): Expression = catalystExpr match { |
| 98 | + case TransformExpression(function: ScalarFunction[_], args, _) => |
| 99 | + val resolvedArgs: Seq[Expression] = args.map(resolveTransformCatalyst(_, timeZoneId)) |
| 100 | + val castedArgs: Seq[Expression] = resolvedArgs.zip(function.inputTypes()).map { |
| 101 | + case (arg, expectedType) if !arg.dataType.sameType(expectedType) => Cast(arg, expectedType, timeZoneId) |
| 102 | + case (arg, _) => arg |
| 103 | + } |
| 104 | + V2ExpressionUtils.resolveScalarFunction(function, castedArgs) |
| 105 | + case other => other |
| 106 | + } |
| 107 | + |
| 108 | + def toCatalyst( |
| 109 | + v2Expr: V2Expression, |
| 110 | + fields: Array[StructField], |
| 111 | + functionRegistry: FunctionRegistry |
| 112 | + ): Expression = |
52 | 113 | v2Expr match { |
53 | | - case IdentityTransform(ref) => toCatalyst(ref, fields) |
| 114 | + case IdentityTransform(ref) => toCatalyst(ref, fields, functionRegistry) |
54 | 115 | case ref: NamedReference if ref.fieldNames.length == 1 => |
55 | 116 | val (field, ordinal) = fields |
56 | 117 | .zipWithIndex |
57 | 118 | .find { case (field, _) => field.name == ref.fieldNames.head } |
58 | 119 | .getOrElse(throw CHClientException(s"Invalid field reference: $ref")) |
59 | 120 | BoundReference(ordinal, field.dataType, field.nullable) |
| 121 | + case t: Transform => |
| 122 | + val catalystArgs = t.arguments().map(toCatalyst(_, fields, functionRegistry)) |
| 123 | + loadV2FunctionOpt(t.name(), catalystArgs, functionRegistry) |
| 124 | + .map(bound => TransformExpression(bound, catalystArgs)).getOrElse { |
| 125 | + throw CHClientException(s"Unsupported expression: $v2Expr") |
| 126 | + } |
| 127 | + case literal: LiteralValue[Any] => expressions.Literal(literal.value) |
60 | 128 | case _ => throw CHClientException( |
61 | | - s"Unsupported V2 expression: $v2Expr, SPARK-33779: Spark 3.3 only support IdentityTransform" |
| 129 | + s"Unsupported expression: $v2Expr" |
62 | 130 | ) |
63 | 131 | } |
64 | 132 |
|
65 | | - def toSparkTransformOpt(expr: Expr): Option[Transform] = Try(toSparkTransform(expr)) match { |
66 | | - case Success(t) => Some(t) |
67 | | - case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None |
68 | | - case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow)) |
69 | | - } |
70 | | - |
71 | | - // Some functions of ClickHouse which match Spark pre-defined Transforms |
72 | | - // |
73 | | - // toYear, YEAR - Converts a date or date with time to a UInt16 (AD) |
74 | | - // toYYYYMM - Converts a date or date with time to a UInt32 (YYYY*100 + MM) |
75 | | - // toYYYYMMDD - Converts a date or date with time to a UInt32 (YYYY*10000 + MM*100 + DD) |
76 | | - // toHour, HOUR - Converts a date with time to a UInt8 (0-23) |
| 133 | + def toSparkTransformOpt(expr: Expr, functionRegistry: FunctionRegistry): Option[Transform] = |
| 134 | + Try(toSparkExpression(expr, functionRegistry)) match { |
| 135 | + // need this function because spark `Table`'s `partitioning` field should be `Transform` |
| 136 | + case Success(t: Transform) => Some(t) |
| 137 | + case Success(_) => None |
| 138 | + case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None |
| 139 | + case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow)) |
| 140 | + } |
77 | 141 |
|
78 | | - def toSparkTransform(expr: Expr): Transform = expr match { |
79 | | - case FieldRef(col) => identity(col) |
80 | | - case FuncExpr("toYear", List(FieldRef(col))) => years(col) |
81 | | - case FuncExpr("YEAR", List(FieldRef(col))) => years(col) |
82 | | - case FuncExpr("toYYYYMM", List(FieldRef(col))) => months(col) |
83 | | - case FuncExpr("toYYYYMMDD", List(FieldRef(col))) => days(col) |
84 | | - case FuncExpr("toHour", List(FieldRef(col))) => hours(col) |
85 | | - case FuncExpr("HOUR", List(FieldRef(col))) => hours(col) |
86 | | - // TODO support arbitrary functions |
87 | | - // case FuncExpr("xxHash64", List(FieldRef(col))) => apply("ck_xx_hash64", column(col)) |
88 | | - case FuncExpr("rand", Nil) => apply("rand") |
89 | | - case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col) |
90 | | - case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported") |
91 | | - } |
| 142 | + def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression = |
| 143 | + expr match { |
| 144 | + case FieldRef(col) => identity(col) |
| 145 | + case StringLiteral(value) => literal(value) |
| 146 | + case FuncExpr("rand", Nil) => apply("rand") |
| 147 | + case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col) |
| 148 | + case FuncExpr(funName, args) if functionRegistry.clickHouseToSparkFunc.contains(funName) => |
| 149 | + apply(functionRegistry.clickHouseToSparkFunc(funName), args.map(toSparkExpression(_, functionRegistry)): _*) |
| 150 | + case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported") |
| 151 | + } |
92 | 152 |
|
93 | | - def toClickHouse(transform: Transform): Expr = transform match { |
94 | | - case YearsTransform(FieldReference(Seq(col))) => FuncExpr("toYear", List(FieldRef(col))) |
95 | | - case MonthsTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMM", List(FieldRef(col))) |
96 | | - case DaysTransform(FieldReference(Seq(col))) => FuncExpr("toYYYYMMDD", List(FieldRef(col))) |
97 | | - case HoursTransform(FieldReference(Seq(col))) => FuncExpr("toHour", List(FieldRef(col))) |
| 153 | + def toClickHouse( |
| 154 | + transform: Transform, |
| 155 | + functionRegistry: FunctionRegistry |
| 156 | + ): Expr = transform match { |
98 | 157 | case IdentityTransform(fieldRefs) => FieldRef(fieldRefs.describe) |
99 | | - case ApplyTransform(name, args) => FuncExpr(name, args.map(arg => SQLExpr(arg.describe())).toList) |
| 158 | + case ApplyTransform(name, args) if functionRegistry.sparkToClickHouseFunc.contains(name) => |
| 159 | + FuncExpr(functionRegistry.sparkToClickHouseFunc(name), args.map(arg => SQLExpr(arg.describe())).toList) |
100 | 160 | case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket") |
101 | 161 | case other: Transform => throw CHClientException(s"Unsupported transform: $other") |
102 | 162 | } |
103 | 163 |
|
104 | 164 | def inferTransformSchema( |
105 | 165 | primarySchema: StructType, |
106 | 166 | secondarySchema: StructType, |
107 | | - transform: Transform |
| 167 | + transform: Transform, |
| 168 | + functionRegistry: FunctionRegistry |
108 | 169 | ): StructField = transform match { |
109 | | - case years: YearsTransform => StructField(years.toString, IntegerType) |
110 | | - case months: MonthsTransform => StructField(months.toString, IntegerType) |
111 | | - case days: DaysTransform => StructField(days.toString, IntegerType) |
112 | | - case hours: HoursTransform => StructField(hours.toString, IntegerType) |
113 | 170 | case IdentityTransform(FieldReference(Seq(col))) => primarySchema.find(_.name == col) |
114 | 171 | .orElse(secondarySchema.find(_.name == col)) |
115 | 172 | .getOrElse(throw CHClientException(s"Invalid partition column: $col")) |
116 | | - case ckXxhHash64 @ ApplyTransform("ck_xx_hash64", _) => StructField(ckXxhHash64.toString, LongType) |
| 173 | + case t @ ApplyTransform(transformName, _) if functionRegistry.load(transformName).isDefined => |
| 174 | + val resType = |
| 175 | + functionRegistry.load(transformName).getOrElse(throw new NoSuchFunctionException(transformName)) match { |
| 176 | + case f: ScalarFunction[_] => f.resultType() |
| 177 | + case other => throw CHClientException(s"Unsupported function: $other") |
| 178 | + } |
| 179 | + StructField(t.toString, resType) |
117 | 180 | case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket") |
118 | 181 | case other: Transform => throw CHClientException(s"Unsupported transform: $other") |
119 | 182 | } |
|
0 commit comments