Skip to content

Commit 3dcdd81

Browse files
committed
Spark 3.4: Change ExprUtils to implicit
1 parent a8bdcbf commit 3dcdd81

File tree

6 files changed

+50
-43
lines changed

6 files changed

+50
-43
lines changed

spark-3.4/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClusterShardByTransformSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class ClusterShardByTransformSuite extends SparkClickHouseClusterTest {
106106
("cityHash64", Array("value"))
107107
).foreach {
108108
case (func_name: String, func_args: Array[String]) =>
109-
test(s"shard by $func_name")(runTest(func_name, func_args))
109+
test(s"shard by $func_name(${func_args.mkString(",")})")(runTest(func_name, func_args))
110110
}
111111

112112
}

spark-3.4/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ExprUtils.scala

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,34 +31,40 @@ import xenon.clickhouse.spec.ClusterSpec
3131

3232
import scala.util.{Failure, Success, Try}
3333

34-
class ExprUtils(functionRegistry: FunctionRegistry) extends SQLConfHelper with Serializable {
34+
object ExprUtils extends SQLConfHelper with Serializable {
3535

36-
def toSparkPartitions(partitionKey: Option[List[Expr]]): Array[Transform] =
37-
partitionKey.seq.flatten.flatten(toSparkTransformOpt).toArray
36+
def toSparkPartitions(partitionKey: Option[List[Expr]])(implicit
37+
functionRegistry: FunctionRegistry
38+
): Array[Transform] =
39+
partitionKey.seq.flatten.flatten(toSparkTransformOpt(_)).toArray
3840

39-
def toSparkSplits(shardingKey: Option[Expr], partitionKey: Option[List[Expr]]): Array[Transform] =
40-
(shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt).toArray
41+
def toSparkSplits(shardingKey: Option[Expr], partitionKey: Option[List[Expr]])(implicit
42+
functionRegistry: FunctionRegistry
43+
): Array[Transform] =
44+
(shardingKey.seq ++ partitionKey.seq.flatten).flatten(toSparkTransformOpt(_)).toArray
4145

4246
def toSparkSortOrders(
4347
shardingKeyIgnoreRand: Option[Expr],
4448
partitionKey: Option[List[Expr]],
4549
sortingKey: Option[List[OrderExpr]],
4650
cluster: Option[ClusterSpec]
47-
): Array[SortOrder] =
51+
)(implicit functionRegistry: FunctionRegistry): Array[SortOrder] =
4852
toSparkSplits(
4953
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight)),
5054
partitionKey
5155
).map(Expressions.sort(_, SortDirection.ASCENDING)) ++:
5256
sortingKey.seq.flatten.flatten { case OrderExpr(expr, asc, nullFirst) =>
5357
val direction = if (asc) SortDirection.ASCENDING else SortDirection.DESCENDING
5458
val nullOrder = if (nullFirst) NullOrdering.NULLS_FIRST else NullOrdering.NULLS_LAST
55-
toSparkTransformOpt(expr).map(trans => Expressions.sort(trans, direction, nullOrder))
59+
toSparkTransformOpt(expr).map(trans =>
60+
Expressions.sort(trans, direction, nullOrder)
61+
)
5662
}.toArray
5763

