|
14 | 14 |
|
15 | 15 | package xenon.clickhouse.func |
16 | 16 |
|
| 17 | +import org.apache.spark.sql.catalyst.InternalRow |
17 | 18 | import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ScalarFunction, UnboundFunction} |
18 | 19 | import org.apache.spark.sql.types._ |
19 | 20 | import org.apache.spark.unsafe.types.UTF8String |
20 | 21 |
|
21 | 22 | abstract class MultiArgsHash extends UnboundFunction with ClickhouseEquivFunction { |
22 | | - trait Base extends ScalarFunction[Long] { |
23 | | - // must not be private object, nor do it successors, because spark would compile them |
24 | | - override def canonicalName: String = s"clickhouse.$name" |
25 | | - override def resultType: DataType = LongType |
26 | | - override def toString: String = name |
27 | | - } |
28 | | - |
29 | | - object Arg1 extends Base { |
30 | | - override def name: String = s"${funcName}_1" |
31 | | - override def inputTypes: Array[DataType] = Array.fill(1)(StringType) |
32 | | - def invoke(value: UTF8String): Long = invokeBase(value) |
33 | | - } |
34 | | - |
35 | | - object Arg2 extends Base { |
36 | | - override def name: String = s"${funcName}_2" |
37 | | - override def inputTypes: Array[DataType] = Array.fill(2)(StringType) |
38 | | - def invoke(v1: UTF8String, v2: UTF8String): Long = Seq(v1, v2).map(invokeBase).reduce(combineHashes) |
39 | | - } |
40 | | - |
41 | | - object Arg3 extends Base { |
42 | | - override def name: String = s"${funcName}_3" |
43 | | - override def inputTypes: Array[DataType] = Array.fill(3)(StringType) |
44 | | - def invoke(v1: UTF8String, v2: UTF8String, v3: UTF8String): Long = |
45 | | - Seq(v1, v2, v3).map(invokeBase).reduce(combineHashes) |
46 | | - } |
47 | | - |
48 | | - object Arg4 extends Base { |
49 | | - override def name: String = s"${funcName}_4" |
50 | | - override def inputTypes: Array[DataType] = Array.fill(4)(StringType) |
51 | | - def invoke(v1: UTF8String, v2: UTF8String, v3: UTF8String, v4: UTF8String): Long = |
52 | | - Seq(v1, v2, v3, v4).map(invokeBase).reduce(combineHashes) |
53 | | - } |
54 | | - |
55 | | - object Arg5 extends Base { |
56 | | - override def name: String = s"${funcName}_4" |
57 | | - override def inputTypes: Array[DataType] = Array.fill(5)(StringType) |
58 | | - def invoke(v1: UTF8String, v2: UTF8String, v3: UTF8String, v4: UTF8String, v5: UTF8String): Long = |
59 | | - Seq(v1, v2, v3, v4, v5).map(invokeBase).reduce(combineHashes) |
60 | | - } |
61 | 23 | private def isExceptedType(dt: DataType): Boolean = |
62 | 24 | dt.isInstanceOf[StringType] |
63 | 25 |
|
64 | 26 | final override def name: String = funcName |
65 | | - final override def bind(inputType: StructType): BoundFunction = inputType.fields match { |
66 | | - case Array(StructField(_, dt, _, _)) if List(dt).forall(isExceptedType) => this.Arg1 |
67 | | - case Array( |
68 | | - StructField(_, dt1, _, _), |
69 | | - StructField(_, dt2, _, _) |
70 | | - ) if List(dt1, dt2).forall(isExceptedType) => |
71 | | - this.Arg2 |
72 | | - case Array( |
73 | | - StructField(_, dt1, _, _), |
74 | | - StructField(_, dt2, _, _), |
75 | | - StructField(_, dt3, _, _) |
76 | | - ) if List(dt1, dt2, dt3).forall(isExceptedType) => |
77 | | - this.Arg3 |
78 | | - case Array( |
79 | | - StructField(_, dt1, _, _), |
80 | | - StructField(_, dt2, _, _), |
81 | | - StructField(_, dt3, _, _), |
82 | | - StructField(_, dt4, _, _) |
83 | | - ) if List(dt1, dt2, dt3, dt4).forall(isExceptedType) => |
84 | | - this.Arg4 |
85 | | - case Array( |
86 | | - StructField(_, dt1, _, _), |
87 | | - StructField(_, dt2, _, _), |
88 | | - StructField(_, dt3, _, _), |
89 | | - StructField(_, dt4, _, _), |
90 | | - StructField(_, dt5, _, _) |
91 | | - ) if List(dt1, dt2, dt3, dt4, dt5).forall(isExceptedType) => |
92 | | - this.Arg5 |
93 | | - case _ => throw new UnsupportedOperationException(s"Expect up to 5 STRING argument. $description") |
| 27 | + final override def bind(inputType: StructType): BoundFunction = { |
| 28 | + val inputDataTypes = inputType.fields.map(_.dataType) |
| 29 | + if (inputDataTypes.forall(isExceptedType)) new ScalarFunction[Long] { |
| 30 | + override def inputTypes(): Array[DataType] = inputDataTypes |
| 31 | + override def name: String = funcName |
| 32 | + override def canonicalName: String = s"clickhouse.$name" |
| 33 | + override def resultType: DataType = LongType |
| 34 | + override def toString: String = name |
| 35 | + override def produceResult(input: InternalRow): Long = { |
| 36 | + val inputStrings: Seq[UTF8String] = |
| 37 | + input.toSeq(Seq.fill(input.numFields)(StringType)).asInstanceOf[Seq[UTF8String]] |
| 38 | + inputStrings.map(invokeBase).reduce(combineHashes) |
| 39 | + } |
| 40 | + } |
| 41 | + else throw new UnsupportedOperationException(s"Expect multiple STRING argument. $description") |
| 42 | + |
94 | 43 | } |
95 | 44 |
|
96 | 45 | protected def funcName: String |
|
0 commit comments