Skip to content

Commit 683c558

Browse files
authored
Spark 3.3: Implement SupportsRuntimeFiltering (#276)
1 parent 5e83d0c commit 683c558

File tree

5 files changed

+95
-4
lines changed

5 files changed

+95
-4
lines changed

spark-3.3/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/cluster/ClickHouseClusterReadSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ package org.apache.spark.sql.clickhouse.cluster
1616

1717
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf.READ_DISTRIBUTED_CONVERT_LOCAL
1818
import org.apache.spark.sql.{AnalysisException, Row}
19+
import org.apache.spark.sql.catalyst.TableIdentifier
20+
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
1921

2022
class ClickHouseClusterReadSuite extends SparkClickHouseClusterTest {
2123

@@ -84,4 +86,31 @@ class ClickHouseClusterReadSuite extends SparkClickHouseClusterTest {
8486
)
8587
}
8688
}
89+
90+
test("runtime filter - distributed table") {
91+
withSimpleDistTable("single_replica", "runtime_db", "runtime_tbl", true) { (_, db, tbl_dist, _) =>
92+
spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=false")
93+
checkAnswer(
94+
spark.sql(s"SELECT id FROM $db.$tbl_dist " +
95+
s"WHERE id IN (" +
96+
s" SELECT id FROM $db.$tbl_dist " +
97+
s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" +
98+
s")"),
99+
Row(1)
100+
)
101+
102+
spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=true")
103+
val df = spark.sql(s"SELECT id FROM $db.$tbl_dist " +
104+
s"WHERE id IN (" +
105+
s" SELECT id FROM $db.$tbl_dist " +
106+
s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" +
107+
s")")
108+
checkAnswer(df, Row(1))
109+
val runtimeFilterExists = df.queryExecution.sparkPlan.exists {
110+
case BatchScanExec(_, _, runtimeFilters, _) if runtimeFilters.nonEmpty => true
111+
case _ => false
112+
}
113+
assert(runtimeFilterExists)
114+
}
115+
}
87116
}

spark-3.3/clickhouse-spark-it/src/test/scala/org/apache/spark/sql/clickhouse/single/ClickHouseSingleSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package org.apache.spark.sql.clickhouse.single
1616

1717
import org.apache.spark.sql.Row
18+
import org.apache.spark.sql.catalyst.TableIdentifier
19+
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
1820
import org.apache.spark.sql.types._
1921

2022
class ClickHouseSingleSuite extends SparkClickHouseSingleTest {
@@ -451,4 +453,34 @@ class ClickHouseSingleSuite extends SparkClickHouseSingleTest {
451453
spark.sql(s"UNCACHE TABLE $db.$tbl")
452454
}
453455
}
456+
457+
test("runtime filter") {
458+
val db = "runtime_db"
459+
val tbl = "runtime_tbl"
460+
461+
withSimpleTable(db, tbl, true) {
462+
spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=false")
463+
checkAnswer(
464+
spark.sql(s"SELECT id FROM $db.$tbl " +
465+
s"WHERE id IN (" +
466+
s" SELECT id FROM $db.$tbl " +
467+
s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" +
468+
s")"),
469+
Row(1)
470+
)
471+
472+
spark.sql("set spark.clickhouse.read.runtimeFilter.enabled=true")
473+
val df = spark.sql(s"SELECT id FROM $db.$tbl " +
474+
s"WHERE id IN (" +
475+
s" SELECT id FROM $db.$tbl " +
476+
s" WHERE DATE_FORMAT(create_time, 'yyyy-MM-dd') between '2021-01-01' and '2022-01-01'" +
477+
s")")
478+
checkAnswer(df, Row(1))
479+
val runtimeFilterExists = df.queryExecution.sparkPlan.exists {
480+
case BatchScanExec(_, _, runtimeFilters, _) if runtimeFilters.nonEmpty => true
481+
case _ => false
482+
}
483+
assert(runtimeFilterExists)
484+
}
485+
}
454486
}