5864
private def loadV2FunctionOpt(
5965
name: String,
6066
args: Seq[Expression]
61-
): Option[BoundFunction] = {
67+
)(implicit functionRegistry: FunctionRegistry): Option[BoundFunction] = {
6268
def loadFunction(ident: Identifier): UnboundFunction =
6369
functionRegistry.load(ident.name).getOrElse(throw new NoSuchFunctionException(ident))
6470
val inputType = StructType(args.zipWithIndex.map {
@@ -77,7 +83,10 @@ class ExprUtils(functionRegistry: FunctionRegistry) extends SQLConfHelper with S
7783
}
7884
}
7985

80-
def toCatalyst(v2Expr: V2Expression, fields: Array[StructField]): Expression =
86+
def toCatalyst(
87+
v2Expr: V2Expression,
88+
fields: Array[StructField]
89+
)(implicit functionRegistry: FunctionRegistry): Expression =
8190
v2Expr match {
8291
case IdentityTransform(ref) => toCatalyst(ref, fields)
8392
case ref: NamedReference if ref.fieldNames.length == 1 =>
@@ -88,35 +97,35 @@ class ExprUtils(functionRegistry: FunctionRegistry) extends SQLConfHelper with S
8897
BoundReference(ordinal, field.dataType, field.nullable)
8998
case t: Transform =>
9099
val catalystArgs = t.arguments().map(toCatalyst(_, fields))
91-
loadV2FunctionOpt(t.name(), catalystArgs).map { bound =>
92-
TransformExpression(bound, catalystArgs)
93-
}.getOrElse {
100+
loadV2FunctionOpt(t.name(), catalystArgs).map(bound => TransformExpression(bound, catalystArgs)).getOrElse {
94101
throw CHClientException(s"Unsupported expression: $v2Expr")
95102
}
96103
case _ => throw CHClientException(
97104
s"Unsupported expression: $v2Expr"
98105
)
99106
}
100107

101-
def toSparkTransformOpt(expr: Expr): Option[Transform] = Try(toSparkExpression(expr)) match {
102-
// need this function because spark `Table`'s `partitioning` field should be `Transform`
103-
case Success(t: Transform) => Some(t)
104-
case Success(_) => None
105-
case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None
106-
case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow))
107-
}
108+
def toSparkTransformOpt(expr: Expr)(implicit functionRegistry: FunctionRegistry): Option[Transform] =
109+
Try(toSparkExpression(expr)) match {
110+
// need this function because spark `Table`'s `partitioning` field should be `Transform`
111+
case Success(t: Transform) => Some(t)
112+
case Success(_) => None
113+
case Failure(_) if conf.getConf(IGNORE_UNSUPPORTED_TRANSFORM) => None
114+
case Failure(rethrow) => throw new AnalysisException(rethrow.getMessage, cause = Some(rethrow))
115+
}
108116

109-
def toSparkExpression(expr: Expr): V2Expression = expr match {
110-
case FieldRef(col) => identity(col)
111-
case StringLiteral(value) => literal(value)
112-
case FuncExpr("rand", Nil) => apply("rand")
113-
case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col)
114-
case FuncExpr(funName, args) if functionRegistry.getFuncMappingByCk.contains(funName) =>
115-
apply(functionRegistry.getFuncMappingByCk(funName), args.map(toSparkExpression): _*)
116-
case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported")
117-
}
117+
def toSparkExpression(expr: Expr)(implicit functionRegistry: FunctionRegistry): V2Expression =
118+
expr match {
119+
case FieldRef(col) => identity(col)
120+
case StringLiteral(value) => literal(value)
121+
case FuncExpr("rand", Nil) => apply("rand")
122+
case FuncExpr("toYYYYMMDD", List(FuncExpr("toDate", List(FieldRef(col))))) => identity(col)
123+
case FuncExpr(funName, args) if functionRegistry.getFuncMappingByCk.contains(funName) =>
124+
apply(functionRegistry.getFuncMappingByCk(funName), args.map(toSparkExpression): _*)
125+
case unsupported => throw CHClientException(s"Unsupported ClickHouse expression: $unsupported")
126+
}
118127

