Skip to content

Commit 3a54608

Browse files
committed
[SPARK-53998] Add addition E2E tests for RTM
1 parent ab9cc08 commit 3a54608

File tree

1 file changed

+393
-0
lines changed

1 file changed

+393
-0
lines changed
Lines changed: 393 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,393 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.streaming
19+
20+
import java.util.concurrent.ConcurrentLinkedQueue
21+
22+
import scala.collection.mutable
23+
24+
import org.scalatest.time.SpanSugar._
25+
26+
import org.apache.spark.SparkContext
27+
import org.apache.spark.sql.{ForeachWriter, Row}
28+
import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock
29+
import org.apache.spark.sql.execution.streaming.LowLatencyMemoryStream
30+
import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryWrapper
31+
import org.apache.spark.sql.functions._
32+
import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
33+
import org.apache.spark.sql.test.TestSparkSession
34+
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
35+
36+
class StreamRealTimeModeE2ESuite extends StreamRealTimeModeE2ESuiteBase {
37+
38+
import testImplicits._
39+
40+
override protected def createSparkSession =
41+
new TestSparkSession(
42+
new SparkContext(
43+
"local[15]",
44+
"streaming-rtm-e2e-context",
45+
sparkConf.set("spark.sql.shuffle.partitions", "5")
46+
)
47+
)
48+
49+
private def runForeachTest(withUnion: Boolean): Unit = {
50+
var query: StreamingQuery = null
51+
try {
52+
withTempDir { checkpointDir =>
53+
val clock = new GlobalSingletonManualClock()
54+
LowLatencyClock.setClock(clock)
55+
val uniqueSinkName = if (withUnion) {
56+
sinkName + "-union"
57+
} else {
58+
sinkName
59+
}
60+
61+
val read = LowLatencyMemoryStream[(String, Int)](5)
62+
val read1 = LowLatencyMemoryStream[(String, Int)](5)
63+
val dataframe = if (withUnion) {
64+
read.toDF().union(read1.toDF())
65+
} else {
66+
read.toDF()
67+
}
68+
69+
query = dataframe
70+
.select(col("_1").as("key"), col("_2").as("value"))
71+
.select(
72+
concat(
73+
col("key").cast("STRING"),
74+
lit("-"),
75+
col("value").cast("STRING")
76+
).as("output")
77+
)
78+
.writeStream
79+
.outputMode(OutputMode.Update())
80+
.foreach(new ForeachWriter[Row] {
81+
private var batchPartitionId: String = null
82+
private val processedThisBatch = new ConcurrentLinkedQueue[String]()
83+
override def open(partitionId: Long, epochId: Long): Boolean = {
84+
ResultsCollector
85+
.computeIfAbsent(uniqueSinkName, (_) => new ConcurrentLinkedQueue[String]())
86+
batchPartitionId = s"$uniqueSinkName-$epochId-$partitionId"
87+
assert(
88+
!ResultsCollector.containsKey(batchPartitionId),
89+
s"should NOT contain batchPartitionId ${batchPartitionId}"
90+
)
91+
ResultsCollector
92+
.put(batchPartitionId, new ConcurrentLinkedQueue[String]())
93+
true
94+
}
95+
96+
override def process(value: Row): Unit = {
97+
val v = value.getAs[String]("output")
98+
ResultsCollector.get(uniqueSinkName).add(v)
99+
processedThisBatch.add(v)
100+
}
101+
102+
override def close(errorOrNull: Throwable): Unit = {
103+
104+
assert(
105+
ResultsCollector.containsKey(batchPartitionId),
106+
s"should contain batchPartitionId ${batchPartitionId}"
107+
)
108+
ResultsCollector.get(batchPartitionId).addAll(processedThisBatch)
109+
processedThisBatch.clear()
110+
}
111+
})
112+
.option("checkpointLocation", checkpointDir.getName)
113+
.queryName("foreach")
114+
// doesn't matter the batch duration set here since we are going
115+
// to manually control batch durations via manual clock
116+
.trigger(defaultTrigger)
117+
.start()
118+
119+
val expectedResults = mutable.ListBuffer[String]()
120+
val expectedResultsByBatch = mutable.HashMap[Int, mutable.ListBuffer[String]]()
121+
122+
val numRows = 10
123+
for (i <- 0 until 3) {
124+
expectedResultsByBatch(i) = new mutable.ListBuffer[String]()
125+
for (key <- List("a", "b", "c")) {
126+
for (j <- 1 to numRows) {
127+
read.addData((key, 1))
128+
val data = s"$key-1"
129+
expectedResults += data
130+
expectedResultsByBatch(i) += data
131+
}
132+
}
133+
134+
if (withUnion) {
135+
for (key <- List("d", "e", "f")) {
136+
for (j <- 1 to numRows) {
137+
read1.addData((key, 2))
138+
val data = s"$key-2"
139+
expectedResults += data
140+
expectedResultsByBatch(i) += data
141+
}
142+
}
143+
}
144+
145+
eventually(timeout(60.seconds)) {
146+
ResultsCollector
147+
.get(uniqueSinkName)
148+
.toArray(new Array[String](ResultsCollector.get(uniqueSinkName).size()))
149+
.toList
150+
.sorted should equal(expectedResults.sorted)
151+
}
152+
153+
clock.advance(defaultTrigger.batchDurationMs)
154+
eventually(timeout(60.seconds)) {
155+
query
156+
.asInstanceOf[StreamingQueryWrapper]
157+
.streamingQuery
158+
.getLatestExecutionContext()
159+
.batchId should be(i + 1)
160+
query.lastProgress.sources(0).numInputRows should be(numRows * 3)
161+
162+
val commitedResults = new mutable.ListBuffer[String]()
163+
val numPartitions = if (withUnion) 10 else 5
164+
for (v <- 0 until numPartitions) {
165+
val it = ResultsCollector.get(s"$uniqueSinkName-${i}-$v").iterator()
166+
while (it.hasNext) {
167+
commitedResults += it.next()
168+
}
169+
}
170+
171+
commitedResults.sorted should equal(expectedResultsByBatch(i).sorted)
172+
}
173+
}
174+
}
175+
} finally {
176+
if (query != null) {
177+
query.stop()
178+
}
179+
}
180+
}
181+
182+
private def runMapPartitionsTest(withUnion: Boolean): Unit = {
183+
var query: StreamingQuery = null
184+
try {
185+
withTempDir { checkpointDir =>
186+
val clock = new GlobalSingletonManualClock()
187+
LowLatencyClock.setClock(clock)
188+
val uniqueSinkName = if (withUnion) {
189+
sinkName + "mapPartitions-union"
190+
} else {
191+
sinkName + "mapPartitions"
192+
}
193+
194+
val read = LowLatencyMemoryStream[(String, Int)](5)
195+
val read1 = LowLatencyMemoryStream[(String, Int)](5)
196+
val dataframe = if (withUnion) {
197+
read.toDF().union(read1.toDF())
198+
} else {
199+
read.toDF()
200+
}
201+
202+
val df = dataframe
203+
.select(col("_1").as("key"), col("_2").as("value"))
204+
.select(
205+
concat(
206+
col("key").cast("STRING"),
207+
lit("-"),
208+
col("value").cast("STRING")
209+
).as("output")
210+
)
211+
.as[String]
212+
.mapPartitions(rows => {
213+
rows.map(row => {
214+
val collector = ResultsCollector
215+
.computeIfAbsent(uniqueSinkName, (_) => new ConcurrentLinkedQueue[String]())
216+
collector.add(row)
217+
row
218+
})
219+
})
220+
.toDF()
221+
222+
query = runStreamingQuery(sinkName, df)
223+
224+
val expectedResults = mutable.ListBuffer[String]()
225+
val expectedResultsByBatch = mutable.HashMap[Int, mutable.ListBuffer[String]]()
226+
227+
val numRows = 10
228+
for (i <- 0 until 3) {
229+
expectedResultsByBatch(i) = new mutable.ListBuffer[String]()
230+
for (key <- List("a", "b", "c")) {
231+
for (j <- 1 to numRows) {
232+
read.addData((key, 1))
233+
val data = s"$key-1"
234+
expectedResults += data
235+
expectedResultsByBatch(i) += data
236+
}
237+
}
238+
239+
if (withUnion) {
240+
for (key <- List("d", "e", "f")) {
241+
for (j <- 1 to numRows) {
242+
read1.addData((key, 2))
243+
val data = s"$key-2"
244+
expectedResults += data
245+
expectedResultsByBatch(i) += data
246+
}
247+
}
248+
}
249+
250+
// results collected from mapPartitions
251+
eventually(timeout(60.seconds)) {
252+
ResultsCollector
253+
.get(uniqueSinkName)
254+
.toArray(new Array[String](ResultsCollector.get(uniqueSinkName).size()))
255+
.toList
256+
.sorted should equal(expectedResults.sorted)
257+
}
258+
259+
// results collected from foreach sink
260+
eventually(timeout(60.seconds)) {
261+
ResultsCollector
262+
.get(sinkName)
263+
.toArray(new Array[String](ResultsCollector.get(sinkName).size()))
264+
.toList
265+
.sorted should equal(expectedResults.sorted)
266+
}
267+
268+
clock.advance(defaultTrigger.batchDurationMs)
269+
eventually(timeout(60.seconds)) {
270+
query
271+
.asInstanceOf[StreamingQueryWrapper]
272+
.streamingQuery
273+
.getLatestExecutionContext()
274+
.batchId should be(i + 1)
275+
query.lastProgress.sources(0).numInputRows should be(numRows * 3)
276+
}
277+
}
278+
}
279+
} finally {
280+
if (query != null) {
281+
query.stop()
282+
}
283+
}
284+
}
285+
286+
test("foreach") {
287+
runForeachTest(withUnion = false)
288+
}
289+
290+
test("union - foreach") {
291+
runForeachTest(withUnion = true)
292+
}
293+
294+
test("mapPartitions") {
295+
runMapPartitionsTest(withUnion = false)
296+
}
297+
298+
test("union - mapPartitions") {
299+
runMapPartitionsTest(withUnion = true)
300+
}
301+
302+
test("scala stateless UDF") {
303+
val myUDF = (id: Int) => id + 1
304+
val udf = spark.udf.register("myUDF", myUDF)
305+
val (read, clock) = createMemoryStream()
306+
307+
val df = read
308+
.toDF()
309+
.select(col("_1").as("key"), udf(col("_2")).as("value_plus_1"))
310+
.select(concat(col("key"), lit("-"), col("value_plus_1").cast("STRING")).as("output"))
311+
312+
var query: StreamingQuery = null
313+
try {
314+
query = runStreamingQuery("scala_udf", df)
315+
processBatches(query, read, clock, 10, 3, (key, value) => Array(s"$key-${value + 1}"))
316+
} finally {
317+
if (query != null) query.stop()
318+
}
319+
}
320+
321+
test("stream static join") {
322+
val (read, clock) = createMemoryStream()
323+
val staticDf = spark
324+
.range(1, 31, 1, 10)
325+
.selectExpr("id AS join_key", "id AS join_value")
326+
// This will produce HashAggregateExec which should not be blocked by allowList
327+
// since it's the batch subquery
328+
.groupBy("join_key")
329+
.agg(max($"join_value").as("join_value"))
330+
331+
val df = read
332+
.toDF()
333+
.select(col("_1").as("key"), col("_2").as("value"))
334+
.join(staticDf, col("value") === col("join_key"))
335+
.select(concat(col("key"), lit("-"), col("value"), lit("-"), col("join_value")).as("output"))
336+
337+
var query: StreamingQuery = null
338+
try {
339+
query = runStreamingQuery("stream_static_join", df)
340+
processBatches(query, read, clock, 10, 3, (key, value) => Array(s"$key-$value-$value"))
341+
} finally {
342+
if (query != null) query.stop()
343+
}
344+
}
345+
346+
test("to_json and from_json round-trip") {
347+
val (read, clock) = createMemoryStream()
348+
val schema = new StructType().add("key", StringType).add("value", IntegerType)
349+
350+
val df = read
351+
.toDF()
352+
.select(struct(col("_1").as("key"), col("_2").as("value")).as("json"))
353+
.select(from_json(to_json(col("json")), schema).as("json"))
354+
.select(concat(col("json.key"), lit("-"), col("json.value")))
355+
356+
var query: StreamingQuery = null
357+
try {
358+
query = runStreamingQuery("json_roundtrip", df)
359+
processBatches(query, read, clock, 10, 3, (key, value) => Array(s"$key-$value"))
360+
} finally {
361+
if (query != null) query.stop()
362+
}
363+
}
364+
365+
test("generateExec passthrough") {
366+
val (read, clock) = createMemoryStream()
367+
368+
val df = read
369+
.toDF()
370+
.select(col("_1").as("key"), col("_2").as("value"))
371+
.withColumn("value_array", array(col("value"), -col("value")))
372+
df.createOrReplaceTempView("tempView")
373+
val explodeDF =
374+
spark
375+
.sql("select key, explode(value_array) as exploded_value from tempView")
376+
.select(concat(col("key"), lit("-"), col("exploded_value").cast("STRING")).as("output"))
377+
378+
var query: StreamingQuery = null
379+
try {
380+
query = runStreamingQuery("generateExec_passthrough", explodeDF)
381+
processBatches(
382+
query,
383+
read,
384+
clock,
385+
10,
386+
3,
387+
(key, value) => Array(s"$key-$value", s"$key--$value")
388+
)
389+
} finally {
390+
if (query != null) query.stop()
391+
}
392+
}
393+
}

0 commit comments

Comments
 (0)