Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition;
import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
Expand All @@ -38,6 +39,7 @@
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Redistribute;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter;
import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.ThrowingBadRecordRouter;
Expand Down Expand Up @@ -379,12 +381,19 @@ public WriteResult expandUntriggered(
PCollection<KV<DestinationT, StorageApiWritePayload>> successfulConvertedRows =
convertMessagesResult.get(successfulConvertedRowsTag);

if (numShards > 0) {
boolean streaming = input.getPipeline().getOptions().as(StreamingOptions.class).isStreaming();
if (numShards > 0 && streaming) {
successfulConvertedRows =
successfulConvertedRows.apply(
"ResdistibuteNumShards",
Redistribute.<KV<DestinationT, StorageApiWritePayload>>arbitrarily()
.withNumBuckets(numShards));
} else if (numShards > 0 && !streaming) {
successfulConvertedRows =
successfulConvertedRows
.apply("AddKeyWithSideInputs", ParDo.of(new AddShardKeyFn<>(numShards)))
.apply("RedistributeNumShards", Redistribute.byKey())
.apply("Remove shard", Values.create());
}

PCollectionTuple writeRecordsResult =
Expand Down Expand Up @@ -457,6 +466,52 @@ private void addErrorCollections(
}
}

/**
* A {@link DoFn} that applies a composite sharding key to incoming records to optimize BigQuery
* Storage API throughput.
*
* <p>This transform manages the balance between connection count (resource overhead) and
* processing parallelism by distributing data across {@code numShards} buckets:
*
* <ul>
* <li><b>Data Affinity:</b> By using a composite key {@code KV<DestT, Integer>}, this transform
* ensures that all records for a specific destination (table) are grouped into specific
* shard buckets. This allows downstream transforms to maintain stable {@code
* StreamConnection} sessions for each destination, minimizing connection thrashing.
* <li><b>Parallel Throughput:</b> By appending a pseudo-random integer shard index, this
* transform allows the runner to distribute the records for a single destination across up
* to {@code numShards} parallel streams, parallelizing the write throughput of "hot"
* (high-volume) destinations.
* <li><b>Concurrency Scaling:</b> The {@code numShards} parameter acts as the parallelism
* multiplier per destination. The total potential concurrency across the pipeline is {@code
* numShards * total_destinations}, allowing users to scale write throughput by increasing
* {@code numShards} for bottlenecked tables.
* </ul>
*
* <p>The output structure is {@code KV<KV<DestT, Integer>, KV<DestT, Payload>>}. Downstream,
* {@link Redistribute#byKey()} uses this composite key to partition the data, ensuring the runner
* effectively balances load while respecting the per-destination parallelism limits configured
* here.
*/
private static class AddShardKeyFn<DestT, ElemT>
extends DoFn<
KV<DestT, StorageApiWritePayload>,
KV<KV<DestT, Integer>, KV<DestT, StorageApiWritePayload>>> {
private final int shardBound;

public AddShardKeyFn(int numShards) {
this.shardBound = Math.max(1, numShards);
}

@ProcessElement
public void processElement(
@Element KV<DestT, StorageApiWritePayload> element,
OutputReceiver<KV<KV<DestT, Integer>, KV<DestT, StorageApiWritePayload>>> outputReceiver) {
int shard = ThreadLocalRandom.current().nextInt(shardBound);
outputReceiver.output(KV.of(KV.of(element.getKey(), shard), element));
}
}
Comment thread
stankiewicz marked this conversation as resolved.

private static class ConvertInsertErrorToBadRecord
extends DoFn<BigQueryStorageApiInsertError, BadRecord> {

Expand Down
Loading