|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.sql.streaming |
| 19 | + |
| 20 | +import java.util.concurrent.ConcurrentLinkedQueue |
| 21 | + |
| 22 | +import scala.collection.mutable |
| 23 | + |
| 24 | +import org.scalatest.time.SpanSugar._ |
| 25 | + |
| 26 | +import org.apache.spark.SparkContext |
| 27 | +import org.apache.spark.sql.{ForeachWriter, Row} |
| 28 | +import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock |
| 29 | +import org.apache.spark.sql.execution.streaming.LowLatencyMemoryStream |
| 30 | +import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryWrapper |
| 31 | +import org.apache.spark.sql.functions._ |
| 32 | +import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock |
| 33 | +import org.apache.spark.sql.test.TestSparkSession |
| 34 | +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} |
| 35 | + |
| 36 | +class StreamRealTimeModeE2ESuite extends StreamRealTimeModeE2ESuiteBase { |
| 37 | + |
| 38 | + import testImplicits._ |
| 39 | + |
| 40 | + override protected def createSparkSession = |
| 41 | + new TestSparkSession( |
| 42 | + new SparkContext( |
| 43 | + "local[15]", |
| 44 | + "streaming-rtm-e2e-context", |
| 45 | + sparkConf.set("spark.sql.shuffle.partitions", "5") |
| 46 | + ) |
| 47 | + ) |
| 48 | + |
| 49 | + private def runForeachTest(withUnion: Boolean): Unit = { |
| 50 | + var query: StreamingQuery = null |
| 51 | + try { |
| 52 | + withTempDir { checkpointDir => |
| 53 | + val clock = new GlobalSingletonManualClock() |
| 54 | + LowLatencyClock.setClock(clock) |
| 55 | + val uniqueSinkName = if (withUnion) { |
| 56 | + sinkName + "-union" |
| 57 | + } else { |
| 58 | + sinkName |
| 59 | + } |
| 60 | + |
| 61 | + val read = LowLatencyMemoryStream[(String, Int)](5) |
| 62 | + val read1 = LowLatencyMemoryStream[(String, Int)](5) |
| 63 | + val dataframe = if (withUnion) { |
| 64 | + read.toDF().union(read1.toDF()) |
| 65 | + } else { |
| 66 | + read.toDF() |
| 67 | + } |
| 68 | + |
| 69 | + query = dataframe |
| 70 | + .select(col("_1").as("key"), col("_2").as("value")) |
| 71 | + .select( |
| 72 | + concat( |
| 73 | + col("key").cast("STRING"), |
| 74 | + lit("-"), |
| 75 | + col("value").cast("STRING") |
| 76 | + ).as("output") |
| 77 | + ) |
| 78 | + .writeStream |
| 79 | + .outputMode(OutputMode.Update()) |
| 80 | + .foreach(new ForeachWriter[Row] { |
| 81 | + private var batchPartitionId: String = null |
| 82 | + private val processedThisBatch = new ConcurrentLinkedQueue[String]() |
| 83 | + override def open(partitionId: Long, epochId: Long): Boolean = { |
| 84 | + ResultsCollector |
| 85 | + .computeIfAbsent(uniqueSinkName, (_) => new ConcurrentLinkedQueue[String]()) |
| 86 | + batchPartitionId = s"$uniqueSinkName-$epochId-$partitionId" |
| 87 | + assert( |
| 88 | + !ResultsCollector.containsKey(batchPartitionId), |
| 89 | + s"should NOT contain batchPartitionId ${batchPartitionId}" |
| 90 | + ) |
| 91 | + ResultsCollector |
| 92 | + .put(batchPartitionId, new ConcurrentLinkedQueue[String]()) |
| 93 | + true |
| 94 | + } |
| 95 | + |
| 96 | + override def process(value: Row): Unit = { |
| 97 | + val v = value.getAs[String]("output") |
| 98 | + ResultsCollector.get(uniqueSinkName).add(v) |
| 99 | + processedThisBatch.add(v) |
| 100 | + } |
| 101 | + |
| 102 | + override def close(errorOrNull: Throwable): Unit = { |
| 103 | + |
| 104 | + assert( |
| 105 | + ResultsCollector.containsKey(batchPartitionId), |
| 106 | + s"should contain batchPartitionId ${batchPartitionId}" |
| 107 | + ) |
| 108 | + ResultsCollector.get(batchPartitionId).addAll(processedThisBatch) |
| 109 | + processedThisBatch.clear() |
| 110 | + } |
| 111 | + }) |
| 112 | + .option("checkpointLocation", checkpointDir.getName) |
| 113 | + .queryName("foreach") |
| 114 | + // doesn't matter the batch duration set here since we are going |
| 115 | + // to manually control batch durations via manual clock |
| 116 | + .trigger(defaultTrigger) |
| 117 | + .start() |
| 118 | + |
| 119 | + val expectedResults = mutable.ListBuffer[String]() |
| 120 | + val expectedResultsByBatch = mutable.HashMap[Int, mutable.ListBuffer[String]]() |
| 121 | + |
| 122 | + val numRows = 10 |
| 123 | + for (i <- 0 until 3) { |
| 124 | + expectedResultsByBatch(i) = new mutable.ListBuffer[String]() |
| 125 | + for (key <- List("a", "b", "c")) { |
| 126 | + for (j <- 1 to numRows) { |
| 127 | + read.addData((key, 1)) |
| 128 | + val data = s"$key-1" |
| 129 | + expectedResults += data |
| 130 | + expectedResultsByBatch(i) += data |
| 131 | + } |
| 132 | + } |
| 133 | + |
| 134 | + if (withUnion) { |
| 135 | + for (key <- List("d", "e", "f")) { |
| 136 | + for (j <- 1 to numRows) { |
| 137 | + read1.addData((key, 2)) |
| 138 | + val data = s"$key-2" |
| 139 | + expectedResults += data |
| 140 | + expectedResultsByBatch(i) += data |
| 141 | + } |
| 142 | + } |
| 143 | + } |
| 144 | + |
| 145 | + eventually(timeout(60.seconds)) { |
| 146 | + ResultsCollector |
| 147 | + .get(uniqueSinkName) |
| 148 | + .toArray(new Array[String](ResultsCollector.get(uniqueSinkName).size())) |
| 149 | + .toList |
| 150 | + .sorted should equal(expectedResults.sorted) |
| 151 | + } |
| 152 | + |
| 153 | + clock.advance(defaultTrigger.batchDurationMs) |
| 154 | + eventually(timeout(60.seconds)) { |
| 155 | + query |
| 156 | + .asInstanceOf[StreamingQueryWrapper] |
| 157 | + .streamingQuery |
| 158 | + .getLatestExecutionContext() |
| 159 | + .batchId should be(i + 1) |
| 160 | + query.lastProgress.sources(0).numInputRows should be(numRows * 3) |
| 161 | + |
| 162 | + val commitedResults = new mutable.ListBuffer[String]() |
| 163 | + val numPartitions = if (withUnion) 10 else 5 |
| 164 | + for (v <- 0 until numPartitions) { |
| 165 | + val it = ResultsCollector.get(s"$uniqueSinkName-${i}-$v").iterator() |
| 166 | + while (it.hasNext) { |
| 167 | + commitedResults += it.next() |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + commitedResults.sorted should equal(expectedResultsByBatch(i).sorted) |
| 172 | + } |
| 173 | + } |
| 174 | + } |
| 175 | + } finally { |
| 176 | + if (query != null) { |
| 177 | + query.stop() |
| 178 | + } |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + private def runMapPartitionsTest(withUnion: Boolean): Unit = { |
| 183 | + var query: StreamingQuery = null |
| 184 | + try { |
| 185 | + withTempDir { checkpointDir => |
| 186 | + val clock = new GlobalSingletonManualClock() |
| 187 | + LowLatencyClock.setClock(clock) |
| 188 | + val uniqueSinkName = if (withUnion) { |
| 189 | + sinkName + "mapPartitions-union" |
| 190 | + } else { |
| 191 | + sinkName + "mapPartitions" |
| 192 | + } |
| 193 | + |
| 194 | + val read = LowLatencyMemoryStream[(String, Int)](5) |
| 195 | + val read1 = LowLatencyMemoryStream[(String, Int)](5) |
| 196 | + val dataframe = if (withUnion) { |
| 197 | + read.toDF().union(read1.toDF()) |
| 198 | + } else { |
| 199 | + read.toDF() |
| 200 | + } |
| 201 | + |
| 202 | + val df = dataframe |
| 203 | + .select(col("_1").as("key"), col("_2").as("value")) |
| 204 | + .select( |
| 205 | + concat( |
| 206 | + col("key").cast("STRING"), |
| 207 | + lit("-"), |
| 208 | + col("value").cast("STRING") |
| 209 | + ).as("output") |
| 210 | + ) |
| 211 | + .as[String] |
| 212 | + .mapPartitions(rows => { |
| 213 | + rows.map(row => { |
| 214 | + val collector = ResultsCollector |
| 215 | + .computeIfAbsent(uniqueSinkName, (_) => new ConcurrentLinkedQueue[String]()) |
| 216 | + collector.add(row) |
| 217 | + row |
| 218 | + }) |
| 219 | + }) |
| 220 | + .toDF() |
| 221 | + |
| 222 | + query = runStreamingQuery(sinkName, df) |
| 223 | + |
| 224 | + val expectedResults = mutable.ListBuffer[String]() |
| 225 | + val expectedResultsByBatch = mutable.HashMap[Int, mutable.ListBuffer[String]]() |
| 226 | + |
| 227 | + val numRows = 10 |
| 228 | + for (i <- 0 until 3) { |
| 229 | + expectedResultsByBatch(i) = new mutable.ListBuffer[String]() |
| 230 | + for (key <- List("a", "b", "c")) { |
| 231 | + for (j <- 1 to numRows) { |
| 232 | + read.addData((key, 1)) |
| 233 | + val data = s"$key-1" |
| 234 | + expectedResults += data |
| 235 | + expectedResultsByBatch(i) += data |
| 236 | + } |
| 237 | + } |
| 238 | + |
| 239 | + if (withUnion) { |
| 240 | + for (key <- List("d", "e", "f")) { |
| 241 | + for (j <- 1 to numRows) { |
| 242 | + read1.addData((key, 2)) |
| 243 | + val data = s"$key-2" |
| 244 | + expectedResults += data |
| 245 | + expectedResultsByBatch(i) += data |
| 246 | + } |
| 247 | + } |
| 248 | + } |
| 249 | + |
| 250 | + // results collected from mapPartitions |
| 251 | + eventually(timeout(60.seconds)) { |
| 252 | + ResultsCollector |
| 253 | + .get(uniqueSinkName) |
| 254 | + .toArray(new Array[String](ResultsCollector.get(uniqueSinkName).size())) |
| 255 | + .toList |
| 256 | + .sorted should equal(expectedResults.sorted) |
| 257 | + } |
| 258 | + |
| 259 | + // results collected from foreach sink |
| 260 | + eventually(timeout(60.seconds)) { |
| 261 | + ResultsCollector |
| 262 | + .get(sinkName) |
| 263 | + .toArray(new Array[String](ResultsCollector.get(sinkName).size())) |
| 264 | + .toList |
| 265 | + .sorted should equal(expectedResults.sorted) |
| 266 | + } |
| 267 | + |
| 268 | + clock.advance(defaultTrigger.batchDurationMs) |
| 269 | + eventually(timeout(60.seconds)) { |
| 270 | + query |
| 271 | + .asInstanceOf[StreamingQueryWrapper] |
| 272 | + .streamingQuery |
| 273 | + .getLatestExecutionContext() |
| 274 | + .batchId should be(i + 1) |
| 275 | + query.lastProgress.sources(0).numInputRows should be(numRows * 3) |
| 276 | + } |
| 277 | + } |
| 278 | + } |
| 279 | + } finally { |
| 280 | + if (query != null) { |
| 281 | + query.stop() |
| 282 | + } |
| 283 | + } |
| 284 | + } |
| 285 | + |
| 286 | + test("foreach") { |
| 287 | + runForeachTest(withUnion = false) |
| 288 | + } |
| 289 | + |
| 290 | + test("union - foreach") { |
| 291 | + runForeachTest(withUnion = true) |
| 292 | + } |
| 293 | + |
| 294 | + test("mapPartitions") { |
| 295 | + runMapPartitionsTest(withUnion = false) |
| 296 | + } |
| 297 | + |
| 298 | + test("union - mapPartitions") { |
| 299 | + runMapPartitionsTest(withUnion = true) |
| 300 | + } |
| 301 | + |
| 302 | + test("scala stateless UDF") { |
| 303 | + val myUDF = (id: Int) => id + 1 |
| 304 | + val udf = spark.udf.register("myUDF", myUDF) |
| 305 | + val (read, clock) = createMemoryStream() |
| 306 | + |
| 307 | + val df = read |
| 308 | + .toDF() |
| 309 | + .select(col("_1").as("key"), udf(col("_2")).as("value_plus_1")) |
| 310 | + .select(concat(col("key"), lit("-"), col("value_plus_1").cast("STRING")).as("output")) |
| 311 | + |
| 312 | + var query: StreamingQuery = null |
| 313 | + try { |
| 314 | + query = runStreamingQuery("scala_udf", df) |
| 315 | + processBatches(query, read, clock, 10, 3, (key, value) => Array(s"$key-${value + 1}")) |
| 316 | + } finally { |
| 317 | + if (query != null) query.stop() |
| 318 | + } |
| 319 | + } |
| 320 | + |
| 321 | + test("stream static join") { |
| 322 | + val (read, clock) = createMemoryStream() |
| 323 | + val staticDf = spark |
| 324 | + .range(1, 31, 1, 10) |
| 325 | + .selectExpr("id AS join_key", "id AS join_value") |
| 326 | + // This will produce HashAggregateExec which should not be blocked by allowList |
| 327 | + // since it's the batch subquery |
| 328 | + .groupBy("join_key") |
| 329 | + .agg(max($"join_value").as("join_value")) |
| 330 | + |
| 331 | + val df = read |
| 332 | + .toDF() |
| 333 | + .select(col("_1").as("key"), col("_2").as("value")) |
| 334 | + .join(staticDf, col("value") === col("join_key")) |
| 335 | + .select(concat(col("key"), lit("-"), col("value"), lit("-"), col("join_value")).as("output")) |
| 336 | + |
| 337 | + var query: StreamingQuery = null |
| 338 | + try { |
| 339 | + query = runStreamingQuery("stream_static_join", df) |
| 340 | + processBatches(query, read, clock, 10, 3, (key, value) => Array(s"$key-$value-$value")) |
| 341 | + } finally { |
| 342 | + if (query != null) query.stop() |
| 343 | + } |
| 344 | + } |
| 345 | + |
| 346 | + test("to_json and from_json round-trip") { |
| 347 | + val (read, clock) = createMemoryStream() |
| 348 | + val schema = new StructType().add("key", StringType).add("value", IntegerType) |
| 349 | + |
| 350 | + val df = read |
| 351 | + .toDF() |
| 352 | + .select(struct(col("_1").as("key"), col("_2").as("value")).as("json")) |
| 353 | + .select(from_json(to_json(col("json")), schema).as("json")) |
| 354 | + .select(concat(col("json.key"), lit("-"), col("json.value"))) |
| 355 | + |
| 356 | + var query: StreamingQuery = null |
| 357 | + try { |
| 358 | + query = runStreamingQuery("json_roundtrip", df) |
| 359 | + processBatches(query, read, clock, 10, 3, (key, value) => Array(s"$key-$value")) |
| 360 | + } finally { |
| 361 | + if (query != null) query.stop() |
| 362 | + } |
| 363 | + } |
| 364 | + |
| 365 | + test("generateExec passthrough") { |
| 366 | + val (read, clock) = createMemoryStream() |
| 367 | + |
| 368 | + val df = read |
| 369 | + .toDF() |
| 370 | + .select(col("_1").as("key"), col("_2").as("value")) |
| 371 | + .withColumn("value_array", array(col("value"), -col("value"))) |
| 372 | + df.createOrReplaceTempView("tempView") |
| 373 | + val explodeDF = |
| 374 | + spark |
| 375 | + .sql("select key, explode(value_array) as exploded_value from tempView") |
| 376 | + .select(concat(col("key"), lit("-"), col("exploded_value").cast("STRING")).as("output")) |
| 377 | + |
| 378 | + var query: StreamingQuery = null |
| 379 | + try { |
| 380 | + query = runStreamingQuery("generateExec_passthrough", explodeDF) |
| 381 | + processBatches( |
| 382 | + query, |
| 383 | + read, |
| 384 | + clock, |
| 385 | + 10, |
| 386 | + 3, |
| 387 | + (key, value) => Array(s"$key-$value", s"$key--$value") |
| 388 | + ) |
| 389 | + } finally { |
| 390 | + if (query != null) query.stop() |
| 391 | + } |
| 392 | + } |
| 393 | +} |
0 commit comments