Skip to content

Commit 1e92fab

Browse files
committed
Spark 3.4 UDF: support varargs for Hash UDFs
1 parent a7a4c03 commit 1e92fab

File tree

1 file changed

+17
-68
lines changed

1 file changed

+17
-68
lines changed

spark-3.4/clickhouse-spark/src/main/scala/xenon/clickhouse/func/MultiArgsHash.scala

Lines changed: 17 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,83 +14,32 @@
1414

1515
package xenon.clickhouse.func
1616

17+
import org.apache.spark.sql.catalyst.InternalRow
1718
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction}
1819
import org.apache.spark.sql.types._
1920
import org.apache.spark.unsafe.types.UTF8String
2021

2122
abstract class MultiArgsHash extends UnboundFunction with ClickhouseEquivFunction {
22-
trait Base extends ScalarFunction[Long] {
23-
// must not be private object, nor do it successors, because spark would compile them
24-
override def canonicalName: String = s"clickhouse.$name"
25-
override def resultType: DataType = LongType
26-
override def toString: String = name
27-
}
28-
29-
object Arg1 extends Base {
30-
override def name: String = s"${funcName}_1"
31-
override def inputTypes: Array[DataType] = Array.fill(1)(StringType)
32-
def invoke(value: UTF8String): Long = invokeBase(value)
33-
}
34-
35-
object Arg2 extends Base {
36-
override def name: String = s"${funcName}_2"
37-
override def inputTypes: Array[DataType] = Array.fill(2)(StringType)
38-
def invoke(v1: UTF8String, v2: UTF8String): Long = Seq(v1, v2).map(invokeBase).reduce(combineHashes)
39-
}
40-
41-
object Arg3 extends Base {
42-
override def name: String = s"${funcName}_3"
43-
override def inputTypes: Array[DataType] = Array.fill(3)(StringType)
44-
def invoke(v1: UTF8String, v2: UTF8String, v3: UTF8String): Long =
45-
Seq(v1, v2, v3).map(invokeBase).reduce(combineHashes)
46-
}
47-
48-
object Arg4 extends Base {
49-
override def name: String = s"${funcName}_4"
50-
override def inputTypes: Array[DataType] = Array.fill(4)(StringType)
51-
def invoke(v1: UTF8String, v2: UTF8String, v3: UTF8String, v4: UTF8String): Long =
52-
Seq(v1, v2, v3, v4).map(invokeBase).reduce(combineHashes)
53-
}
54-
55-
object Arg5 extends Base {
56-
override def name: String = s"${funcName}_4"
57-
override def inputTypes: Array[DataType] = Array.fill(5)(StringType)
58-
def invoke(v1: UTF8String, v2: UTF8String, v3: UTF8String, v4: UTF8String, v5: UTF8String): Long =
59-
Seq(v1, v2, v3, v4, v5).map(invokeBase).reduce(combineHashes)
60-
}
6123
private def isExceptedType(dt: DataType): Boolean =
6224
dt.isInstanceOf[StringType]
6325

6426
final override def name: String = funcName
65-
final override def bind(inputType: StructType): BoundFunction = inputType.fields match {
66-
case Array(StructField(_, dt, _, _)) if List(dt).forall(isExceptedType) => this.Arg1
67-
case Array(
68-
StructField(_, dt1, _, _),
69-
StructField(_, dt2, _, _)
70-
) if List(dt1, dt2).forall(isExceptedType) =>
71-
this.Arg2
72-
case Array(
73-
StructField(_, dt1, _, _),
74-
StructField(_, dt2, _, _),
75-
StructField(_, dt3, _, _)
76-
) if List(dt1, dt2, dt3).forall(isExceptedType) =>
77-
this.Arg3
78-
case Array(
79-
StructField(_, dt1, _, _),
80-
StructField(_, dt2, _, _),
81-
StructField(_, dt3, _, _),
82-
StructField(_, dt4, _, _)
83-
) if List(dt1, dt2, dt3, dt4).forall(isExceptedType) =>
84-
this.Arg4
85-
case Array(
86-
StructField(_, dt1, _, _),
87-
StructField(_, dt2, _, _),
88-
StructField(_, dt3, _, _),
89-
StructField(_, dt4, _, _),
90-
StructField(_, dt5, _, _)
91-
) if List(dt1, dt2, dt3, dt4, dt5).forall(isExceptedType) =>
92-
this.Arg5
93-
case _ => throw new UnsupportedOperationException(s"Expect up to 5 STRING argument. $description")
27+
final override def bind(inputType: StructType): BoundFunction = {
28+
val inputDataTypes = inputType.fields.map(_.dataType)
29+
if (inputDataTypes.forall(isExceptedType)) new ScalarFunction[Long] {
30+
override def inputTypes(): Array[DataType] = inputDataTypes
31+
override def name: String = funcName
32+
override def canonicalName: String = s"clickhouse.$name"
33+
override def resultType: DataType = LongType
34+
override def toString: String = name
35+
override def produceResult(input: InternalRow): Long = {
36+
val inputStrings: Seq[UTF8String] =
37+
input.toSeq(Seq.fill(input.numFields)(StringType)).asInstanceOf[Seq[UTF8String]]
38+
inputStrings.map(invokeBase).reduce(combineHashes)
39+
}
40+
}
41+
else throw new UnsupportedOperationException(s"Expect multiple STRING argument. $description")
42+
9443
}
9544

9645
protected def funcName: String

0 commit comments

Comments
 (0)