spark-3.3/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/ClickHouseSQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,13 @@ object ClickHouseSQLConf {
173173
.transform(_.toLowerCase)
174174
.createWithDefault("json")
175175

176+
val RUNTIME_FILTER_ENABLED: ConfigEntry[Boolean] =
177+
buildConf("spark.clickhouse.read.runtimeFilter.enabled")
178+
.doc("Enable runtime filter for reading.")
179+
.version("0.8.0")
180+
.booleanConf
181+
.createWithDefault(false)
182+
176183
val WRITE_FORMAT: ConfigEntry[String] =
177184
buildConf("spark.clickhouse.write.format")
178185
.doc("Serialize format for writing. Supported formats: json, arrow")

spark-3.3/clickhouse-spark/src/main/scala/org/apache/spark/sql/clickhouse/SparkOptions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class ReadOptions(_options: JMap[String, String]) extends SparkOptions {
4848

4949
def format: String =
5050
eval(READ_FORMAT.key, READ_FORMAT)
51+
52+
def runtimeFilterEnabled: Boolean =
53+
eval(RUNTIME_FILTER_ENABLED.key, RUNTIME_FILTER_ENABLED)
5154
}
5255

5356
class WriteOptions(_options: JMap[String, String]) extends SparkOptions {

spark-3.3/clickhouse-spark/src/main/scala/xenon/clickhouse/read/ClickHouseRead.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ package xenon.clickhouse.read
1616

1717
import org.apache.spark.sql.catalyst.InternalRow
1818
import org.apache.spark.sql.clickhouse.ClickHouseSQLConf._
19-
import org.apache.spark.sql.connector.expressions.Transform
19+
import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference, Transform}
2020
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
2121
import org.apache.spark.sql.connector.metric.CustomMetric
2222
import org.apache.spark.sql.connector.read._
@@ -127,8 +127,14 @@ class ClickHouseScanBuilder(
127127

128128
class ClickHouseBatchScan(scanJob: ScanJobDescription) extends Scan with Batch
129129
with SupportsReportPartitioning
130+
with SupportsRuntimeFiltering
130131
with PartitionReaderFactory
131-
with ClickHouseHelper {
132+
with ClickHouseHelper
133+
with SQLHelper {
134+
135+
implicit private val tz: ZoneId = scanJob.tz
136+
137+
private var runtimeFilters: Array[Filter] = Array.empty
132138

133139
val database: String = scanJob.database
134140
val table: String = scanJob.table
@@ -187,9 +193,13 @@ class ClickHouseBatchScan(scanJob: ScanJobDescription) extends Scan with Batch
187193
override def createReader(_partition: InputPartition): PartitionReader[InternalRow] = {
188194
val format = scanJob.readOptions.format
189195
val partition = _partition.asInstanceOf[ClickHouseInputPartition]
196+
val finalScanJob = scanJob.copy(filtersExpr =
197+
scanJob.filtersExpr + " AND "
198+
+ compileFilters(AlwaysTrue :: runtimeFilters.toList)
199+
)
190200
format match {
191-
case "json" => new ClickHouseJsonReader(scanJob, partition)
192-
case "binary" => new ClickHouseBinaryReader(scanJob, partition)
201+
case "json" => new ClickHouseJsonReader(finalScanJob, partition)
202+
case "binary" => new ClickHouseBinaryReader(finalScanJob, partition)
193203
case unsupported => throw CHClientException(s"Unsupported read format: $unsupported")
194204
}
195205
}
@@ -198,4 +208,14 @@ class ClickHouseBatchScan(scanJob: ScanJobDescription) extends Scan with Batch
198208
BlocksReadMetric(),
199209
BytesReadMetric()
200210
)
211+
212+
override def filterAttributes(): Array[NamedReference] =
213+
if (scanJob.readOptions.runtimeFilterEnabled) {
214+
scanJob.readSchema.fields.map(field => Expressions.column(field.name))
215+
} else {
216+
Array.empty
217+
}
218+
219+
override def filter(filters: Array[Filter]): Unit =
220+
runtimeFilters = filters
201221
}

0 commit comments

Comments
 (0)