Skip to content

Commit e5809f7

Browse files
committed
Spark 3.4: refactor implicit into normal arg in ExprUtils
1 parent 286c21f commit e5809f7

File tree

5 files changed

+46
-34
lines changed

5 files changed

+46
-34
lines changed

spark-3.4/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,38 +33,44 @@ import scala.util.{Failure, Success, Try}
3333

3434
object ExprUtils extends SQLConfHelper with Serializable {
3535

36-
def toSparkPartitions(partitionKey: Option[List[Expr]])(implicit
36+
def toSparkPartitions(
37+
partitionKey: Option[List[Expr]],
3738
functionRegistry: FunctionRegistry
3839
): Array[Transform] =
39-
partitionKey.seq.flatten.flatten(toSparkTransformOpt(_)).toArray
40+
partitionKey.seq.flatten.flatten(toSparkTransformOpt(_, functionRegistry)).toArray
4041

41-
def toSparkSplits(shardingKey: Option[Expr], partitionKey: Option[List[Expr]])(implicit
42+
def toSparkSplits(
43+
shardingKey: Option[Expr],
44+
partitionKey: Option[List[Expr]],
4245
functionRegistry: FunctionRegistry
4346
): Array[Transform] =
44-
(shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_)).toArray
47+
(shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_, functionRegistry)).toArray
4548

4649
def toSparkSortOrders(
4750
shardingKeyIgnoreRand: Option[Expr],
4851
partitionKey: Option[List[Expr]],
4952
sortingKey: Option[List[OrderExpr]],
50-
cluster: Option[ClusterSpec]
51-
)(implicit functionRegistry: FunctionRegistry): Array[SortOrder] =
53+
cluster: Option[ClusterSpec],
54+
functionRegistry: FunctionRegistry
55+
): Array[SortOrder] =
5256
toSparkSplits(
5357
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight)),
54-
partitionKey
58+
partitionKey,
59+
functionRegistry
5560
).map(Expressions.sort(_, SortDirection.ASCENDING)) ++:
5661
sortingKey.seq.flatten.flatten { case OrderExpr(expr, asc, nullFirst) =>
5762
val direction = if (asc) SortDirection.ASCENDING else SortDirection.DESCENDING
5863
val nullOrder = if (nullFirst) NullOrdering.NULLS_FIRST else NullOrdering.NULLS_LAST
59-
toSparkTransformOpt(expr).map(trans =>
64+
toSparkTransformOpt(expr, functionRegistry).map(trans =>
6065
Expressions.sort(trans, direction, nullOrder)
6166
)
6267
}.toArray
6368

6469
private def loadV2FunctionOpt(
6570
name: String,
66-
args: Seq[Expression]
67-
)(implicit functionRegistry: FunctionRegistry): Option[BoundFunction] = {
71+
args: Seq[Expression],
72+
functionRegistry: FunctionRegistry
73+
): Option[BoundFunction] = {
6874
def loadFunction(ident: Identifier): UnboundFunction =
6975
functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident))
7076
val inputType = StructType(args.zipWithIndex.map {
@@ -85,47 +91,52 @@ object ExprUtils extends SQLConfHelper with Serializable {
8591

8692
def toCatalyst(
8793
v2Expr: V2Expression,
88-
fields: Array[StructField]
89-
)(implicit functionRegistry: FunctionRegistry): Expression =
94+
fields: Array[StructField],
95+
functionRegistry: FunctionRegistry
96+
): Expression =
9097
v2Expr match {
91-
case IdentityTransform(ref) => toCatalyst(ref, fields)
98+
case IdentityTransform(ref) => toCatalyst(ref, fields, functionRegistry)
9299
case ref: NamedReference if ref.fieldNames.length == 1 =>
93100
val (field, ordinal) = fields
94101
.zipWithIndex
95102
.find { case (field, _) => field.name == ref.fieldNames.head }
96103
.getOrElse(throw CHClientException(s"Invalid field reference: $ref"))
97104
BoundReference(ordinal, field.dataType, field.nullable)
98105
case t: Transform =>
99-
val catalystArgs = t.arguments().map(toCatalyst(_, fields))
100-
loadV2FunctionOpt(t.name(), catalystArgs).map(bound => TransformExpression(bound, catalystArgs)).getOrElse {
101-
throw CHClientException(s"Unsupported expression: $v2Expr")
102-
}
106+
val catalystArgs = t.arguments().map(toCatalyst(_, fields, functionRegistry))
107+
loadV2FunctionOpt(t.name(), catalystArgs, functionRegistry)
108+
.map(bound => TransformExpression(bound, catalystArgs)).getOrElse {
109+
throw CHClientException(s"Unsupported expression: $v2Expr")
110+
}
103111
case _ => throw CHClientException(
104112
s"Unsupported expression: $v2Expr"
105113
)
106114
}
107115

108-
def toSparkTransformOpt(expr: Expr)(implicit functionRegistry: FunctionRegistry): Option[Transform] =
109-
Try(toSparkExpression(expr)) match {
116+
def toSparkTransformOpt(expr: Expr, functionRegistry: FunctionRegistry): Option[Transform] =
117+
Try(toSparkExpression(expr, functionRegistry)) match {
110118
// need this function because spark `Table`'s `partitioning` field should be `Transform`
111119
case Success(t: Transform) => Some(t)
112120
case Success(_) => None
113121
case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None
114122
case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow))
115123
}
116124

117-
def toSparkExpression(expr: Expr)(implicit functionRegistry: FunctionRegistry): V2Expression =
125+
def toSparkExpression(expr: Expr, functionRegistry: FunctionRegistry): V2Expression =
118126
expr match {
119127
case FieldRef(col) => identity(col)
120128
case StringLiteral(value) => literal(value)
121129
case FuncExpr("rand", Nil) => apply("rand")
122130
case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col)
123131
case FuncExpr(funName, args) if functionRegistry.getFuncMappingByCk.contains(funName) =>
124-
apply(functionRegistry.getFuncMappingByCk(funName), args.map(toSparkExpression): _*)
132+
apply(functionRegistry.getFuncMappingByCk(funName), args.map(toSparkExpression(_, functionRegistry)): _*)
125133
case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported")
126134
}
127135