119-
def toClickHouse(transform: Transform): Expr = transform match {
128+
def toClickHouse(transform: Transform)(implicit functionRegistry: FunctionRegistry): Expr = transform match {
120129
case IdentityTransform(fieldRefs) => FieldRef(fieldRefs.describe)
121130
case ApplyTransform(name, args) if functionRegistry.getFuncMappingBySpark.contains(name) =>
122131
FuncExpr(functionRegistry.getFuncMappingBySpark(name), args.map(arg => SQLExpr(arg.describe())).toList)
@@ -128,7 +137,7 @@ class ExprUtils(functionRegistry: FunctionRegistry) extends SQLConfHelper with S
128137
primarySchema: StructType,
129138
secondarySchema: StructType,
130139
transform: Transform
131-
): StructField = transform match {
140+
)(implicit functionRegistry: FunctionRegistry): StructField = transform match {
132141
case IdentityTransform(FieldReference(Seq(col))) => primarySchema.find(_.name == col)
133142
.orElse(secondarySchema.find(_.name == col))
134143
.getOrElse(throw CHClientException(s"Invalid partition column: $col"))
@@ -142,10 +151,6 @@ class ExprUtils(functionRegistry: FunctionRegistry) extends SQLConfHelper with S
142151
case bucket: BucketTransform => throw CHClientException(s"Bucket transform not support yet: $bucket")
143152
case other: Transform => throw CHClientException(s"Unsupported transform: $other")
144153
}
145-
}
146-
147-
object ExprUtils {
148-
def apply(functionRegistry: FunctionRegistry): ExprUtils = new ExprUtils(functionRegistry)
149154

150155
def toSplitWithModulo(shardingKey: Expr, weight: Int): FuncExpr =
151156
FuncExpr("positiveModulo", List(shardingKey, StringLiteral(weight.toString)))

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class ClickHouseCatalog extends TableCatalog
209209

210210
val partitionsClause = partitions match {
211211
case transforms if transforms.nonEmpty =>
212-
transforms.map(ExprUtils(functionRegistry).toClickHouse(_).sql).mkString("PARTITION BY (", ", ", ")")
212+
transforms.map(ExprUtils.toClickHouse(_)(functionRegistry).sql).mkString("PARTITION BY (", ", ", ")")
213213
case _ => ""
214214
}
215215

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/ClickHouseTable.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,11 @@ case class ClickHouseTable(
133133
private lazy val metadataSchema: StructType =
134134
StructType(metadataColumns.map(_.asInstanceOf[ClickHouseMetadataColumn].toStructField))
135135

136-
override lazy val partitioning: Array[Transform] = ExprUtils(functionRegistry).toSparkPartitions(partitionKey)
136+
override lazy val partitioning: Array[Transform] = ExprUtils.toSparkPartitions(partitionKey)(functionRegistry)
137137

138138
override lazy val partitionSchema: StructType = StructType(
139139
partitioning.map(partTransform =>
140-
ExprUtils(functionRegistry).inferTransformSchema(schema, metadataSchema, partTransform)
140+
ExprUtils.inferTransformSchema(schema, metadataSchema, partTransform)(functionRegistry)
141141
)
142142
)
143143

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/ClickHouseWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ abstract class ClickHouseWriter(writeJob: WriteJobDescription)
6363
protected lazy val shardExpr: Option[Expression] = writeJob.sparkShardExpr match {
6464
case None => None
6565
case Some(v2Expr) =>
66-
val catalystExpr = ExprUtils(writeJob.functionRegistry).toCatalyst(v2Expr, writeJob.dataSetSchema.fields)
66+
val catalystExpr = ExprUtils.toCatalyst(v2Expr, writeJob.dataSetSchema.fields)(writeJob.functionRegistry)
6767
catalystExpr match {
6868
case BoundReference(_, dataType, _)
6969
if dataType.isInstanceOf[ByteType] // list all integral types here because we can not access `IntegralType`

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/write/WriteJobDescription.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ case class WriteJobDescription(
4141
functionRegistry: FunctionRegistry
4242
) {
4343

44+
implicit val _functionRegistry: FunctionRegistry = functionRegistry
45+
4446
def targetDatabase(convert2Local: Boolean): String = tableEngineSpec match {
4547
case dist: DistributedEngineSpec if convert2Local => dist.local_db
4648
case _ => tableSpec.database
@@ -57,7 +59,7 @@ case class WriteJobDescription(
5759
}
5860

5961
def sparkShardExpr: Option[Expression] = shardingKeyIgnoreRand match {
60-
case Some(expr) => ExprUtils(functionRegistry).toSparkTransformOpt(expr)
62+
case Some(expr) => ExprUtils.toSparkTransformOpt(expr)
6163
case _ => None
6264
}
6365

@@ -69,12 +71,12 @@ case class WriteJobDescription(
6971
// need to apply module during sorting in `toSparkSortOrders`), data belongs to shard 1 will be sorted in the
7072
// front for all tasks, resulting in instant high pressure for shard 1 when stage starts.
7173
if (writeOptions.repartitionByPartition) {
72-
ExprUtils(functionRegistry).toSparkSplits(
74+
ExprUtils.toSparkSplits(
7375
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight * 5)),
7476
partitionKey
7577
)
7678
} else {
77-
ExprUtils(functionRegistry).toSparkSplits(
79+
ExprUtils.toSparkSplits(
7880
shardingKeyIgnoreRand.map(k => ExprUtils.toSplitWithModulo(k, cluster.get.totalWeight * 5)),
7981
None
8082
)
@@ -83,6 +85,6 @@ case class WriteJobDescription(
8385
def sparkSortOrders: Array[SortOrder] = {
8486
val _partitionKey = if (writeOptions.localSortByPartition) partitionKey else None
8587
val _sortingKey = if (writeOptions.localSortByKey) sortingKey else None
86-
ExprUtils(functionRegistry).toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey, cluster)
88+
ExprUtils.toSparkSortOrders(shardingKeyIgnoreRand, _partitionKey, _sortingKey, cluster)
8789
}
8890
}

0 commit comments

Comments
 (0)