diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index 007bba5c6cdf..05f3dcb90be5 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -30,6 +30,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; @@ -38,6 +39,7 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Redistribute; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.errorhandling.BadRecord; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter; import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter.ThrowingBadRecordRouter; @@ -379,12 +381,19 @@ public WriteResult expandUntriggered( PCollection> successfulConvertedRows = convertMessagesResult.get(successfulConvertedRowsTag); - if (numShards > 0) { + boolean streaming = input.getPipeline().getOptions().as(StreamingOptions.class).isStreaming(); + if (numShards > 0 && streaming) { successfulConvertedRows = successfulConvertedRows.apply( "ResdistibuteNumShards", Redistribute.>arbitrarily() .withNumBuckets(numShards)); + } else if (numShards > 0 && !streaming) { + successfulConvertedRows = + successfulConvertedRows + .apply("AddKeyWithSideInputs", ParDo.of(new AddShardKeyFn<>(numShards))) + .apply("RedistributeNumShards", Redistribute.byKey()) + .apply("Remove shard", Values.create()); } PCollectionTuple writeRecordsResult = @@ -457,6 +466,52 @@ private void addErrorCollections( } } + /** + * A {@link DoFn} that applies a composite sharding key to incoming records to optimize BigQuery + * Storage API throughput. + * + *

This transform manages the balance between connection count (resource overhead) and + * processing parallelism by distributing data across {@code numShards} buckets: + * + *

+ * + *

The output structure is {@code KV, KV>}. Downstream, + * {@link Redistribute#byKey()} uses this composite key to partition the data, ensuring the runner + * effectively balances load while respecting the per-destination parallelism limits configured + * here. + */ + private static class AddShardKeyFn + extends DoFn< + KV, + KV, KV>> { + private final int shardBound; + + public AddShardKeyFn(int numShards) { + this.shardBound = Math.max(1, numShards); + } + + @ProcessElement + public void processElement( + @Element KV element, + OutputReceiver, KV>> outputReceiver) { + int shard = ThreadLocalRandom.current().nextInt(shardBound); + outputReceiver.output(KV.of(KV.of(element.getKey(), shard), element)); + } + } + private static class ConvertInsertErrorToBadRecord extends DoFn {