128-
def toClickHouse(transform: Transform)(implicit functionRegistry: FunctionRegistry): Expr = transform match {
136+
def toClickHouse(
137+
transform: Transform,
138+
functionRegistry: FunctionRegistry
139+
): Expr = transform match {
129140
case IdentityTransform(fieldRefs) => FieldRef(fieldRefs.describe)
130141
case ApplyTransform(name, args) if functionRegistry.getFuncMappingBySpark.contains(name) =>
131142
FuncExpr(functionRegistry.getFuncMappingBySpark(name), args.map(arg => SQLExpr(arg.describe())).toList)
@@ -136,8 +147,9 @@ object ExprUtils extends SQLConfHelper with Serializable {
136147
def inferTransformSchema(
137148
primarySchema: StructType,
138149
secondarySchema: StructType,
139-
transform: Transform
140-
)(implicit functionRegistry: FunctionRegistry): StructField = transform match {
150+
transform: Transform,
151+
functionRegistry: FunctionRegistry
152+
): StructField = transform match {
141153
case IdentityTransform(FieldReference(Seq(col))) => primarySchema.find(_.name == col)
142154
.orElse(secondarySchema.find(_.name == col))
143155
.getOrElse(throw CHClientException(s"Invalid partition column: $col"))

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class ClickHouseCatalog extends TableCatalog
209209

210210
val partitionsClause = partitions match {
211211
case transforms if transforms.nonEmpty =>
212-
transforms.map(ExprUtils.toClickHouse(_)(functionRegistry).sql).mkString("PARTITION BY (", ", ", ")")
212+
transforms.map(ExprUtils.toClickHouse(_, functionRegistry).sql).mkString("PARTITION BY (", ", ", ")")
213213
case _ => ""
214214
}
215215

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseTable.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ case class ClickHouseTable(
133133
private lazy val metadataSchema: StructType =
134134
StructType(metadataColumns.map(_.asInstanceOf[ClickHouseMetadataColumn].toStructField))
135135

136-
override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey)(functionRegistry)
136+
override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey, functionRegistry)
137137

138138
override lazy val partitionSchema: StructType = StructType(
139139
partitioning.map(partTransform =>
140-
ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform)(functionRegistry)
140+
ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform, functionRegistry)
141141
)
142142
)
143143

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/ClickHouseWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription)
6363
protected lazy val shardExpr: Option[Expression] = writeJob.sparkShardExpr match {
6464
case None => None
6565
case Some(v2Expr) =>
66-
val catalystExpr = ExprUtils.toCatalyst(v2Expr, writeJob.dataSetSchema.fields)(writeJob.functionRegistry)
66+
val catalystExpr = ExprUtils.toCatalyst(v2Expr, writeJob.dataSetSchema.fields, writeJob.functionRegistry)
6767
catalystExpr match {
6868
case BoundReference(_, dataType, _)
6969
if dataType.isInstanceOf[ByteType] // list all integral types here because we can not access `IntegralType`

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/WriteJobDescription.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ case class WriteJobDescription(
4141
functionRegistry: FunctionRegistry
4242
) {
4343

44-
implicit val _functionRegistry: FunctionRegistry = functionRegistry
45-
4644
def targetDatabase(convert2Local: Boolean): String = tableEngineSpec match {
4745
case dist: DistributedEngineSpec if convert2Local => dist.local_db
4846
case _ => tableSpec.database
@@ -59,7 +57,7 @@ case class WriteJobDescription(
5957
}
6058

6159
def sparkShardExpr: Option[Expression] = shardingKeyIgnoreRand match {
62-
case Some(expr) => ExprUtils.toSparkTransformOpt(expr)
60+
case Some(expr) => ExprUtils.toSparkTransformOpt(expr, functionRegistry)
6361
case _ => None
6462
}
6563

@@ -73,18 +71,20 @@ case class WriteJobDescription(
7371
if (writeOptions.repartitionByPartition) {
7472
ExprUtils.toSparkSplits(
7573
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight * 5)),
76-
partitionKey
74+
partitionKey,
75+
functionRegistry
7776
)
7877
} else {
7978
ExprUtils.toSparkSplits(
8079
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight * 5)),
81-
None
80+
None,
81+
functionRegistry
8282
)
8383
}
8484

8585
def sparkSortOrders: Array[SortOrder] = {
8686
val _partitionKey = if (writeOptions.localSortByPartition) partitionKey else None
8787
val _sortingKey = if (writeOptions.localSortByKey) sortingKey else None
88-
ExprUtils.toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey, cluster)
88+
ExprUtils.toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey, cluster, functionRegistry)
8989
}
9090
}

0 commit comments

Comments
 (0)