From c4d81eeb5a921b2d9e6365cc8860c67da2f1806a Mon Sep 17 00:00:00 2001 From: Radek Stankiewicz Date: Wed, 3 Jun 2026 11:34:29 +0200 Subject: [PATCH] Improve sharding for bounded pcolleciton, to better control concurrent connections. this will keep elements for same destination close to each other and shard them. For single table write it's same behaviour, for dynamic destination it will improve reduce amount of connections used --- .../sdk/io/gcp/bigquery/StorageApiLoads.java | 57 ++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index 007bba5c6cdf..05f3dcb90be5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -30,6 +30,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; @@ -38,6 +39,7 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Redistribute; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.errorhandling.BadRecord; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.ThrowingBadRecordRouter; @@ -379,12 +381,19 @@ public WriteResult expandUntriggered( PCollection> successfulConvertedRows = convertMessagesResult.get(successfulConvertedRowsTag); - if (numShards > 0) { + boolean streaming = input.getPipeline().getOptions().as(StreamingOptions.class).isStreaming(); + if (numShards > 0 && streaming) { successfulConvertedRows = successfulConvertedRows.apply( "ResdistibuteNumShards", Redistribute.>arbitrarily() .withNumBuckets(numShards)); + } else if (numShards > 0 && !streaming) { + successfulConvertedRows = + successfulConvertedRows + .apply("AddKeyWithSideInputs", ParDo.of(new AddShardKeyFn<>(numShards))) + .apply("RedistributeNumShards", Redistribute.byKey()) + .apply("Remove shard", Values.create()); } PCollectionTuple writeRecordsResult = @@ -457,6 +466,52 @@ private void addErrorCollections( } } + /** + * A {@link DoFn} that applies a composite sharding key to incoming records to optimize BigQuery + * Storage API throughput. + * + *

This transform manages the balance between connection count (resource overhead) and + * processing parallelism by distributing data across {@code numShards} buckets: + * + *

    + *
  • Data Affinity: By using a composite key {@code KV}, this transform + * ensures that all records for a specific destination (table) are grouped into specific + * shard buckets. This allows downstream transforms to maintain stable {@code + * StreamConnection} sessions for each destination, minimizing connection thrashing. + *
  • Parallel Throughput: By appending a pseudo-random integer shard index, this + * transform allows the runner to distribute the records for a single destination across up + * to {@code numShards} parallel streams, parallelizing the write throughput of "hot" + * (high-volume) destinations. + *
  • Concurrency Scaling: The {@code numShards} parameter acts as the parallelism + * multiplier per destination. The total potential concurrency across the pipeline is {@code + * numShards * total_destinations}, allowing users to scale write throughput by increasing + * {@code numShards} for bottlenecked tables. + *
+ * + *

The output structure is {@code KV, KV>}. Downstream, + * {@link Redistribute#byKey()} uses this composite key to partition the data, ensuring the runner + * effectively balances load while respecting the per-destination parallelism limits configured + * here. + */ + private static class AddShardKeyFn + extends DoFn< + KV, + KV, KV>> { + private final int shardBound; + + public AddShardKeyFn(int numShards) { + this.shardBound = Math.max(1, numShards); + } + + @ProcessElement + public void processElement( + @Element KV element, + OutputReceiver, KV>> outputReceiver) { + int shard = ThreadLocalRandom.current().nextInt(shardBound); + outputReceiver.output(KV.of(KV.of(element.getKey(), shard), element)); + } + } + private static class ConvertInsertErrorToBadRecord extends DoFn {