diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml index fd0ca36d36dfe..08ace7adbde5c 100644 --- a/connector/kafka-0-10-sql/pom.xml +++ b/connector/kafka-0-10-sql/pom.xml @@ -78,6 +78,11 @@ org.scala-lang.modules scala-parallel-collections_${scala.binary.version} + org.apache.kafka kafka-clients diff --git a/connector/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/connector/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index e096f120b8926..28a3af7b9074e 100644 --- a/connector/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/connector/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -16,3 +16,4 @@ # org.apache.spark.sql.kafka010.KafkaSourceProvider +org.apache.spark.sql.kafka010.share.KafkaShareSourceProvider diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareMicroBatchStream.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareMicroBatchStream.scala new file mode 100644 index 0000000000000..8f944977a466b --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareMicroBatchStream.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import java.{util => ju} +import java.util.Optional +import java.util.concurrent.atomic.AtomicLong + +import scala.jdk.CollectionConverters._ + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} +import org.apache.spark.sql.connector.read.streaming._ +import org.apache.spark.sql.kafka010.share.consumer.AcknowledgmentMode +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ + +/** + * A [[MicroBatchStream]] that reads data from Kafka share groups. + * + * Unlike [[org.apache.spark.sql.kafka010.KafkaMicroBatchStream]] which tracks sequential offsets + * per partition, this implementation: + * + * 1. Subscribes to topics via share groups (multiple consumers can read from same partition) + * 2. Uses batch IDs for ordering instead of Kafka offsets + * 3. Tracks non-sequential acquired record offsets + * 4. Supports acknowledgment-based commit (ACCEPT/RELEASE/REJECT) + * + * The [[KafkaShareSourceOffset]] is the custom [[Offset]] that contains: + * - Share group ID + * - Batch ID for ordering + * - Map of TopicPartition to acquired record ranges (non-sequential offsets) + * + * Fault tolerance is achieved via: + * - Kafka's acquisition locks (auto-release on timeout) + * - Checkpointing of acknowledgment state + * - At-least-once semantics by default + * - Optional exactly-once via idempotent sinks or checkpoint-based dedup + * + * @param shareGroupId The Kafka share group identifier + * @param topics Set of topics to subscribe to + * @param executorKafkaParams Kafka params for executor-side consumers + * @param options Source options + * @param metadataPath Path for storing checkpoint metadata + * @param acknowledgmentMode Implicit or explicit acknowledgment + * @param exactlyOnceStrategy Strategy for exactly-once semantics + */ +private[kafka010] class KafkaShareMicroBatchStream( + val shareGroupId: String, + val topics: Set[String], + executorKafkaParams: ju.Map[String, Object], + options: CaseInsensitiveStringMap, + metadataPath: String, + acknowledgmentMode: AcknowledgmentMode, + exactlyOnceStrategy: ExactlyOnceStrategy) + extends SupportsTriggerAvailableNow + with ReportsSourceMetrics + with MicroBatchStream + with Logging { + + private val pollTimeoutMs = options.getLong( + KafkaShareSourceProvider.CONSUMER_POLL_TIMEOUT, + SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L) + + private val lockTimeoutMs = options.getLong( + KafkaShareSourceProvider.SHARE_LOCK_TIMEOUT, + 30000L) + + private val maxRecordsPerBatch = Option(options.get( + KafkaShareSourceProvider.MAX_RECORDS_PER_BATCH)).map(_.toLong) + + private val parallelism = options.getInt( + KafkaShareSourceProvider.SHARE_PARALLELISM, + SparkSession.active.sparkContext.defaultParallelism) + + private val includeHeaders = options.getBoolean( + KafkaShareSourceProvider.INCLUDE_HEADERS, false) + + // Batch ID counter - used for offset ordering since share groups don't have sequential offsets + private val batchIdCounter = new AtomicLong(0) + + // Track in-flight batches for acknowledgment + private val inFlightManager = new ShareInFlightManager() + + // Checkpoint writer for exactly-once semantics + private lazy val checkpointWriter = new ShareCheckpointWriter(metadataPath) + + // State for Trigger.AvailableNow + private var isTriggerAvailableNow: Boolean = false + private var allDataForTriggerAvailableNow: Option[KafkaShareSourceOffset] = None + + /** + * Return the initial offset for the stream. + * For share groups, this is an empty offset since Kafka manages state. + */ + override def initialOffset(): Offset = { + // Check for existing checkpoint + val existingOffset = checkpointWriter.getLatest() + existingOffset.getOrElse(KafkaShareSourceOffset.empty(shareGroupId)) + } + + /** + * Return the latest available offset. + * For share groups, we create a new batch ID since Kafka doesn't expose "latest" in the same way. + */ + override def latestOffset(): Offset = { + throw new UnsupportedOperationException( + "latestOffset(Offset, ReadLimit) should be called instead of this method") + } + + override def latestOffset(start: Offset, readLimit: ReadLimit): Offset = { + val startOffset = KafkaShareSourceOffset(start) + val newBatchId = batchIdCounter.incrementAndGet() + + // For share groups, we don't have a concept of "latest offset" like traditional consumers. + // Instead, we create a placeholder offset that will be populated when we actually poll. + // The actual records will be determined during planInputPartitions. + val offset = KafkaShareSourceOffset.forBatch(shareGroupId, newBatchId) + + // Apply rate limiting if configured + readLimit match { + case rows: ReadMaxRows => + // Rate limiting will be applied during partition planning + logDebug(s"Rate limit set to ${rows.maxRows()} records per batch") + case _ => // No rate limiting + } + + if (isTriggerAvailableNow && allDataForTriggerAvailableNow.isEmpty) { + allDataForTriggerAvailableNow = Some(offset) + } + + offset + } + + /** + * Plan input partitions for the batch. + * + * Unlike traditional Kafka source which creates one partition per TopicPartition, + * share groups can have multiple consumers reading from the same partition. + * We create partitions based on configured parallelism. + */ + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + val startOffset = KafkaShareSourceOffset(start) + val endOffset = KafkaShareSourceOffset(end) + + logInfo(s"Planning ${parallelism} input partitions for share group $shareGroupId " + + s"(batch ${startOffset.batchId} to ${endOffset.batchId})") + + // Create input partitions based on parallelism, not Kafka partitions + // Each Spark task will have its own share consumer that will receive records + // from any available partition in the subscribed topics + (0 until parallelism).map { partitionId => + KafkaShareInputPartition( + shareGroupId = shareGroupId, + sparkPartitionId = partitionId, + topics = topics, + kafkaParams = executorKafkaParams, + pollTimeoutMs = pollTimeoutMs, + lockTimeoutMs = lockTimeoutMs, + maxRecordsPerBatch = maxRecordsPerBatch, + includeHeaders = includeHeaders, + acknowledgmentMode = acknowledgmentMode, + exactlyOnceStrategy = exactlyOnceStrategy, + batchId = endOffset.batchId + ) + }.toArray + } + + override def createReaderFactory(): PartitionReaderFactory = { + KafkaShareBatchReaderFactory + } + + override def deserializeOffset(json: String): Offset = { + KafkaShareSourceOffset(json) + } + + /** + * Commit the batch by acknowledging records to Kafka. + * + * This is called after the batch has been successfully processed. + * For share groups, this commits the acknowledgments (not offsets). + */ + override def commit(end: Offset): Unit = { + val offset = KafkaShareSourceOffset(end) + val batchId = offset.batchId + + logInfo(s"Committing batch $batchId for share group $shareGroupId") + + // In implicit mode, we assume all records are successfully processed + // and commit ACCEPT acknowledgments + if (acknowledgmentMode == AcknowledgmentMode.Implicit) { + inFlightManager.getBatch(batchId).foreach { batch => + logDebug(s"Implicitly acknowledging ${batch.pending} pending records as ACCEPT") + batch.acknowledgeAllAsAccept() + } + } + + // Write checkpoint for exactly-once strategies + exactlyOnceStrategy match { + case ExactlyOnceStrategy.CheckpointDedup => + checkpointWriter.write(offset) + case _ => // No checkpoint needed for other strategies + } + + // Clean up completed batch + inFlightManager.removeBatch(batchId) + } + + override def stop(): Unit = { + // Release any pending records before stopping + val released = inFlightManager.releaseAll() + if (released > 0) { + logWarning(s"Released $released pending records during shutdown") + } + inFlightManager.clear() + } + + override def toString: String = s"KafkaShareMicroBatchStream[shareGroup=$shareGroupId, topics=$topics]" + + override def metrics(latestConsumedOffset: Optional[Offset]): ju.Map[String, String] = { + val offset = Option(latestConsumedOffset.orElse(null)) + offset match { + case Some(o: KafkaShareSourceOffset) => + Map( + "shareGroupId" -> shareGroupId, + "batchId" -> o.batchId.toString, + "totalAcquiredRecords" -> o.totalRecords.toString, + "partitionsWithRecords" -> o.partitionsWithRecords.size.toString + ).asJava + case _ => + ju.Collections.emptyMap() + } + } + + override def prepareForTriggerAvailableNow(): Unit = { + isTriggerAvailableNow = true + } + + /** + * Register acquired records for acknowledgment tracking. + * Called by partition readers when they acquire records. + */ + def registerAcquiredRecords( + batchId: Long, + records: Seq[ShareInFlightRecord]): Unit = { + val batch = inFlightManager.getOrCreateBatch(batchId) + records.foreach(batch.addRecord) + } + + /** + * Get the in-flight manager for external access (e.g., explicit acknowledgment). + */ + def getInFlightManager: ShareInFlightManager = inFlightManager +} + +/** + * Input partition for share group consumers. + * Each partition represents a Spark task that will poll from the share group. + */ +case class KafkaShareInputPartition( + shareGroupId: String, + sparkPartitionId: Int, + topics: Set[String], + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + lockTimeoutMs: Long, + maxRecordsPerBatch: Option[Long], + includeHeaders: Boolean, + acknowledgmentMode: AcknowledgmentMode, + exactlyOnceStrategy: ExactlyOnceStrategy, + batchId: Long) extends InputPartition { + + // Preferred locations for share groups are not as meaningful as traditional consumers + // since any consumer can receive records from any partition + override def preferredLocations(): Array[String] = Array.empty +} + +/** + * Factory for creating share group partition readers. + */ +object KafkaShareBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): KafkaSharePartitionReader = { + val sharePartition = partition.asInstanceOf[KafkaShareInputPartition] + new KafkaSharePartitionReader( + shareGroupId = sharePartition.shareGroupId, + sparkPartitionId = sharePartition.sparkPartitionId, + topics = sharePartition.topics, + kafkaParams = sharePartition.kafkaParams, + pollTimeoutMs = sharePartition.pollTimeoutMs, + lockTimeoutMs = sharePartition.lockTimeoutMs, + maxRecordsPerBatch = sharePartition.maxRecordsPerBatch, + includeHeaders = sharePartition.includeHeaders, + acknowledgmentMode = sharePartition.acknowledgmentMode, + exactlyOnceStrategy = sharePartition.exactlyOnceStrategy, + batchId = sharePartition.batchId + ) + } +} + +/** + * Strategy for achieving exactly-once semantics. + */ +sealed trait ExactlyOnceStrategy +object ExactlyOnceStrategy { + /** No exactly-once guarantees (at-least-once only) */ + case object None extends ExactlyOnceStrategy + + /** Deduplicate at sink using record keys (topic, partition, offset) */ + case object Idempotent extends ExactlyOnceStrategy + + /** Two-phase commit with transaction coordinator */ + case object TwoPhaseCommit extends ExactlyOnceStrategy + + /** Track processed records in checkpoint for deduplication */ + case object CheckpointDedup extends ExactlyOnceStrategy + + def fromString(s: String): ExactlyOnceStrategy = s.toLowerCase match { + case "none" | "" => None + case "idempotent" => Idempotent + case "two-phase-commit" | "2pc" => TwoPhaseCommit + case "checkpoint-dedup" | "checkpoint" => CheckpointDedup + case _ => throw new IllegalArgumentException(s"Unknown exactly-once strategy: $s") + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaSharePartitionReader.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaSharePartitionReader.scala new file mode 100644 index 0000000000000..02019ab7d3391 --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaSharePartitionReader.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import java.{util => ju} +import java.nio.charset.StandardCharsets + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.header.Headers + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.connector.read.PartitionReader +import org.apache.spark.sql.kafka010.share.consumer._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Partition reader for Kafka share groups. + * + * This reader: + * 1. Polls records from the share group + * 2. Tracks acquired records for acknowledgment + * 3. Converts records to Spark InternalRow format + * 4. Handles acknowledgment based on configured mode + * + * Unlike traditional Kafka readers that seek to specific offsets, + * share group readers receive whatever records Kafka assigns to them. + * Records can come from any partition in the subscribed topics. + */ +private[kafka010] class KafkaSharePartitionReader( + shareGroupId: String, + sparkPartitionId: Int, + topics: Set[String], + kafkaParams: ju.Map[String, Object], + pollTimeoutMs: Long, + lockTimeoutMs: Long, + maxRecordsPerBatch: Option[Long], + includeHeaders: Boolean, + acknowledgmentMode: AcknowledgmentMode, + exactlyOnceStrategy: ExactlyOnceStrategy, + batchId: Long) extends PartitionReader[InternalRow] with Logging { + + // The share consumer for this reader + private val shareConsumer = KafkaShareDataConsumer.acquire( + shareGroupId, topics, kafkaParams, lockTimeoutMs) + + // Buffer for polled records + private var recordBuffer: Iterator[ShareInFlightRecord] = Iterator.empty + + // Current record being processed + private var currentRecord: ShareInFlightRecord = _ + + // Track all records acquired in this batch for acknowledgment + private val acquiredRecords = ArrayBuffer[ShareInFlightRecord]() + + // Track processed record keys for checkpoint-based deduplication + private val processedKeys = new ju.HashSet[RecordKey]() + + // State tracking + private var recordsRead: Long = 0 + private var hasPolled: Boolean = false + private var isExhausted: Boolean = false + + /** + * Move to the next record. + * Returns false when there are no more records to process. + */ + override def next(): Boolean = { + // Check if we've hit the record limit + if (maxRecordsPerBatch.exists(recordsRead >= _)) { + logDebug(s"Reached max records per batch: $maxRecordsPerBatch") + return false + } + + // Try to get next record from buffer + if (recordBuffer.hasNext) { + currentRecord = recordBuffer.next() + recordsRead += 1 + return true + } + + // If we haven't polled yet, or buffer is empty, poll for more records + if (!isExhausted) { + val polledRecords = pollRecords() + if (polledRecords.nonEmpty) { + recordBuffer = polledRecords.iterator + currentRecord = recordBuffer.next() + recordsRead += 1 + return true + } else { + // No more records available + isExhausted = true + return false + } + } + + false + } + + /** + * Get the current record as an InternalRow. + */ + override def get(): InternalRow = { + assert(currentRecord != null, "next() must be called before get()") + recordToInternalRow(currentRecord.record) + } + + /** + * Close the reader and handle acknowledgments. + */ + override def close(): Unit = { + try { + // Handle acknowledgment based on mode + acknowledgmentMode match { + case AcknowledgmentMode.Implicit => + // Acknowledge all records as ACCEPT + acknowledgeAllAsAccept() + + case AcknowledgmentMode.Explicit => + // In explicit mode, any unacknowledged records are released + releaseUnacknowledgedRecords() + } + + // Commit acknowledgments to Kafka + shareConsumer.commitSync() + } catch { + case e: Exception => + logError(s"Error during acknowledgment commit: ${e.getMessage}", e) + // Try to release records on error + try { + releaseAllRecords() + shareConsumer.commitSync() + } catch { + case inner: Exception => + logError(s"Error releasing records: ${inner.getMessage}", inner) + } + } finally { + shareConsumer.release() + } + + logInfo(s"Share partition reader closed. " + + s"Read $recordsRead records, acquired ${acquiredRecords.size} records") + } + + /** + * Poll for records from the share consumer. + */ + private def pollRecords(): Seq[ShareInFlightRecord] = { + hasPolled = true + + // Calculate how many records we can still fetch + val remainingCapacity = maxRecordsPerBatch.map(_ - recordsRead).getOrElse(Long.MaxValue) + + if (remainingCapacity <= 0) { + return Seq.empty + } + + try { + val records = shareConsumer.poll(pollTimeoutMs) + + // Apply rate limiting if needed + val limitedRecords = if (maxRecordsPerBatch.isDefined) { + records.take(remainingCapacity.toInt) + } else { + records + } + + // Apply deduplication for checkpoint-based strategy + val dedupedRecords = exactlyOnceStrategy match { + case ExactlyOnceStrategy.CheckpointDedup => + limitedRecords.filterNot(r => processedKeys.contains(r.recordKey)) + case _ => + limitedRecords + } + + // Track acquired records + acquiredRecords ++= dedupedRecords + dedupedRecords.foreach(r => processedKeys.add(r.recordKey)) + + logDebug(s"Polled ${records.size} records, kept ${dedupedRecords.size} after dedup") + dedupedRecords + } catch { + case e: Exception => + logError(s"Error polling from share group: ${e.getMessage}", e) + Seq.empty + } + } + + /** + * Acknowledge all acquired records as ACCEPT. + */ + private def acknowledgeAllAsAccept(): Unit = { + val pending = acquiredRecords.filter(_.acknowledgment.isEmpty) + if (pending.nonEmpty) { + logDebug(s"Acknowledging ${pending.size} records as ACCEPT") + pending.foreach { record => + shareConsumer.acknowledge(record.recordKey, AcknowledgmentType.ACCEPT) + } + } + } + + /** + * Release all unacknowledged records. + */ + private def releaseUnacknowledgedRecords(): Unit = { + val unacked = acquiredRecords.filter(_.acknowledgment.isEmpty) + if (unacked.nonEmpty) { + logWarning(s"Releasing ${unacked.size} unacknowledged records") + unacked.foreach { record => + shareConsumer.acknowledge(record.recordKey, AcknowledgmentType.RELEASE) + } + } + } + + /** + * Release all records (for error recovery). + */ + private def releaseAllRecords(): Unit = { + logWarning(s"Releasing all ${acquiredRecords.size} records due to error") + acquiredRecords.foreach { record => + shareConsumer.acknowledge(record.recordKey, AcknowledgmentType.RELEASE) + } + } + + /** + * Convert a Kafka ConsumerRecord to Spark InternalRow. + * + * Schema: key, value, topic, partition, offset, timestamp, timestampType, headers + */ + private def recordToInternalRow( + record: ConsumerRecord[Array[Byte], Array[Byte]]): InternalRow = { + val row = new GenericInternalRow( + if (includeHeaders) 8 else 7 + ) + + row.update(0, record.key()) // key: binary + row.update(1, record.value()) // value: binary + row.update(2, UTF8String.fromString(record.topic())) // topic: string + row.update(3, record.partition()) // partition: int + row.update(4, record.offset()) // offset: long + row.update(5, record.timestamp()) // timestamp: long + row.update(6, record.timestampType().id) // timestampType: int + + if (includeHeaders) { + row.update(7, headersToArrayData(record.headers())) // headers: array + } + + row + } + + /** + * Convert Kafka headers to Spark ArrayData. + */ + private def headersToArrayData(headers: Headers): GenericArrayData = { + val headersList = headers.asScala.map { header => + val headerRow = new GenericInternalRow(2) + headerRow.update(0, UTF8String.fromString(header.key())) + headerRow.update(1, header.value()) + headerRow + }.toArray + new GenericArrayData(headersList) + } +} + +/** + * Schema for Kafka share source records. + */ +object KafkaShareRecordSchema { + val BASE_SCHEMA: StructType = StructType(Seq( + StructField("key", BinaryType, nullable = true), + StructField("value", BinaryType, nullable = true), + StructField("topic", StringType, nullable = false), + StructField("partition", IntegerType, nullable = false), + StructField("offset", LongType, nullable = false), + StructField("timestamp", LongType, nullable = false), + StructField("timestampType", IntegerType, nullable = false) + )) + + val HEADER_SCHEMA: StructType = StructType(Seq( + StructField("key", StringType, nullable = false), + StructField("value", BinaryType, nullable = true) + )) + + val SCHEMA_WITH_HEADERS: StructType = BASE_SCHEMA.add( + StructField("headers", ArrayType(HEADER_SCHEMA), nullable = true) + ) + + def getSchema(includeHeaders: Boolean): StructType = { + if (includeHeaders) SCHEMA_WITH_HEADERS else BASE_SCHEMA + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceOffset.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceOffset.scala new file mode 100644 index 0000000000000..cf5d4d51383a5 --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceOffset.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import org.json4s.{DefaultFormats, Formats, JValue} +import org.json4s.JsonAST.{JArray, JObject, JString} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, parse, render} + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.connector.read.streaming.{Offset => SparkOffset} +import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} + +/** + * Represents a unique identifier for a record in a share group. + * Unlike traditional consumer groups, share groups can have non-sequential offset tracking + * because records can be: + * 1. Acquired by different consumers in any order + * 2. Acknowledged out of order + * 3. Released and reacquired multiple times + * + * @param topic The Kafka topic name + * @param partition The partition number + * @param offset The record offset + */ +case class RecordKey(topic: String, partition: Int, offset: Long) { + def topicPartition: TopicPartition = new TopicPartition(topic, partition) + + override def toString: String = s"$topic-$partition:$offset" +} + +object RecordKey { + def apply(tp: TopicPartition, offset: Long): RecordKey = { + RecordKey(tp.topic(), tp.partition(), offset) + } + + def fromString(s: String): RecordKey = { + val parts = s.split(":") + require(parts.length == 2, s"Invalid RecordKey string: $s") + val topicPart = parts(0).lastIndexOf("-") + require(topicPart > 0, s"Invalid topic-partition format: ${parts(0)}") + RecordKey( + parts(0).substring(0, topicPart), + parts(0).substring(topicPart + 1).toInt, + parts(1).toLong + ) + } +} + +/** + * Represents a range of acquired records for checkpointing. + * This captures the non-sequential nature of share group offset tracking. + * + * @param offsets Set of individual offsets acquired in this batch (non-sequential) + * @param acquiredAt Timestamp when records were acquired + * @param lockExpiresAt Timestamp when acquisition lock expires + * @param deliveryCount Number of times these records have been delivered + */ +case class AcquiredRecordRange( + offsets: Set[Long], + acquiredAt: Long, + lockExpiresAt: Long, + deliveryCount: Short = 1) { + + def size: Int = offsets.size + + def minOffset: Long = if (offsets.isEmpty) -1L else offsets.min + + def maxOffset: Long = if (offsets.isEmpty) -1L else offsets.max + + /** Check if a specific offset is in this range */ + def contains(offset: Long): Boolean = offsets.contains(offset) + + /** Check if the lock has expired */ + def isLockExpired(currentTimeMs: Long): Boolean = currentTimeMs >= lockExpiresAt + + /** Add an offset to this range */ + def withOffset(offset: Long): AcquiredRecordRange = copy(offsets = offsets + offset) + + /** Remove an offset from this range (when acknowledged) */ + def withoutOffset(offset: Long): AcquiredRecordRange = copy(offsets = offsets - offset) +} + +object AcquiredRecordRange { + /** Create an empty range */ + def empty(acquiredAt: Long, lockTimeoutMs: Long): AcquiredRecordRange = { + AcquiredRecordRange(Set.empty, acquiredAt, acquiredAt + lockTimeoutMs) + } + + /** Create a range from a contiguous offset range (for initial acquisition) */ + def fromRange(start: Long, end: Long, acquiredAt: Long, lockTimeoutMs: Long): AcquiredRecordRange = { + AcquiredRecordRange((start to end).toSet, acquiredAt, acquiredAt + lockTimeoutMs) + } +} + +/** + * Offset for the Kafka Share Source that tracks non-sequential offset acquisition. + * + * Unlike [[org.apache.spark.sql.kafka010.KafkaSourceOffset]] which tracks a single offset + * per partition, this tracks: + * 1. The batch ID for ordering + * 2. Per-partition sets of acquired record offsets (can be non-sequential) + * 3. Acquisition timestamps and lock expiry times + * + * This design accounts for the fact that in share groups: + * - Multiple consumers can acquire records from the same partition concurrently + * - Records can be acknowledged/released out of order + * - Kafka broker assigns random available offsets, not sequential ranges + * + * @param shareGroupId The share group identifier + * @param batchId The micro-batch identifier (for ordering) + * @param acquiredRecords Map of TopicPartition to acquired record ranges + */ +case class KafkaShareSourceOffset( + shareGroupId: String, + batchId: Long, + acquiredRecords: Map[TopicPartition, AcquiredRecordRange]) extends Offset { + + implicit val formats: Formats = DefaultFormats + + override val json: String = { + val partitions = acquiredRecords.map { case (tp, range) => + ("topic" -> tp.topic()) ~ + ("partition" -> tp.partition()) ~ + ("offsets" -> range.offsets.toSeq.sorted) ~ + ("acquiredAt" -> range.acquiredAt) ~ + ("lockExpiresAt" -> range.lockExpiresAt) ~ + ("deliveryCount" -> range.deliveryCount.toInt) + } + + compact(render( + ("shareGroupId" -> shareGroupId) ~ + ("batchId" -> batchId) ~ + ("partitions" -> partitions) + )) + } + + /** Get all record keys in this offset */ + def getAllRecordKeys: Set[RecordKey] = { + acquiredRecords.flatMap { case (tp, range) => + range.offsets.map(offset => RecordKey(tp, offset)) + }.toSet + } + + /** Get the total number of acquired records */ + def totalRecords: Int = acquiredRecords.values.map(_.size).sum + + /** Check if any partition has records */ + def hasRecords: Boolean = acquiredRecords.exists(_._2.offsets.nonEmpty) + + /** Get partitions that have records */ + def partitionsWithRecords: Set[TopicPartition] = { + acquiredRecords.filter(_._2.offsets.nonEmpty).keySet + } + + /** Create a new offset with an additional acquired record */ + def withAcquiredRecord(tp: TopicPartition, offset: Long, acquiredAt: Long, lockTimeoutMs: Long): KafkaShareSourceOffset = { + val currentRange = acquiredRecords.getOrElse(tp, + AcquiredRecordRange.empty(acquiredAt, lockTimeoutMs)) + copy(acquiredRecords = acquiredRecords + (tp -> currentRange.withOffset(offset))) + } + + /** Create a new offset with a record removed (acknowledged) */ + def withAcknowledgedRecord(tp: TopicPartition, offset: Long): KafkaShareSourceOffset = { + acquiredRecords.get(tp) match { + case Some(range) => + val newRange = range.withoutOffset(offset) + if (newRange.offsets.isEmpty) { + copy(acquiredRecords = acquiredRecords - tp) + } else { + copy(acquiredRecords = acquiredRecords + (tp -> newRange)) + } + case None => this + } + } +} + +object KafkaShareSourceOffset { + implicit val formats: Formats = DefaultFormats + + /** Create an empty offset for the start of processing */ + def empty(shareGroupId: String): KafkaShareSourceOffset = { + KafkaShareSourceOffset(shareGroupId, 0L, Map.empty) + } + + /** Create an initial offset for a new batch */ + def forBatch(shareGroupId: String, batchId: Long): KafkaShareSourceOffset = { + KafkaShareSourceOffset(shareGroupId, batchId, Map.empty) + } + + /** Parse from JSON string */ + def apply(json: String): KafkaShareSourceOffset = { + val parsed = parse(json) + + val shareGroupId = (parsed \ "shareGroupId").extract[String] + val batchId = (parsed \ "batchId").extract[Long] + val partitions = (parsed \ "partitions").extract[List[JValue]] + + val acquiredRecords = partitions.map { p => + val topic = (p \ "topic").extract[String] + val partition = (p \ "partition").extract[Int] + val offsets = (p \ "offsets").extract[List[Long]].toSet + val acquiredAt = (p \ "acquiredAt").extract[Long] + val lockExpiresAt = (p \ "lockExpiresAt").extract[Long] + val deliveryCount = (p \ "deliveryCount").extract[Int].toShort + + val tp = new TopicPartition(topic, partition) + val range = AcquiredRecordRange(offsets, acquiredAt, lockExpiresAt, deliveryCount) + tp -> range + }.toMap + + KafkaShareSourceOffset(shareGroupId, batchId, acquiredRecords) + } + + /** Convert from SerializedOffset */ + def apply(offset: SerializedOffset): KafkaShareSourceOffset = { + apply(offset.json) + } + + /** Convert from Spark streaming Offset */ + def apply(offset: SparkOffset): KafkaShareSourceOffset = { + offset match { + case k: KafkaShareSourceOffset => k + case so: SerializedOffset => apply(so) + case _ => + throw new IllegalArgumentException( + s"Invalid conversion from offset of ${offset.getClass} to KafkaShareSourceOffset") + } + } + + /** Extract acquired records from an Offset */ + def getAcquiredRecords(offset: Offset): Map[TopicPartition, AcquiredRecordRange] = { + offset match { + case o: KafkaShareSourceOffset => o.acquiredRecords + case so: SerializedOffset => KafkaShareSourceOffset(so).acquiredRecords + case _ => + throw new IllegalArgumentException( + s"Invalid conversion from offset of ${offset.getClass} to KafkaShareSourceOffset") + } + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceProvider.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceProvider.scala new file mode 100644 index 0000000000000..dc0a228bc2afa --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceProvider.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import java.{util => ju} +import java.util.{Locale, UUID} + +import scala.jdk.CollectionConverters._ + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.ByteArrayDeserializer + +import org.apache.spark.internal.Logging +import org.apache.spark.kafka010.KafkaConfigUpdater +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream +import org.apache.spark.sql.internal.connector.SimpleTableProvider +import org.apache.spark.sql.kafka010.share.consumer.AcknowledgmentMode +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Provider for Kafka Share Group data source. + * + * This provider enables Spark Structured Streaming to consume from Kafka + * using the share group protocol introduced in Kafka 4.x (KIP-932). + * + * Key differences from the standard Kafka source: + * 1. Uses share groups instead of consumer groups + * 2. Multiple consumers can read from the same partition concurrently + * 3. Acknowledgment-based delivery (ACCEPT/RELEASE/REJECT) + * 4. Non-sequential offset tracking + * + * Usage: + * {{{ + * spark.readStream + * .format("kafka-share") + * .option("kafka.bootstrap.servers", "localhost:9092") + * .option("kafka.share.group.id", "my-share-group") + * .option("subscribe", "topic1,topic2") + * .load() + * }}} + */ +private[kafka010] class KafkaShareSourceProvider + extends DataSourceRegister + with SimpleTableProvider + with Logging { + + import KafkaShareSourceProvider._ + + override def shortName(): String = "kafka-share" + + override def getTable(options: CaseInsensitiveStringMap): KafkaShareTable = { + val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) + new KafkaShareTable(includeHeaders) + } +} + +/** + * Table implementation for Kafka Share source. + */ +class KafkaShareTable(includeHeaders: Boolean) + extends Table with SupportsRead { + + import KafkaShareSourceProvider._ + + override def name(): String = "KafkaShareTable" + + override def schema(): StructType = KafkaShareRecordSchema.getSchema(includeHeaders) + + override def capabilities(): ju.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + () => new KafkaShareScan(options, includeHeaders) + } +} + +/** + * Scan implementation for Kafka Share source. + */ +class KafkaShareScan(options: CaseInsensitiveStringMap, includeHeaders: Boolean) + extends Scan { + + override def readSchema(): StructType = KafkaShareRecordSchema.getSchema(includeHeaders) + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { + val caseInsensitiveOptions = CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap) + validateOptions(caseInsensitiveOptions) + + val shareGroupId = getShareGroupId(caseInsensitiveOptions) + val topics = getTopics(caseInsensitiveOptions) + val executorKafkaParams = kafkaParamsForExecutors(caseInsensitiveOptions, shareGroupId) + val acknowledgmentMode = getAcknowledgmentMode(caseInsensitiveOptions) + val exactlyOnceStrategy = getExactlyOnceStrategy(caseInsensitiveOptions) + + new KafkaShareMicroBatchStream( + shareGroupId = shareGroupId, + topics = topics, + executorKafkaParams = executorKafkaParams, + options = options, + metadataPath = checkpointLocation, + acknowledgmentMode = acknowledgmentMode, + exactlyOnceStrategy = exactlyOnceStrategy + ) + } + + private def validateOptions(params: CaseInsensitiveMap[String]): Unit = { + // Validate bootstrap servers + if (!params.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) { + throw new IllegalArgumentException( + s"Option 'kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}' must be specified") + } + + // Validate share group ID + if (!params.contains(SHARE_GROUP_ID) && !params.contains(s"kafka.$SHARE_GROUP_ID")) { + throw new IllegalArgumentException( + s"Option '$SHARE_GROUP_ID' must be specified for Kafka share source") + } + + // Validate topics + val hasSubscribe = params.contains(SUBSCRIBE) + val hasSubscribePattern = params.contains(SUBSCRIBE_PATTERN) + + if (!hasSubscribe && !hasSubscribePattern) { + throw new IllegalArgumentException( + s"One of '$SUBSCRIBE' or '$SUBSCRIBE_PATTERN' must be specified") + } + + if (hasSubscribe && hasSubscribePattern) { + throw new IllegalArgumentException( + s"Only one of '$SUBSCRIBE' or '$SUBSCRIBE_PATTERN' can be specified") + } + + // Validate acknowledgment mode + params.get(ACKNOWLEDGMENT_MODE).foreach { mode => + try { + AcknowledgmentMode.fromString(mode) + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException( + s"Invalid acknowledgment mode: '$mode'. Must be 'implicit' or 'explicit'") + } + } + + // Validate exactly-once strategy + params.get(EXACTLY_ONCE_STRATEGY).foreach { strategy => + try { + ExactlyOnceStrategy.fromString(strategy) + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException( + s"Invalid exactly-once strategy: '$strategy'. " + + "Must be 'none', 'idempotent', 'two-phase-commit', or 'checkpoint-dedup'") + } + } + + // Warn about unsupported options from traditional Kafka source + val unsupportedOptions = Seq( + "startingOffsets", "endingOffsets", "startingTimestamp", "endingTimestamp", + "assign" // Share groups subscribe to topics, not assign partitions + ) + unsupportedOptions.foreach { opt => + if (params.contains(opt)) { + logWarning(s"Option '$opt' is not applicable for Kafka share source and will be ignored") + } + } + } + + private def getShareGroupId(params: CaseInsensitiveMap[String]): String = { + params.getOrElse(SHARE_GROUP_ID, + params.getOrElse(s"kafka.$SHARE_GROUP_ID", + throw new IllegalArgumentException(s"Share group ID not specified"))) + } + + private def getTopics(params: CaseInsensitiveMap[String]): Set[String] = { + params.get(SUBSCRIBE) match { + case Some(topics) => + topics.split(",").map(_.trim).filter(_.nonEmpty).toSet + case None => + // Pattern subscription - return empty set, will be handled differently + Set.empty + } + } + + private def getAcknowledgmentMode(params: CaseInsensitiveMap[String]): AcknowledgmentMode = { + params.get(ACKNOWLEDGMENT_MODE) + .map(AcknowledgmentMode.fromString) + .getOrElse(AcknowledgmentMode.Implicit) + } + + private def getExactlyOnceStrategy(params: CaseInsensitiveMap[String]): ExactlyOnceStrategy = { + params.get(EXACTLY_ONCE_STRATEGY) + .map(ExactlyOnceStrategy.fromString) + .getOrElse(ExactlyOnceStrategy.None) + } + + private def kafkaParamsForExecutors( + params: CaseInsensitiveMap[String], + shareGroupId: String): ju.Map[String, Object] = { + val specifiedParams = params + .filter { case (k, _) => k.toLowerCase(Locale.ROOT).startsWith("kafka.") } + .map { case (k, v) => k.substring(6) -> v } // Remove "kafka." prefix + + KafkaConfigUpdater("executor", specifiedParams.toMap) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, classOf[ByteArrayDeserializer].getName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, classOf[ByteArrayDeserializer].getName) + .set("group.id", shareGroupId) // Share group ID + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + .build() + } +} + +/** + * Configuration options for Kafka Share source. + */ +object KafkaShareSourceProvider { + // Topic subscription options + val SUBSCRIBE = "subscribe" + val SUBSCRIBE_PATTERN = "subscribePattern" + + // Share group options + val SHARE_GROUP_ID = "kafka.share.group.id" + val SHARE_LOCK_TIMEOUT = "kafka.share.lock.timeout.ms" + val SHARE_PARALLELISM = "kafka.share.parallelism" + + // Acknowledgment options + val ACKNOWLEDGMENT_MODE = "kafka.share.acknowledgment.mode" + + // Exactly-once options + val EXACTLY_ONCE_STRATEGY = "kafka.share.exactly.once.strategy" + val DEDUP_COLUMNS = "kafka.share.dedup.columns" + + // Consumer options + val CONSUMER_POLL_TIMEOUT = "kafka.consumer.poll.timeout.ms" + val MAX_RECORDS_PER_BATCH = "kafka.share.max.records.per.batch" + + // Output options + val INCLUDE_HEADERS = "includeHeaders" + + // Default values + val DEFAULT_LOCK_TIMEOUT_MS = 30000L + val DEFAULT_POLL_TIMEOUT_MS = 512L + val DEFAULT_ACKNOWLEDGMENT_MODE = "implicit" + val DEFAULT_EXACTLY_ONCE_STRATEGY = "none" +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareCheckpointWriter.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareCheckpointWriter.scala new file mode 100644 index 0000000000000..c643a6c0a274b --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareCheckpointWriter.scala @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import java.io.{BufferedWriter, InputStreamReader, OutputStreamWriter} +import java.nio.charset.StandardCharsets + +import scala.io.Source +import scala.util.{Failure, Success, Try} + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, MetadataVersionUtil} + +/** + * Writes and reads share group checkpoint state to/from HDFS-compatible storage. + * + * The checkpoint contains: + * 1. Batch ID for ordering + * 2. Share group ID + * 3. Acquired record offsets per partition + * 4. Acknowledgment state + * + * This is used for: + * - Recovery after driver failure + * - Exactly-once semantics (checkpoint-based deduplication) + * - Progress tracking + */ +class ShareCheckpointWriter(metadataPath: String) extends Logging { + + private val VERSION = 1 + private val checkpointDir = s"$metadataPath/share-offsets" + + // Lazy initialization of file system + private lazy val (hadoopConf, fs) = { + val spark = SparkSession.active + val conf = spark.sparkContext.hadoopConfiguration + val fileSystem = FileSystem.get(new Path(checkpointDir).toUri, conf) + (conf, fileSystem) + } + + /** + * Write the offset for a batch to the checkpoint. + */ + def write(offset: KafkaShareSourceOffset): Unit = { + val batchPath = new Path(checkpointDir, offset.batchId.toString) + + try { + // Ensure parent directory exists + val parentPath = batchPath.getParent + if (!fs.exists(parentPath)) { + fs.mkdirs(parentPath) + } + + // Write to temporary file first + val tempPath = new Path(s"${batchPath.toString}.tmp") + val outputStream = fs.create(tempPath, true) + val writer = new BufferedWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) + + try { + writer.write(s"v$VERSION\n") + writer.write(offset.json) + writer.newLine() + } finally { + writer.close() + } + + // Atomic rename + if (fs.exists(batchPath)) { + fs.delete(batchPath, false) + } + fs.rename(tempPath, batchPath) + + logDebug(s"Wrote checkpoint for batch ${offset.batchId}") + } catch { + case e: Exception => + logError(s"Failed to write checkpoint for batch ${offset.batchId}: ${e.getMessage}", e) + throw e + } + } + + /** + * Read the offset for a specific batch from the checkpoint. + */ + def read(batchId: Long): Option[KafkaShareSourceOffset] = { + val batchPath = new Path(checkpointDir, batchId.toString) + + if (!fs.exists(batchPath)) { + return None + } + + try { + val inputStream = fs.open(batchPath) + val reader = Source.fromInputStream(inputStream, StandardCharsets.UTF_8.name()) + + try { + val lines = reader.getLines().toList + + if (lines.isEmpty) { + logWarning(s"Empty checkpoint file for batch $batchId") + return None + } + + // Validate version + val versionLine = lines.head + if (!versionLine.startsWith("v")) { + logWarning(s"Invalid checkpoint version format: $versionLine") + return None + } + + val version = versionLine.substring(1).toInt + if (version > VERSION) { + logWarning(s"Checkpoint version $version is newer than supported version $VERSION") + return None + } + + // Parse offset JSON + val json = lines.drop(1).mkString("\n") + Some(KafkaShareSourceOffset(json)) + } finally { + reader.close() + } + } catch { + case e: Exception => + logError(s"Failed to read checkpoint for batch $batchId: ${e.getMessage}", e) + None + } + } + + /** + * Get the latest committed batch offset. + */ + def getLatest(): Option[KafkaShareSourceOffset] = { + try { + val checkpointPath = new Path(checkpointDir) + + if (!fs.exists(checkpointPath)) { + return None + } + + // Find the highest batch ID + val batchIds = fs.listStatus(checkpointPath) + .filter(_.isFile) + .flatMap { status => + Try(status.getPath.getName.toLong).toOption + } + .sorted + .reverse + + batchIds.headOption.flatMap(read) + } catch { + case e: Exception => + logError(s"Failed to get latest checkpoint: ${e.getMessage}", e) + None + } + } + + /** + * Get all committed batch offsets. + */ + def getAll(): Seq[KafkaShareSourceOffset] = { + try { + val checkpointPath = new Path(checkpointDir) + + if (!fs.exists(checkpointPath)) { + return Seq.empty + } + + fs.listStatus(checkpointPath) + .filter(_.isFile) + .flatMap { status => + Try(status.getPath.getName.toLong).toOption + } + .sorted + .flatMap(read) + .toSeq + } catch { + case e: Exception => + logError(s"Failed to get all checkpoints: ${e.getMessage}", e) + Seq.empty + } + } + + /** + * Purge old checkpoints, keeping only the most recent N. + */ + def purge(keepCount: Int): Int = { + try { + val checkpointPath = new Path(checkpointDir) + + if (!fs.exists(checkpointPath)) { + return 0 + } + + val batchIds = fs.listStatus(checkpointPath) + .filter(_.isFile) + .flatMap { status => + Try(status.getPath.getName.toLong).toOption + } + .sorted + + val toDelete = batchIds.dropRight(keepCount) + + toDelete.foreach { batchId => + val path = new Path(checkpointDir, batchId.toString) + fs.delete(path, false) + logDebug(s"Purged checkpoint for batch $batchId") + } + + toDelete.length + } catch { + case e: Exception => + logError(s"Failed to purge checkpoints: ${e.getMessage}", e) + 0 + } + } + + /** + * Delete checkpoint for a specific batch. + */ + def delete(batchId: Long): Boolean = { + try { + val batchPath = new Path(checkpointDir, batchId.toString) + if (fs.exists(batchPath)) { + fs.delete(batchPath, false) + true + } else { + false + } + } catch { + case e: Exception => + logError(s"Failed to delete checkpoint for batch $batchId: ${e.getMessage}", e) + false + } + } +} + +/** + * Recovery manager for share group streaming queries. + * + * Handles recovery scenarios: + * 1. Driver restart - resume from last committed batch + * 2. Task failure - acquire locks expire, records redelivered + * 3. Executor failure - same as task failure + */ +class ShareRecoveryManager( + shareGroupId: String, + checkpointPath: String) extends Logging { + + private val checkpointWriter = new ShareCheckpointWriter(checkpointPath) + + /** + * Recover the streaming query state after a restart. + * + * @return The offset to start from, and any pending records to release + */ + def recover(): RecoveryState = { + logInfo(s"Starting recovery for share group $shareGroupId") + + val latestOffset = checkpointWriter.getLatest() + + latestOffset match { + case Some(offset) => + logInfo(s"Recovered from batch ${offset.batchId} with ${offset.totalRecords} records") + + // Any acquired records in the checkpoint were not acknowledged + // They should be released (Kafka will handle this via lock expiry) + val pendingRecords = offset.getAllRecordKeys + + if (pendingRecords.nonEmpty) { + logWarning(s"Found ${pendingRecords.size} unacknowledged records from batch ${offset.batchId}. " + + "These will be redelivered by Kafka after lock expiry.") + } + + RecoveryState( + startBatchId = offset.batchId + 1, + lastCommittedOffset = Some(offset), + pendingRecords = pendingRecords, + isCleanStart = false + ) + + case None => + logInfo("No checkpoint found, starting fresh") + RecoveryState( + startBatchId = 0, + lastCommittedOffset = None, + pendingRecords = Set.empty, + isCleanStart = true + ) + } + } + + /** + * Get the checkpoint writer for external use. + */ + def getCheckpointWriter: ShareCheckpointWriter = checkpointWriter +} + +/** + * State recovered from checkpoint. + */ +case class RecoveryState( + startBatchId: Long, + lastCommittedOffset: Option[KafkaShareSourceOffset], + pendingRecords: Set[RecordKey], + isCleanStart: Boolean) + +/** + * Exception types for share group operations. + */ +object KafkaShareExceptions { + + class ShareGroupException(message: String, cause: Throwable = null) + extends RuntimeException(message, cause) + + class AcquisitionLockExpiredException(recordKey: RecordKey) + extends ShareGroupException(s"Acquisition lock expired for record: $recordKey") + + class AcknowledgmentFailedException(message: String, cause: Throwable = null) + extends ShareGroupException(s"Failed to acknowledge: $message", cause) + + class CheckpointException(message: String, cause: Throwable = null) + extends ShareGroupException(s"Checkpoint error: $message", cause) + + class RecoveryException(message: String, cause: Throwable = null) + extends ShareGroupException(s"Recovery error: $message", cause) + + def acquisitionLockExpired(key: RecordKey): AcquisitionLockExpiredException = + new AcquisitionLockExpiredException(key) + + def acknowledgmentFailed(message: String, cause: Throwable = null): AcknowledgmentFailedException = + new AcknowledgmentFailedException(message, cause) + + def checkpointFailed(message: String, cause: Throwable = null): CheckpointException = + new CheckpointException(message, cause) + + def recoveryFailed(message: String, cause: Throwable = null): RecoveryException = + new RecoveryException(message, cause) +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareInFlightRecord.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareInFlightRecord.scala new file mode 100644 index 0000000000000..944b9ec8cc475 --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareInFlightRecord.scala @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +import scala.jdk.CollectionConverters._ + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition + +/** + * Acknowledgment types for share consumer records. + * Mirrors Kafka's AcknowledgeType enum. + */ +object AcknowledgmentType extends Enumeration { + type AcknowledgmentType = Value + val ACCEPT = Value("ACCEPT") // Successfully processed - mark as complete + val RELEASE = Value("RELEASE") // Release back for redelivery - mark as available + val REJECT = Value("REJECT") // Permanently reject - move to dead letter / archive +} + +/** + * Represents a single in-flight record in a share group. + * Tracks the record data along with its acquisition state and acknowledgment status. + * + * @param recordKey Unique identifier for the record + * @param record The actual Kafka consumer record + * @param acquiredAt Timestamp when the record was acquired + * @param lockExpiresAt Timestamp when the acquisition lock expires + * @param deliveryCount Number of times this record has been delivered + * @param acknowledgment The pending acknowledgment type (if any) + */ +case class ShareInFlightRecord( + recordKey: RecordKey, + record: ConsumerRecord[Array[Byte], Array[Byte]], + acquiredAt: Long, + lockExpiresAt: Long, + deliveryCount: Short, + acknowledgment: Option[AcknowledgmentType.AcknowledgmentType] = None) { + + def topic: String = record.topic() + def partition: Int = record.partition() + def offset: Long = record.offset() + def topicPartition: TopicPartition = new TopicPartition(topic, partition) + + /** Check if the lock has expired */ + def isLockExpired(currentTimeMs: Long): Boolean = currentTimeMs >= lockExpiresAt + + /** Check if this record has been acknowledged */ + def isAcknowledged: Boolean = acknowledgment.isDefined + + /** Mark this record as acknowledged with the given type */ + def withAcknowledgment(ackType: AcknowledgmentType.AcknowledgmentType): ShareInFlightRecord = { + copy(acknowledgment = Some(ackType)) + } + + /** Get remaining lock time in milliseconds */ + def remainingLockTimeMs(currentTimeMs: Long): Long = { + math.max(0, lockExpiresAt - currentTimeMs) + } + + override def toString: String = { + s"ShareInFlightRecord[${recordKey}, ack=${acknowledgment.map(_.toString).getOrElse("pending")}, " + + s"deliveryCount=$deliveryCount, lockExpires=$lockExpiresAt]" + } +} + +object ShareInFlightRecord { + /** Create from a Kafka consumer record */ + def apply( + record: ConsumerRecord[Array[Byte], Array[Byte]], + acquiredAt: Long, + lockTimeoutMs: Long, + deliveryCount: Short): ShareInFlightRecord = { + ShareInFlightRecord( + recordKey = RecordKey(record.topic(), record.partition(), record.offset()), + record = record, + acquiredAt = acquiredAt, + lockExpiresAt = acquiredAt + lockTimeoutMs, + deliveryCount = deliveryCount + ) + } +} + +/** + * Tracks in-flight records for a share consumer batch. + * Manages the lifecycle of records from acquisition through acknowledgment. + * + * Thread-safe implementation using ConcurrentHashMap for executor-side usage. + */ +class ShareInFlightBatch(val batchId: Long) { + // Map of RecordKey -> ShareInFlightRecord for all in-flight records + private val records = new ConcurrentHashMap[RecordKey, ShareInFlightRecord]() + + // Counters for tracking acknowledgment state + private val pendingCount = new AtomicInteger(0) + private val acceptedCount = new AtomicInteger(0) + private val releasedCount = new AtomicInteger(0) + private val rejectedCount = new AtomicInteger(0) + + /** Add a record to the in-flight batch */ + def addRecord(record: ShareInFlightRecord): Unit = { + records.put(record.recordKey, record) + pendingCount.incrementAndGet() + } + + /** Get a record by its key */ + def getRecord(key: RecordKey): Option[ShareInFlightRecord] = { + Option(records.get(key)) + } + + /** Acknowledge a record with the given type */ + def acknowledge(key: RecordKey, ackType: AcknowledgmentType.AcknowledgmentType): Boolean = { + val record = records.get(key) + if (record != null && record.acknowledgment.isEmpty) { + records.put(key, record.withAcknowledgment(ackType)) + pendingCount.decrementAndGet() + ackType match { + case AcknowledgmentType.ACCEPT => acceptedCount.incrementAndGet() + case AcknowledgmentType.RELEASE => releasedCount.incrementAndGet() + case AcknowledgmentType.REJECT => rejectedCount.incrementAndGet() + } + true + } else { + false + } + } + + /** Acknowledge all pending records as ACCEPT */ + def acknowledgeAllAsAccept(): Int = { + var count = 0 + records.forEach { (key, record) => + if (record.acknowledgment.isEmpty) { + records.put(key, record.withAcknowledgment(AcknowledgmentType.ACCEPT)) + pendingCount.decrementAndGet() + acceptedCount.incrementAndGet() + count += 1 + } + } + count + } + + /** Release all pending records (for failure recovery) */ + def releaseAllPending(): Int = { + var count = 0 + records.forEach { (key, record) => + if (record.acknowledgment.isEmpty) { + records.put(key, record.withAcknowledgment(AcknowledgmentType.RELEASE)) + pendingCount.decrementAndGet() + releasedCount.incrementAndGet() + count += 1 + } + } + count + } + + /** Get all pending (unacknowledged) records */ + def getPendingRecords: Seq[ShareInFlightRecord] = { + records.values().asScala.filter(_.acknowledgment.isEmpty).toSeq + } + + /** Get all records with a specific acknowledgment type */ + def getRecordsByAckType(ackType: AcknowledgmentType.AcknowledgmentType): Seq[ShareInFlightRecord] = { + records.values().asScala.filter(_.acknowledgment.contains(ackType)).toSeq + } + + /** Get all accepted records */ + def getAcceptedRecords: Seq[ShareInFlightRecord] = getRecordsByAckType(AcknowledgmentType.ACCEPT) + + /** Get all released records */ + def getReleasedRecords: Seq[ShareInFlightRecord] = getRecordsByAckType(AcknowledgmentType.RELEASE) + + /** Get all rejected records */ + def getRejectedRecords: Seq[ShareInFlightRecord] = getRecordsByAckType(AcknowledgmentType.REJECT) + + /** Get records grouped by TopicPartition */ + def getRecordsByPartition: Map[TopicPartition, Seq[ShareInFlightRecord]] = { + records.values().asScala.groupBy(_.topicPartition).toMap + } + + /** Get acknowledgments grouped by TopicPartition for commit */ + def getAcknowledgmentsForCommit: Map[TopicPartition, Map[Long, AcknowledgmentType.AcknowledgmentType]] = { + records.values().asScala + .filter(_.acknowledgment.isDefined) + .groupBy(_.topicPartition) + .map { case (tp, recs) => + tp -> recs.map(r => r.offset -> r.acknowledgment.get).toMap + } + .toMap + } + + /** Check if all records have been acknowledged */ + def isComplete: Boolean = pendingCount.get() == 0 + + /** Get counts */ + def totalRecords: Int = records.size() + def pending: Int = pendingCount.get() + def accepted: Int = acceptedCount.get() + def released: Int = releasedCount.get() + def rejected: Int = rejectedCount.get() + + /** Clear all records */ + def clear(): Unit = { + records.clear() + pendingCount.set(0) + acceptedCount.set(0) + releasedCount.set(0) + rejectedCount.set(0) + } + + override def toString: String = { + s"ShareInFlightBatch[batchId=$batchId, total=${totalRecords}, " + + s"pending=${pending}, accepted=${accepted}, released=${released}, rejected=${rejected}]" + } +} + +/** + * Manager for in-flight batches across multiple micro-batches. + * Used by the driver to track acknowledgment state. + */ +class ShareInFlightManager { + // Map of batchId -> ShareInFlightBatch + private val batches = new ConcurrentHashMap[Long, ShareInFlightBatch]() + + /** Create a new in-flight batch for the given batch ID */ + def createBatch(batchId: Long): ShareInFlightBatch = { + val batch = new ShareInFlightBatch(batchId) + batches.put(batchId, batch) + batch + } + + /** Get an existing batch */ + def getBatch(batchId: Long): Option[ShareInFlightBatch] = { + Option(batches.get(batchId)) + } + + /** Get or create a batch */ + def getOrCreateBatch(batchId: Long): ShareInFlightBatch = { + batches.computeIfAbsent(batchId, _ => new ShareInFlightBatch(batchId)) + } + + /** Remove a completed batch */ + def removeBatch(batchId: Long): Option[ShareInFlightBatch] = { + Option(batches.remove(batchId)) + } + + /** Get all incomplete batches */ + def getIncompleteBatches: Seq[ShareInFlightBatch] = { + batches.values().asScala.filterNot(_.isComplete).toSeq + } + + /** Release all pending records in all batches (for shutdown) */ + def releaseAll(): Int = { + batches.values().asScala.map(_.releaseAllPending()).sum + } + + /** Clear all batches */ + def clear(): Unit = { + batches.clear() + } + + /** Get total pending records across all batches */ + def totalPending: Int = { + batches.values().asScala.map(_.pending).sum + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareStateBatch.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareStateBatch.scala new file mode 100644 index 0000000000000..13d5740f743bc --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/ShareStateBatch.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import org.apache.kafka.common.TopicPartition + +/** + * Represents the delivery state of a record in a share group. + * Mirrors Kafka's RecordState enum. + */ +object DeliveryState extends Enumeration { + type DeliveryState = Value + val AVAILABLE = Value(0, "AVAILABLE") // Record is available for acquisition + val ACQUIRED = Value(1, "ACQUIRED") // Record is acquired by a consumer + val ACKNOWLEDGED = Value(2, "ACKNOWLEDGED") // Record has been acknowledged (processed) + val ARCHIVED = Value(4, "ARCHIVED") // Record is archived (dead-lettered) + + def fromByte(b: Byte): DeliveryState = b match { + case 0 => AVAILABLE + case 1 => ACQUIRED + case 2 => ACKNOWLEDGED + case 4 => ARCHIVED + case _ => throw new IllegalArgumentException(s"Unknown delivery state: $b") + } +} + +/** + * Represents a contiguous range of offsets with their delivery state and count. + * This mirrors Kafka's PersisterStateBatch which is the fundamental unit for + * tracking state in share groups. + * + * Unlike traditional consumer groups that track a single offset per partition, + * share groups track ranges of offsets with different states because: + * 1. Multiple consumers can acquire records from the same partition concurrently + * 2. Records can be acknowledged out of order + * 3. Records can be released (redelivered) without affecting other records + * + * @param firstOffset Start of the offset range (inclusive) + * @param lastOffset End of the offset range (inclusive) + * @param deliveryState Current state of records in this range + * @param deliveryCount Number of times records in this range have been delivered + */ +case class ShareStateBatch( + firstOffset: Long, + lastOffset: Long, + deliveryState: DeliveryState.DeliveryState, + deliveryCount: Short) { + + require(firstOffset <= lastOffset, + s"firstOffset ($firstOffset) must be <= lastOffset ($lastOffset)") + require(deliveryCount >= 0, s"deliveryCount must be non-negative: $deliveryCount") + + /** Returns the number of records in this batch */ + def recordCount: Long = lastOffset - firstOffset + 1 + + /** Check if a specific offset is within this batch */ + def contains(offset: Long): Boolean = offset >= firstOffset && offset <= lastOffset + + /** Check if this batch can be merged with another batch */ + def canMergeWith(other: ShareStateBatch): Boolean = { + // Batches can be merged if they are adjacent and have the same state/count + (this.lastOffset + 1 == other.firstOffset || other.lastOffset + 1 == this.firstOffset) && + this.deliveryState == other.deliveryState && + this.deliveryCount == other.deliveryCount + } + + /** Merge with another adjacent batch with the same state */ + def mergeWith(other: ShareStateBatch): ShareStateBatch = { + require(canMergeWith(other), s"Cannot merge non-adjacent batches: $this and $other") + ShareStateBatch( + math.min(this.firstOffset, other.firstOffset), + math.max(this.lastOffset, other.lastOffset), + this.deliveryState, + this.deliveryCount + ) + } + + /** + * Split this batch at the given offset, returning two batches. + * The offset will be the lastOffset of the first batch. + */ + def splitAt(offset: Long): (ShareStateBatch, ShareStateBatch) = { + require(contains(offset) && offset < lastOffset, + s"Cannot split at offset $offset - must be within range [$firstOffset, $lastOffset)") + ( + ShareStateBatch(firstOffset, offset, deliveryState, deliveryCount), + ShareStateBatch(offset + 1, lastOffset, deliveryState, deliveryCount) + ) + } + + /** Create a copy with updated delivery state */ + def withState(newState: DeliveryState.DeliveryState): ShareStateBatch = { + copy(deliveryState = newState) + } + + /** Create a copy with incremented delivery count */ + def withIncrementedDeliveryCount(): ShareStateBatch = { + copy(deliveryCount = (deliveryCount + 1).toShort) + } + + override def toString: String = { + s"ShareStateBatch[$firstOffset-$lastOffset: $deliveryState, count=$deliveryCount]" + } +} + +object ShareStateBatch { + /** + * Create a batch for a single offset + */ + def single(offset: Long, state: DeliveryState.DeliveryState, count: Short = 1): ShareStateBatch = { + ShareStateBatch(offset, offset, state, count) + } + + /** + * Create an ACQUIRED batch for a range of offsets + */ + def acquired(firstOffset: Long, lastOffset: Long, deliveryCount: Short = 1): ShareStateBatch = { + ShareStateBatch(firstOffset, lastOffset, DeliveryState.ACQUIRED, deliveryCount) + } + + /** + * Create an ACKNOWLEDGED batch for a range of offsets + */ + def acknowledged(firstOffset: Long, lastOffset: Long, deliveryCount: Short = 1): ShareStateBatch = { + ShareStateBatch(firstOffset, lastOffset, DeliveryState.ACKNOWLEDGED, deliveryCount) + } +} + +/** + * Container for share group state for a specific topic partition. + * This mirrors Kafka's ShareGroupOffset which is persisted to __share_group_state topic. + * + * @param snapshotEpoch Epoch for snapshot versioning + * @param stateEpoch Epoch for state changes (incremented on each state modification) + * @param leaderEpoch Leader epoch for fencing stale coordinators + * @param startOffset The lowest offset that is still tracked (offset before this are complete) + * @param stateBatches List of state batches tracking non-complete offsets + */ +case class SharePartitionState( + topicPartition: TopicPartition, + snapshotEpoch: Int, + stateEpoch: Int, + leaderEpoch: Int, + startOffset: Long, + stateBatches: Seq[ShareStateBatch]) { + + /** + * Get the state for a specific offset. + * Returns None if the offset is below startOffset (already complete). + */ + def getState(offset: Long): Option[ShareStateBatch] = { + if (offset < startOffset) { + None // Already completed/archived + } else { + stateBatches.find(_.contains(offset)) + } + } + + /** + * Get all ACQUIRED offsets in this partition state + */ + def getAcquiredOffsets: Seq[Long] = { + stateBatches + .filter(_.deliveryState == DeliveryState.ACQUIRED) + .flatMap(batch => batch.firstOffset to batch.lastOffset) + } + + /** + * Get the highest tracked offset + */ + def highestOffset: Long = { + if (stateBatches.isEmpty) startOffset - 1 + else stateBatches.map(_.lastOffset).max + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/consumer/InternalKafkaShareConsumerPool.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/consumer/InternalKafkaShareConsumerPool.scala new file mode 100644 index 0000000000000..511621ac6f4e5 --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/consumer/InternalKafkaShareConsumerPool.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share.consumer + +import java.{util => ju} +import java.time.Duration +import java.util.concurrent.ConcurrentHashMap + +import org.apache.commons.pool2.{BaseKeyedPooledObjectFactory, PooledObject, SwallowedExceptionListener} +import org.apache.commons.pool2.impl.{BaseObjectPoolConfig, DefaultEvictionPolicy, DefaultPooledObject, GenericKeyedObjectPool, GenericKeyedObjectPoolConfig} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010._ + +/** + * Object pool for [[InternalKafkaShareConsumer]] which is keyed by [[ShareConsumerCacheKey]]. + * + * Unlike traditional Kafka consumers that are keyed by (groupId, topicPartition), + * share consumers are keyed by (shareGroupId, topics) because share groups: + * 1. Subscribe to topics, not individual partitions + * 2. Allow multiple consumers to receive records from the same partition + * + * This pool leverages [[GenericKeyedObjectPool]] internally with the same contract: + * after using a borrowed object, you must either call returnObject() if healthy, + * or invalidateObject() if the object should be destroyed. + */ +private[consumer] class InternalKafkaShareConsumerPool( + objectFactory: ShareObjectFactory, + poolConfig: SharePoolConfig) extends Logging { + + def this(conf: SparkConf) = { + this(new ShareObjectFactory, new SharePoolConfig(conf)) + } + + // Pool is intended to have soft capacity only + assert(poolConfig.getMaxTotal < 0) + + private val pool = { + val internalPool = new GenericKeyedObjectPool[ShareConsumerCacheKey, InternalKafkaShareConsumer]( + objectFactory, poolConfig) + internalPool.setSwallowedExceptionListener(ShareSwallowedExceptionListener) + internalPool + } + + /** + * Borrow a [[InternalKafkaShareConsumer]] from the pool. + * If there's no idle consumer for the key, a new one will be created. + * + * @param key The cache key (shareGroupId, topics) + * @param kafkaParams Kafka configuration parameters + * @param lockTimeoutMs Acquisition lock timeout + * @return A borrowed InternalKafkaShareConsumer + */ + def borrowObject( + key: ShareConsumerCacheKey, + kafkaParams: ju.Map[String, Object], + lockTimeoutMs: Long): InternalKafkaShareConsumer = { + updateParamsForKey(key, kafkaParams, lockTimeoutMs) + + if (size >= poolConfig.softMaxSize) { + logWarning("Share consumer pool exceeds its soft max size, cleaning up idle objects...") + pool.clearOldest() + } + + pool.borrowObject(key) + } + + /** Return borrowed consumer to the pool. */ + def returnObject(consumer: InternalKafkaShareConsumer): Unit = { + val key = extractCacheKey(consumer) + pool.returnObject(key, consumer) + } + + /** Invalidate (destroy) a borrowed consumer. */ + def invalidateObject(consumer: InternalKafkaShareConsumer): Unit = { + val key = extractCacheKey(consumer) + pool.invalidateObject(key, consumer) + } + + /** Invalidate all idle consumers for the given key. */ + def invalidateKey(key: ShareConsumerCacheKey): Unit = { + pool.clear(key) + } + + /** + * Close the pool. Once closed, borrowObject will fail. + * returnObject and invalidateObject will continue to work. + */ + def close(): Unit = { + pool.close() + } + + def reset(): Unit = { + pool.clear() + } + + def numIdle: Int = pool.getNumIdle + def numIdle(key: ShareConsumerCacheKey): Int = pool.getNumIdle(key) + def numActive: Int = pool.getNumActive + def numActive(key: ShareConsumerCacheKey): Int = pool.getNumActive(key) + def size: Int = numIdle + numActive + def size(key: ShareConsumerCacheKey): Int = numIdle(key) + numActive(key) + + private def updateParamsForKey( + key: ShareConsumerCacheKey, + kafkaParams: ju.Map[String, Object], + lockTimeoutMs: Long): Unit = { + objectFactory.keyToParams.putIfAbsent(key, ShareConsumerParams(kafkaParams, lockTimeoutMs)) + } + + private def extractCacheKey(consumer: InternalKafkaShareConsumer): ShareConsumerCacheKey = { + ShareConsumerCacheKey(consumer.shareGroupId, consumer.topics) + } +} + +/** + * Parameters needed to create a share consumer. + */ +private[consumer] case class ShareConsumerParams( + kafkaParams: ju.Map[String, Object], + lockTimeoutMs: Long) + +/** + * Exception listener for the pool. + */ +private[consumer] object ShareSwallowedExceptionListener + extends SwallowedExceptionListener with Logging { + override def onSwallowException(e: Exception): Unit = { + logError(s"Error closing Kafka share consumer", e) + } +} + +/** + * Pool configuration for share consumers. + */ +private[consumer] class SharePoolConfig(conf: SparkConf) + extends GenericKeyedObjectPoolConfig[InternalKafkaShareConsumer] { + + private var _softMaxSize = Int.MaxValue + + def softMaxSize: Int = _softMaxSize + + init() + + private def init(): Unit = { + // Use the same configuration keys as regular consumers + _softMaxSize = conf.get(CONSUMER_CACHE_CAPACITY) + + val jmxEnabled = conf.get(CONSUMER_CACHE_JMX_ENABLED) + val minEvictableIdleTimeMillis = conf.get(CONSUMER_CACHE_TIMEOUT) + val evictorThreadRunIntervalMillis = conf.get(CONSUMER_CACHE_EVICTOR_THREAD_RUN_INTERVAL) + + // Behavior configuration: + // 1. Min idle per key = 0: don't create unnecessary consumers + // 2. Max idle per key = 3: keep a few idle for reuse + // 3. Max total per key = infinite: don't restrict borrowing + // 4. Max total = infinite: all objects managed in pool + setMinIdlePerKey(0) + setMaxIdlePerKey(3) + setMaxTotalPerKey(-1) + setMaxTotal(-1) + + // Eviction configuration + setMinEvictableIdleDuration(Duration.ofMillis(minEvictableIdleTimeMillis)) + setSoftMinEvictableIdleDuration(BaseObjectPoolConfig.DEFAULT_SOFT_MIN_EVICTABLE_IDLE_DURATION) + setTimeBetweenEvictionRuns(Duration.ofMillis(evictorThreadRunIntervalMillis)) + setNumTestsPerEvictionRun(10) + setEvictionPolicy(new DefaultEvictionPolicy[InternalKafkaShareConsumer]()) + + // Fail immediately on exhausted pool + setBlockWhenExhausted(false) + + setJmxEnabled(jmxEnabled) + setJmxNamePrefix("kafka010-cached-share-consumer-pool") + } +} + +/** + * Factory for creating and destroying share consumers. + */ +private[consumer] class ShareObjectFactory + extends BaseKeyedPooledObjectFactory[ShareConsumerCacheKey, InternalKafkaShareConsumer] { + + val keyToParams = new ConcurrentHashMap[ShareConsumerCacheKey, ShareConsumerParams]() + + override def create(key: ShareConsumerCacheKey): InternalKafkaShareConsumer = { + Option(keyToParams.get(key)) match { + case Some(params) => + new InternalKafkaShareConsumer( + shareGroupId = key.shareGroupId, + topics = key.topics, + kafkaParams = params.kafkaParams, + lockTimeoutMs = params.lockTimeoutMs + ) + case None => + throw new IllegalStateException( + "Share consumer params should be set before borrowing object.") + } + } + + override def wrap(value: InternalKafkaShareConsumer): PooledObject[InternalKafkaShareConsumer] = { + new DefaultPooledObject[InternalKafkaShareConsumer](value) + } + + override def destroyObject( + key: ShareConsumerCacheKey, + p: PooledObject[InternalKafkaShareConsumer]): Unit = { + p.getObject.close() + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/consumer/KafkaShareDataConsumer.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/consumer/KafkaShareDataConsumer.scala new file mode 100644 index 0000000000000..b41158dced8a4 --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/consumer/KafkaShareDataConsumer.scala @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share.consumer + +import java.{util => ju} +import java.io.Closeable +import java.time.Duration +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.consumer.{ConsumerRecord, KafkaShareConsumer} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.deploy.security.HadoopDelegationTokenManager +import org.apache.spark.internal.{Logging, MDC} +import org.apache.spark.internal.LogKeys._ +import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaTokenUtil} +import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS +import org.apache.spark.sql.kafka010.share._ + +/** + * Internal wrapper around Kafka's [[KafkaShareConsumer]] for use in Spark executors. + * + * Unlike the traditional [[org.apache.spark.sql.kafka010.consumer.InternalKafkaConsumer]], + * this consumer: + * 1. Subscribes to topics (not assigned to specific partitions) + * 2. Tracks acquired records for acknowledgment + * 3. Supports explicit acknowledgment with ACCEPT/RELEASE/REJECT semantics + * + * NOTE: Like KafkaShareConsumer, this class is not thread-safe. + * + * @param shareGroupId The share group identifier + * @param topics Set of topics to subscribe to + * @param kafkaParams Kafka consumer configuration parameters + * @param lockTimeoutMs Acquisition lock timeout in milliseconds + */ +private[kafka010] class InternalKafkaShareConsumer( + val shareGroupId: String, + val topics: Set[String], + val kafkaParams: ju.Map[String, Object], + val lockTimeoutMs: Long = 30000L) extends Closeable with Logging { + + // Track acquired records for acknowledgment + private val acquiredRecords = new ConcurrentHashMap[RecordKey, ShareInFlightRecord]() + + // Statistics + private val totalRecordsPolled = new AtomicLong(0) + private val totalRecordsAcknowledged = new AtomicLong(0) + private val totalPollCalls = new AtomicInteger(0) + + // Exposed for testing + private[consumer] val clusterConfig = KafkaTokenUtil.findMatchingTokenClusterConfig( + SparkEnv.get.conf, kafkaParams.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + .asInstanceOf[String]) + + // Kafka consumer - using var for token refresh support + private[consumer] var kafkaParamsWithSecurity: ju.Map[String, Object] = _ + private var _consumer: KafkaShareConsumer[Array[Byte], Array[Byte]] = createConsumer() + + private def consumer: KafkaShareConsumer[Array[Byte], Array[Byte]] = _consumer + + /** + * Poll for new records from subscribed topics. + * Records are automatically acquired and tracked for acknowledgment. + * + * @param pollTimeoutMs Maximum time to block waiting for records + * @return List of acquired records wrapped as ShareInFlightRecord + */ + def poll(pollTimeoutMs: Long): Seq[ShareInFlightRecord] = { + val records = consumer.poll(Duration.ofMillis(pollTimeoutMs)) + val acquiredAt = System.currentTimeMillis() + val result = ArrayBuffer[ShareInFlightRecord]() + + totalPollCalls.incrementAndGet() + + records.forEach { record => + // Get delivery count from record headers if available (Kafka 4.x) + val deliveryCount = getDeliveryCount(record) + + val inFlightRecord = ShareInFlightRecord( + record = record, + acquiredAt = acquiredAt, + lockTimeoutMs = lockTimeoutMs, + deliveryCount = deliveryCount + ) + + acquiredRecords.put(inFlightRecord.recordKey, inFlightRecord) + result += inFlightRecord + totalRecordsPolled.incrementAndGet() + } + + logDebug(s"Polled ${result.size} records from share group $shareGroupId") + result.toSeq + } + + /** + * Acknowledge a single record with the specified type. + * + * @param key The record key to acknowledge + * @param ackType The acknowledgment type (ACCEPT, RELEASE, REJECT) + * @return true if acknowledgment was recorded, false if record not found + */ + def acknowledge(key: RecordKey, ackType: AcknowledgmentType.AcknowledgmentType): Boolean = { + val record = acquiredRecords.get(key) + if (record != null) { + val updatedRecord = record.withAcknowledgment(ackType) + acquiredRecords.put(key, updatedRecord) + totalRecordsAcknowledged.incrementAndGet() + + // Apply acknowledgment to Kafka consumer + ackType match { + case AcknowledgmentType.ACCEPT => + consumer.acknowledge(record.record) + case AcknowledgmentType.RELEASE => + consumer.acknowledge(record.record, org.apache.kafka.clients.consumer.AcknowledgeType.RELEASE) + case AcknowledgmentType.REJECT => + consumer.acknowledge(record.record, org.apache.kafka.clients.consumer.AcknowledgeType.REJECT) + } + true + } else { + logWarning(s"Record $key not found for acknowledgment") + false + } + } + + /** + * Acknowledge multiple records with the specified type. + */ + def acknowledgeAll( + keys: Iterable[RecordKey], + ackType: AcknowledgmentType.AcknowledgmentType): Int = { + keys.count(key => acknowledge(key, ackType)) + } + + /** + * Acknowledge a record by offset for a specific TopicPartition. + */ + def acknowledgeByOffset( + tp: TopicPartition, + offset: Long, + ackType: AcknowledgmentType.AcknowledgmentType): Boolean = { + acknowledge(RecordKey(tp, offset), ackType) + } + + /** + * Acknowledge all pending records as ACCEPT. + * Used for implicit acknowledgment mode. + */ + def acknowledgeAllPendingAsAccept(): Int = { + var count = 0 + acquiredRecords.forEach { (key, record) => + if (record.acknowledgment.isEmpty) { + acknowledge(key, AcknowledgmentType.ACCEPT) + count += 1 + } + } + count + } + + /** + * Release all pending records (for failure recovery). + * This allows Kafka to redeliver them to other consumers. + */ + def releaseAllPending(): Int = { + var count = 0 + acquiredRecords.forEach { (key, record) => + if (record.acknowledgment.isEmpty) { + acknowledge(key, AcknowledgmentType.RELEASE) + count += 1 + } + } + count + } + + /** + * Commit all pending acknowledgments synchronously. + * This ensures durability of acknowledgments in Kafka. + */ + def commitSync(): Unit = { + consumer.commitSync() + logDebug(s"Committed acknowledgments for share group $shareGroupId") + } + + /** + * Commit acknowledgments asynchronously. + */ + def commitAsync(): Unit = { + consumer.commitAsync() + } + + /** + * Get all acquired records that haven't been acknowledged yet. + */ + def getPendingRecords: Seq[ShareInFlightRecord] = { + acquiredRecords.values().asScala.filter(_.acknowledgment.isEmpty).toSeq + } + + /** + * Get acquired records grouped by TopicPartition. + */ + def getRecordsByPartition: Map[TopicPartition, Seq[ShareInFlightRecord]] = { + acquiredRecords.values().asScala.groupBy(_.topicPartition).toMap + } + + /** + * Get acknowledgments ready for commit, grouped by TopicPartition. + */ + def getAcknowledgmentsForCommit: Map[TopicPartition, Map[Long, AcknowledgmentType.AcknowledgmentType]] = { + acquiredRecords.values().asScala + .filter(_.acknowledgment.isDefined) + .groupBy(_.topicPartition) + .map { case (tp, recs) => + tp -> recs.map(r => r.offset -> r.acknowledgment.get).toMap + } + .toMap + } + + /** + * Clear acknowledged records from tracking (after successful commit). + */ + def clearAcknowledged(): Unit = { + val keysToRemove = acquiredRecords.asScala + .filter(_._2.acknowledgment.isDefined) + .keys + .toSeq + keysToRemove.foreach(acquiredRecords.remove) + } + + /** + * Check if any records have expired locks. + */ + def hasExpiredLocks: Boolean = { + val now = System.currentTimeMillis() + acquiredRecords.values().asScala.exists(_.isLockExpired(now)) + } + + /** + * Get records with expired locks. + */ + def getExpiredLockRecords: Seq[ShareInFlightRecord] = { + val now = System.currentTimeMillis() + acquiredRecords.values().asScala.filter(_.isLockExpired(now)).toSeq + } + + /** + * Get statistics about this consumer. + */ + def getStats: ShareConsumerStats = ShareConsumerStats( + shareGroupId = shareGroupId, + totalRecordsPolled = totalRecordsPolled.get(), + totalRecordsAcknowledged = totalRecordsAcknowledged.get(), + totalPollCalls = totalPollCalls.get(), + pendingRecords = acquiredRecords.values().asScala.count(_.acknowledgment.isEmpty), + acknowledgedRecords = acquiredRecords.values().asScala.count(_.acknowledgment.isDefined) + ) + + override def close(): Unit = { + // Release any pending records before closing + val pendingCount = releaseAllPending() + if (pendingCount > 0) { + logWarning(s"Released $pendingCount pending records before closing share consumer") + try { + commitSync() + } catch { + case e: Exception => + logWarning(s"Failed to commit releases on close: ${e.getMessage}") + } + } + consumer.close() + } + + def wakeup(): Unit = { + consumer.wakeup() + } + + /** Create the underlying Kafka share consumer */ + private def createConsumer(): KafkaShareConsumer[Array[Byte], Array[Byte]] = { + kafkaParamsWithSecurity = KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) + .setAuthenticationConfigIfNeeded(clusterConfig) + .build() + + val c = new KafkaShareConsumer[Array[Byte], Array[Byte]](kafkaParamsWithSecurity) + c.subscribe(topics.asJava) + c + } + + /** + * Extract delivery count from record headers. + * Kafka 4.x adds delivery count information to share group records. + */ + private def getDeliveryCount(record: ConsumerRecord[Array[Byte], Array[Byte]]): Short = { + // TODO: Extract from Kafka 4.x headers when available + // For now, default to 1 (first delivery) + 1 + } +} + +/** + * Statistics for a share consumer. + */ +case class ShareConsumerStats( + shareGroupId: String, + totalRecordsPolled: Long, + totalRecordsAcknowledged: Long, + totalPollCalls: Int, + pendingRecords: Int, + acknowledgedRecords: Int) + +/** + * Configuration for acknowledgment behavior. + */ +sealed trait AcknowledgmentMode +object AcknowledgmentMode { + /** Automatically acknowledge all records as ACCEPT when batch completes */ + case object Implicit extends AcknowledgmentMode + /** Require explicit acknowledgment via foreachBatch or user code */ + case object Explicit extends AcknowledgmentMode + + def fromString(s: String): AcknowledgmentMode = s.toLowerCase match { + case "implicit" | "auto" => Implicit + case "explicit" | "manual" => Explicit + case _ => throw new IllegalArgumentException(s"Unknown acknowledgment mode: $s") + } +} + +/** + * Helper object for acquiring KafkaShareDataConsumer instances. + * Manages consumer lifecycle including token refresh. + */ +private[kafka010] class KafkaShareDataConsumer( + shareGroupId: String, + topics: Set[String], + kafkaParams: ju.Map[String, Object], + lockTimeoutMs: Long, + consumerPool: InternalKafkaShareConsumerPool) extends Logging { + + private val isTokenProviderEnabled = + HadoopDelegationTokenManager.isServiceEnabled(SparkEnv.get.conf, "kafka") + + @volatile private[consumer] var _consumer: Option[InternalKafkaShareConsumer] = None + + // Tracking stats + private var startTimestampNano: Long = System.nanoTime() + private var totalTimeReadNanos: Long = 0 + private var totalRecordsRead: Long = 0 + + /** + * Get or create the internal consumer. + */ + def getOrRetrieveConsumer(): InternalKafkaShareConsumer = { + if (_consumer.isEmpty) { + retrieveConsumer() + } + require(_consumer.isDefined, "Consumer must be defined") + + // Check if token refresh is needed + if (isTokenProviderEnabled && KafkaTokenUtil.needTokenUpdate( + _consumer.get.kafkaParamsWithSecurity, _consumer.get.clusterConfig)) { + logDebug("Cached share consumer uses an old delegation token, invalidating.") + releaseConsumer() + consumerPool.invalidateKey(ShareConsumerCacheKey(shareGroupId, topics)) + retrieveConsumer() + } + + _consumer.get + } + + /** + * Poll for records from the share group. + */ + def poll(pollTimeoutMs: Long): Seq[ShareInFlightRecord] = { + val consumer = getOrRetrieveConsumer() + val startTime = System.nanoTime() + val records = consumer.poll(pollTimeoutMs) + totalTimeReadNanos += (System.nanoTime() - startTime) + totalRecordsRead += records.size + records + } + + /** + * Acknowledge a record. + */ + def acknowledge(key: RecordKey, ackType: AcknowledgmentType.AcknowledgmentType): Boolean = { + _consumer.exists(_.acknowledge(key, ackType)) + } + + /** + * Commit acknowledgments synchronously. + */ + def commitSync(): Unit = { + _consumer.foreach(_.commitSync()) + } + + /** + * Release the consumer back to the pool. + */ + def release(): Unit = { + val kafkaMeta = _consumer + .map(c => s"shareGroupId=${c.shareGroupId} topics=${c.topics}") + .getOrElse("") + val walTime = System.nanoTime() - startTimestampNano + + val taskCtx = TaskContext.get() + val taskContextInfo = if (taskCtx != null) { + s" for taskId=${taskCtx.taskAttemptId()} partitionId=${taskCtx.partitionId()}." + } else { + "." + } + + logInfo(s"From Kafka share group $kafkaMeta read $totalRecordsRead records, " + + s"taking ${totalTimeReadNanos / NANOS_PER_MILLIS.toDouble} ms, " + + s"during time span of ${walTime / NANOS_PER_MILLIS.toDouble} ms$taskContextInfo") + + releaseConsumer() + } + + private def retrieveConsumer(): Unit = { + _consumer = Option(consumerPool.borrowObject( + ShareConsumerCacheKey(shareGroupId, topics), + kafkaParams, + lockTimeoutMs + )) + startTimestampNano = System.nanoTime() + totalTimeReadNanos = 0 + totalRecordsRead = 0 + require(_consumer.isDefined, "borrowing consumer from pool must always succeed.") + } + + private def releaseConsumer(): Unit = { + if (_consumer.isDefined) { + consumerPool.returnObject(_consumer.get) + _consumer = None + } + } +} + +/** + * Cache key for share consumer pool. + */ +case class ShareConsumerCacheKey(shareGroupId: String, topics: Set[String]) { + override def hashCode(): Int = shareGroupId.hashCode * 31 + topics.hashCode() + + override def equals(obj: Any): Boolean = obj match { + case other: ShareConsumerCacheKey => + this.shareGroupId == other.shareGroupId && this.topics == other.topics + case _ => false + } +} + +/** + * Companion object for acquiring KafkaShareDataConsumer instances. + */ +object KafkaShareDataConsumer extends Logging { + private val sparkConf = SparkEnv.get.conf + private val consumerPool = new InternalKafkaShareConsumerPool(sparkConf) + + /** + * Acquire a share data consumer for the given share group and topics. + */ + def acquire( + shareGroupId: String, + topics: Set[String], + kafkaParams: ju.Map[String, Object], + lockTimeoutMs: Long = 30000L): KafkaShareDataConsumer = { + if (TaskContext.get() != null && TaskContext.get().attemptNumber() >= 1) { + // If this is a reattempt, invalidate cached consumer + consumerPool.invalidateKey(ShareConsumerCacheKey(shareGroupId, topics)) + } + + new KafkaShareDataConsumer(shareGroupId, topics, kafkaParams, lockTimeoutMs, consumerPool) + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/CheckpointDeduplication.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/CheckpointDeduplication.scala new file mode 100644 index 0000000000000..9c2c67a5b1002 --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/CheckpointDeduplication.scala @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share.exactlyonce + +import java.io.{DataInputStream, DataOutputStream} +import java.util.concurrent.ConcurrentHashMap + +import scala.jdk.CollectionConverters._ +import scala.util.Try + +import org.json4s.{DefaultFormats, Formats} +import org.json4s.jackson.JsonMethods.{compact, parse, render} +import org.json4s.JsonDSL._ + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.kafka010.share._ + +/** + * Strategy C: Checkpoint-Based Deduplication for Exactly-Once Semantics + * + * This strategy achieves exactly-once semantics by tracking processed record IDs + * in Spark's checkpoint and skipping redelivered records. + * + * How it works: + * 1. When a record is acquired, check if it's in the processed set + * 2. If already processed, skip it (and ACCEPT immediately) + * 3. If not processed, process it and add to processed set + * 4. Periodically persist processed set to checkpoint + * + * State Management: + * - ProcessedRecords: Records that have been successfully processed + * - PendingRecords: Records currently being processed (for recovery) + * + * Recovery: + * - On restart, load processed records from checkpoint + * - Release any pending records (they weren't committed) + * + * Requirements: + * - Checkpoint storage (HDFS, S3, etc.) + * + * Advantages: + * - Works with any sink + * - Medium complexity + * - Built on Spark's checkpoint mechanism + * + * Disadvantages: + * - State growth over time (needs cleanup) + * - Checkpoint overhead + * - Memory usage for in-memory dedup set + */ +class CheckpointDedupManager(checkpointPath: String) extends Logging { + + implicit val formats: Formats = DefaultFormats + + private val processedRecordsPath = s"$checkpointPath/processed_records" + private val pendingRecordsPath = s"$checkpointPath/pending_records" + private val metadataPath = s"$checkpointPath/dedup_metadata" + + // In-memory cache of processed records + private val processedRecords = ConcurrentHashMap.newKeySet[RecordKey]() + + // Currently pending records (being processed in current batch) + private val pendingRecords = ConcurrentHashMap.newKeySet[RecordKey]() + + // Batch tracking + private var currentBatchId: Long = -1 + private var lastCleanupBatchId: Long = -1 + + // Configuration + private val maxProcessedRecords = 1000000 // Max records to keep in memory + private val cleanupIntervalBatches = 100 // Cleanup every N batches + + /** + * Initialize the dedup manager by loading state from checkpoint. + */ + def initialize(): Unit = { + loadProcessedRecords() + recoverPendingRecords() + logInfo(s"Initialized checkpoint dedup manager with ${processedRecords.size()} processed records") + } + + /** + * Check if a record has already been processed. + */ + def isProcessed(key: RecordKey): Boolean = { + processedRecords.contains(key) + } + + /** + * Mark a record as pending (currently being processed). + */ + def markPending(key: RecordKey): Unit = { + pendingRecords.add(key) + } + + /** + * Mark a record as processed (successfully completed). + */ + def markProcessed(key: RecordKey): Unit = { + pendingRecords.remove(key) + processedRecords.add(key) + } + + /** + * Mark multiple records as processed. + */ + def markProcessed(keys: Iterable[RecordKey]): Unit = { + keys.foreach { key => + pendingRecords.remove(key) + processedRecords.add(key) + } + } + + /** + * Get pending records that need to be released on failure. + */ + def getPendingRecords: Set[RecordKey] = { + pendingRecords.asScala.toSet + } + + /** + * Clear pending records (after successful commit). + */ + def clearPending(): Unit = { + pendingRecords.clear() + } + + /** + * Start a new batch. + */ + def startBatch(batchId: Long): Unit = { + currentBatchId = batchId + + // Check if cleanup is needed + if (batchId - lastCleanupBatchId >= cleanupIntervalBatches) { + cleanupOldRecords() + lastCleanupBatchId = batchId + } + } + + /** + * Commit a batch - persist processed records to checkpoint. + */ + def commitBatch(batchId: Long): Unit = { + // Move pending to processed + val pendingSnapshot = pendingRecords.asScala.toSet + pendingSnapshot.foreach { key => + processedRecords.add(key) + pendingRecords.remove(key) + } + + // Persist to checkpoint + persistProcessedRecords() + + logDebug(s"Committed batch $batchId with ${pendingSnapshot.size} records") + } + + /** + * Rollback a batch - clear pending records. + */ + def rollbackBatch(batchId: Long): Set[RecordKey] = { + val rollbackRecords = pendingRecords.asScala.toSet + pendingRecords.clear() + logWarning(s"Rolled back batch $batchId, ${rollbackRecords.size} records to release") + rollbackRecords + } + + /** + * Cleanup old processed records to prevent unbounded growth. + * + * Strategy: Keep only records from recent batches based on offset ranges. + */ + private def cleanupOldRecords(): Unit = { + val sizeBefore = processedRecords.size() + + if (sizeBefore > maxProcessedRecords) { + // Group by topic-partition and keep only recent offsets + val grouped = processedRecords.asScala.groupBy(r => (r.topic, r.partition)) + val cleaned = ConcurrentHashMap.newKeySet[RecordKey]() + + grouped.foreach { case ((topic, partition), records) => + // Keep records with highest offsets + val sorted = records.toSeq.sortBy(_.offset).reverse + val toKeep = sorted.take(maxProcessedRecords / grouped.size) + toKeep.foreach(cleaned.add) + } + + processedRecords.clear() + cleaned.forEach(processedRecords.add) + + logInfo(s"Cleaned up processed records: $sizeBefore -> ${processedRecords.size()}") + } + } + + /** + * Persist processed records to checkpoint storage. + */ + private def persistProcessedRecords(): Unit = { + try { + val spark = SparkSession.active + val hadoopConf = spark.sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + + val path = new Path(processedRecordsPath) + fs.mkdirs(path.getParent) + + val tempPath = new Path(s"$processedRecordsPath.tmp") + val out = new DataOutputStream(fs.create(tempPath, true)) + + try { + // Write version + out.writeInt(1) + + // Write count + out.writeInt(processedRecords.size()) + + // Write each record key + processedRecords.forEach { key => + out.writeUTF(key.topic) + out.writeInt(key.partition) + out.writeLong(key.offset) + } + } finally { + out.close() + } + + // Atomic rename + fs.delete(path, false) + fs.rename(tempPath, path) + + logDebug(s"Persisted ${processedRecords.size()} processed records to checkpoint") + } catch { + case e: Exception => + logError(s"Failed to persist processed records: ${e.getMessage}", e) + } + } + + /** + * Load processed records from checkpoint storage. + */ + private def loadProcessedRecords(): Unit = { + try { + val spark = SparkSession.active + val hadoopConf = spark.sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + + val path = new Path(processedRecordsPath) + if (!fs.exists(path)) { + logInfo("No existing processed records checkpoint found") + return + } + + val in = new DataInputStream(fs.open(path)) + try { + // Read version + val version = in.readInt() + require(version == 1, s"Unsupported checkpoint version: $version") + + // Read count + val count = in.readInt() + + // Read each record key + (0 until count).foreach { _ => + val topic = in.readUTF() + val partition = in.readInt() + val offset = in.readLong() + processedRecords.add(RecordKey(topic, partition, offset)) + } + + logInfo(s"Loaded $count processed records from checkpoint") + } finally { + in.close() + } + } catch { + case e: Exception => + logError(s"Failed to load processed records: ${e.getMessage}", e) + } + } + + /** + * Recover from pending records checkpoint. + * Pending records from a failed batch should be released. + */ + private def recoverPendingRecords(): Unit = { + try { + val spark = SparkSession.active + val hadoopConf = spark.sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + + val path = new Path(pendingRecordsPath) + if (!fs.exists(path)) { + return + } + + val in = new DataInputStream(fs.open(path)) + val recoveredPending = scala.collection.mutable.Set[RecordKey]() + + try { + val version = in.readInt() + val count = in.readInt() + + (0 until count).foreach { _ => + val topic = in.readUTF() + val partition = in.readInt() + val offset = in.readLong() + recoveredPending.add(RecordKey(topic, partition, offset)) + } + } finally { + in.close() + } + + if (recoveredPending.nonEmpty) { + logWarning(s"Found ${recoveredPending.size} pending records from previous run - " + + "these will need to be released") + // Note: The caller should release these records + recoveredPending.foreach(pendingRecords.add) + } + + // Delete pending checkpoint after recovery + fs.delete(path, false) + } catch { + case e: Exception => + logError(s"Failed to recover pending records: ${e.getMessage}", e) + } + } + + /** + * Persist pending records for crash recovery. + */ + def persistPendingRecords(): Unit = { + if (pendingRecords.isEmpty) return + + try { + val spark = SparkSession.active + val hadoopConf = spark.sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + + val path = new Path(pendingRecordsPath) + fs.mkdirs(path.getParent) + + val out = new DataOutputStream(fs.create(path, true)) + try { + out.writeInt(1) // version + out.writeInt(pendingRecords.size()) + + pendingRecords.forEach { key => + out.writeUTF(key.topic) + out.writeInt(key.partition) + out.writeLong(key.offset) + } + } finally { + out.close() + } + } catch { + case e: Exception => + logError(s"Failed to persist pending records: ${e.getMessage}", e) + } + } + + /** + * Close the manager and persist final state. + */ + def close(): Unit = { + persistProcessedRecords() + persistPendingRecords() + } + + /** + * Get statistics about the dedup manager. + */ + def getStats: DedupStats = DedupStats( + processedCount = processedRecords.size(), + pendingCount = pendingRecords.size(), + currentBatchId = currentBatchId, + lastCleanupBatchId = lastCleanupBatchId + ) +} + +/** + * Statistics for the checkpoint dedup manager. + */ +case class DedupStats( + processedCount: Int, + pendingCount: Int, + currentBatchId: Long, + lastCleanupBatchId: Long) + +/** + * Checkpoint state for a share group batch. + * Stored in Spark's checkpoint for recovery. + */ +case class ShareBatchCheckpoint( + batchId: Long, + shareGroupId: String, + processedRecords: Set[RecordKey], + pendingRecords: Set[RecordKey], + timestamp: Long = System.currentTimeMillis()) { + + implicit val formats: Formats = DefaultFormats + + def toJson: String = { + compact(render( + ("batchId" -> batchId) ~ + ("shareGroupId" -> shareGroupId) ~ + ("processedRecords" -> processedRecords.map(_.toString).toList) ~ + ("pendingRecords" -> pendingRecords.map(_.toString).toList) ~ + ("timestamp" -> timestamp) + )) + } +} + +object ShareBatchCheckpoint { + implicit val formats: Formats = DefaultFormats + + def fromJson(json: String): ShareBatchCheckpoint = { + val parsed = parse(json) + val batchId = (parsed \ "batchId").extract[Long] + val shareGroupId = (parsed \ "shareGroupId").extract[String] + val processedRecords = (parsed \ "processedRecords").extract[List[String]] + .map(RecordKey.fromString).toSet + val pendingRecords = (parsed \ "pendingRecords").extract[List[String]] + .map(RecordKey.fromString).toSet + val timestamp = (parsed \ "timestamp").extract[Long] + + ShareBatchCheckpoint(batchId, shareGroupId, processedRecords, pendingRecords, timestamp) + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/IdempotentSink.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/IdempotentSink.scala new file mode 100644 index 0000000000000..ba9fe4ea012db --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/IdempotentSink.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share.exactlyonce + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.DataStreamWriter +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} + +/** + * Strategy A: Idempotent Processing for Exactly-Once Semantics + * + * This strategy achieves exactly-once semantics by deduplicating records at the sink + * using unique record identifiers (topic, partition, offset). + * + * How it works: + * 1. Each Kafka record has a unique (topic, partition, offset) tuple + * 2. The sink uses UPSERT/MERGE semantics with this tuple as the key + * 3. If a record is redelivered (due to failure), the UPSERT is a no-op + * + * Requirements: + * - Sink must support idempotent writes (INSERT ON CONFLICT DO NOTHING/UPDATE) + * - Supported sinks: Delta Lake, JDBC with UPSERT, Cassandra, MongoDB, etc. + * + * Advantages: + * - Simple to implement + * - Low latency (no coordination overhead) + * - Works with at-least-once delivery + * + * Disadvantages: + * - Requires sink support for idempotent writes + * - Storage overhead for deduplication keys + */ +object IdempotentSink extends Logging { + + /** Column names used for deduplication */ + val DEDUP_KEY_TOPIC = "__kafka_topic" + val DEDUP_KEY_PARTITION = "__kafka_partition" + val DEDUP_KEY_OFFSET = "__kafka_offset" + + /** Schema for deduplication columns */ + val DEDUP_KEY_SCHEMA: StructType = StructType(Seq( + StructField(DEDUP_KEY_TOPIC, StringType, nullable = false), + StructField(DEDUP_KEY_PARTITION, IntegerType, nullable = false), + StructField(DEDUP_KEY_OFFSET, LongType, nullable = false) + )) + + /** + * Add deduplication key columns to a DataFrame read from Kafka share source. + * + * The Kafka share source already provides topic, partition, and offset columns. + * This method renames them to avoid conflicts with user schema. + * + * @param df DataFrame from Kafka share source + * @return DataFrame with deduplication key columns added + */ + def addDedupColumns(df: DataFrame): DataFrame = { + df.withColumn(DEDUP_KEY_TOPIC, col("topic")) + .withColumn(DEDUP_KEY_PARTITION, col("partition")) + .withColumn(DEDUP_KEY_OFFSET, col("offset")) + } + + /** + * Create a composite deduplication key from topic, partition, offset. + * + * @param df DataFrame with Kafka source columns + * @param keyColumnName Name for the composite key column + * @return DataFrame with composite key column added + */ + def addCompositeKey(df: DataFrame, keyColumnName: String = "__kafka_key"): DataFrame = { + df.withColumn(keyColumnName, + concat(col("topic"), lit("-"), col("partition"), lit(":"), col("offset"))) + } + + /** + * Configure Delta Lake sink for idempotent writes. + * + * Uses MERGE operation with the deduplication key as the merge condition. + */ + def configureDeltaSink( + df: DataFrame, + tablePath: String, + partitionBy: Seq[String] = Seq.empty): DataStreamWriter[Row] = { + val dfWithKeys = addDedupColumns(df) + + val writer = dfWithKeys.writeStream + .format("delta") + .outputMode("append") + .option("checkpointLocation", s"$tablePath/_checkpoints") + .option("mergeSchema", "true") + + if (partitionBy.nonEmpty) { + writer.partitionBy(partitionBy: _*) + } else { + writer + } + } + + /** + * Write to Delta Lake with deduplication using foreachBatch. + * + * This uses Delta Lake's MERGE operation to achieve exactly-once semantics. + */ + def writeToDeltaWithDedup( + df: DataFrame, + tablePath: String, + spark: SparkSession): DataStreamWriter[Row] = { + import io.delta.tables.DeltaTable + + df.writeStream + .foreachBatch { (batchDf: DataFrame, batchId: Long) => + val dfWithKeys = addDedupColumns(batchDf) + + if (DeltaTable.isDeltaTable(spark, tablePath)) { + val deltaTable = DeltaTable.forPath(spark, tablePath) + + // MERGE with deduplication key + deltaTable.as("target") + .merge( + dfWithKeys.as("source"), + s"target.$DEDUP_KEY_TOPIC = source.$DEDUP_KEY_TOPIC AND " + + s"target.$DEDUP_KEY_PARTITION = source.$DEDUP_KEY_PARTITION AND " + + s"target.$DEDUP_KEY_OFFSET = source.$DEDUP_KEY_OFFSET" + ) + .whenNotMatched() + .insertAll() + .execute() + } else { + // First write - create the table + dfWithKeys.write + .format("delta") + .mode("overwrite") + .save(tablePath) + } + } + } + + /** + * Configure JDBC sink for idempotent writes using INSERT ON CONFLICT. + * + * Note: Requires the target table to have a unique constraint on + * (topic, partition, offset) columns. + */ + def configureJdbcSink( + df: DataFrame, + url: String, + table: String, + connectionProperties: java.util.Properties): DataStreamWriter[Row] = { + val dfWithKeys = addDedupColumns(df) + + dfWithKeys.writeStream + .foreachBatch { (batchDf: DataFrame, batchId: Long) => + batchDf.write + .mode("append") + .jdbc(url, table, connectionProperties) + } + } + + /** + * Generate SQL for creating a deduplication-enabled table. + * + * @param tableName Table name + * @param userSchema User's schema (excluding dedup columns) + * @param dialect SQL dialect (postgresql, mysql, sqlite) + * @return CREATE TABLE SQL statement + */ + def generateCreateTableSql( + tableName: String, + userSchema: StructType, + dialect: String = "postgresql"): String = { + val allFields = DEDUP_KEY_SCHEMA.fields ++ userSchema.fields + + val columnDefs = allFields.map { field => + val sqlType = dialect match { + case "postgresql" => sparkTypeToPostgres(field.dataType) + case "mysql" => sparkTypeToMysql(field.dataType) + case _ => sparkTypeToAnsi(field.dataType) + } + s"${field.name} $sqlType${if (field.nullable) "" else " NOT NULL"}" + }.mkString(",\n ") + + val uniqueConstraint = s"UNIQUE ($DEDUP_KEY_TOPIC, $DEDUP_KEY_PARTITION, $DEDUP_KEY_OFFSET)" + + s"""CREATE TABLE IF NOT EXISTS $tableName ( + | $columnDefs, + | $uniqueConstraint + |)""".stripMargin + } + + private def sparkTypeToPostgres(dt: org.apache.spark.sql.types.DataType): String = dt match { + case StringType => "TEXT" + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case org.apache.spark.sql.types.BinaryType => "BYTEA" + case org.apache.spark.sql.types.TimestampType => "TIMESTAMP" + case org.apache.spark.sql.types.DoubleType => "DOUBLE PRECISION" + case org.apache.spark.sql.types.BooleanType => "BOOLEAN" + case _ => "TEXT" + } + + private def sparkTypeToMysql(dt: org.apache.spark.sql.types.DataType): String = dt match { + case StringType => "TEXT" + case IntegerType => "INT" + case LongType => "BIGINT" + case org.apache.spark.sql.types.BinaryType => "BLOB" + case org.apache.spark.sql.types.TimestampType => "TIMESTAMP" + case org.apache.spark.sql.types.DoubleType => "DOUBLE" + case org.apache.spark.sql.types.BooleanType => "BOOLEAN" + case _ => "TEXT" + } + + private def sparkTypeToAnsi(dt: org.apache.spark.sql.types.DataType): String = dt match { + case StringType => "VARCHAR" + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case org.apache.spark.sql.types.BinaryType => "BINARY" + case org.apache.spark.sql.types.TimestampType => "TIMESTAMP" + case org.apache.spark.sql.types.DoubleType => "DOUBLE" + case org.apache.spark.sql.types.BooleanType => "BOOLEAN" + case _ => "VARCHAR" + } +} + +/** + * Helper for deduplicating within a DataFrame using Spark's dropDuplicates. + * + * Useful when processing data in memory before writing to a non-idempotent sink. + */ +object InMemoryDeduplication extends Logging { + + /** + * Deduplicate records within a batch using Kafka coordinates. + * + * @param df DataFrame with topic, partition, offset columns + * @return Deduplicated DataFrame + */ + def deduplicateByKafkaCoordinates(df: DataFrame): DataFrame = { + df.dropDuplicates("topic", "partition", "offset") + } + + /** + * Deduplicate records using a custom key expression. + * + * @param df DataFrame to deduplicate + * @param keyColumns Columns to use as deduplication key + * @return Deduplicated DataFrame + */ + def deduplicateByKey(df: DataFrame, keyColumns: String*): DataFrame = { + df.dropDuplicates(keyColumns: _*) + } + + /** + * Deduplicate with watermark for streaming state management. + * + * @param df DataFrame with timestamp column + * @param timestampColumn Column containing event timestamps + * @param watermarkDelay Watermark delay (e.g., "10 minutes") + * @param keyColumns Deduplication key columns + * @return Deduplicated DataFrame with watermark + */ + def deduplicateWithWatermark( + df: DataFrame, + timestampColumn: String, + watermarkDelay: String, + keyColumns: String*): DataFrame = { + df.withWatermark(timestampColumn, watermarkDelay) + .dropDuplicates(keyColumns: _*) + } +} + diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/TwoPhaseCommitCoordinator.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/TwoPhaseCommitCoordinator.scala new file mode 100644 index 0000000000000..c1d2ce74ff93d --- /dev/null +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/share/exactlyonce/TwoPhaseCommitCoordinator.scala @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share.exactlyonce + +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} + +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.kafka010.share._ +import org.apache.spark.sql.kafka010.share.consumer._ + +/** + * Strategy B: Two-Phase Commit for True Exactly-Once Semantics + * + * This strategy achieves true exactly-once semantics by atomically committing + * both the output writes and Kafka acknowledgments using a two-phase commit protocol. + * + * How it works: + * Phase 1 (Prepare): + * 1. Write output to staging location + * 2. Prepare Kafka acknowledgments (but don't commit) + * 3. Record transaction state to WAL + * + * Phase 2 (Commit): + * 1. Move staged output to final location + * 2. Commit Kafka acknowledgments + * 3. Mark transaction as complete + * + * Rollback (on failure): + * 1. Delete staged output + * 2. Release Kafka records (RELEASE acknowledgment) + * 3. Mark transaction as aborted + * + * Requirements: + * - Transaction-capable sink (staging support) + * - WAL for transaction state + * - Coordination between driver and executors + * + * Advantages: + * - True exactly-once semantics + * - Works with any sink that supports staging + * + * Disadvantages: + * - Higher latency (two phases) + * - More complex implementation + * - Requires coordination + */ +class TwoPhaseCommitCoordinator( + shareGroupId: String, + checkpointPath: String) extends Logging { + + import TwoPhaseCommitCoordinator._ + + // Transaction state tracking + private val activeTransactions = new ConcurrentHashMap[Long, TransactionState]() + + // WAL for transaction recovery + private val transactionLog = new TransactionLog(checkpointPath) + + // Transaction ID generator + private val transactionIdCounter = new AtomicInteger(0) + + /** + * Begin a new transaction for a batch. + * + * @param batchId The micro-batch ID + * @return Transaction handle for the batch + */ + def beginTransaction(batchId: Long): TransactionHandle = { + val transactionId = transactionIdCounter.incrementAndGet() + val state = TransactionState( + transactionId = transactionId, + batchId = batchId, + phase = Phase.STARTED, + startTime = System.currentTimeMillis(), + pendingAcks = new ConcurrentHashMap[RecordKey, AcknowledgmentType.AcknowledgmentType]() + ) + + activeTransactions.put(batchId, state) + transactionLog.write(TransactionRecord(transactionId, batchId, Phase.STARTED)) + + logInfo(s"Started transaction $transactionId for batch $batchId") + TransactionHandle(transactionId, batchId, this) + } + + /** + * Phase 1: Prepare the transaction. + * + * This should be called after output has been written to staging and + * acknowledgments have been recorded (but not committed to Kafka). + */ + def prepareTransaction(batchId: Long): Boolean = { + val state = activeTransactions.get(batchId) + if (state == null) { + logError(s"No active transaction found for batch $batchId") + return false + } + + if (state.phase != Phase.STARTED) { + logError(s"Transaction ${state.transactionId} is in ${state.phase}, expected STARTED") + return false + } + + // Update state to PREPARED + state.phase = Phase.PREPARED + state.prepareTime = System.currentTimeMillis() + transactionLog.write(TransactionRecord(state.transactionId, batchId, Phase.PREPARED)) + + logInfo(s"Prepared transaction ${state.transactionId} for batch $batchId " + + s"with ${state.pendingAcks.size()} acknowledgments") + true + } + + /** + * Phase 2: Commit the transaction. + * + * This should be called after verifying the output has been moved to final location. + */ + def commitTransaction(batchId: Long, consumer: InternalKafkaShareConsumer): Boolean = { + val state = activeTransactions.get(batchId) + if (state == null) { + logError(s"No active transaction found for batch $batchId") + return false + } + + if (state.phase != Phase.PREPARED) { + logError(s"Transaction ${state.transactionId} is in ${state.phase}, expected PREPARED") + return false + } + + try { + // Apply all pending acknowledgments to Kafka + state.pendingAcks.forEach { (key, ackType) => + consumer.acknowledge(key, ackType) + } + + // Commit to Kafka + consumer.commitSync() + + // Update state to COMMITTED + state.phase = Phase.COMMITTED + state.commitTime = System.currentTimeMillis() + transactionLog.write(TransactionRecord(state.transactionId, batchId, Phase.COMMITTED)) + + // Clean up + activeTransactions.remove(batchId) + + logInfo(s"Committed transaction ${state.transactionId} for batch $batchId") + true + } catch { + case e: Exception => + logError(s"Failed to commit transaction ${state.transactionId}: ${e.getMessage}", e) + // Attempt rollback + rollbackTransaction(batchId, consumer) + false + } + } + + /** + * Rollback a transaction. + * + * This releases all acquired records back to Kafka for redelivery. + */ + def rollbackTransaction(batchId: Long, consumer: InternalKafkaShareConsumer): Boolean = { + val state = activeTransactions.get(batchId) + if (state == null) { + logWarning(s"No active transaction found for batch $batchId to rollback") + return false + } + + try { + // Release all records + state.pendingAcks.forEach { (key, _) => + consumer.acknowledge(key, AcknowledgmentType.RELEASE) + } + + // Commit releases to Kafka + consumer.commitSync() + + // Update state to ABORTED + state.phase = Phase.ABORTED + transactionLog.write(TransactionRecord(state.transactionId, batchId, Phase.ABORTED)) + + // Clean up + activeTransactions.remove(batchId) + + logInfo(s"Rolled back transaction ${state.transactionId} for batch $batchId") + true + } catch { + case e: Exception => + logError(s"Failed to rollback transaction ${state.transactionId}: ${e.getMessage}", e) + false + } + } + + /** + * Record an acknowledgment for the current transaction. + * The acknowledgment won't be applied until the transaction commits. + */ + def recordAcknowledgment( + batchId: Long, + key: RecordKey, + ackType: AcknowledgmentType.AcknowledgmentType): Unit = { + val state = activeTransactions.get(batchId) + if (state != null) { + state.pendingAcks.put(key, ackType) + } + } + + /** + * Get the current phase of a transaction. + */ + def getTransactionPhase(batchId: Long): Option[Phase] = { + Option(activeTransactions.get(batchId)).map(_.phase) + } + + /** + * Recover incomplete transactions on startup. + * + * - STARTED transactions -> Abort + * - PREPARED transactions -> Attempt to complete or abort + * - COMMITTED/ABORTED -> Clean up + */ + def recoverTransactions(consumer: InternalKafkaShareConsumer): Unit = { + val incompleteTransactions = transactionLog.readIncomplete() + + incompleteTransactions.foreach { record => + record.phase match { + case Phase.STARTED => + logWarning(s"Found incomplete STARTED transaction ${record.transactionId}, aborting") + // STARTED transactions should be aborted + + case Phase.PREPARED => + logWarning(s"Found PREPARED transaction ${record.transactionId}, " + + "checking if output was committed") + // For PREPARED, we need to check if output was committed + // If yes, commit to Kafka; if no, abort + + case Phase.COMMITTED | Phase.ABORTED => + logInfo(s"Transaction ${record.transactionId} already completed, cleaning up") + transactionLog.markComplete(record.transactionId) + + case _ => + logWarning(s"Unknown transaction phase: ${record.phase}") + } + } + } + + /** + * Close the coordinator and clean up resources. + */ + def close(): Unit = { + // Abort any active transactions + activeTransactions.forEach { (batchId, state) => + if (state.phase != Phase.COMMITTED && state.phase != Phase.ABORTED) { + logWarning(s"Aborting incomplete transaction ${state.transactionId} on close") + transactionLog.write(TransactionRecord(state.transactionId, batchId, Phase.ABORTED)) + } + } + activeTransactions.clear() + transactionLog.close() + } +} + +object TwoPhaseCommitCoordinator { + + /** + * Transaction phases in 2PC protocol. + */ + sealed trait Phase + object Phase { + case object STARTED extends Phase // Transaction started, processing in progress + case object PREPARED extends Phase // Output staged, ready to commit + case object COMMITTED extends Phase // Transaction committed successfully + case object ABORTED extends Phase // Transaction aborted/rolled back + } + + /** + * Internal state for an active transaction. + */ + case class TransactionState( + transactionId: Int, + batchId: Long, + var phase: Phase, + startTime: Long, + var prepareTime: Long = 0L, + var commitTime: Long = 0L, + pendingAcks: ConcurrentHashMap[RecordKey, AcknowledgmentType.AcknowledgmentType]) + + /** + * Record persisted to transaction log. + */ + case class TransactionRecord( + transactionId: Int, + batchId: Long, + phase: Phase, + timestamp: Long = System.currentTimeMillis()) + + /** + * Handle for interacting with an active transaction. + */ + case class TransactionHandle( + transactionId: Int, + batchId: Long, + coordinator: TwoPhaseCommitCoordinator) { + + def recordAck(key: RecordKey, ackType: AcknowledgmentType.AcknowledgmentType): Unit = { + coordinator.recordAcknowledgment(batchId, key, ackType) + } + + def prepare(): Boolean = coordinator.prepareTransaction(batchId) + + def commit(consumer: InternalKafkaShareConsumer): Boolean = { + coordinator.commitTransaction(batchId, consumer) + } + + def rollback(consumer: InternalKafkaShareConsumer): Boolean = { + coordinator.rollbackTransaction(batchId, consumer) + } + } +} + +/** + * Write-ahead log for transaction state. + * Persists transaction records for recovery. + */ +private[exactlyonce] class TransactionLog(checkpointPath: String) extends Logging { + import TwoPhaseCommitCoordinator._ + + private val logPath = s"$checkpointPath/transactions" + private val records = new ConcurrentHashMap[Int, TransactionRecord]() + + // Load existing records on initialization + loadExisting() + + def write(record: TransactionRecord): Unit = { + records.put(record.transactionId, record) + // TODO: Persist to file system for durability + logDebug(s"Wrote transaction record: $record") + } + + def read(transactionId: Int): Option[TransactionRecord] = { + Option(records.get(transactionId)) + } + + def readIncomplete(): Seq[TransactionRecord] = { + records.values().asScala + .filter(r => r.phase != Phase.COMMITTED && r.phase != Phase.ABORTED) + .toSeq + } + + def markComplete(transactionId: Int): Unit = { + records.remove(transactionId) + } + + def close(): Unit = { + // TODO: Flush any pending writes + } + + private def loadExisting(): Unit = { + // TODO: Load from file system + } +} + +/** + * Helper for staging output writes. + */ +object StagingHelper extends Logging { + + /** + * Generate a staging path for a batch. + */ + def getStagingPath(basePath: String, batchId: Long): String = { + s"$basePath/_staging/batch_$batchId" + } + + /** + * Generate the final path for a batch. + */ + def getFinalPath(basePath: String, batchId: Long): String = { + s"$basePath/data/batch_$batchId" + } + + /** + * Move staged data to final location. + * This should be atomic (rename operation). + */ + def commitStaging(stagingPath: String, finalPath: String): Boolean = { + try { + import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.sql.SparkSession + + val spark = SparkSession.active + val hadoopConf = spark.sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + + val srcPath = new Path(stagingPath) + val dstPath = new Path(finalPath) + + // Ensure parent directory exists + fs.mkdirs(dstPath.getParent) + + // Atomic rename + fs.rename(srcPath, dstPath) + true + } catch { + case e: Exception => + logError(s"Failed to commit staging: ${e.getMessage}", e) + false + } + } + + /** + * Delete staged data on rollback. + */ + def rollbackStaging(stagingPath: String): Boolean = { + try { + import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.sql.SparkSession + + val spark = SparkSession.active + val hadoopConf = spark.sparkContext.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + + fs.delete(new Path(stagingPath), true) + true + } catch { + case e: Exception => + logError(s"Failed to rollback staging: ${e.getMessage}", e) + false + } + } +} + diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceFaultToleranceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceFaultToleranceSuite.scala new file mode 100644 index 0000000000000..2e4fdc2087276 --- /dev/null +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceFaultToleranceSuite.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import java.io.File +import java.nio.file.Files + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.kafka010.share.exactlyonce._ + +/** + * Integration tests for Kafka Share Source fault tolerance. + * + * These tests verify: + * 1. At-least-once semantics via acquisition lock expiry + * 2. Recovery from driver failure + * 3. Checkpoint-based deduplication for exactly-once + * 4. Two-phase commit coordinator + */ +class KafkaShareSourceFaultToleranceSuite extends SparkFunSuite { + + private var tempDir: File = _ + + override def beforeEach(): Unit = { + super.beforeEach() + tempDir = Files.createTempDirectory("kafka-share-test").toFile + } + + override def afterEach(): Unit = { + if (tempDir != null) { + deleteRecursively(tempDir) + } + super.afterEach() + } + + // ==================== Offset Tracking Tests ==================== + + test("KafkaShareSourceOffset - non-sequential offset tracking") { + // Simulate share group behavior where offsets are not sequential + val tp = new TopicPartition("test-topic", 0) + + // Share consumer might receive offsets 100, 105, 110 (not sequential) + val range = AcquiredRecordRange(Set(100L, 105L, 110L), 1000L, 31000L, 1) + val offset = KafkaShareSourceOffset("group", 1L, Map(tp -> range)) + + assert(offset.totalRecords === 3) + assert(offset.acquiredRecords(tp).contains(100L)) + assert(offset.acquiredRecords(tp).contains(105L)) + assert(offset.acquiredRecords(tp).contains(110L)) + // Offset 101 was given to another consumer + assert(!offset.acquiredRecords(tp).contains(101L)) + } + + test("KafkaShareSourceOffset - multiple partitions with non-sequential offsets") { + val tp0 = new TopicPartition("test-topic", 0) + val tp1 = new TopicPartition("test-topic", 1) + + // Consumer A got offsets from tp0: 100, 102, 104 + // Consumer B (different share member) got tp0: 101, 103, 105 + // This consumer got tp1: 200, 201 + val range0 = AcquiredRecordRange(Set(100L, 102L, 104L), 1000L, 31000L, 1) + val range1 = AcquiredRecordRange(Set(200L, 201L), 1000L, 31000L, 1) + + val offset = KafkaShareSourceOffset("group", 1L, Map(tp0 -> range0, tp1 -> range1)) + + assert(offset.totalRecords === 5) + assert(offset.partitionsWithRecords.size === 2) + } + + // ==================== Acknowledgment State Tests ==================== + + test("ShareInFlightBatch - track acknowledgments for non-sequential offsets") { + val batch = new ShareInFlightBatch(1L) + val tp = new TopicPartition("topic", 0) + + // Add non-sequential offsets + Seq(100L, 105L, 110L, 115L).foreach { offset => + val key = RecordKey(tp, offset) + batch.addRecord(ShareInFlightRecord( + recordKey = key, + record = null, + acquiredAt = 1000L, + lockExpiresAt = 31000L, + deliveryCount = 1, + acknowledgment = None + )) + } + + // Acknowledge some (simulating out-of-order processing) + batch.acknowledge(RecordKey(tp, 110L), AcknowledgmentType.ACCEPT) + batch.acknowledge(RecordKey(tp, 100L), AcknowledgmentType.ACCEPT) + batch.acknowledge(RecordKey(tp, 105L), AcknowledgmentType.RELEASE) // Needs retry + + val acks = batch.getAcknowledgmentsForCommit + assert(acks(tp).size === 3) + assert(acks(tp)(100L) === AcknowledgmentType.ACCEPT) + assert(acks(tp)(105L) === AcknowledgmentType.RELEASE) + assert(acks(tp)(110L) === AcknowledgmentType.ACCEPT) + + // 115 is still pending + assert(batch.pending === 1) + } + + // ==================== Recovery Tests ==================== + + test("ShareRecoveryManager - recovery from checkpoint") { + val checkpointPath = new File(tempDir, "checkpoint").getAbsolutePath + + // Create initial state + val manager1 = new ShareRecoveryManager("test-group", checkpointPath) + val writer = manager1.getCheckpointWriter + + val tp = new TopicPartition("topic", 0) + val range = AcquiredRecordRange(Set(100L, 105L), 1000L, 31000L, 1) + val offset = KafkaShareSourceOffset("test-group", 5L, Map(tp -> range)) + + writer.write(offset) + + // Simulate restart - create new manager + val manager2 = new ShareRecoveryManager("test-group", checkpointPath) + val recoveryState = manager2.recover() + + assert(!recoveryState.isCleanStart) + assert(recoveryState.startBatchId === 6L) // Resume from next batch + assert(recoveryState.lastCommittedOffset.isDefined) + assert(recoveryState.lastCommittedOffset.get.batchId === 5L) + + // Pending records should be flagged for release (lock expiry) + assert(recoveryState.pendingRecords.size === 2) + assert(recoveryState.pendingRecords.contains(RecordKey(tp, 100L))) + assert(recoveryState.pendingRecords.contains(RecordKey(tp, 105L))) + } + + test("ShareRecoveryManager - clean start with no checkpoint") { + val checkpointPath = new File(tempDir, "new-checkpoint").getAbsolutePath + val manager = new ShareRecoveryManager("test-group", checkpointPath) + + val recoveryState = manager.recover() + + assert(recoveryState.isCleanStart) + assert(recoveryState.startBatchId === 0L) + assert(recoveryState.lastCommittedOffset.isEmpty) + assert(recoveryState.pendingRecords.isEmpty) + } + + // ==================== Checkpoint Deduplication Tests ==================== + + test("CheckpointDedupManager - deduplicate redelivered records") { + val checkpointPath = new File(tempDir, "dedup-checkpoint").getAbsolutePath + val manager = new CheckpointDedupManager(checkpointPath) + manager.initialize() + + val key1 = RecordKey("topic", 0, 100L) + val key2 = RecordKey("topic", 0, 101L) + + // First delivery + manager.startBatch(1L) + assert(!manager.isProcessed(key1)) + assert(!manager.isProcessed(key2)) + + manager.markPending(key1) + manager.markPending(key2) + manager.markProcessed(key1) + manager.commitBatch(1L) + + // key1 is now processed, key2 was not completed (simulating failure) + + // Redelivery - key1 should be skipped + manager.startBatch(2L) + assert(manager.isProcessed(key1)) // Should be deduplicated + assert(!manager.isProcessed(key2)) // Was not committed, needs reprocessing + } + + test("CheckpointDedupManager - rollback releases pending records") { + val checkpointPath = new File(tempDir, "dedup-checkpoint2").getAbsolutePath + val manager = new CheckpointDedupManager(checkpointPath) + manager.initialize() + + val key1 = RecordKey("topic", 0, 100L) + val key2 = RecordKey("topic", 0, 101L) + + manager.startBatch(1L) + manager.markPending(key1) + manager.markPending(key2) + + // Simulate failure - rollback + val rolledBack = manager.rollbackBatch(1L) + + assert(rolledBack.size === 2) + assert(rolledBack.contains(key1)) + assert(rolledBack.contains(key2)) + + // These should NOT be marked as processed + assert(!manager.isProcessed(key1)) + assert(!manager.isProcessed(key2)) + } + + // ==================== Two-Phase Commit Tests ==================== + + test("TwoPhaseCommitCoordinator - successful commit flow") { + val checkpointPath = new File(tempDir, "2pc-checkpoint").getAbsolutePath + val coordinator = new TwoPhaseCommitCoordinator("test-group", checkpointPath) + + val handle = coordinator.beginTransaction(1L) + val key = RecordKey("topic", 0, 100L) + + handle.recordAck(key, AcknowledgmentType.ACCEPT) + + // Phase 1: Prepare + assert(handle.prepare()) + assert(coordinator.getTransactionPhase(1L).contains(TwoPhaseCommitCoordinator.Phase.PREPARED)) + + // Note: Cannot test Phase 2 without a real Kafka consumer + // In a real test, we would: + // assert(handle.commit(consumer)) + // assert(coordinator.getTransactionPhase(1L).isEmpty) // Completed and cleaned up + + coordinator.close() + } + + test("TwoPhaseCommitCoordinator - transaction phases") { + val checkpointPath = new File(tempDir, "2pc-checkpoint2").getAbsolutePath + val coordinator = new TwoPhaseCommitCoordinator("test-group", checkpointPath) + + // STARTED phase + val handle = coordinator.beginTransaction(1L) + assert(coordinator.getTransactionPhase(1L).contains(TwoPhaseCommitCoordinator.Phase.STARTED)) + + // Cannot prepare without being in STARTED state after already preparing + handle.recordAck(RecordKey("topic", 0, 100L), AcknowledgmentType.ACCEPT) + assert(handle.prepare()) + assert(coordinator.getTransactionPhase(1L).contains(TwoPhaseCommitCoordinator.Phase.PREPARED)) + + // Cannot prepare again + assert(!coordinator.prepareTransaction(1L)) + + coordinator.close() + } + + // ==================== Lock Expiry Tests ==================== + + test("AcquisitionLock - detect expired locks") { + val tp = new TopicPartition("topic", 0) + + // Lock acquired at time 1000, expires at 31000 (30 second timeout) + val range = AcquiredRecordRange(Set(100L, 105L), 1000L, 31000L, 1) + + assert(!range.isLockExpired(1000L)) + assert(!range.isLockExpired(15000L)) + assert(!range.isLockExpired(30999L)) + assert(range.isLockExpired(31000L)) + assert(range.isLockExpired(32000L)) + } + + test("ShareInFlightRecord - lock expiry detection") { + val key = RecordKey("topic", 0, 100L) + val record = ShareInFlightRecord( + recordKey = key, + record = null, + acquiredAt = 1000L, + lockExpiresAt = 31000L, + deliveryCount = 1, + acknowledgment = None + ) + + assert(!record.isLockExpired(1000L)) + assert(!record.isLockExpired(30000L)) + assert(record.isLockExpired(31000L)) + assert(record.isLockExpired(35000L)) + + assert(record.remainingLockTimeMs(1000L) === 30000L) + assert(record.remainingLockTimeMs(30000L) === 1000L) + assert(record.remainingLockTimeMs(31000L) === 0L) + assert(record.remainingLockTimeMs(32000L) === 0L) + } + + // ==================== Helper Methods ==================== + + private def deleteRecursively(file: File): Unit = { + if (file.isDirectory) { + file.listFiles().foreach(deleteRecursively) + } + file.delete() + } +} + diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceOffsetSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceOffsetSuite.scala new file mode 100644 index 0000000000000..81ad73318cf6f --- /dev/null +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/KafkaShareSourceOffsetSuite.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for KafkaShareSourceOffset and related classes. + */ +class KafkaShareSourceOffsetSuite extends SparkFunSuite { + + test("RecordKey - basic creation") { + val key = RecordKey("test-topic", 0, 100L) + + assert(key.topic === "test-topic") + assert(key.partition === 0) + assert(key.offset === 100L) + } + + test("RecordKey - from TopicPartition") { + val tp = new TopicPartition("test-topic", 2) + val key = RecordKey(tp, 200L) + + assert(key.topic === "test-topic") + assert(key.partition === 2) + assert(key.offset === 200L) + assert(key.topicPartition === tp) + } + + test("RecordKey - toString and fromString round-trip") { + val original = RecordKey("my-topic", 5, 12345L) + val str = original.toString + val parsed = RecordKey.fromString(str) + + assert(parsed === original) + } + + test("AcquiredRecordRange - basic creation") { + val offsets = Set(100L, 105L, 110L) + val range = AcquiredRecordRange(offsets, 1000L, 31000L, 1) + + assert(range.offsets === offsets) + assert(range.size === 3) + assert(range.minOffset === 100L) + assert(range.maxOffset === 110L) + assert(range.acquiredAt === 1000L) + assert(range.lockExpiresAt === 31000L) + } + + test("AcquiredRecordRange - contains offset") { + val range = AcquiredRecordRange(Set(100L, 105L, 110L), 1000L, 31000L, 1) + + assert(range.contains(100L)) + assert(range.contains(105L)) + assert(!range.contains(101L)) + assert(!range.contains(99L)) + } + + test("AcquiredRecordRange - isLockExpired") { + val range = AcquiredRecordRange(Set(100L), 1000L, 31000L, 1) + + assert(!range.isLockExpired(1000L)) + assert(!range.isLockExpired(30999L)) + assert(range.isLockExpired(31000L)) + assert(range.isLockExpired(32000L)) + } + + test("AcquiredRecordRange - withOffset adds offset") { + val range = AcquiredRecordRange(Set(100L), 1000L, 31000L, 1) + val updated = range.withOffset(200L) + + assert(updated.offsets === Set(100L, 200L)) + assert(range.offsets === Set(100L)) // Original unchanged + } + + test("AcquiredRecordRange - withoutOffset removes offset") { + val range = AcquiredRecordRange(Set(100L, 200L), 1000L, 31000L, 1) + val updated = range.withoutOffset(100L) + + assert(updated.offsets === Set(200L)) + assert(range.offsets === Set(100L, 200L)) // Original unchanged + } + + test("AcquiredRecordRange - empty creation") { + val range = AcquiredRecordRange.empty(1000L, 30000L) + + assert(range.offsets.isEmpty) + assert(range.size === 0) + assert(range.acquiredAt === 1000L) + assert(range.lockExpiresAt === 31000L) + } + + test("AcquiredRecordRange - fromRange creation") { + val range = AcquiredRecordRange.fromRange(100L, 104L, 1000L, 30000L) + + assert(range.offsets === Set(100L, 101L, 102L, 103L, 104L)) + assert(range.size === 5) + } + + test("KafkaShareSourceOffset - empty offset") { + val offset = KafkaShareSourceOffset.empty("test-group") + + assert(offset.shareGroupId === "test-group") + assert(offset.batchId === 0L) + assert(offset.acquiredRecords.isEmpty) + assert(!offset.hasRecords) + assert(offset.totalRecords === 0) + } + + test("KafkaShareSourceOffset - forBatch creation") { + val offset = KafkaShareSourceOffset.forBatch("test-group", 5L) + + assert(offset.shareGroupId === "test-group") + assert(offset.batchId === 5L) + assert(offset.acquiredRecords.isEmpty) + } + + test("KafkaShareSourceOffset - JSON serialization round-trip") { + val tp1 = new TopicPartition("topic1", 0) + val tp2 = new TopicPartition("topic2", 1) + + val range1 = AcquiredRecordRange(Set(100L, 105L, 110L), 1000L, 31000L, 1) + val range2 = AcquiredRecordRange(Set(200L, 210L), 2000L, 32000L, 2) + + val offset = KafkaShareSourceOffset( + shareGroupId = "my-share-group", + batchId = 10L, + acquiredRecords = Map(tp1 -> range1, tp2 -> range2) + ) + + val json = offset.json + val parsed = KafkaShareSourceOffset(json) + + assert(parsed.shareGroupId === offset.shareGroupId) + assert(parsed.batchId === offset.batchId) + assert(parsed.acquiredRecords.size === 2) + assert(parsed.acquiredRecords(tp1).offsets === range1.offsets) + assert(parsed.acquiredRecords(tp2).offsets === range2.offsets) + } + + test("KafkaShareSourceOffset - getAllRecordKeys") { + val tp1 = new TopicPartition("topic1", 0) + val tp2 = new TopicPartition("topic2", 1) + + val range1 = AcquiredRecordRange(Set(100L, 105L), 1000L, 31000L, 1) + val range2 = AcquiredRecordRange(Set(200L), 2000L, 32000L, 1) + + val offset = KafkaShareSourceOffset("group", 1L, Map(tp1 -> range1, tp2 -> range2)) + val keys = offset.getAllRecordKeys + + assert(keys.size === 3) + assert(keys.contains(RecordKey(tp1, 100L))) + assert(keys.contains(RecordKey(tp1, 105L))) + assert(keys.contains(RecordKey(tp2, 200L))) + } + + test("KafkaShareSourceOffset - totalRecords") { + val tp1 = new TopicPartition("topic1", 0) + val tp2 = new TopicPartition("topic2", 1) + + val range1 = AcquiredRecordRange(Set(100L, 105L, 110L), 1000L, 31000L, 1) + val range2 = AcquiredRecordRange(Set(200L, 210L), 2000L, 32000L, 1) + + val offset = KafkaShareSourceOffset("group", 1L, Map(tp1 -> range1, tp2 -> range2)) + + assert(offset.totalRecords === 5) + } + + test("KafkaShareSourceOffset - hasRecords and partitionsWithRecords") { + val tp1 = new TopicPartition("topic1", 0) + val tp2 = new TopicPartition("topic2", 1) + + val emptyOffset = KafkaShareSourceOffset.empty("group") + assert(!emptyOffset.hasRecords) + assert(emptyOffset.partitionsWithRecords.isEmpty) + + val range = AcquiredRecordRange(Set(100L), 1000L, 31000L, 1) + val nonEmptyOffset = KafkaShareSourceOffset("group", 1L, Map(tp1 -> range)) + + assert(nonEmptyOffset.hasRecords) + assert(nonEmptyOffset.partitionsWithRecords === Set(tp1)) + } + + test("KafkaShareSourceOffset - withAcquiredRecord") { + val tp = new TopicPartition("topic1", 0) + val offset = KafkaShareSourceOffset.empty("group") + + val updated = offset.withAcquiredRecord(tp, 100L, 1000L, 30000L) + + assert(updated.acquiredRecords.contains(tp)) + assert(updated.acquiredRecords(tp).contains(100L)) + assert(offset.acquiredRecords.isEmpty) // Original unchanged + } + + test("KafkaShareSourceOffset - withAcknowledgedRecord") { + val tp = new TopicPartition("topic1", 0) + val range = AcquiredRecordRange(Set(100L, 105L), 1000L, 31000L, 1) + val offset = KafkaShareSourceOffset("group", 1L, Map(tp -> range)) + + val updated = offset.withAcknowledgedRecord(tp, 100L) + + assert(updated.acquiredRecords(tp).offsets === Set(105L)) + } + + test("KafkaShareSourceOffset - withAcknowledgedRecord removes empty partition") { + val tp = new TopicPartition("topic1", 0) + val range = AcquiredRecordRange(Set(100L), 1000L, 31000L, 1) + val offset = KafkaShareSourceOffset("group", 1L, Map(tp -> range)) + + val updated = offset.withAcknowledgedRecord(tp, 100L) + + assert(!updated.acquiredRecords.contains(tp)) + } +} + diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/ShareInFlightRecordSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/ShareInFlightRecordSuite.scala new file mode 100644 index 0000000000000..3957401604c14 --- /dev/null +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/ShareInFlightRecordSuite.scala @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for ShareInFlightRecord and ShareInFlightBatch. + */ +class ShareInFlightRecordSuite extends SparkFunSuite { + + test("AcknowledgmentType - values exist") { + assert(AcknowledgmentType.ACCEPT.toString === "ACCEPT") + assert(AcknowledgmentType.RELEASE.toString === "RELEASE") + assert(AcknowledgmentType.REJECT.toString === "REJECT") + } + + test("ShareInFlightBatch - add and retrieve records") { + val batch = new ShareInFlightBatch(1L) + val key1 = RecordKey("topic", 0, 100L) + val key2 = RecordKey("topic", 0, 101L) + + // Create mock in-flight records + val record1 = createMockInFlightRecord(key1, 1000L, 31000L) + val record2 = createMockInFlightRecord(key2, 1000L, 31000L) + + batch.addRecord(record1) + batch.addRecord(record2) + + assert(batch.totalRecords === 2) + assert(batch.pending === 2) + assert(batch.accepted === 0) + assert(!batch.isComplete) + } + + test("ShareInFlightBatch - acknowledge with ACCEPT") { + val batch = new ShareInFlightBatch(1L) + val key = RecordKey("topic", 0, 100L) + val record = createMockInFlightRecord(key, 1000L, 31000L) + + batch.addRecord(record) + assert(batch.pending === 1) + + val result = batch.acknowledge(key, AcknowledgmentType.ACCEPT) + assert(result) + assert(batch.pending === 0) + assert(batch.accepted === 1) + assert(batch.isComplete) + } + + test("ShareInFlightBatch - acknowledge with RELEASE") { + val batch = new ShareInFlightBatch(1L) + val key = RecordKey("topic", 0, 100L) + val record = createMockInFlightRecord(key, 1000L, 31000L) + + batch.addRecord(record) + batch.acknowledge(key, AcknowledgmentType.RELEASE) + + assert(batch.pending === 0) + assert(batch.released === 1) + assert(batch.accepted === 0) + } + + test("ShareInFlightBatch - acknowledge with REJECT") { + val batch = new ShareInFlightBatch(1L) + val key = RecordKey("topic", 0, 100L) + val record = createMockInFlightRecord(key, 1000L, 31000L) + + batch.addRecord(record) + batch.acknowledge(key, AcknowledgmentType.REJECT) + + assert(batch.pending === 0) + assert(batch.rejected === 1) + } + + test("ShareInFlightBatch - cannot acknowledge same record twice") { + val batch = new ShareInFlightBatch(1L) + val key = RecordKey("topic", 0, 100L) + val record = createMockInFlightRecord(key, 1000L, 31000L) + + batch.addRecord(record) + assert(batch.acknowledge(key, AcknowledgmentType.ACCEPT)) + assert(!batch.acknowledge(key, AcknowledgmentType.RELEASE)) // Should return false + } + + test("ShareInFlightBatch - acknowledgeAllAsAccept") { + val batch = new ShareInFlightBatch(1L) + + (0 until 5).foreach { i => + val key = RecordKey("topic", 0, i.toLong) + batch.addRecord(createMockInFlightRecord(key, 1000L, 31000L)) + } + + assert(batch.pending === 5) + + val count = batch.acknowledgeAllAsAccept() + assert(count === 5) + assert(batch.pending === 0) + assert(batch.accepted === 5) + assert(batch.isComplete) + } + + test("ShareInFlightBatch - releaseAllPending") { + val batch = new ShareInFlightBatch(1L) + + (0 until 5).foreach { i => + val key = RecordKey("topic", 0, i.toLong) + batch.addRecord(createMockInFlightRecord(key, 1000L, 31000L)) + } + + // Acknowledge some + batch.acknowledge(RecordKey("topic", 0, 0L), AcknowledgmentType.ACCEPT) + batch.acknowledge(RecordKey("topic", 0, 1L), AcknowledgmentType.ACCEPT) + + assert(batch.pending === 3) + + val released = batch.releaseAllPending() + assert(released === 3) + assert(batch.pending === 0) + assert(batch.released === 3) + assert(batch.accepted === 2) + } + + test("ShareInFlightBatch - getPendingRecords") { + val batch = new ShareInFlightBatch(1L) + + (0 until 3).foreach { i => + val key = RecordKey("topic", 0, i.toLong) + batch.addRecord(createMockInFlightRecord(key, 1000L, 31000L)) + } + + batch.acknowledge(RecordKey("topic", 0, 1L), AcknowledgmentType.ACCEPT) + + val pending = batch.getPendingRecords + assert(pending.size === 2) + assert(pending.exists(_.recordKey.offset === 0L)) + assert(pending.exists(_.recordKey.offset === 2L)) + } + + test("ShareInFlightBatch - getRecordsByPartition") { + val batch = new ShareInFlightBatch(1L) + val tp0 = new TopicPartition("topic", 0) + val tp1 = new TopicPartition("topic", 1) + + batch.addRecord(createMockInFlightRecord(RecordKey(tp0, 100L), 1000L, 31000L)) + batch.addRecord(createMockInFlightRecord(RecordKey(tp0, 101L), 1000L, 31000L)) + batch.addRecord(createMockInFlightRecord(RecordKey(tp1, 200L), 1000L, 31000L)) + + val byPartition = batch.getRecordsByPartition + + assert(byPartition.size === 2) + assert(byPartition(tp0).size === 2) + assert(byPartition(tp1).size === 1) + } + + test("ShareInFlightBatch - getAcknowledgmentsForCommit") { + val batch = new ShareInFlightBatch(1L) + val tp = new TopicPartition("topic", 0) + + batch.addRecord(createMockInFlightRecord(RecordKey(tp, 100L), 1000L, 31000L)) + batch.addRecord(createMockInFlightRecord(RecordKey(tp, 101L), 1000L, 31000L)) + batch.addRecord(createMockInFlightRecord(RecordKey(tp, 102L), 1000L, 31000L)) + + batch.acknowledge(RecordKey(tp, 100L), AcknowledgmentType.ACCEPT) + batch.acknowledge(RecordKey(tp, 101L), AcknowledgmentType.RELEASE) + + val acks = batch.getAcknowledgmentsForCommit + + assert(acks.size === 1) // One partition + assert(acks(tp).size === 2) // Two acknowledged records + assert(acks(tp)(100L) === AcknowledgmentType.ACCEPT) + assert(acks(tp)(101L) === AcknowledgmentType.RELEASE) + } + + test("ShareInFlightBatch - clear") { + val batch = new ShareInFlightBatch(1L) + + (0 until 5).foreach { i => + val key = RecordKey("topic", 0, i.toLong) + batch.addRecord(createMockInFlightRecord(key, 1000L, 31000L)) + } + + batch.acknowledgeAllAsAccept() + assert(batch.totalRecords === 5) + + batch.clear() + assert(batch.totalRecords === 0) + assert(batch.pending === 0) + assert(batch.accepted === 0) + } + + test("ShareInFlightManager - create and get batch") { + val manager = new ShareInFlightManager() + + val batch1 = manager.createBatch(1L) + val batch2 = manager.createBatch(2L) + + assert(manager.getBatch(1L).isDefined) + assert(manager.getBatch(2L).isDefined) + assert(manager.getBatch(3L).isEmpty) + } + + test("ShareInFlightManager - getOrCreateBatch") { + val manager = new ShareInFlightManager() + + val batch1 = manager.getOrCreateBatch(1L) + val batch2 = manager.getOrCreateBatch(1L) // Should return same batch + + assert(batch1 eq batch2) + } + + test("ShareInFlightManager - removeBatch") { + val manager = new ShareInFlightManager() + + manager.createBatch(1L) + assert(manager.getBatch(1L).isDefined) + + val removed = manager.removeBatch(1L) + assert(removed.isDefined) + assert(manager.getBatch(1L).isEmpty) + } + + test("ShareInFlightManager - getIncompleteBatches") { + val manager = new ShareInFlightManager() + + val batch1 = manager.createBatch(1L) + val batch2 = manager.createBatch(2L) + + batch1.addRecord(createMockInFlightRecord(RecordKey("t", 0, 0L), 1000L, 31000L)) + batch2.addRecord(createMockInFlightRecord(RecordKey("t", 0, 1L), 1000L, 31000L)) + + batch1.acknowledgeAllAsAccept() // Complete batch1 + + val incomplete = manager.getIncompleteBatches + assert(incomplete.size === 1) + assert(incomplete.head.batchId === 2L) + } + + test("ShareInFlightManager - releaseAll") { + val manager = new ShareInFlightManager() + + val batch1 = manager.createBatch(1L) + val batch2 = manager.createBatch(2L) + + (0 until 3).foreach { i => + batch1.addRecord(createMockInFlightRecord(RecordKey("t", 0, i.toLong), 1000L, 31000L)) + } + (0 until 2).foreach { i => + batch2.addRecord(createMockInFlightRecord(RecordKey("t", 0, (i + 10).toLong), 1000L, 31000L)) + } + + val released = manager.releaseAll() + assert(released === 5) + assert(manager.totalPending === 0) + } + + test("ShareInFlightManager - totalPending") { + val manager = new ShareInFlightManager() + + val batch1 = manager.createBatch(1L) + val batch2 = manager.createBatch(2L) + + (0 until 3).foreach { i => + batch1.addRecord(createMockInFlightRecord(RecordKey("t", 0, i.toLong), 1000L, 31000L)) + } + (0 until 2).foreach { i => + batch2.addRecord(createMockInFlightRecord(RecordKey("t", 0, (i + 10).toLong), 1000L, 31000L)) + } + + assert(manager.totalPending === 5) + + batch1.acknowledgeAllAsAccept() + assert(manager.totalPending === 2) + } + + // Helper method to create mock in-flight records + private def createMockInFlightRecord( + key: RecordKey, + acquiredAt: Long, + lockExpiresAt: Long): ShareInFlightRecord = { + ShareInFlightRecord( + recordKey = key, + record = null, // We don't need actual ConsumerRecord for these tests + acquiredAt = acquiredAt, + lockExpiresAt = lockExpiresAt, + deliveryCount = 1, + acknowledgment = None + ) + } +} + diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/ShareStateBatchSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/ShareStateBatchSuite.scala new file mode 100644 index 0000000000000..e6ddc86e75222 --- /dev/null +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/share/ShareStateBatchSuite.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.share + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for ShareStateBatch and related classes. + */ +class ShareStateBatchSuite extends SparkFunSuite { + + test("ShareStateBatch - basic creation and properties") { + val batch = ShareStateBatch(100L, 150L, DeliveryState.ACQUIRED, 1) + + assert(batch.firstOffset === 100L) + assert(batch.lastOffset === 150L) + assert(batch.deliveryState === DeliveryState.ACQUIRED) + assert(batch.deliveryCount === 1) + assert(batch.recordCount === 51L) + } + + test("ShareStateBatch - contains offset check") { + val batch = ShareStateBatch(100L, 150L, DeliveryState.ACQUIRED, 1) + + assert(batch.contains(100L)) + assert(batch.contains(125L)) + assert(batch.contains(150L)) + assert(!batch.contains(99L)) + assert(!batch.contains(151L)) + } + + test("ShareStateBatch - single offset creation") { + val batch = ShareStateBatch.single(100L, DeliveryState.ACKNOWLEDGED, 2) + + assert(batch.firstOffset === 100L) + assert(batch.lastOffset === 100L) + assert(batch.recordCount === 1L) + assert(batch.deliveryCount === 2) + } + + test("ShareStateBatch - merge adjacent batches with same state") { + val batch1 = ShareStateBatch(100L, 120L, DeliveryState.ACKNOWLEDGED, 1) + val batch2 = ShareStateBatch(121L, 150L, DeliveryState.ACKNOWLEDGED, 1) + + assert(batch1.canMergeWith(batch2)) + val merged = batch1.mergeWith(batch2) + + assert(merged.firstOffset === 100L) + assert(merged.lastOffset === 150L) + assert(merged.recordCount === 51L) + } + + test("ShareStateBatch - cannot merge non-adjacent batches") { + val batch1 = ShareStateBatch(100L, 120L, DeliveryState.ACKNOWLEDGED, 1) + val batch2 = ShareStateBatch(125L, 150L, DeliveryState.ACKNOWLEDGED, 1) + + assert(!batch1.canMergeWith(batch2)) + } + + test("ShareStateBatch - cannot merge batches with different states") { + val batch1 = ShareStateBatch(100L, 120L, DeliveryState.ACKNOWLEDGED, 1) + val batch2 = ShareStateBatch(121L, 150L, DeliveryState.ACQUIRED, 1) + + assert(!batch1.canMergeWith(batch2)) + } + + test("ShareStateBatch - split at offset") { + val batch = ShareStateBatch(100L, 150L, DeliveryState.ACQUIRED, 1) + val (left, right) = batch.splitAt(125L) + + assert(left.firstOffset === 100L) + assert(left.lastOffset === 125L) + assert(right.firstOffset === 126L) + assert(right.lastOffset === 150L) + } + + test("ShareStateBatch - withState creates copy with new state") { + val original = ShareStateBatch(100L, 150L, DeliveryState.ACQUIRED, 1) + val acknowledged = original.withState(DeliveryState.ACKNOWLEDGED) + + assert(acknowledged.deliveryState === DeliveryState.ACKNOWLEDGED) + assert(original.deliveryState === DeliveryState.ACQUIRED) // Original unchanged + } + + test("ShareStateBatch - withIncrementedDeliveryCount") { + val original = ShareStateBatch(100L, 150L, DeliveryState.AVAILABLE, 1) + val incremented = original.withIncrementedDeliveryCount() + + assert(incremented.deliveryCount === 2) + assert(original.deliveryCount === 1) // Original unchanged + } + + test("ShareStateBatch - validation rejects invalid ranges") { + intercept[IllegalArgumentException] { + ShareStateBatch(150L, 100L, DeliveryState.ACQUIRED, 1) // lastOffset < firstOffset + } + + intercept[IllegalArgumentException] { + ShareStateBatch(100L, 150L, DeliveryState.ACQUIRED, -1) // Negative delivery count + } + } + + test("DeliveryState - fromByte conversion") { + assert(DeliveryState.fromByte(0) === DeliveryState.AVAILABLE) + assert(DeliveryState.fromByte(1) === DeliveryState.ACQUIRED) + assert(DeliveryState.fromByte(2) === DeliveryState.ACKNOWLEDGED) + assert(DeliveryState.fromByte(4) === DeliveryState.ARCHIVED) + + intercept[IllegalArgumentException] { + DeliveryState.fromByte(99) + } + } + + test("SharePartitionState - basic creation") { + val tp = new TopicPartition("test-topic", 0) + val batches = Seq( + ShareStateBatch(100L, 120L, DeliveryState.ACKNOWLEDGED, 1), + ShareStateBatch(121L, 150L, DeliveryState.ACQUIRED, 1) + ) + + val state = SharePartitionState(tp, 1, 5, 0, 100L, batches) + + assert(state.topicPartition === tp) + assert(state.snapshotEpoch === 1) + assert(state.stateEpoch === 5) + assert(state.startOffset === 100L) + assert(state.stateBatches.size === 2) + } + + test("SharePartitionState - getState for offset") { + val tp = new TopicPartition("test-topic", 0) + val batches = Seq( + ShareStateBatch(100L, 120L, DeliveryState.ACKNOWLEDGED, 1), + ShareStateBatch(121L, 150L, DeliveryState.ACQUIRED, 1) + ) + + val state = SharePartitionState(tp, 1, 5, 0, 100L, batches) + + assert(state.getState(99L).isEmpty) // Below startOffset + assert(state.getState(110L).isDefined) + assert(state.getState(110L).get.deliveryState === DeliveryState.ACKNOWLEDGED) + assert(state.getState(130L).isDefined) + assert(state.getState(130L).get.deliveryState === DeliveryState.ACQUIRED) + } + + test("SharePartitionState - getAcquiredOffsets") { + val tp = new TopicPartition("test-topic", 0) + val batches = Seq( + ShareStateBatch(100L, 110L, DeliveryState.ACKNOWLEDGED, 1), + ShareStateBatch(111L, 120L, DeliveryState.ACQUIRED, 1) + ) + + val state = SharePartitionState(tp, 1, 5, 0, 100L, batches) + val acquired = state.getAcquiredOffsets + + assert(acquired.size === 10) // 111 to 120 inclusive + assert(acquired.contains(111L)) + assert(acquired.contains(120L)) + assert(!acquired.contains(100L)) + } + + test("SharePartitionState - highestOffset") { + val tp = new TopicPartition("test-topic", 0) + val batches = Seq( + ShareStateBatch(100L, 120L, DeliveryState.ACKNOWLEDGED, 1), + ShareStateBatch(121L, 150L, DeliveryState.ACQUIRED, 1) + ) + + val state = SharePartitionState(tp, 1, 5, 0, 100L, batches) + + assert(state.highestOffset === 150L) + } + + test("SharePartitionState - empty batches") { + val tp = new TopicPartition("test-topic", 0) + val state = SharePartitionState(tp, 1, 5, 0, 100L, Seq.empty) + + assert(state.highestOffset === 99L) // startOffset - 1 + assert(state.getAcquiredOffsets.isEmpty) + } +} + diff --git a/docs/streaming/structured-streaming-kafka-share-groups.md b/docs/streaming/structured-streaming-kafka-share-groups.md new file mode 100644 index 0000000000000..d40ce476f32b9 --- /dev/null +++ b/docs/streaming/structured-streaming-kafka-share-groups.md @@ -0,0 +1,199 @@ +--- +layout: global +title: Structured Streaming + Kafka Share Groups Integration Guide (Kafka broker version 4.0 or higher) +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +Structured Streaming integration for Kafka 4.0+ Share Groups (KIP-932) to read data using queue semantics. + +## Overview + +Kafka Share Groups enable queue-style consumption where multiple consumers can receive records from the same partition concurrently. Unlike traditional consumer groups with exclusive partition assignment, share groups provide: + +- Per-record acknowledgment (ACCEPT, RELEASE, REJECT) +- Automatic redelivery on failure via acquisition locks +- Non-sequential offset tracking + +## Linking + + groupId = org.apache.spark + artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +Requires Kafka client 4.0+ with ShareConsumer API support. + +## Reading Data from Kafka Share Groups + +
+ +
+{% highlight scala %} +val df = spark + .readStream + .format("kafka-share") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.share.group.id", "my-share-group") + .option("subscribe", "topic1") + .load() + +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %} +
+ +
+{% highlight python %} +df = spark \ + .readStream \ + .format("kafka-share") \ + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") \ + .option("kafka.share.group.id", "my-share-group") \ + .option("subscribe", "topic1") \ + .load() + +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %} +
+ +
+{% highlight java %} +Dataset df = spark + .readStream() + .format("kafka-share") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("kafka.share.group.id", "my-share-group") + .option("subscribe", "topic1") + .load(); + +df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)"); +{% endhighlight %} +
+ +
+ +## Schema + +Each row has the following schema: + +| Column | Type | +|--------|------| +| key | binary | +| value | binary | +| topic | string | +| partition | int | +| offset | long | +| timestamp | long | +| timestampType | int | +| headers (optional) | array | + +## Configuration Options + +| Option | Required | Default | Description | +|--------|----------|---------|-------------| +| kafka.bootstrap.servers | yes | none | Kafka broker addresses | +| kafka.share.group.id | yes | none | Share group identifier | +| subscribe | yes | none | Comma-separated list of topics | +| subscribePattern | no | none | Topic pattern (alternative to subscribe) | +| kafka.share.acknowledgment.mode | no | implicit | `implicit` or `explicit` | +| kafka.share.exactly.once.strategy | no | none | `none`, `idempotent`, `two-phase-commit`, or `checkpoint-dedup` | +| kafka.share.parallelism | no | spark.default.parallelism | Number of concurrent consumers | +| kafka.share.lock.timeout.ms | no | 30000 | Acquisition lock timeout | +| includeHeaders | no | false | Include Kafka headers | + +## Acknowledgment Modes + +### Implicit Mode (Default) + +Records are automatically acknowledged as ACCEPT when the batch completes successfully. On failure, acquisition locks expire and Kafka redelivers records. + +### Explicit Mode + +Use `foreachBatch` to manually acknowledge records: + +
+
+{% highlight scala %} +df.writeStream + .foreachBatch { (batchDf: DataFrame, batchId: Long) => + // Process records + // Acknowledgments handled by user logic + } + .start() +{% endhighlight %} +
+
+ +## Fault Tolerance + +### At-Least-Once Semantics + +Default behavior ensures no data loss: + +1. Records acquired with time-limited lock (default 30s) +2. On task failure, lock expires automatically +3. Kafka transitions records from ACQUIRED to AVAILABLE +4. Records redelivered to other consumers + +### Recovery Scenarios + +| Failure | Recovery | Guarantee | +|---------|----------|-----------| +| Task failure | Lock expires, records redelivered | At-least-once | +| Driver failure | Resume from checkpoint | At-least-once | +| Kafka broker failure | WAL replay on new leader | No data loss | + +## Exactly-Once Strategies + +### Idempotent Sink + +Deduplicate at sink using record coordinates (topic, partition, offset): + +{% highlight scala %} +df.writeStream + .format("delta") + .option("mergeSchema", "true") + .start("/path/to/table") +{% endhighlight %} + +Works with Delta Lake, JDBC with UPSERT, or any idempotent sink. + +### Checkpoint Deduplication + +Track processed records in Spark checkpoint: + +{% highlight scala %} +spark.readStream + .format("kafka-share") + .option("kafka.share.exactly.once.strategy", "checkpoint-dedup") + ... +{% endhighlight %} + +### Two-Phase Commit + +Atomic commit of output and acknowledgments. Higher latency but true exactly-once: + +{% highlight scala %} +spark.readStream + .format("kafka-share") + .option("kafka.share.exactly.once.strategy", "two-phase-commit") + ... +{% endhighlight %} + + +## Deploying + +Same as traditional Kafka source. Include the `spark-sql-kafka-0-10` artifact and its dependencies when deploying your application. + diff --git a/docs/structured-streaming-kafka-share-groups.md b/docs/structured-streaming-kafka-share-groups.md new file mode 100644 index 0000000000000..f0f4eb747b2d9 --- /dev/null +++ b/docs/structured-streaming-kafka-share-groups.md @@ -0,0 +1,23 @@ +--- +layout: global +title: Structured Streaming + Kafka Share Groups Integration Guide (Kafka broker version 4.0 or higher) +redirect: ./streaming/structured-streaming-kafka-share-groups.html +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + +This page has moved [here](./streaming/structured-streaming-kafka-share-groups.html). +