From b270e53c4f42f7b9e49b240c9f104c92fd983d71 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Tue, 7 Oct 2025 15:50:09 +0700 Subject: [PATCH 1/5] add support for s3 multipart upload --- pom.xml | 2 +- .../payloadoffloading/AwsManagedCmk.java | 6 + .../amazon/payloadoffloading/CustomerKey.java | 7 + .../MultipartPayloadStore.java | 27 +++ .../MultipartPayloadStoreAsync.java | 29 +++ .../PayloadStorageAsyncConfiguration.java | 20 ++ .../PayloadStorageConfiguration.java | 20 ++ .../PayloadStorageConfigurationBase.java | 62 +++++- .../amazon/payloadoffloading/S3AsyncDao.java | 91 +++++++++ .../S3BackedMultipartPayloadStore.java | 27 +++ .../S3BackedMultipartPayloadStoreAsync.java | 30 +++ .../amazon/payloadoffloading/S3Dao.java | 78 ++++++++ .../ServerSideEncryptionStrategy.java | 2 + .../payloadoffloading/AwsManagedCmkTest.java | 14 ++ .../payloadoffloading/CustomerKeyTest.java | 13 ++ .../PayloadStorageAsyncConfigurationTest.java | 47 +++++ .../PayloadStorageConfigurationTest.java | 47 +++++ .../payloadoffloading/S3AsyncDaoTest.java | 183 +++++++++++++++++- ...3BackedMultipartPayloadStoreAsyncTest.java | 127 ++++++++++++ .../S3BackedMultipartPayloadStoreTest.java | 102 ++++++++++ .../amazon/payloadoffloading/S3DaoTest.java | 157 +++++++++++++++ 21 files changed, 1083 insertions(+), 8 deletions(-) create mode 100644 src/main/java/software/amazon/payloadoffloading/MultipartPayloadStore.java create mode 100644 src/main/java/software/amazon/payloadoffloading/MultipartPayloadStoreAsync.java create mode 100644 src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStore.java create mode 100644 src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsync.java create mode 100644 src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsyncTest.java create mode 100644 src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreTest.java diff --git a/pom.xml b/pom.xml index 0cea9a4..0cc01eb 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ software.amazon.payloadoffloading payloadoffloading-common - 2.2.0 + 2.3.0 jar Payload offloading common library for AWS Common library between extended Amazon AWS clients to save payloads up to 2GB on Amazon S3. diff --git a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java index ae291f4..e4d2109 100644 --- a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java +++ b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java @@ -1,6 +1,7 @@ package software.amazon.payloadoffloading; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; public class AwsManagedCmk implements ServerSideEncryptionStrategy { @@ -8,4 +9,9 @@ public class AwsManagedCmk implements ServerSideEncryptionStrategy { public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) { putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); } + + @Override + public void decorate(CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder) { + createMultipartUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + } } diff --git a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java index 7f62d49..bcc3175 100644 --- a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java +++ b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java @@ -1,6 +1,7 @@ package software.amazon.payloadoffloading; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; public class CustomerKey implements ServerSideEncryptionStrategy { @@ -15,4 +16,10 @@ public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) { putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); putObjectRequestBuilder.ssekmsKeyId(awsKmsKeyId); } + + @Override + public void decorate(CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder) { + createMultipartUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + createMultipartUploadRequestBuilder.ssekmsKeyId(awsKmsKeyId); + } } diff --git a/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStore.java new file mode 100644 index 0000000..ba86f63 --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStore.java @@ -0,0 +1,27 @@ +package software.amazon.payloadoffloading; + +/** + * Optional extension interface for {@link PayloadStore} implementations that can perform multipart + * upload and streaming retrieval for very large payloads to reduce memory pressure. + */ +public interface MultipartPayloadStore extends PayloadStore { + + /** + * Store a payload using streaming/multipart-aware logic. Default falls back to normal storage. + * @param payload UTF-8 text payload + * @param s3Key pre-generated S3 key (or equivalent object key) to use + * @return pointer string (JSON) referencing stored payload + */ + default String storeOriginalPayloadMultipart(String payload, String s3Key) { + return storeOriginalPayload(payload, s3Key); + } + + /** + * Retrieve a payload potentially using streaming/multipart-aware logic. Default falls back to normal retrieval. + * @param payloadPointer pointer JSON string + * @return original payload content + */ + default String getOriginalPayloadMultipart(String payloadPointer) { + return getOriginalPayload(payloadPointer); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStoreAsync.java new file mode 100644 index 0000000..0c7d01a --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStoreAsync.java @@ -0,0 +1,29 @@ +package software.amazon.payloadoffloading; + +import java.util.concurrent.CompletableFuture; + +/** + * Optional extension interface for {@link PayloadStoreAsync} implementations that can perform multipart + * upload and streaming retrieval for very large payloads to reduce memory pressure. + */ +public interface MultipartPayloadStoreAsync extends PayloadStoreAsync { + + /** + * Store payload content using multipart semantics when possible. + * @param payload UTF-8 text payload + * @param s3Key object key to use + * @return future pointer string (JSON) referencing stored payload + */ + default CompletableFuture storeOriginalPayloadMultipart(String payload, String s3Key) { + return storeOriginalPayload(payload, s3Key); + } + + /** + * Retrieve payload using streaming/multipart logic when available. Default delegates to normal method. + * @param payloadPointer pointer JSON string + * @return future original payload content + */ + default CompletableFuture getOriginalPayloadMultipart(String payloadPointer) { + return getOriginalPayload(payloadPointer); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java index 3bd8d08..2cf3edd 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java @@ -150,4 +150,24 @@ public PayloadStorageAsyncConfiguration withObjectCannedACL(ObjectCannedACL obje setObjectCannedACL(objectCannedACL); return this; } + + /** + * Enables or disables multipart upload support. + * @param enabled true to enable multipart uploads when threshold exceeded. + * @return updated configuration + */ + public PayloadStorageAsyncConfiguration withMultipartUploadEnabled(boolean enabled) { + setMultipartUploadEnabled(enabled); + return this; + } + + /** + * Sets the multipart upload threshold (in bytes). Only used when multipart upload is enabled. + * @param threshold threshold in bytes (must be >0) otherwise default (5MB) is applied. + * @return updated configuration + */ + public PayloadStorageAsyncConfiguration withMultipartUploadThreshold(int threshold) { + setMultipartUploadThreshold(threshold); + return this; + } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java index 9ab3c10..cdcd9eb 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java @@ -148,4 +148,24 @@ public PayloadStorageConfiguration withObjectCannedACL(ObjectCannedACL objectCan setObjectCannedACL(objectCannedACL); return this; } + + /** + * Enables or disables multipart upload support. + * @param enabled true to enable multipart uploads when threshold exceeded. + * @return updated configuration + */ + public PayloadStorageConfiguration withMultipartUploadEnabled(boolean enabled) { + setMultipartUploadEnabled(enabled); + return this; + } + + /** + * Sets the multipart upload threshold (in bytes). Only used when multipart upload is enabled. + * @param threshold threshold in bytes (must be >0) otherwise default (5MB) is applied. + * @return updated configuration + */ + public PayloadStorageConfiguration withMultipartUploadThreshold(int threshold) { + setMultipartUploadThreshold(threshold); + return this; + } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java index 7d08746..d9a353a 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java @@ -4,7 +4,6 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.NotThreadSafe; import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; /** @@ -22,6 +21,12 @@ public abstract class PayloadStorageConfigurationBase { private int payloadSizeThreshold = 0; private boolean alwaysThroughS3 = false; private boolean payloadSupport = false; + // Enable multipart upload support (opt-in, default false) + private boolean multipartUploadEnabled = false; + // Threshold (bytes) above which multipart should be attempted when enabled (default 5MB) + private int multipartUploadThreshold = 5 * 1024 * 1024; + // Multipart part size (bytes). Each part (except last) must be >=5MB. Default 5MB. + private int multipartUploadPartSize = 5 * 1024 * 1024; /** * This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS. */ @@ -44,6 +49,9 @@ public PayloadStorageConfigurationBase(PayloadStorageConfigurationBase other) { this.payloadSizeThreshold = other.getPayloadSizeThreshold(); this.serverSideEncryptionStrategy = other.getServerSideEncryptionStrategy(); this.objectCannedACL = other.getObjectCannedACL(); + this.multipartUploadEnabled = other.isMultipartUploadEnabled(); + this.multipartUploadThreshold = other.getMultipartUploadThreshold(); + this.multipartUploadPartSize = other.getMultipartUploadPartSize(); } /** @@ -175,4 +183,56 @@ public boolean isObjectCannedACLDefined() { public ObjectCannedACL getObjectCannedACL() { return objectCannedACL; } + + /** + * Checks whether multipart upload support is enabled. Default: false. + * + * @return true if multipart upload support is enabled. + */ + public boolean isMultipartUploadEnabled() { return multipartUploadEnabled; } + + /** + * Enable or disable multipart upload support. When enabling, callers should ensure they provide a PayloadStore + * implementation capable of multipart operations; otherwise normal single PUT behavior will be used. + * + * @param multipartUploadEnabled flag to enable/disable multipart support. + */ + public void setMultipartUploadEnabled(boolean multipartUploadEnabled) { this.multipartUploadEnabled = multipartUploadEnabled; } + + /** + * Gets the multipart upload threshold in bytes. Default 5MB. + * + * @return threshold in bytes. + */ + public int getMultipartUploadThreshold() { return multipartUploadThreshold; } + + /** + * Sets the multipart upload threshold in bytes. Values less than or equal to zero will reset to default (5MB). + * + * @param multipartUploadThreshold threshold in bytes + */ + public void setMultipartUploadThreshold(int multipartUploadThreshold) { + if (multipartUploadThreshold <= 0) { + this.multipartUploadThreshold = 5 * 1024 * 1024; + } else { + this.multipartUploadThreshold = multipartUploadThreshold; + } + } + + /** + * Gets the configured multipart upload part size (bytes). Default 5MB. + */ + public int getMultipartUploadPartSize() { return multipartUploadPartSize; } + + /** + * Sets the multipart upload part size (bytes). Values < 5MB are rounded up to 5MB. + */ + public void setMultipartUploadPartSize(int partSize) { + int min = 5 * 1024 * 1024; + if (partSize < min) { + this.multipartUploadPartSize = min; + } else { + this.multipartUploadPartSize = partSize; + } + } } \ No newline at end of file diff --git a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java index a0dc868..60db795 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java @@ -15,6 +15,18 @@ import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; + +import java.util.Arrays; +import java.util.List; +import java.util.ArrayList; +import java.nio.charset.StandardCharsets; + /** * Dao layer to access S3. @@ -115,4 +127,83 @@ public CompletableFuture deletePayloadFromS3(String s3BucketName, String s return null; }); } + + + public CompletableFuture storeTextMultipartInS3(String bucket, String key, String payloadContentStr, int partSize, int multipartThreshold) { + byte[] data = payloadContentStr.getBytes(StandardCharsets.UTF_8); + if (data.length < multipartThreshold) { + return storeTextInS3(bucket, key, payloadContentStr); + } + + CreateMultipartUploadRequest.Builder createBuilder = CreateMultipartUploadRequest.builder() + .bucket(bucket) + .key(key); + if (objectCannedACL != null) { + createBuilder.acl(objectCannedACL); + } + // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html + if (serverSideEncryptionStrategy != null) { + serverSideEncryptionStrategy.decorate(createBuilder); + } + + return s3Client.createMultipartUpload(createBuilder.build()) + .thenCompose(createResp -> { + String uploadId = createResp.uploadId(); + int partCount = (int) Math.ceil((double) data.length / partSize); + List> partFutures = new ArrayList<>(partCount); + + int offset = 0; + int partNumber = 1; + while (offset < data.length) { + int currSize = Math.min(partSize, data.length - offset); + byte[] slice = Arrays.copyOfRange(data, offset, offset + currSize); + offset += currSize; + final int thisPartNumber = partNumber++; + + UploadPartRequest uploadPartRequest = UploadPartRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .partNumber(thisPartNumber) + .contentLength((long) slice.length) + .build(); + + CompletableFuture fut = s3Client.uploadPart(uploadPartRequest, AsyncRequestBody.fromBytes(slice)) + .thenApply(resp -> CompletedPart.builder().partNumber(thisPartNumber).eTag(resp.eTag()).build()); + partFutures.add(fut); + } + + CompletableFuture all = CompletableFuture.allOf(partFutures.toArray(new CompletableFuture[0])); + return all.thenCompose(v -> { + List completed = new ArrayList<>(partFutures.size()); + for (CompletableFuture f : partFutures) { + completed.add(f.join()); + } + CompletedMultipartUpload cmu = CompletedMultipartUpload.builder().parts(completed).build(); + CompleteMultipartUploadRequest completeReq = CompleteMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .multipartUpload(cmu) + .build(); + return s3Client.completeMultipartUpload(completeReq) + .handle((vv, t) -> { + if (t != null) { + s3Client.abortMultipartUpload(AbortMultipartUploadRequest.builder().bucket(bucket).key(key).uploadId(uploadId).build()); + throw new CompletionException(t); + } + LOG.info("S3 multipart object created, Bucket name: " + bucket + ", Object key: " + key + "."); + return (Void) null; + }); + }); + }).exceptionally(t -> { + Throwable cause = Util.unwrapFutureException(t); + if (cause instanceof SdkException) { + String errorMessage = "Failed to store the message content in an S3 multipart object."; + LOG.error(errorMessage, cause); + throw SdkException.create(errorMessage, cause); + } + throw new CompletionException(cause); + }); + } } diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStore.java new file mode 100644 index 0000000..ef3b22a --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStore.java @@ -0,0 +1,27 @@ +package software.amazon.payloadoffloading; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class S3BackedMultipartPayloadStore extends S3BackedPayloadStore implements MultipartPayloadStore { + private static final Logger LOG = LoggerFactory.getLogger(S3BackedMultipartPayloadStore.class); + private final S3Dao multipartDao; + private final int partSize; + private final int threshold; + private final String s3BucketName; + + public S3BackedMultipartPayloadStore(S3Dao s3Dao, String s3BucketName, int partSize, int threshold) { + super(s3Dao, s3BucketName); + this.multipartDao = s3Dao; + this.partSize = partSize; + this.threshold = threshold; + this.s3BucketName = s3BucketName; + } + + @Override + public String storeOriginalPayloadMultipart(String payload, String s3Key) { + multipartDao.storeTextMultipartInS3(s3BucketName, s3Key, payload, partSize, threshold); + LOG.info("S3 multipart object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); + return new PayloadS3Pointer(s3BucketName, s3Key).toJson(); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsync.java new file mode 100644 index 0000000..b5f3ba0 --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsync.java @@ -0,0 +1,30 @@ +package software.amazon.payloadoffloading; + +import java.util.concurrent.CompletableFuture; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class S3BackedMultipartPayloadStoreAsync extends S3BackedPayloadStoreAsync implements MultipartPayloadStoreAsync { + private static final Logger LOG = LoggerFactory.getLogger(S3BackedMultipartPayloadStoreAsync.class); + private final S3AsyncDao multipartDao; + private final String s3BucketName; + private final int partSize; + private final int threshold; + + public S3BackedMultipartPayloadStoreAsync(S3AsyncDao s3Dao, String s3BucketName, int partSize, int threshold) { + super(s3Dao, s3BucketName); + this.multipartDao = s3Dao; + this.s3BucketName = s3BucketName; + this.partSize = partSize; + this.threshold = threshold; + } + + @Override + public CompletableFuture storeOriginalPayloadMultipart(String payload, String s3Key) { + return multipartDao.storeTextMultipartInS3(s3BucketName, s3Key, payload, partSize, threshold) + .thenApply(v -> { + LOG.info("S3 multipart object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); + return new PayloadS3Pointer(s3BucketName, s3Key).toJson(); + }); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java index 2b03dd5..b63d810 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java @@ -12,9 +12,21 @@ import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.utils.IoUtils; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; /** * Dao layer to access S3. @@ -104,4 +116,70 @@ public void deletePayloadFromS3(String s3BucketName, String s3Key) { LOG.info("S3 object deleted, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); } + + public void storeTextMultipartInS3(String bucket, String key, String payloadContentStr, int partSize, int multipartThreshold) { + byte[] data = payloadContentStr.getBytes(StandardCharsets.UTF_8); + if (data.length < multipartThreshold) { + storeTextInS3(bucket, key, payloadContentStr); + return; + } + + CreateMultipartUploadRequest.Builder createBuilder = CreateMultipartUploadRequest.builder() + .bucket(bucket) + .key(key); + if (objectCannedACL != null) { + createBuilder.acl(objectCannedACL); + } + // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html + if (serverSideEncryptionStrategy != null) { + serverSideEncryptionStrategy.decorate(createBuilder); + } + + String uploadId = null; + try { + CreateMultipartUploadResponse createResp = s3Client.createMultipartUpload(createBuilder.build()); + uploadId = createResp.uploadId(); + + int partCount = (int) Math.ceil((double) data.length / partSize); + List completedParts = new ArrayList<>(partCount); + + for (int partNumber = 1, offset = 0; offset < data.length; partNumber++) { + int currSize = Math.min(partSize, data.length - offset); + byte[] slice = Arrays.copyOfRange(data, offset, offset + currSize); + offset += currSize; + + UploadPartRequest uploadPartRequest = UploadPartRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .partNumber(partNumber) + .contentLength((long) slice.length) + .build(); + + UploadPartResponse upr = s3Client.uploadPart(uploadPartRequest, RequestBody.fromBytes(slice)); + completedParts.add(CompletedPart.builder().partNumber(partNumber).eTag(upr.eTag()).build()); + } + + CompletedMultipartUpload completed = CompletedMultipartUpload.builder().parts(completedParts).build(); + CompleteMultipartUploadRequest completeReq = CompleteMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .multipartUpload(completed) + .build(); + s3Client.completeMultipartUpload(completeReq); + LOG.info("S3 multipart object created, Bucket name: " + bucket + ", Object key: " + key + ", Parts: " + completedParts.size() + "."); + } catch (SdkException e) { + if (uploadId != null) { + try { + s3Client.abortMultipartUpload(AbortMultipartUploadRequest.builder().bucket(bucket).key(key).uploadId(uploadId).build()); + } catch (Exception abortEx) { + LOG.warn("Failed to abort multipart upload after failure.", abortEx); + } + } + String errorMessage = "Failed to store the message content in an S3 multipart object."; + LOG.error(errorMessage, e); + throw SdkException.create(errorMessage, e); + } + } } diff --git a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java index f385ce6..8cd4f89 100644 --- a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java +++ b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java @@ -1,7 +1,9 @@ package software.amazon.payloadoffloading; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; public interface ServerSideEncryptionStrategy { void decorate(PutObjectRequest.Builder putObjectRequestBuilder); + void decorate(CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder); } diff --git a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java index 723678d..18a5c0a 100644 --- a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java +++ b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java @@ -2,8 +2,10 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; + import static org.junit.jupiter.api.Assertions.assertEquals; @@ -19,4 +21,16 @@ public void testAwsManagedCmkStrategySetsCorrectEncryptionValues() { assertEquals(putObjectRequest.serverSideEncryption(), (ServerSideEncryption.AWS_KMS)); } + + + @Test + public void testAwsManagedCmkStrategyMultipartSetsCorrectEncryptionValues() { + AwsManagedCmk awsManagedCmk = new AwsManagedCmk(); + + CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder(); + awsManagedCmk.decorate(createMultipartUploadRequestBuilder); + CreateMultipartUploadRequest createMultipartUploadRequest = createMultipartUploadRequestBuilder.build(); + + assertEquals(createMultipartUploadRequest.serverSideEncryption(), (ServerSideEncryption.AWS_KMS)); + } } \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java index cc55b70..ed39cc5 100644 --- a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java +++ b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java @@ -2,6 +2,7 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -21,4 +22,16 @@ public void testCustomerKeyStrategySetsCorrectEncryptionValues() { assertEquals(putObjectRequest.serverSideEncryption(), ServerSideEncryption.AWS_KMS); assertEquals(putObjectRequest.ssekmsKeyId(), AWS_KMS_KEY_ID); } + + @Test + public void testCustomerKeyStrategySetsMultipartUploadEncryptionValues() { + CustomerKey customerKey = new CustomerKey(AWS_KMS_KEY_ID); + + CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder(); + customerKey.decorate(createMultipartUploadRequestBuilder); + CreateMultipartUploadRequest createMultipartUploadRequest = createMultipartUploadRequestBuilder.build(); + + assertEquals(createMultipartUploadRequest.serverSideEncryption(), ServerSideEncryption.AWS_KMS); + assertEquals(createMultipartUploadRequest.ssekmsKeyId(), AWS_KMS_KEY_ID); + } } \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java index 30cf188..78372d3 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java @@ -99,4 +99,51 @@ public void testCannedAccessControlList() { assertTrue(payloadStorageConfiguration.isObjectCannedACLDefined()); assertEquals(objectCannelACL, payloadStorageConfiguration.getObjectCannedACL()); } + + @Test + public void testMultipartUploadEnabled() { + PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); + + payloadStorageConfiguration.setMultipartUploadEnabled(true); + assertTrue(payloadStorageConfiguration.isMultipartUploadEnabled()); + + payloadStorageConfiguration.setMultipartUploadEnabled(false); + assertFalse(payloadStorageConfiguration.isMultipartUploadEnabled()); + } + + @Test + public void testMultipartUploadThreshold() { + PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); + + int customThreshold = 10 * 1024 * 1024; // 10MB + payloadStorageConfiguration.setMultipartUploadThreshold(customThreshold); + assertEquals(customThreshold, payloadStorageConfiguration.getMultipartUploadThreshold()); + } + + @Test + public void testMultipartUploadPartSize() { + PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); + + int customPartSize = 10 * 1024 * 1024; // 10MB + payloadStorageConfiguration.setMultipartUploadPartSize(customPartSize); + assertEquals(customPartSize, payloadStorageConfiguration.getMultipartUploadPartSize()); + } + + @Test + public void testMultipartConfigurationInCopyConstructor() { + S3AsyncClient s3Async = mock(S3AsyncClient.class); + + PayloadStorageAsyncConfiguration original = new PayloadStorageAsyncConfiguration(); + original.withPayloadSupportEnabled(s3Async, s3BucketName) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(10 * 1024 * 1024) + .setMultipartUploadPartSize(8 * 1024 * 1024); + + PayloadStorageAsyncConfiguration copy = new PayloadStorageAsyncConfiguration(original); + + assertTrue(copy.isMultipartUploadEnabled()); + assertEquals(10 * 1024 * 1024, copy.getMultipartUploadThreshold()); + assertEquals(8 * 1024 * 1024, copy.getMultipartUploadPartSize()); + assertNotSame(copy, original); + } } \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java index f4adf77..e304666 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java @@ -99,4 +99,51 @@ public void testCannedAccessControlList() { assertTrue(payloadStorageConfiguration.isObjectCannedACLDefined()); assertEquals(objectCannelACL, payloadStorageConfiguration.getObjectCannedACL()); } + + @Test + public void testMultipartUploadEnabled() { + PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); + + payloadStorageConfiguration.setMultipartUploadEnabled(true); + assertTrue(payloadStorageConfiguration.isMultipartUploadEnabled()); + + payloadStorageConfiguration.setMultipartUploadEnabled(false); + assertFalse(payloadStorageConfiguration.isMultipartUploadEnabled()); + } + + @Test + public void testMultipartUploadThreshold() { + PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); + + int customThreshold = 10 * 1024 * 1024; // 10MB + payloadStorageConfiguration.setMultipartUploadThreshold(customThreshold); + assertEquals(customThreshold, payloadStorageConfiguration.getMultipartUploadThreshold()); + } + + @Test + public void testMultipartUploadPartSize() { + PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); + + int customPartSize = 10 * 1024 * 1024; // 10MB + payloadStorageConfiguration.setMultipartUploadPartSize(customPartSize); + assertEquals(customPartSize, payloadStorageConfiguration.getMultipartUploadPartSize()); + } + + @Test + public void testMultipartConfigurationInCopyConstructor() { + S3Client s3 = mock(S3Client.class); + + PayloadStorageConfiguration original = new PayloadStorageConfiguration(); + original.withPayloadSupportEnabled(s3, s3BucketName) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(10 * 1024 * 1024) + .setMultipartUploadPartSize(8 * 1024 * 1024); + + PayloadStorageConfiguration copy = new PayloadStorageConfiguration(original); + + assertTrue(copy.isMultipartUploadEnabled()); + assertEquals(10 * 1024 * 1024, copy.getMultipartUploadThreshold()); + assertEquals(8 * 1024 * 1024, copy.getMultipartUploadPartSize()); + assertNotSame(copy, original); + } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java index 1ecccb8..f3ca780 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java @@ -1,27 +1,34 @@ package software.amazon.payloadoffloading; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.*; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; + public class S3AsyncDaoTest { @@ -111,4 +118,168 @@ public void deleteTextTest() { verify(s3AsyncClient, times(1)).deleteObject(any(DeleteObjectRequest.class)); } + + + @Test + public void storeTextMultipartInS3_FallbackToSinglePut_WhenBelowThreshold() { + dao = new S3AsyncDao(s3AsyncClient); + String smallPayload = "Small payload"; + int multipartThreshold = 1000; + int partSize = 5 * 1024 * 1024; + + when(s3AsyncClient.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, smallPayload, partSize, multipartThreshold).join(); + + verify(s3AsyncClient, times(1)).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + verify(s3AsyncClient, never()).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3AsyncClient, never()).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); + verify(s3AsyncClient, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } + + @Test + public void storeTextMultipartInS3_UseMultipart_WhenAboveThreshold() { + dao = new S3AsyncDao(s3AsyncClient); + String largePayload = generateString(10 * 1024 * 1024); // 10MB + int multipartThreshold = 5 * 1024 * 1024; // 5MB + int partSize = 5 * 1024 * 1024; // 5MB + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(createResponse)); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) + .thenReturn(CompletableFuture.completedFuture(uploadPartResponse)); + + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold).join(); + + verify(s3AsyncClient, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3AsyncClient, times(2)).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); // 10MB / 5MB = 2 parts + verify(s3AsyncClient, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + verify(s3AsyncClient, never()).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + } + + @Test + public void storeTextMultipartInS3_WithSSEAndACL() { + dao = new S3AsyncDao(s3AsyncClient, serverSideEncryptionStrategy, objectCannedACL); + String largePayload = generateString(6 * 1024 * 1024); // 6MB + int multipartThreshold = 5 * 1024 * 1024; // 5MB + int partSize = 5 * 1024 * 1024; // 5MB + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(createResponse)); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) + .thenReturn(CompletableFuture.completedFuture(uploadPartResponse)); + + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + ArgumentCaptor createCaptor = + ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold).join(); + + verify(s3AsyncClient, times(1)).createMultipartUpload(createCaptor.capture()); + CreateMultipartUploadRequest capturedRequest = createCaptor.getValue(); + + assertEquals(S3_BUCKET_NAME, capturedRequest.bucket()); + assertEquals(ANY_S3_KEY, capturedRequest.key()); + assertEquals(ServerSideEncryption.AWS_KMS, capturedRequest.serverSideEncryption()); + assertEquals(objectCannedACL, capturedRequest.acl()); + } + + @Test + public void storeTextMultipartInS3_VerifyCompleteRequest() { + dao = new S3AsyncDao(s3AsyncClient); + String largePayload = generateString(6 * 1024 * 1024); // 6MB + int multipartThreshold = 5 * 1024 * 1024; // 5MB + int partSize = 5 * 1024 * 1024; // 5MB + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(createResponse)); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("etag-test") + .build(); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) + .thenReturn(CompletableFuture.completedFuture(uploadPartResponse)); + + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + ArgumentCaptor completeCaptor = + ArgumentCaptor.forClass(CompleteMultipartUploadRequest.class); + + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold).join(); + + verify(s3AsyncClient, times(1)).completeMultipartUpload(completeCaptor.capture()); + CompleteMultipartUploadRequest capturedRequest = completeCaptor.getValue(); + + assertEquals(S3_BUCKET_NAME, capturedRequest.bucket()); + assertEquals(ANY_S3_KEY, capturedRequest.key()); + assertEquals("test-upload-id", capturedRequest.uploadId()); + assertNotNull(capturedRequest.multipartUpload()); + assertEquals(2, capturedRequest.multipartUpload().parts().size()); + } + + @Test + public void storeTextMultipartInS3_AbortOnCompleteFailure() { + dao = new S3AsyncDao(s3AsyncClient); + String largePayload = generateString(6 * 1024 * 1024); // 6MB + int multipartThreshold = 5 * 1024 * 1024; // 5MB + int partSize = 5 * 1024 * 1024; // 5MB + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(createResponse)); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) + .thenReturn(CompletableFuture.completedFuture(uploadPartResponse)); + + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(SdkException.builder().message("Complete failed").build()); + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(failedFuture); + + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + assertThrows(CompletionException.class, () -> { + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold).join(); + }); + + verify(s3AsyncClient, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + + private String generateString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('a'); + } + return sb.toString(); + } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsyncTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsyncTest.java new file mode 100644 index 0000000..5fd1d60 --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsyncTest.java @@ -0,0 +1,127 @@ +package software.amazon.payloadoffloading; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.exception.SdkException; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +public class S3BackedMultipartPayloadStoreAsyncTest { + private static final String S3_BUCKET_NAME = "test-bucket-name"; + private static final String ANY_PAYLOAD = "AnyPayload"; + private static final String ANY_S3_KEY = "AnyS3key"; + private static final int PART_SIZE = 5 * 1024 * 1024; // 5MB + private static final int THRESHOLD = 5 * 1024 * 1024; // 5MB + + private S3BackedMultipartPayloadStoreAsync multipartPayloadStore; + private S3AsyncDao s3AsyncDao; + + @BeforeEach + public void setup() { + s3AsyncDao = mock(S3AsyncDao.class); + multipartPayloadStore = new S3BackedMultipartPayloadStoreAsync(s3AsyncDao, S3_BUCKET_NAME, PART_SIZE, THRESHOLD); + } + + @Test + public void testStoreOriginalPayloadMultipartOnSuccess() { + when(s3AsyncDao.storeTextMultipartInS3(any(String.class), any(String.class), any(String.class), + any(Integer.class), any(Integer.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + String actualPayloadPointer = multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY).join(); + + ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor partSizeCaptor = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor thresholdCaptor = ArgumentCaptor.forClass(Integer.class); + + verify(s3AsyncDao, times(1)).storeTextMultipartInS3( + bucketCaptor.capture(), + keyCaptor.capture(), + payloadCaptor.capture(), + partSizeCaptor.capture(), + thresholdCaptor.capture() + ); + + assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue()); + assertEquals(ANY_S3_KEY, keyCaptor.getValue()); + assertEquals(ANY_PAYLOAD, payloadCaptor.getValue()); + assertEquals(PART_SIZE, partSizeCaptor.getValue()); + assertEquals(THRESHOLD, thresholdCaptor.getValue()); + + PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); + assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); + } + + @Test + public void testStoreOriginalPayloadMultipartWithCustomPartSize() { + int customPartSize = 10 * 1024 * 1024; // 10MB + int customThreshold = 8 * 1024 * 1024; // 8MB + + multipartPayloadStore = new S3BackedMultipartPayloadStoreAsync(s3AsyncDao, S3_BUCKET_NAME, customPartSize, customThreshold); + + String payload = generateString(12 * 1024 * 1024); // 12MB + String s3Key = "custom-size-key"; + + when(s3AsyncDao.storeTextMultipartInS3(any(String.class), any(String.class), any(String.class), + any(Integer.class), any(Integer.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + multipartPayloadStore.storeOriginalPayloadMultipart(payload, s3Key).join(); + + verify(s3AsyncDao, times(1)).storeTextMultipartInS3( + eq(S3_BUCKET_NAME), + eq(s3Key), + eq(payload), + eq(customPartSize), + eq(customThreshold) + ); + } + + @Test + public void testStoreOriginalPayloadMultipartOnS3Failure() { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(SdkException.create("S3 Exception", new Throwable())); + + when(s3AsyncDao.storeTextMultipartInS3(any(String.class), any(String.class), any(String.class), + any(Integer.class), any(Integer.class))) + .thenReturn(failedFuture); + + CompletionException exception = assertThrows(CompletionException.class, () -> { + multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY).join(); + }); + + assertTrue(exception.getMessage().contains("S3 Exception")); + } + + @Test + public void testStoreOriginalPayloadMultipartHandlesNullFromDao() { + when(s3AsyncDao.storeTextMultipartInS3(any(String.class), any(String.class), any(String.class), + any(Integer.class), any(Integer.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + String actualPayloadPointer = multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY).join(); + + // Should still return valid pointer even if DAO returns null/void + assertNotNull(actualPayloadPointer); + PayloadS3Pointer pointer = PayloadS3Pointer.fromJson(actualPayloadPointer); + assertEquals(S3_BUCKET_NAME, pointer.getS3BucketName()); + assertEquals(ANY_S3_KEY, pointer.getS3Key()); + } + + private String generateString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('a'); + } + return sb.toString(); + } +} diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreTest.java new file mode 100644 index 0000000..0d2ca17 --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreTest.java @@ -0,0 +1,102 @@ +package software.amazon.payloadoffloading; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.exception.SdkException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +public class S3BackedMultipartPayloadStoreTest { + private static final String S3_BUCKET_NAME = "test-bucket-name"; + private static final String ANY_PAYLOAD = "AnyPayload"; + private static final String ANY_S3_KEY = "AnyS3key"; + private static final int PART_SIZE = 5 * 1024 * 1024; // 5MB + private static final int THRESHOLD = 5 * 1024 * 1024; // 5MB + + private S3BackedMultipartPayloadStore multipartPayloadStore; + private S3Dao s3Dao; + + @BeforeEach + public void setup() { + s3Dao = mock(S3Dao.class); + multipartPayloadStore = new S3BackedMultipartPayloadStore(s3Dao, S3_BUCKET_NAME, PART_SIZE, THRESHOLD); + } + + @Test + public void testStoreOriginalPayloadMultipartOnSuccess() { + String actualPayloadPointer = multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY); + + ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor partSizeCaptor = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor thresholdCaptor = ArgumentCaptor.forClass(Integer.class); + + verify(s3Dao, times(1)).storeTextMultipartInS3( + bucketCaptor.capture(), + keyCaptor.capture(), + payloadCaptor.capture(), + partSizeCaptor.capture(), + thresholdCaptor.capture() + ); + + assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue()); + assertEquals(ANY_S3_KEY, keyCaptor.getValue()); + assertEquals(ANY_PAYLOAD, payloadCaptor.getValue()); + assertEquals(PART_SIZE, partSizeCaptor.getValue()); + assertEquals(THRESHOLD, thresholdCaptor.getValue()); + + PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); + assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); + } + + @Test + public void testStoreOriginalPayloadMultipartWithCustomPartSize() { + int customPartSize = 10 * 1024 * 1024; // 10MB + int customThreshold = 8 * 1024 * 1024; // 8MB + + multipartPayloadStore = new S3BackedMultipartPayloadStore(s3Dao, S3_BUCKET_NAME, customPartSize, customThreshold); + + String payload = generateString(12 * 1024 * 1024); // 12MB + String s3Key = "custom-size-key"; + + multipartPayloadStore.storeOriginalPayloadMultipart(payload, s3Key); + + verify(s3Dao, times(1)).storeTextMultipartInS3( + eq(S3_BUCKET_NAME), + eq(s3Key), + eq(payload), + eq(customPartSize), + eq(customThreshold) + ); + } + + @Test + public void testStoreOriginalPayloadMultipartOnS3Failure() { + doThrow(SdkException.create("S3 Exception", new Throwable())) + .when(s3Dao) + .storeTextMultipartInS3( + any(String.class), + any(String.class), + any(String.class), + any(Integer.class), + any(Integer.class) + ); + + assertThrows(SdkException.class, () -> { + multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY); + }, "S3 Exception"); + } + + private String generateString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('a'); + } + return sb.toString(); + } +} diff --git a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java index b0a8f25..53662ce 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java @@ -8,15 +8,29 @@ import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.when; + +import java.util.List; public class S3DaoTest { @@ -75,4 +89,147 @@ public void storeTextInS3WithBothTest() { assertEquals(objectCannedACL, argument.getValue().acl()); assertEquals(S3_BUCKET_NAME, argument.getValue().bucket()); } + + @Test + public void storeTextMultipartInS3_FallbackToSinglePut_WhenBelowThreshold() { + dao = new S3Dao(s3Client); + String smallPayload = "Small payload"; + int multipartThreshold = 1000; + int partSize = 5 * 1024 * 1024; + + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, smallPayload, partSize, multipartThreshold); + + verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + verify(s3Client, never()).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3Client, never()).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } + + @Test + public void storeTextMultipartInS3_VerifyCompleteRequest() { + dao = new S3Dao(s3Client); + String largePayload = generateString(6 * 1024 * 1024); // 6MB + int multipartThreshold = 5 * 1024 * 1024; // 5MB + int partSize = 5 * 1024 * 1024; // 5MB + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("etag-test") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + ArgumentCaptor completeCaptor = + ArgumentCaptor.forClass(CompleteMultipartUploadRequest.class); + + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold); + + verify(s3Client, times(1)).completeMultipartUpload(completeCaptor.capture()); + CompleteMultipartUploadRequest capturedRequest = completeCaptor.getValue(); + + assertEquals(S3_BUCKET_NAME, capturedRequest.bucket()); + assertEquals(ANY_S3_KEY, capturedRequest.key()); + assertEquals("test-upload-id", capturedRequest.uploadId()); + assertNotNull(capturedRequest.multipartUpload()); + assertEquals(2, capturedRequest.multipartUpload().parts().size()); + } + + @Test + public void storeTextMultipartInS3_WithSSEAndACL() { + dao = new S3Dao(s3Client, serverSideEncryptionStrategy, objectCannedACL); + String largePayload = generateString(6 * 1024 * 1024); // 6MB + int multipartThreshold = 5 * 1024 * 1024; // 5MB + int partSize = 5 * 1024 * 1024; // 5MB + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + ArgumentCaptor createCaptor = + ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold); + + verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture()); + CreateMultipartUploadRequest capturedRequest = createCaptor.getValue(); + + assertEquals(S3_BUCKET_NAME, capturedRequest.bucket()); + assertEquals(ANY_S3_KEY, capturedRequest.key()); + assertEquals(ServerSideEncryption.AWS_KMS, capturedRequest.serverSideEncryption()); + assertEquals(objectCannedACL, capturedRequest.acl()); + } + + @Test + public void storeTextMultipartInS3_AbortOnUploadPartFailure() { + dao = new S3Dao(s3Client); + String largePayload = generateString(6 * 1024 * 1024); // 6MB + int multipartThreshold = 5 * 1024 * 1024; // 5MB + int partSize = 5 * 1024 * 1024; // 5MB + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenThrow(SdkException.builder().message("Upload part failed").build()); + + assertThrows(SdkException.class, () -> { + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold); + }); + + verify(s3Client, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + verify(s3Client, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } + + @Test + public void storeTextMultipartInS3_AbortOnCompleteFailure() { + dao = new S3Dao(s3Client); + String largePayload = generateString(6 * 1024 * 1024); // 6MB + int multipartThreshold = 5 * 1024 * 1024; // 5MB + int partSize = 5 * 1024 * 1024; // 5MB + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenThrow(SdkException.builder().message("Complete failed").build()); + + assertThrows(SdkException.class, () -> { + dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold); + }); + + verify(s3Client, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + + private String generateString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('a'); + } + return sb.toString(); + } } \ No newline at end of file From db0d73a473e218e981dd807531a507374ee028b7 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 09:31:18 +0700 Subject: [PATCH 2/5] add stream support --- pom.xml | 4 +- .../payloadoffloading/AwsManagedCmk.java | 4 +- .../amazon/payloadoffloading/CustomerKey.java | 6 +- .../MultipartPayloadStore.java | 27 --- .../MultipartPayloadStoreAsync.java | 29 --- .../PayloadStorageAsyncConfiguration.java | 24 ++- .../PayloadStorageConfiguration.java | 59 +++++- .../PayloadStorageConfigurationBase.java | 73 ++++---- .../amazon/payloadoffloading/S3AsyncDao.java | 170 ++++++++--------- .../S3BackedMultipartPayloadStore.java | 27 --- .../S3BackedMultipartPayloadStoreAsync.java | 30 --- .../S3BackedStreamPayloadStore.java | 37 ++++ .../S3BackedStreamPayloadStoreAsync.java | 45 +++++ .../amazon/payloadoffloading/S3Dao.java | 156 +++++++++------- .../ServerSideEncryptionStrategy.java | 2 +- .../payloadoffloading/StreamPayloadStore.java | 33 ++++ .../StreamPayloadStoreAsync.java | 34 ++++ .../payloadoffloading/AwsManagedCmkTest.java | 2 +- .../payloadoffloading/CustomerKeyTest.java | 2 +- .../PayloadStorageAsyncConfigurationTest.java | 36 ++-- .../PayloadStorageConfigurationTest.java | 36 ++-- .../payloadoffloading/S3AsyncDaoTest.java | 174 ------------------ ...3BackedMultipartPayloadStoreAsyncTest.java | 127 ------------- .../S3BackedMultipartPayloadStoreTest.java | 102 ---------- .../S3BackedStreamPayloadStoreAsyncTest.java | 114 ++++++++++++ .../S3BackedStreamPayloadStoreTest.java | 90 +++++++++ .../amazon/payloadoffloading/S3DaoTest.java | 164 ++--------------- 27 files changed, 699 insertions(+), 908 deletions(-) delete mode 100644 src/main/java/software/amazon/payloadoffloading/MultipartPayloadStore.java delete mode 100644 src/main/java/software/amazon/payloadoffloading/MultipartPayloadStoreAsync.java delete mode 100644 src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStore.java delete mode 100644 src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsync.java create mode 100644 src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java create mode 100644 src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java create mode 100644 src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java create mode 100644 src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java delete mode 100644 src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsyncTest.java delete mode 100644 src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreTest.java create mode 100644 src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java create mode 100644 src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java diff --git a/pom.xml b/pom.xml index 0cc01eb..15052fb 100644 --- a/pom.xml +++ b/pom.xml @@ -36,13 +36,13 @@ - 2.20.130 + 2.27.14 software.amazon.awssdk - s3 + s3-transfer-manager ${aws-java-sdk.version} diff --git a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java index e4d2109..eaaa49f 100644 --- a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java +++ b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java @@ -11,7 +11,7 @@ public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) { } @Override - public void decorate(CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder) { - createMultipartUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + public void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder) { + createStreamUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); } } diff --git a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java index bcc3175..9a2db93 100644 --- a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java +++ b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java @@ -18,8 +18,8 @@ public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) { } @Override - public void decorate(CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder) { - createMultipartUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); - createMultipartUploadRequestBuilder.ssekmsKeyId(awsKmsKeyId); + public void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder) { + createStreamUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + createStreamUploadRequestBuilder.ssekmsKeyId(awsKmsKeyId); } } diff --git a/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStore.java deleted file mode 100644 index ba86f63..0000000 --- a/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStore.java +++ /dev/null @@ -1,27 +0,0 @@ -package software.amazon.payloadoffloading; - -/** - * Optional extension interface for {@link PayloadStore} implementations that can perform multipart - * upload and streaming retrieval for very large payloads to reduce memory pressure. - */ -public interface MultipartPayloadStore extends PayloadStore { - - /** - * Store a payload using streaming/multipart-aware logic. Default falls back to normal storage. - * @param payload UTF-8 text payload - * @param s3Key pre-generated S3 key (or equivalent object key) to use - * @return pointer string (JSON) referencing stored payload - */ - default String storeOriginalPayloadMultipart(String payload, String s3Key) { - return storeOriginalPayload(payload, s3Key); - } - - /** - * Retrieve a payload potentially using streaming/multipart-aware logic. Default falls back to normal retrieval. - * @param payloadPointer pointer JSON string - * @return original payload content - */ - default String getOriginalPayloadMultipart(String payloadPointer) { - return getOriginalPayload(payloadPointer); - } -} diff --git a/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStoreAsync.java deleted file mode 100644 index 0c7d01a..0000000 --- a/src/main/java/software/amazon/payloadoffloading/MultipartPayloadStoreAsync.java +++ /dev/null @@ -1,29 +0,0 @@ -package software.amazon.payloadoffloading; - -import java.util.concurrent.CompletableFuture; - -/** - * Optional extension interface for {@link PayloadStoreAsync} implementations that can perform multipart - * upload and streaming retrieval for very large payloads to reduce memory pressure. - */ -public interface MultipartPayloadStoreAsync extends PayloadStoreAsync { - - /** - * Store payload content using multipart semantics when possible. - * @param payload UTF-8 text payload - * @param s3Key object key to use - * @return future pointer string (JSON) referencing stored payload - */ - default CompletableFuture storeOriginalPayloadMultipart(String payload, String s3Key) { - return storeOriginalPayload(payload, s3Key); - } - - /** - * Retrieve payload using streaming/multipart logic when available. Default delegates to normal method. - * @param payloadPointer pointer JSON string - * @return future original payload content - */ - default CompletableFuture getOriginalPayloadMultipart(String payloadPointer) { - return getOriginalPayload(payloadPointer); - } -} diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java index 2cf3edd..5840971 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java @@ -152,22 +152,32 @@ public PayloadStorageAsyncConfiguration withObjectCannedACL(ObjectCannedACL obje } /** - * Enables or disables multipart upload support. - * @param enabled true to enable multipart uploads when threshold exceeded. + * Enables or disables stream upload support. + * @param enabled true to enable stream uploads when threshold exceeded. * @return updated configuration */ - public PayloadStorageAsyncConfiguration withMultipartUploadEnabled(boolean enabled) { - setMultipartUploadEnabled(enabled); + public PayloadStorageAsyncConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); return this; } /** - * Sets the multipart upload threshold (in bytes). Only used when multipart upload is enabled. + * Sets the stream upload threshold (in bytes). Only used when stream upload is enabled. * @param threshold threshold in bytes (must be >0) otherwise default (5MB) is applied. * @return updated configuration */ - public PayloadStorageAsyncConfiguration withMultipartUploadThreshold(int threshold) { - setMultipartUploadThreshold(threshold); + public PayloadStorageAsyncConfiguration withStreamUploadThreshold(int threshold) { + setStreamUploadThreshold(threshold); + return this; + } + + public PayloadStorageAsyncConfiguration withStreamUploadPartSize(int partSize) { + setStreamUploadPartSize(partSize); + return this; + } + + public PayloadStorageAsyncConfiguration withS3Region(String s3Region) { + setS3Region(s3Region); return this; } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java index cdcd9eb..cd7b782 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java @@ -4,6 +4,7 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.NotThreadSafe; import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; @@ -35,14 +36,17 @@ public class PayloadStorageConfiguration extends PayloadStorageConfigurationBase private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfiguration.class); private S3Client s3; + private S3AsyncClient s3Async; public PayloadStorageConfiguration() { s3 = null; + s3Async = null; } public PayloadStorageConfiguration(PayloadStorageConfiguration other) { super(other); this.s3 = other.getS3Client(); + this.s3Async = other.getS3AsyncClient(); } /** @@ -102,6 +106,37 @@ public S3Client getS3Client() { return s3; } + /** + * Sets the optional Amazon S3 async client to be used for TransferManager. + * This is useful for pre-configuring the async client with specific endpoint/credentials + * (e.g., for LocalStack testing). + * + * @param s3AsyncClient The S3AsyncClient to use for TransferManager operations. + */ + public void setS3AsyncClient(S3AsyncClient s3AsyncClient) { + this.s3Async = s3AsyncClient; + } + + /** + * Sets the optional Amazon S3 async client to be used for TransferManager. + * + * @param s3AsyncClient The S3AsyncClient to use for TransferManager operations. + * @return the updated PayloadStorageConfiguration object. + */ + public PayloadStorageConfiguration withS3AsyncClient(S3AsyncClient s3AsyncClient) { + setS3AsyncClient(s3AsyncClient); + return this; + } + + /** + * Gets the Amazon S3 async client which is being used for TransferManager operations. + * + * @return Reference to the Amazon S3 async client, or null if not configured. + */ + public S3AsyncClient getS3AsyncClient() { + return s3Async; + } + /** * Sets the payload size threshold for storing payloads in Amazon S3. * @@ -150,22 +185,32 @@ public PayloadStorageConfiguration withObjectCannedACL(ObjectCannedACL objectCan } /** - * Enables or disables multipart upload support. - * @param enabled true to enable multipart uploads when threshold exceeded. + * Enables or disables stream upload support. + * @param enabled true to enable stream uploads when threshold exceeded. * @return updated configuration */ - public PayloadStorageConfiguration withMultipartUploadEnabled(boolean enabled) { - setMultipartUploadEnabled(enabled); + public PayloadStorageConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); return this; } /** - * Sets the multipart upload threshold (in bytes). Only used when multipart upload is enabled. + * Sets the stream upload threshold (in bytes). Only used when stream upload is enabled. * @param threshold threshold in bytes (must be >0) otherwise default (5MB) is applied. * @return updated configuration */ - public PayloadStorageConfiguration withMultipartUploadThreshold(int threshold) { - setMultipartUploadThreshold(threshold); + public PayloadStorageConfiguration withStreamUploadThreshold(int threshold) { + setStreamUploadThreshold(threshold); + return this; + } + + public PayloadStorageConfiguration withStreamUploadPartSize(int partSize) { + setStreamUploadPartSize(partSize); + return this; + } + + public PayloadStorageConfiguration withS3Region(String s3Region) { + setS3Region(s3Region); return this; } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java index d9a353a..68947ae 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java @@ -18,15 +18,16 @@ public abstract class PayloadStorageConfigurationBase { private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfigurationBase.class); private String s3BucketName; + private String s3Region = "ap-southeast-1"; private int payloadSizeThreshold = 0; private boolean alwaysThroughS3 = false; private boolean payloadSupport = false; - // Enable multipart upload support (opt-in, default false) - private boolean multipartUploadEnabled = false; - // Threshold (bytes) above which multipart should be attempted when enabled (default 5MB) - private int multipartUploadThreshold = 5 * 1024 * 1024; - // Multipart part size (bytes). Each part (except last) must be >=5MB. Default 5MB. - private int multipartUploadPartSize = 5 * 1024 * 1024; + // Enable stream upload support (opt-in, default false) + private boolean streamUploadEnabled = false; + // Threshold (bytes) above which stream should be attempted when enabled (default 5MB) + private long streamUploadThreshold = 5 * 1024 * 1024L; + // Stream part size (bytes). Each part (except last) must be >=5MB. Default 5MB. + private long streamUploadPartSize = 5 * 1024 * 1024L; /** * This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS. */ @@ -49,9 +50,10 @@ public PayloadStorageConfigurationBase(PayloadStorageConfigurationBase other) { this.payloadSizeThreshold = other.getPayloadSizeThreshold(); this.serverSideEncryptionStrategy = other.getServerSideEncryptionStrategy(); this.objectCannedACL = other.getObjectCannedACL(); - this.multipartUploadEnabled = other.isMultipartUploadEnabled(); - this.multipartUploadThreshold = other.getMultipartUploadThreshold(); - this.multipartUploadPartSize = other.getMultipartUploadPartSize(); + this.streamUploadEnabled = other.isStreamUploadEnabled(); + this.streamUploadThreshold = other.getStreamUploadThreshold(); + this.streamUploadPartSize = other.getStreamUploadPartSize(); + this.s3Region = other.getS3Region(); } /** @@ -185,54 +187,63 @@ public ObjectCannedACL getObjectCannedACL() { } /** - * Checks whether multipart upload support is enabled. Default: false. + * Checks whether stream upload support is enabled. Default: false. * - * @return true if multipart upload support is enabled. + * @return true if stream upload support is enabled. */ - public boolean isMultipartUploadEnabled() { return multipartUploadEnabled; } + public boolean isStreamUploadEnabled() { return streamUploadEnabled; } /** - * Enable or disable multipart upload support. When enabling, callers should ensure they provide a PayloadStore - * implementation capable of multipart operations; otherwise normal single PUT behavior will be used. + * Enable or disable stream upload support. When enabling, callers should ensure they provide a PayloadStore + * implementation capable of stream operations; otherwise normal single PUT behavior will be used. * - * @param multipartUploadEnabled flag to enable/disable multipart support. + * @param streamUploadEnabled flag to enable/disable stream support. */ - public void setMultipartUploadEnabled(boolean multipartUploadEnabled) { this.multipartUploadEnabled = multipartUploadEnabled; } + public void setStreamUploadEnabled(boolean streamUploadEnabled) { this.streamUploadEnabled = streamUploadEnabled; } /** - * Gets the multipart upload threshold in bytes. Default 5MB. + * Gets the stream upload threshold in bytes. Default 5MB. * * @return threshold in bytes. */ - public int getMultipartUploadThreshold() { return multipartUploadThreshold; } + public long getStreamUploadThreshold() { return streamUploadThreshold; } /** - * Sets the multipart upload threshold in bytes. Values less than or equal to zero will reset to default (5MB). + * Sets the stream upload threshold in bytes. Values less than or equal to zero will reset to default (5MB). * - * @param multipartUploadThreshold threshold in bytes + * @param streamUploadThreshold threshold in bytes */ - public void setMultipartUploadThreshold(int multipartUploadThreshold) { - if (multipartUploadThreshold <= 0) { - this.multipartUploadThreshold = 5 * 1024 * 1024; + public void setStreamUploadThreshold(long streamUploadThreshold) { + long min = 5 * 1024 * 1024L; + if (streamUploadThreshold <= min) { + this.streamUploadThreshold = min; } else { - this.multipartUploadThreshold = multipartUploadThreshold; + this.streamUploadThreshold = streamUploadThreshold; } } /** - * Gets the configured multipart upload part size (bytes). Default 5MB. + * Gets the configured stream upload part size (bytes). Default 5MB. */ - public int getMultipartUploadPartSize() { return multipartUploadPartSize; } + public long getStreamUploadPartSize() { return streamUploadPartSize; } /** - * Sets the multipart upload part size (bytes). Values < 5MB are rounded up to 5MB. + * Sets the stream upload part size (bytes). Values < 5MB are rounded up to 5MB. */ - public void setMultipartUploadPartSize(int partSize) { - int min = 5 * 1024 * 1024; + public void setStreamUploadPartSize(long partSize) { + long min = 5 * 1024 * 1024L; if (partSize < min) { - this.multipartUploadPartSize = min; + this.streamUploadPartSize = min; } else { - this.multipartUploadPartSize = partSize; + this.streamUploadPartSize = partSize; } } + + public void setS3Region(String s3Region) { + this.s3Region = s3Region; + } + + public String getS3Region() { + return this.s3Region; + } } \ No newline at end of file diff --git a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java index 60db795..92b0ef4 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java @@ -1,11 +1,13 @@ package software.amazon.payloadoffloading; import java.io.UncheckedIOException; +import java.io.InputStream; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.exception.SdkClientException; @@ -13,34 +15,39 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; - -import java.util.Arrays; -import java.util.List; -import java.util.ArrayList; -import java.nio.charset.StandardCharsets; +import software.amazon.awssdk.transfer.s3.S3TransferManager; +import software.amazon.awssdk.transfer.s3.model.UploadRequest; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Dao layer to access S3. */ public class S3AsyncDao { private static final Logger LOG = LoggerFactory.getLogger(S3AsyncDao.class); + private final S3AsyncClient s3Client; private final ServerSideEncryptionStrategy serverSideEncryptionStrategy; private final ObjectCannedACL objectCannedACL; + /** + * Constructor for basic S3 operations (non-streaming). + * @param s3Client The S3 async client for standard operations + */ public S3AsyncDao(S3AsyncClient s3Client) { this(s3Client, null, null); } + /** + * Full constructor with SSE and ACL configuration. + * @param s3Client The S3 async client + * @param serverSideEncryptionStrategy Server-side encryption configuration + * @param objectCannedACL Canned ACL configuration + */ public S3AsyncDao( S3AsyncClient s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, @@ -77,6 +84,28 @@ public CompletableFuture getTextFromS3(String s3BucketName, String s3Key }); } + public CompletableFuture> getTextStreamFromS3(String s3BucketName, String s3Key) { + GetObjectRequest getObjectRequest = GetObjectRequest.builder() + .bucket(s3BucketName) + .key(s3Key) + .build(); + + return s3Client.getObject(getObjectRequest, AsyncResponseTransformer.toBlockingInputStream()) + .handle((v, tIn) -> { + if (tIn != null) { + Throwable t = Util.unwrapFutureException(tIn); + if (t instanceof SdkException) { + String errorMessage = "Failed to get the S3 object stream which contains the payload."; + LOG.error(errorMessage, t); + throw SdkException.create(errorMessage, t); + } + throw new CompletionException(t); + } + LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key); + return v; + }); + } + public CompletableFuture storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr) { PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() .bucket(s3BucketName) @@ -129,81 +158,52 @@ public CompletableFuture deletePayloadFromS3(String s3BucketName, String s } - public CompletableFuture storeTextMultipartInS3(String bucket, String key, String payloadContentStr, int partSize, int multipartThreshold) { - byte[] data = payloadContentStr.getBytes(StandardCharsets.UTF_8); - if (data.length < multipartThreshold) { - return storeTextInS3(bucket, key, payloadContentStr); - } - - CreateMultipartUploadRequest.Builder createBuilder = CreateMultipartUploadRequest.builder() - .bucket(bucket) - .key(key); - if (objectCannedACL != null) { - createBuilder.acl(objectCannedACL); - } - // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html - if (serverSideEncryptionStrategy != null) { - serverSideEncryptionStrategy.decorate(createBuilder); - } - - return s3Client.createMultipartUpload(createBuilder.build()) - .thenCompose(createResp -> { - String uploadId = createResp.uploadId(); - int partCount = (int) Math.ceil((double) data.length / partSize); - List> partFutures = new ArrayList<>(partCount); - - int offset = 0; - int partNumber = 1; - while (offset < data.length) { - int currSize = Math.min(partSize, data.length - offset); - byte[] slice = Arrays.copyOfRange(data, offset, offset + currSize); - offset += currSize; - final int thisPartNumber = partNumber++; - - UploadPartRequest uploadPartRequest = UploadPartRequest.builder() - .bucket(bucket) - .key(key) - .uploadId(uploadId) - .partNumber(thisPartNumber) - .contentLength((long) slice.length) - .build(); - - CompletableFuture fut = s3Client.uploadPart(uploadPartRequest, AsyncRequestBody.fromBytes(slice)) - .thenApply(resp -> CompletedPart.builder().partNumber(thisPartNumber).eTag(resp.eTag()).build()); - partFutures.add(fut); - } - - CompletableFuture all = CompletableFuture.allOf(partFutures.toArray(new CompletableFuture[0])); - return all.thenCompose(v -> { - List completed = new ArrayList<>(partFutures.size()); - for (CompletableFuture f : partFutures) { - completed.add(f.join()); + /** + * Stores a stream of data to S3 using TransferManager for efficient multipart uploads. + * + * @param bucket The S3 bucket name + * @param key The S3 object key + * @param payloadStream The input stream to upload + * @return CompletableFuture that completes when upload is finished + */ + public CompletableFuture storeTextStreamInS3(String bucket, String key, InputStream payloadStream) { + S3TransferManager transferManager = S3TransferManager.create(); + ExecutorService executor = Executors.newSingleThreadExecutor(); + + try { + UploadRequest.Builder uploadBuilder = UploadRequest.builder() + .putObjectRequest(b -> { + b.bucket(bucket).key(key); + if (objectCannedACL != null) { + b.acl(objectCannedACL); + } + // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html + if (serverSideEncryptionStrategy != null) { + serverSideEncryptionStrategy.decorate(b); } - CompletedMultipartUpload cmu = CompletedMultipartUpload.builder().parts(completed).build(); - CompleteMultipartUploadRequest completeReq = CompleteMultipartUploadRequest.builder() - .bucket(bucket) - .key(key) - .uploadId(uploadId) - .multipartUpload(cmu) - .build(); - return s3Client.completeMultipartUpload(completeReq) - .handle((vv, t) -> { - if (t != null) { - s3Client.abortMultipartUpload(AbortMultipartUploadRequest.builder().bucket(bucket).key(key).uploadId(uploadId).build()); - throw new CompletionException(t); - } - LOG.info("S3 multipart object created, Bucket name: " + bucket + ", Object key: " + key + "."); - return (Void) null; - }); + }) + .requestBody(AsyncRequestBody.fromInputStream(payloadStream, null, executor)); + + return transferManager.upload(uploadBuilder.build()).completionFuture() + .thenApply(completedUpload -> { + LOG.info("S3 stream object created from InputStream, Bucket name: " + bucket + ", Object key: " + key + "."); + return (Void) null; + }) + .exceptionally(t -> { + Throwable cause = Util.unwrapFutureException(t); + if (cause instanceof SdkException) { + String errorMessage = "Failed to store the message content from InputStream in an S3 stream object."; + LOG.error(errorMessage, cause); + throw SdkException.create(errorMessage, cause); + } + throw new CompletionException(cause); }); - }).exceptionally(t -> { - Throwable cause = Util.unwrapFutureException(t); - if (cause instanceof SdkException) { - String errorMessage = "Failed to store the message content in an S3 multipart object."; - LOG.error(errorMessage, cause); - throw SdkException.create(errorMessage, cause); - } - throw new CompletionException(cause); - }); + } catch (UnsupportedOperationException e) { + LOG.warn("TransferManager creation disabled, cannot perform streaming upload: " + e.getMessage()); + throw new CompletionException(e); + } finally { + transferManager.close(); + executor.shutdownNow(); + } } } diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStore.java deleted file mode 100644 index ef3b22a..0000000 --- a/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStore.java +++ /dev/null @@ -1,27 +0,0 @@ -package software.amazon.payloadoffloading; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class S3BackedMultipartPayloadStore extends S3BackedPayloadStore implements MultipartPayloadStore { - private static final Logger LOG = LoggerFactory.getLogger(S3BackedMultipartPayloadStore.class); - private final S3Dao multipartDao; - private final int partSize; - private final int threshold; - private final String s3BucketName; - - public S3BackedMultipartPayloadStore(S3Dao s3Dao, String s3BucketName, int partSize, int threshold) { - super(s3Dao, s3BucketName); - this.multipartDao = s3Dao; - this.partSize = partSize; - this.threshold = threshold; - this.s3BucketName = s3BucketName; - } - - @Override - public String storeOriginalPayloadMultipart(String payload, String s3Key) { - multipartDao.storeTextMultipartInS3(s3BucketName, s3Key, payload, partSize, threshold); - LOG.info("S3 multipart object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); - return new PayloadS3Pointer(s3BucketName, s3Key).toJson(); - } -} diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsync.java deleted file mode 100644 index b5f3ba0..0000000 --- a/src/main/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsync.java +++ /dev/null @@ -1,30 +0,0 @@ -package software.amazon.payloadoffloading; - -import java.util.concurrent.CompletableFuture; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class S3BackedMultipartPayloadStoreAsync extends S3BackedPayloadStoreAsync implements MultipartPayloadStoreAsync { - private static final Logger LOG = LoggerFactory.getLogger(S3BackedMultipartPayloadStoreAsync.class); - private final S3AsyncDao multipartDao; - private final String s3BucketName; - private final int partSize; - private final int threshold; - - public S3BackedMultipartPayloadStoreAsync(S3AsyncDao s3Dao, String s3BucketName, int partSize, int threshold) { - super(s3Dao, s3BucketName); - this.multipartDao = s3Dao; - this.s3BucketName = s3BucketName; - this.partSize = partSize; - this.threshold = threshold; - } - - @Override - public CompletableFuture storeOriginalPayloadMultipart(String payload, String s3Key) { - return multipartDao.storeTextMultipartInS3(s3BucketName, s3Key, payload, partSize, threshold) - .thenApply(v -> { - LOG.info("S3 multipart object created, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); - return new PayloadS3Pointer(s3BucketName, s3Key).toJson(); - }); - } -} diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java new file mode 100644 index 0000000..cd3bc09 --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java @@ -0,0 +1,37 @@ +package software.amazon.payloadoffloading; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import java.io.InputStream; + +public class S3BackedStreamPayloadStore extends S3BackedPayloadStore implements StreamPayloadStore { + private static final Logger LOG = LoggerFactory.getLogger(S3BackedStreamPayloadStore.class); + private final S3Dao s3Dao; + private final String s3BucketName; + + public S3BackedStreamPayloadStore(S3Dao s3Dao, String s3BucketName) { + super(s3Dao, s3BucketName); + this.s3Dao = s3Dao; + this.s3BucketName = s3BucketName; + } + + @Override + public String storeOriginalPayloadStream(InputStream payloadStream, String s3Key) { + s3Dao.storeTextStreamInS3(s3BucketName, s3Key, payloadStream); + LOG.info("S3 stream object created from InputStream, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); + return new PayloadS3Pointer(s3BucketName, s3Key).toJson(); + } + + @Override + public ResponseInputStream getOriginalPayloadStreamStream(String payloadPointer) { + PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(payloadPointer); + String s3BucketName = s3Pointer.getS3BucketName(); + String s3Key = s3Pointer.getS3Key(); + + ResponseInputStream stream = s3Dao.getTextStreamFromS3(s3BucketName, s3Key); + LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key); + return stream; + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java new file mode 100644 index 0000000..7a75e5b --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java @@ -0,0 +1,45 @@ +package software.amazon.payloadoffloading; + +import java.io.InputStream; +import java.util.concurrent.CompletableFuture; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ObjectCannedACL; + +public class S3BackedStreamPayloadStoreAsync extends S3BackedPayloadStoreAsync implements StreamPayloadStoreAsync { + private static final Logger LOG = LoggerFactory.getLogger(S3BackedStreamPayloadStoreAsync.class); + private final S3AsyncDao streamDao; + private final String s3BucketName; + + public S3BackedStreamPayloadStoreAsync(S3AsyncDao s3Dao, String s3BucketName) { + super(s3Dao, s3BucketName); + this.streamDao = s3Dao; + this.s3BucketName = s3BucketName; + } + + @Override + public CompletableFuture storeOriginalPayloadStream(InputStream payloadStream, String s3Key) { + return streamDao.storeTextStreamInS3(s3BucketName, s3Key, payloadStream) + .thenApply(v -> { + LOG.info("S3 stream object created from InputStream, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); + return new PayloadS3Pointer(s3BucketName, s3Key).toJson(); + }); + } + + @Override + public CompletableFuture> getOriginalPayloadStreamStream(String payloadPointer) { + PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(payloadPointer); + String s3BucketName = s3Pointer.getS3BucketName(); + String s3Key = s3Pointer.getS3Key(); + + return streamDao.getTextStreamFromS3(s3BucketName, s3Key) + .thenApply(stream -> { + LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key); + return stream; + }); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java index b63d810..a84f11f 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java @@ -6,47 +6,66 @@ import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; -import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.services.s3.model.UploadPartResponse; -import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.transfer.s3.S3TransferManager; +import software.amazon.awssdk.transfer.s3.model.UploadRequest; import software.amazon.awssdk.utils.IoUtils; import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.io.InputStream; +import java.util.concurrent.Executors; +import java.util.concurrent.ExecutorService; /** * Dao layer to access S3. */ public class S3Dao { private static final Logger LOG = LoggerFactory.getLogger(S3Dao.class); + private final S3Client s3Client; private final ServerSideEncryptionStrategy serverSideEncryptionStrategy; private final ObjectCannedACL objectCannedACL; - + private S3AsyncClient s3AsyncClient; + public S3Dao(S3Client s3Client) { - this(s3Client, null, null); + this.s3Client = s3Client; + this.serverSideEncryptionStrategy = null; + this.objectCannedACL = null; } + /** + * Constructor with SSE and ACL configuration (non-streaming). + * @param s3Client The S3 sync client + * @param serverSideEncryptionStrategy Server-side encryption configuration + * @param objectCannedACL Canned ACL configuration + */ public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL) { this.s3Client = s3Client; this.serverSideEncryptionStrategy = serverSideEncryptionStrategy; this.objectCannedACL = objectCannedACL; } + public S3Dao(S3Client s3Client, S3AsyncClient s3AsyncClient) { + this.s3Client = s3Client; + this.s3AsyncClient = s3AsyncClient; + this.serverSideEncryptionStrategy = null; + this.objectCannedACL = null; + } + + public S3Dao(S3Client s3Client, S3AsyncClient s3AsyncClient, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL) { + this.s3Client = s3Client; + this.s3AsyncClient = s3AsyncClient; + this.serverSideEncryptionStrategy = serverSideEncryptionStrategy; + this.objectCannedACL = objectCannedACL; + } + public String getTextFromS3(String s3BucketName, String s3Key) { GetObjectRequest getObjectRequest = GetObjectRequest.builder() .bucket(s3BucketName) @@ -77,6 +96,23 @@ public String getTextFromS3(String s3BucketName, String s3Key) { return embeddedText; } + public ResponseInputStream getTextStreamFromS3(String s3BucketName, String s3Key) { + GetObjectRequest getObjectRequest = GetObjectRequest.builder() + .bucket(s3BucketName) + .key(s3Key) + .build(); + + try { + ResponseInputStream object = s3Client.getObject(getObjectRequest); + LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key); + return object; + } catch (SdkException e) { + String errorMessage = "Failed to get the S3 object stream which contains the payload."; + LOG.error(errorMessage, e); + throw SdkException.create(errorMessage, e); + } + } + public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr) { PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() .bucket(s3BucketName) @@ -117,69 +153,47 @@ public void deletePayloadFromS3(String s3BucketName, String s3Key) { LOG.info("S3 object deleted, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); } - public void storeTextMultipartInS3(String bucket, String key, String payloadContentStr, int partSize, int multipartThreshold) { - byte[] data = payloadContentStr.getBytes(StandardCharsets.UTF_8); - if (data.length < multipartThreshold) { - storeTextInS3(bucket, key, payloadContentStr); - return; - } - CreateMultipartUploadRequest.Builder createBuilder = CreateMultipartUploadRequest.builder() - .bucket(bucket) - .key(key); - if (objectCannedACL != null) { - createBuilder.acl(objectCannedACL); + /** + * Stores a stream of data to S3 using TransferManager for efficient multipart uploads. + * Requires S3AsyncClient to be provided in the constructor. + * + * @param bucket The S3 bucket name + * @param key The S3 object key + * @param payloadStream The input stream to upload + * @throws IllegalStateException if S3AsyncClient was not provided in constructor + */ + public void storeTextStreamInS3(String bucket, String key, InputStream payloadStream) { + if (s3AsyncClient == null) { + throw new IllegalStateException("S3AsyncClient must be provided in constructor for streaming uploads"); } - // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html - if (serverSideEncryptionStrategy != null) { - serverSideEncryptionStrategy.decorate(createBuilder); - } - - String uploadId = null; - try { - CreateMultipartUploadResponse createResp = s3Client.createMultipartUpload(createBuilder.build()); - uploadId = createResp.uploadId(); - - int partCount = (int) Math.ceil((double) data.length / partSize); - List completedParts = new ArrayList<>(partCount); - - for (int partNumber = 1, offset = 0; offset < data.length; partNumber++) { - int currSize = Math.min(partSize, data.length - offset); - byte[] slice = Arrays.copyOfRange(data, offset, offset + currSize); - offset += currSize; - - UploadPartRequest uploadPartRequest = UploadPartRequest.builder() - .bucket(bucket) - .key(key) - .uploadId(uploadId) - .partNumber(partNumber) - .contentLength((long) slice.length) - .build(); - UploadPartResponse upr = s3Client.uploadPart(uploadPartRequest, RequestBody.fromBytes(slice)); - completedParts.add(CompletedPart.builder().partNumber(partNumber).eTag(upr.eTag()).build()); - } + S3TransferManager transferManager = S3TransferManager.builder().s3Client(s3AsyncClient).build(); + ExecutorService executor = Executors.newSingleThreadExecutor(); - CompletedMultipartUpload completed = CompletedMultipartUpload.builder().parts(completedParts).build(); - CompleteMultipartUploadRequest completeReq = CompleteMultipartUploadRequest.builder() - .bucket(bucket) - .key(key) - .uploadId(uploadId) - .multipartUpload(completed) - .build(); - s3Client.completeMultipartUpload(completeReq); - LOG.info("S3 multipart object created, Bucket name: " + bucket + ", Object key: " + key + ", Parts: " + completedParts.size() + "."); - } catch (SdkException e) { - if (uploadId != null) { - try { - s3Client.abortMultipartUpload(AbortMultipartUploadRequest.builder().bucket(bucket).key(key).uploadId(uploadId).build()); - } catch (Exception abortEx) { - LOG.warn("Failed to abort multipart upload after failure.", abortEx); - } - } - String errorMessage = "Failed to store the message content in an S3 multipart object."; + try { + UploadRequest.Builder uploadBuilder = UploadRequest.builder() + .putObjectRequest(b -> { + b.bucket(bucket).key(key); + if (objectCannedACL != null) { + b.acl(objectCannedACL); + } + // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.htransferManagerl + if (serverSideEncryptionStrategy != null) { + serverSideEncryptionStrategy.decorate(b); + } + }) + .requestBody(AsyncRequestBody.fromInputStream(payloadStream, null, executor)); + + transferManager.upload(uploadBuilder.build()).completionFuture().join(); + LOG.info("S3 stream object created from InputStream, Bucket name: " + bucket + ", Object key: " + key + "."); + } catch (Exception e) { + String errorMessage = "Failed to store the message content from InputStream in an S3 stream object."; LOG.error(errorMessage, e); throw SdkException.create(errorMessage, e); + } finally { + transferManager.close(); + executor.shutdownNow(); } } } diff --git a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java index 8cd4f89..43aced1 100644 --- a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java +++ b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java @@ -5,5 +5,5 @@ public interface ServerSideEncryptionStrategy { void decorate(PutObjectRequest.Builder putObjectRequestBuilder); - void decorate(CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder); + void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder); } diff --git a/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java new file mode 100644 index 0000000..ba3ac0d --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java @@ -0,0 +1,33 @@ +package software.amazon.payloadoffloading; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import java.io.InputStream; + +/** + * Optional extension interface for {@link PayloadStore} implementations that can perform stream + * upload and streaming retrieval for very large payloads to reduce memory pressure. + */ +public interface StreamPayloadStore extends PayloadStore { + + /** + * Store a payload from an InputStream using streaming-aware logic to avoid loading large payloads into memory. + * @param payloadStream InputStream containing UTF-8 text payload + * @param s3Key pre-generated S3 key (or equivalent object key) to use + * @return pointer string (JSON) referencing stored payload + * @throws UnsupportedOperationException if streaming upload from InputStream is not implemented + */ + default String storeOriginalPayloadStream(InputStream payloadStream, String s3Key) { + throw new UnsupportedOperationException("Streaming upload from InputStream not implemented"); + } + + /** + * Retrieve a payload as a stream to avoid loading large payloads into memory. + * @param payloadPointer pointer JSON string + * @return stream containing the original payload content + * @throws UnsupportedOperationException if streaming retrieval is not implemented + */ + default ResponseInputStream getOriginalPayloadStreamStream(String payloadPointer) { + throw new UnsupportedOperationException("Streaming retrieval not implemented"); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java new file mode 100644 index 0000000..0333eec --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java @@ -0,0 +1,34 @@ +package software.amazon.payloadoffloading; + +import java.io.InputStream; +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +/** + * Optional extension interface for {@link PayloadStoreAsync} implementations that can perform stream + * upload and streaming retrieval for very large payloads to reduce memory pressure. + */ +public interface StreamPayloadStoreAsync extends PayloadStoreAsync { + + /** + * Store payload content from an InputStream using stream semantics to avoid loading large payloads into memory. + * @param payloadStream InputStream containing UTF-8 text payload + * @param s3Key object key to use + * @return future pointer string (JSON) referencing stored payload + * @throws UnsupportedOperationException if streaming upload from InputStream is not implemented + */ + default CompletableFuture storeOriginalPayloadStream(InputStream payloadStream, String s3Key) { + throw new UnsupportedOperationException("Streaming upload from InputStream not implemented"); + } + + /** + * Retrieve payload as a stream to avoid loading large payloads into memory. + * @param payloadPointer pointer JSON string + * @return future stream containing the original payload content + * @throws UnsupportedOperationException if streaming retrieval is not implemented + */ + default CompletableFuture> getOriginalPayloadStreamStream(String payloadPointer) { + throw new UnsupportedOperationException("Streaming retrieval not implemented"); + } +} diff --git a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java index 18a5c0a..2140ebc 100644 --- a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java +++ b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java @@ -24,7 +24,7 @@ public void testAwsManagedCmkStrategySetsCorrectEncryptionValues() { @Test - public void testAwsManagedCmkStrategyMultipartSetsCorrectEncryptionValues() { + public void testAwsManagedCmkStrategyStreamSetsCorrectEncryptionValues() { AwsManagedCmk awsManagedCmk = new AwsManagedCmk(); CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder(); diff --git a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java index ed39cc5..30ee761 100644 --- a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java +++ b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java @@ -24,7 +24,7 @@ public void testCustomerKeyStrategySetsCorrectEncryptionValues() { } @Test - public void testCustomerKeyStrategySetsMultipartUploadEncryptionValues() { + public void testCustomerKeyStrategySetsStreamUploadEncryptionValues() { CustomerKey customerKey = new CustomerKey(AWS_KMS_KEY_ID); CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder(); diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java index 78372d3..8aeeee7 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java @@ -101,49 +101,49 @@ public void testCannedAccessControlList() { } @Test - public void testMultipartUploadEnabled() { + public void testStreamUploadEnabled() { PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); - payloadStorageConfiguration.setMultipartUploadEnabled(true); - assertTrue(payloadStorageConfiguration.isMultipartUploadEnabled()); + payloadStorageConfiguration.setStreamUploadEnabled(true); + assertTrue(payloadStorageConfiguration.isStreamUploadEnabled()); - payloadStorageConfiguration.setMultipartUploadEnabled(false); - assertFalse(payloadStorageConfiguration.isMultipartUploadEnabled()); + payloadStorageConfiguration.setStreamUploadEnabled(false); + assertFalse(payloadStorageConfiguration.isStreamUploadEnabled()); } @Test - public void testMultipartUploadThreshold() { + public void testStreamUploadThreshold() { PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); int customThreshold = 10 * 1024 * 1024; // 10MB - payloadStorageConfiguration.setMultipartUploadThreshold(customThreshold); - assertEquals(customThreshold, payloadStorageConfiguration.getMultipartUploadThreshold()); + payloadStorageConfiguration.setStreamUploadThreshold(customThreshold); + assertEquals(customThreshold, payloadStorageConfiguration.getStreamUploadThreshold()); } @Test - public void testMultipartUploadPartSize() { + public void testStreamUploadPartSize() { PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); int customPartSize = 10 * 1024 * 1024; // 10MB - payloadStorageConfiguration.setMultipartUploadPartSize(customPartSize); - assertEquals(customPartSize, payloadStorageConfiguration.getMultipartUploadPartSize()); + payloadStorageConfiguration.setStreamUploadPartSize(customPartSize); + assertEquals(customPartSize, payloadStorageConfiguration.getStreamUploadPartSize()); } @Test - public void testMultipartConfigurationInCopyConstructor() { + public void testStreamConfigurationInCopyConstructor() { S3AsyncClient s3Async = mock(S3AsyncClient.class); PayloadStorageAsyncConfiguration original = new PayloadStorageAsyncConfiguration(); original.withPayloadSupportEnabled(s3Async, s3BucketName) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(10 * 1024 * 1024) - .setMultipartUploadPartSize(8 * 1024 * 1024); + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(10 * 1024 * 1024) + .setStreamUploadPartSize(8 * 1024 * 1024); PayloadStorageAsyncConfiguration copy = new PayloadStorageAsyncConfiguration(original); - assertTrue(copy.isMultipartUploadEnabled()); - assertEquals(10 * 1024 * 1024, copy.getMultipartUploadThreshold()); - assertEquals(8 * 1024 * 1024, copy.getMultipartUploadPartSize()); + assertTrue(copy.isStreamUploadEnabled()); + assertEquals(10 * 1024 * 1024, copy.getStreamUploadThreshold()); + assertEquals(8 * 1024 * 1024, copy.getStreamUploadPartSize()); assertNotSame(copy, original); } } \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java index e304666..d233826 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java @@ -101,49 +101,49 @@ public void testCannedAccessControlList() { } @Test - public void testMultipartUploadEnabled() { + public void testStreamUploadEnabled() { PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); - payloadStorageConfiguration.setMultipartUploadEnabled(true); - assertTrue(payloadStorageConfiguration.isMultipartUploadEnabled()); + payloadStorageConfiguration.setStreamUploadEnabled(true); + assertTrue(payloadStorageConfiguration.isStreamUploadEnabled()); - payloadStorageConfiguration.setMultipartUploadEnabled(false); - assertFalse(payloadStorageConfiguration.isMultipartUploadEnabled()); + payloadStorageConfiguration.setStreamUploadEnabled(false); + assertFalse(payloadStorageConfiguration.isStreamUploadEnabled()); } @Test - public void testMultipartUploadThreshold() { + public void testStreamUploadThreshold() { PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); int customThreshold = 10 * 1024 * 1024; // 10MB - payloadStorageConfiguration.setMultipartUploadThreshold(customThreshold); - assertEquals(customThreshold, payloadStorageConfiguration.getMultipartUploadThreshold()); + payloadStorageConfiguration.setStreamUploadThreshold(customThreshold); + assertEquals(customThreshold, payloadStorageConfiguration.getStreamUploadThreshold()); } @Test - public void testMultipartUploadPartSize() { + public void testStreamUploadPartSize() { PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); int customPartSize = 10 * 1024 * 1024; // 10MB - payloadStorageConfiguration.setMultipartUploadPartSize(customPartSize); - assertEquals(customPartSize, payloadStorageConfiguration.getMultipartUploadPartSize()); + payloadStorageConfiguration.setStreamUploadPartSize(customPartSize); + assertEquals(customPartSize, payloadStorageConfiguration.getStreamUploadPartSize()); } @Test - public void testMultipartConfigurationInCopyConstructor() { + public void testStreamConfigurationInCopyConstructor() { S3Client s3 = mock(S3Client.class); PayloadStorageConfiguration original = new PayloadStorageConfiguration(); original.withPayloadSupportEnabled(s3, s3BucketName) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(10 * 1024 * 1024) - .setMultipartUploadPartSize(8 * 1024 * 1024); + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(10 * 1024 * 1024) + .setStreamUploadPartSize(8 * 1024 * 1024); PayloadStorageConfiguration copy = new PayloadStorageConfiguration(original); - assertTrue(copy.isMultipartUploadEnabled()); - assertEquals(10 * 1024 * 1024, copy.getMultipartUploadThreshold()); - assertEquals(8 * 1024 * 1024, copy.getMultipartUploadPartSize()); + assertTrue(copy.isStreamUploadEnabled()); + assertEquals(10 * 1024 * 1024, copy.getStreamUploadThreshold()); + assertEquals(8 * 1024 * 1024, copy.getStreamUploadPartSize()); assertNotSame(copy, original); } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java index f3ca780..244c8b4 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java @@ -5,29 +5,19 @@ import static org.mockito.Mockito.*; import java.nio.charset.StandardCharsets; -import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; -import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; -import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.services.s3.model.UploadPartResponse; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; public class S3AsyncDaoTest { @@ -118,168 +108,4 @@ public void deleteTextTest() { verify(s3AsyncClient, times(1)).deleteObject(any(DeleteObjectRequest.class)); } - - - @Test - public void storeTextMultipartInS3_FallbackToSinglePut_WhenBelowThreshold() { - dao = new S3AsyncDao(s3AsyncClient); - String smallPayload = "Small payload"; - int multipartThreshold = 1000; - int partSize = 5 * 1024 * 1024; - - when(s3AsyncClient.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))) - .thenReturn(CompletableFuture.completedFuture(null)); - - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, smallPayload, partSize, multipartThreshold).join(); - - verify(s3AsyncClient, times(1)).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); - verify(s3AsyncClient, never()).createMultipartUpload(any(CreateMultipartUploadRequest.class)); - verify(s3AsyncClient, never()).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); - verify(s3AsyncClient, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); - } - - @Test - public void storeTextMultipartInS3_UseMultipart_WhenAboveThreshold() { - dao = new S3AsyncDao(s3AsyncClient); - String largePayload = generateString(10 * 1024 * 1024); // 10MB - int multipartThreshold = 5 * 1024 * 1024; // 5MB - int partSize = 5 * 1024 * 1024; // 5MB - - CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() - .uploadId("test-upload-id") - .build(); - when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(createResponse)); - - UploadPartResponse uploadPartResponse = UploadPartResponse.builder() - .eTag("test-etag") - .build(); - when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) - .thenReturn(CompletableFuture.completedFuture(uploadPartResponse)); - - when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(null)); - - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold).join(); - - verify(s3AsyncClient, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); - verify(s3AsyncClient, times(2)).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); // 10MB / 5MB = 2 parts - verify(s3AsyncClient, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); - verify(s3AsyncClient, never()).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); - } - - @Test - public void storeTextMultipartInS3_WithSSEAndACL() { - dao = new S3AsyncDao(s3AsyncClient, serverSideEncryptionStrategy, objectCannedACL); - String largePayload = generateString(6 * 1024 * 1024); // 6MB - int multipartThreshold = 5 * 1024 * 1024; // 5MB - int partSize = 5 * 1024 * 1024; // 5MB - - CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() - .uploadId("test-upload-id") - .build(); - when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(createResponse)); - - UploadPartResponse uploadPartResponse = UploadPartResponse.builder() - .eTag("test-etag") - .build(); - when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) - .thenReturn(CompletableFuture.completedFuture(uploadPartResponse)); - - when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(null)); - - ArgumentCaptor createCaptor = - ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); - - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold).join(); - - verify(s3AsyncClient, times(1)).createMultipartUpload(createCaptor.capture()); - CreateMultipartUploadRequest capturedRequest = createCaptor.getValue(); - - assertEquals(S3_BUCKET_NAME, capturedRequest.bucket()); - assertEquals(ANY_S3_KEY, capturedRequest.key()); - assertEquals(ServerSideEncryption.AWS_KMS, capturedRequest.serverSideEncryption()); - assertEquals(objectCannedACL, capturedRequest.acl()); - } - - @Test - public void storeTextMultipartInS3_VerifyCompleteRequest() { - dao = new S3AsyncDao(s3AsyncClient); - String largePayload = generateString(6 * 1024 * 1024); // 6MB - int multipartThreshold = 5 * 1024 * 1024; // 5MB - int partSize = 5 * 1024 * 1024; // 5MB - - CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() - .uploadId("test-upload-id") - .build(); - when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(createResponse)); - - UploadPartResponse uploadPartResponse = UploadPartResponse.builder() - .eTag("etag-test") - .build(); - when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) - .thenReturn(CompletableFuture.completedFuture(uploadPartResponse)); - - when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(null)); - - ArgumentCaptor completeCaptor = - ArgumentCaptor.forClass(CompleteMultipartUploadRequest.class); - - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold).join(); - - verify(s3AsyncClient, times(1)).completeMultipartUpload(completeCaptor.capture()); - CompleteMultipartUploadRequest capturedRequest = completeCaptor.getValue(); - - assertEquals(S3_BUCKET_NAME, capturedRequest.bucket()); - assertEquals(ANY_S3_KEY, capturedRequest.key()); - assertEquals("test-upload-id", capturedRequest.uploadId()); - assertNotNull(capturedRequest.multipartUpload()); - assertEquals(2, capturedRequest.multipartUpload().parts().size()); - } - - @Test - public void storeTextMultipartInS3_AbortOnCompleteFailure() { - dao = new S3AsyncDao(s3AsyncClient); - String largePayload = generateString(6 * 1024 * 1024); // 6MB - int multipartThreshold = 5 * 1024 * 1024; // 5MB - int partSize = 5 * 1024 * 1024; // 5MB - - CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() - .uploadId("test-upload-id") - .build(); - when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(createResponse)); - - UploadPartResponse uploadPartResponse = UploadPartResponse.builder() - .eTag("test-etag") - .build(); - when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))) - .thenReturn(CompletableFuture.completedFuture(uploadPartResponse)); - - CompletableFuture failedFuture = new CompletableFuture<>(); - failedFuture.completeExceptionally(SdkException.builder().message("Complete failed").build()); - when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) - .thenReturn(failedFuture); - - when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))) - .thenReturn(CompletableFuture.completedFuture(null)); - - assertThrows(CompletionException.class, () -> { - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold).join(); - }); - - verify(s3AsyncClient, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); - } - - private String generateString(int length) { - StringBuilder sb = new StringBuilder(length); - for (int i = 0; i < length; i++) { - sb.append('a'); - } - return sb.toString(); - } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsyncTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsyncTest.java deleted file mode 100644 index 5fd1d60..0000000 --- a/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreAsyncTest.java +++ /dev/null @@ -1,127 +0,0 @@ -package software.amazon.payloadoffloading; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import software.amazon.awssdk.core.exception.SdkException; - -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; - -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.*; - -public class S3BackedMultipartPayloadStoreAsyncTest { - private static final String S3_BUCKET_NAME = "test-bucket-name"; - private static final String ANY_PAYLOAD = "AnyPayload"; - private static final String ANY_S3_KEY = "AnyS3key"; - private static final int PART_SIZE = 5 * 1024 * 1024; // 5MB - private static final int THRESHOLD = 5 * 1024 * 1024; // 5MB - - private S3BackedMultipartPayloadStoreAsync multipartPayloadStore; - private S3AsyncDao s3AsyncDao; - - @BeforeEach - public void setup() { - s3AsyncDao = mock(S3AsyncDao.class); - multipartPayloadStore = new S3BackedMultipartPayloadStoreAsync(s3AsyncDao, S3_BUCKET_NAME, PART_SIZE, THRESHOLD); - } - - @Test - public void testStoreOriginalPayloadMultipartOnSuccess() { - when(s3AsyncDao.storeTextMultipartInS3(any(String.class), any(String.class), any(String.class), - any(Integer.class), any(Integer.class))) - .thenReturn(CompletableFuture.completedFuture(null)); - - String actualPayloadPointer = multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY).join(); - - ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class); - ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); - ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(String.class); - ArgumentCaptor partSizeCaptor = ArgumentCaptor.forClass(Integer.class); - ArgumentCaptor thresholdCaptor = ArgumentCaptor.forClass(Integer.class); - - verify(s3AsyncDao, times(1)).storeTextMultipartInS3( - bucketCaptor.capture(), - keyCaptor.capture(), - payloadCaptor.capture(), - partSizeCaptor.capture(), - thresholdCaptor.capture() - ); - - assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue()); - assertEquals(ANY_S3_KEY, keyCaptor.getValue()); - assertEquals(ANY_PAYLOAD, payloadCaptor.getValue()); - assertEquals(PART_SIZE, partSizeCaptor.getValue()); - assertEquals(THRESHOLD, thresholdCaptor.getValue()); - - PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); - assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); - } - - @Test - public void testStoreOriginalPayloadMultipartWithCustomPartSize() { - int customPartSize = 10 * 1024 * 1024; // 10MB - int customThreshold = 8 * 1024 * 1024; // 8MB - - multipartPayloadStore = new S3BackedMultipartPayloadStoreAsync(s3AsyncDao, S3_BUCKET_NAME, customPartSize, customThreshold); - - String payload = generateString(12 * 1024 * 1024); // 12MB - String s3Key = "custom-size-key"; - - when(s3AsyncDao.storeTextMultipartInS3(any(String.class), any(String.class), any(String.class), - any(Integer.class), any(Integer.class))) - .thenReturn(CompletableFuture.completedFuture(null)); - - multipartPayloadStore.storeOriginalPayloadMultipart(payload, s3Key).join(); - - verify(s3AsyncDao, times(1)).storeTextMultipartInS3( - eq(S3_BUCKET_NAME), - eq(s3Key), - eq(payload), - eq(customPartSize), - eq(customThreshold) - ); - } - - @Test - public void testStoreOriginalPayloadMultipartOnS3Failure() { - CompletableFuture failedFuture = new CompletableFuture<>(); - failedFuture.completeExceptionally(SdkException.create("S3 Exception", new Throwable())); - - when(s3AsyncDao.storeTextMultipartInS3(any(String.class), any(String.class), any(String.class), - any(Integer.class), any(Integer.class))) - .thenReturn(failedFuture); - - CompletionException exception = assertThrows(CompletionException.class, () -> { - multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY).join(); - }); - - assertTrue(exception.getMessage().contains("S3 Exception")); - } - - @Test - public void testStoreOriginalPayloadMultipartHandlesNullFromDao() { - when(s3AsyncDao.storeTextMultipartInS3(any(String.class), any(String.class), any(String.class), - any(Integer.class), any(Integer.class))) - .thenReturn(CompletableFuture.completedFuture(null)); - - String actualPayloadPointer = multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY).join(); - - // Should still return valid pointer even if DAO returns null/void - assertNotNull(actualPayloadPointer); - PayloadS3Pointer pointer = PayloadS3Pointer.fromJson(actualPayloadPointer); - assertEquals(S3_BUCKET_NAME, pointer.getS3BucketName()); - assertEquals(ANY_S3_KEY, pointer.getS3Key()); - } - - private String generateString(int length) { - StringBuilder sb = new StringBuilder(length); - for (int i = 0; i < length; i++) { - sb.append('a'); - } - return sb.toString(); - } -} diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreTest.java deleted file mode 100644 index 0d2ca17..0000000 --- a/src/test/java/software/amazon/payloadoffloading/S3BackedMultipartPayloadStoreTest.java +++ /dev/null @@ -1,102 +0,0 @@ -package software.amazon.payloadoffloading; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import software.amazon.awssdk.core.exception.SdkException; - -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.*; - -public class S3BackedMultipartPayloadStoreTest { - private static final String S3_BUCKET_NAME = "test-bucket-name"; - private static final String ANY_PAYLOAD = "AnyPayload"; - private static final String ANY_S3_KEY = "AnyS3key"; - private static final int PART_SIZE = 5 * 1024 * 1024; // 5MB - private static final int THRESHOLD = 5 * 1024 * 1024; // 5MB - - private S3BackedMultipartPayloadStore multipartPayloadStore; - private S3Dao s3Dao; - - @BeforeEach - public void setup() { - s3Dao = mock(S3Dao.class); - multipartPayloadStore = new S3BackedMultipartPayloadStore(s3Dao, S3_BUCKET_NAME, PART_SIZE, THRESHOLD); - } - - @Test - public void testStoreOriginalPayloadMultipartOnSuccess() { - String actualPayloadPointer = multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY); - - ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class); - ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); - ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(String.class); - ArgumentCaptor partSizeCaptor = ArgumentCaptor.forClass(Integer.class); - ArgumentCaptor thresholdCaptor = ArgumentCaptor.forClass(Integer.class); - - verify(s3Dao, times(1)).storeTextMultipartInS3( - bucketCaptor.capture(), - keyCaptor.capture(), - payloadCaptor.capture(), - partSizeCaptor.capture(), - thresholdCaptor.capture() - ); - - assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue()); - assertEquals(ANY_S3_KEY, keyCaptor.getValue()); - assertEquals(ANY_PAYLOAD, payloadCaptor.getValue()); - assertEquals(PART_SIZE, partSizeCaptor.getValue()); - assertEquals(THRESHOLD, thresholdCaptor.getValue()); - - PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); - assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); - } - - @Test - public void testStoreOriginalPayloadMultipartWithCustomPartSize() { - int customPartSize = 10 * 1024 * 1024; // 10MB - int customThreshold = 8 * 1024 * 1024; // 8MB - - multipartPayloadStore = new S3BackedMultipartPayloadStore(s3Dao, S3_BUCKET_NAME, customPartSize, customThreshold); - - String payload = generateString(12 * 1024 * 1024); // 12MB - String s3Key = "custom-size-key"; - - multipartPayloadStore.storeOriginalPayloadMultipart(payload, s3Key); - - verify(s3Dao, times(1)).storeTextMultipartInS3( - eq(S3_BUCKET_NAME), - eq(s3Key), - eq(payload), - eq(customPartSize), - eq(customThreshold) - ); - } - - @Test - public void testStoreOriginalPayloadMultipartOnS3Failure() { - doThrow(SdkException.create("S3 Exception", new Throwable())) - .when(s3Dao) - .storeTextMultipartInS3( - any(String.class), - any(String.class), - any(String.class), - any(Integer.class), - any(Integer.class) - ); - - assertThrows(SdkException.class, () -> { - multipartPayloadStore.storeOriginalPayloadMultipart(ANY_PAYLOAD, ANY_S3_KEY); - }, "S3 Exception"); - } - - private String generateString(int length) { - StringBuilder sb = new StringBuilder(length); - for (int i = 0; i < length; i++) { - sb.append('a'); - } - return sb.toString(); - } -} diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java new file mode 100644 index 0000000..1b5633d --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java @@ -0,0 +1,114 @@ +package software.amazon.payloadoffloading; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.exception.SdkException; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +public class S3BackedStreamPayloadStoreAsyncTest { + private static final String S3_BUCKET_NAME = "test-bucket-name"; + private static final String ANY_PAYLOAD = "AnyPayload"; + private static final String ANY_S3_KEY = "AnyS3key"; + + private S3BackedStreamPayloadStoreAsync streamPayloadStore; + private S3AsyncDao s3AsyncDao; + + @BeforeEach + public void setup() { + s3AsyncDao = mock(S3AsyncDao.class); + streamPayloadStore = new S3BackedStreamPayloadStoreAsync(s3AsyncDao, S3_BUCKET_NAME); + } + + @Test + public void testStoreOriginalPayloadStreamOnSuccess() { + when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join(); + + ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(java.io.InputStream.class); + + verify(s3AsyncDao, times(1)).storeTextStreamInS3( + bucketCaptor.capture(), + keyCaptor.capture(), + payloadCaptor.capture() + ); + + assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue()); + assertEquals(ANY_S3_KEY, keyCaptor.getValue()); + assertNotNull(payloadCaptor.getValue()); + + PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); + assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); + } + + @Test + public void testStoreOriginalPayloadStreamWithCustomPartSize() { + // Since part size is now configured at construction time, this test verifies + // that the store operation works correctly with the configured DAO + String payload = generateString(12 * 1024); // Small payload for testing + String s3Key = "custom-size-key"; + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + streamPayloadStore.storeOriginalPayloadStream(payloadStream, s3Key).join(); + + verify(s3AsyncDao, times(1)).storeTextStreamInS3( + eq(S3_BUCKET_NAME), + eq(s3Key), + any(java.io.InputStream.class) + ); + } + + @Test + public void testStoreOriginalPayloadStreamOnS3Failure() { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(SdkException.create("S3 Exception", new Throwable())); + + when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class))) + .thenReturn(failedFuture); + + CompletionException exception = assertThrows(CompletionException.class, () -> { + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join(); + }); + + assertTrue(exception.getMessage().contains("S3 Exception")); + } + + @Test + public void testStoreOriginalPayloadStreamHandlesNullFromDao() { + when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join(); + + // Should still return valid pointer even if DAO returns null/void + assertNotNull(actualPayloadPointer); + PayloadS3Pointer pointer = PayloadS3Pointer.fromJson(actualPayloadPointer); + assertEquals(S3_BUCKET_NAME, pointer.getS3BucketName()); + assertEquals(ANY_S3_KEY, pointer.getS3Key()); + } + + private String generateString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('a'); + } + return sb.toString(); + } +} diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java new file mode 100644 index 0000000..e9eee4e --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java @@ -0,0 +1,90 @@ +package software.amazon.payloadoffloading; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.exception.SdkException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +public class S3BackedStreamPayloadStoreTest { + private static final String S3_BUCKET_NAME = "test-bucket-name"; + private static final String ANY_PAYLOAD = "AnyPayload"; + private static final String ANY_S3_KEY = "AnyS3key"; + + private S3BackedStreamPayloadStore streamPayloadStore; + private S3Dao s3Dao; + + @BeforeEach + public void setup() { + s3Dao = mock(S3Dao.class); + streamPayloadStore = new S3BackedStreamPayloadStore(s3Dao, S3_BUCKET_NAME); + } + + @Test + public void testStoreOriginalPayloadStreamOnSuccess() { + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY); + + ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(java.io.InputStream.class); + + verify(s3Dao, times(1)).storeTextStreamInS3( + bucketCaptor.capture(), + keyCaptor.capture(), + payloadCaptor.capture() + ); + + assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue()); + assertEquals(ANY_S3_KEY, keyCaptor.getValue()); + assertNotNull(payloadCaptor.getValue()); + + PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); + assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); + } + + @Test + public void testStoreOriginalPayloadStreamWithCustomPartSize() { + // Since part size is now configured at construction time, this test verifies + // that the store operation works correctly with the configured DAO + String payload = generateString(12 * 1024); // Small payload for testing + String s3Key = "custom-size-key"; + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + streamPayloadStore.storeOriginalPayloadStream(payloadStream, s3Key); + + verify(s3Dao, times(1)).storeTextStreamInS3( + eq(S3_BUCKET_NAME), + eq(s3Key), + any(java.io.InputStream.class) + ); + } + + @Test + public void testStoreOriginalPayloadStreamOnS3Failure() { + doThrow(SdkException.create("S3 Exception", new Throwable())) + .when(s3Dao) + .storeTextStreamInS3( + any(String.class), + any(String.class), + any(java.io.InputStream.class) + ); + + assertThrows(SdkException.class, () -> { + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY); + }, "S3 Exception"); + } + + private String generateString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('a'); + } + return sb.toString(); + } +} diff --git a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java index 53662ce..a99a3d1 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java @@ -2,50 +2,43 @@ import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; -import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; -import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.services.s3.model.UploadPartResponse; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.when; -import java.util.List; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; public class S3DaoTest { - private static final String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id"; private static final String S3_BUCKET_NAME = "test-bucket-name"; private static final String ANY_PAYLOAD = "AnyPayload"; private static final String ANY_S3_KEY = "AnyS3key"; private final ServerSideEncryptionStrategy serverSideEncryptionStrategy = ServerSideEncryptionFactory.awsManagedCmk(); private final ObjectCannedACL objectCannedACL = ObjectCannedACL.PUBLIC_READ; private S3Client s3Client; + private S3AsyncClient s3AsyncClient; private S3Dao dao; @BeforeEach public void setup() { s3Client = mock(S3Client.class); + s3AsyncClient = mock(S3AsyncClient.class); } @Test @@ -89,147 +82,28 @@ public void storeTextInS3WithBothTest() { assertEquals(objectCannedACL, argument.getValue().acl()); assertEquals(S3_BUCKET_NAME, argument.getValue().bucket()); } - + @Test - public void storeTextMultipartInS3_FallbackToSinglePut_WhenBelowThreshold() { - dao = new S3Dao(s3Client); - String smallPayload = "Small payload"; - int multipartThreshold = 1000; - int partSize = 5 * 1024 * 1024; - - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, smallPayload, partSize, multipartThreshold); + public void storeTextStreamInS3_ThrowsIllegalStateWhenS3AsyncClientNotProvided() { + dao = new S3Dao(s3Client, null, null, null); - verify(s3Client, times(1)).putObject(any(PutObjectRequest.class), any(RequestBody.class)); - verify(s3Client, never()).createMultipartUpload(any(CreateMultipartUploadRequest.class)); - verify(s3Client, never()).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); - verify(s3Client, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); - } + byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8); + InputStream inputStream = new ByteArrayInputStream(testData); - @Test - public void storeTextMultipartInS3_VerifyCompleteRequest() { - dao = new S3Dao(s3Client); - String largePayload = generateString(6 * 1024 * 1024); // 6MB - int multipartThreshold = 5 * 1024 * 1024; // 5MB - int partSize = 5 * 1024 * 1024; // 5MB - - CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() - .uploadId("test-upload-id") - .build(); - when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenReturn(createResponse); - - UploadPartResponse uploadPartResponse = UploadPartResponse.builder() - .eTag("etag-test") - .build(); - when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) - .thenReturn(uploadPartResponse); - - ArgumentCaptor completeCaptor = - ArgumentCaptor.forClass(CompleteMultipartUploadRequest.class); - - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold); - - verify(s3Client, times(1)).completeMultipartUpload(completeCaptor.capture()); - CompleteMultipartUploadRequest capturedRequest = completeCaptor.getValue(); - - assertEquals(S3_BUCKET_NAME, capturedRequest.bucket()); - assertEquals(ANY_S3_KEY, capturedRequest.key()); - assertEquals("test-upload-id", capturedRequest.uploadId()); - assertNotNull(capturedRequest.multipartUpload()); - assertEquals(2, capturedRequest.multipartUpload().parts().size()); - } - - @Test - public void storeTextMultipartInS3_WithSSEAndACL() { - dao = new S3Dao(s3Client, serverSideEncryptionStrategy, objectCannedACL); - String largePayload = generateString(6 * 1024 * 1024); // 6MB - int multipartThreshold = 5 * 1024 * 1024; // 5MB - int partSize = 5 * 1024 * 1024; // 5MB - - CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() - .uploadId("test-upload-id") - .build(); - when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenReturn(createResponse); - - UploadPartResponse uploadPartResponse = UploadPartResponse.builder() - .eTag("test-etag") - .build(); - when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) - .thenReturn(uploadPartResponse); - - ArgumentCaptor createCaptor = - ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); - - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold); - - verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture()); - CreateMultipartUploadRequest capturedRequest = createCaptor.getValue(); - - assertEquals(S3_BUCKET_NAME, capturedRequest.bucket()); - assertEquals(ANY_S3_KEY, capturedRequest.key()); - assertEquals(ServerSideEncryption.AWS_KMS, capturedRequest.serverSideEncryption()); - assertEquals(objectCannedACL, capturedRequest.acl()); - } - - @Test - public void storeTextMultipartInS3_AbortOnUploadPartFailure() { - dao = new S3Dao(s3Client); - String largePayload = generateString(6 * 1024 * 1024); // 6MB - int multipartThreshold = 5 * 1024 * 1024; // 5MB - int partSize = 5 * 1024 * 1024; // 5MB - - CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() - .uploadId("test-upload-id") - .build(); - when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenReturn(createResponse); - - when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) - .thenThrow(SdkException.builder().message("Upload part failed").build()); - - assertThrows(SdkException.class, () -> { - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold); + assertThrows(IllegalStateException.class, () -> { + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); }); - - verify(s3Client, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); - verify(s3Client, never()).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); } @Test - public void storeTextMultipartInS3_AbortOnCompleteFailure() { + public void storeTextStreamInS3_ThrowsExceptionWhenTransferManagerNotProvided() { dao = new S3Dao(s3Client); - String largePayload = generateString(6 * 1024 * 1024); // 6MB - int multipartThreshold = 5 * 1024 * 1024; // 5MB - int partSize = 5 * 1024 * 1024; // 5MB - - CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() - .uploadId("test-upload-id") - .build(); - when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) - .thenReturn(createResponse); - - UploadPartResponse uploadPartResponse = UploadPartResponse.builder() - .eTag("test-etag") - .build(); - when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) - .thenReturn(uploadPartResponse); - - when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) - .thenThrow(SdkException.builder().message("Complete failed").build()); - - assertThrows(SdkException.class, () -> { - dao.storeTextMultipartInS3(S3_BUCKET_NAME, ANY_S3_KEY, largePayload, partSize, multipartThreshold); - }); - verify(s3Client, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); - } + byte[] testData = "Test data".getBytes(StandardCharsets.UTF_8); + InputStream inputStream = new ByteArrayInputStream(testData); - private String generateString(int length) { - StringBuilder sb = new StringBuilder(length); - for (int i = 0; i < length; i++) { - sb.append('a'); - } - return sb.toString(); + assertThrows(IllegalStateException.class, () -> { + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + }); } } \ No newline at end of file From 33bc16dba1ddda863ec03ecb8b8017bfdddf8327 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 09:37:45 +0700 Subject: [PATCH 3/5] cleanup --- .../software/amazon/payloadoffloading/S3AsyncDao.java | 11 ----------- .../java/software/amazon/payloadoffloading/S3Dao.java | 11 ++--------- .../amazon/payloadoffloading/AwsManagedCmkTest.java | 3 +-- .../amazon/payloadoffloading/CustomerKeyTest.java | 2 +- .../amazon/payloadoffloading/S3AsyncDaoTest.java | 9 ++++++--- .../S3BackedStreamPayloadStoreAsyncTest.java | 5 +---- .../S3BackedStreamPayloadStoreTest.java | 4 +--- 7 files changed, 12 insertions(+), 33 deletions(-) diff --git a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java index 92b0ef4..ea7b5e8 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java @@ -29,25 +29,14 @@ */ public class S3AsyncDao { private static final Logger LOG = LoggerFactory.getLogger(S3AsyncDao.class); - private final S3AsyncClient s3Client; private final ServerSideEncryptionStrategy serverSideEncryptionStrategy; private final ObjectCannedACL objectCannedACL; - /** - * Constructor for basic S3 operations (non-streaming). - * @param s3Client The S3 async client for standard operations - */ public S3AsyncDao(S3AsyncClient s3Client) { this(s3Client, null, null); } - /** - * Full constructor with SSE and ACL configuration. - * @param s3Client The S3 async client - * @param serverSideEncryptionStrategy Server-side encryption configuration - * @param objectCannedACL Canned ACL configuration - */ public S3AsyncDao( S3AsyncClient s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java index a84f11f..5c9f331 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java @@ -28,24 +28,17 @@ */ public class S3Dao { private static final Logger LOG = LoggerFactory.getLogger(S3Dao.class); - private final S3Client s3Client; private final ServerSideEncryptionStrategy serverSideEncryptionStrategy; private final ObjectCannedACL objectCannedACL; private S3AsyncClient s3AsyncClient; - + public S3Dao(S3Client s3Client) { this.s3Client = s3Client; this.serverSideEncryptionStrategy = null; this.objectCannedACL = null; } - /** - * Constructor with SSE and ACL configuration (non-streaming). - * @param s3Client The S3 sync client - * @param serverSideEncryptionStrategy Server-side encryption configuration - * @param objectCannedACL Canned ACL configuration - */ public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL) { this.s3Client = s3Client; this.serverSideEncryptionStrategy = serverSideEncryptionStrategy; @@ -155,7 +148,7 @@ public void deletePayloadFromS3(String s3BucketName, String s3Key) { /** - * Stores a stream of data to S3 using TransferManager for efficient multipart uploads. + * Stores a stream of data to S3 using TransferManager * Requires S3AsyncClient to be provided in the constructor. * * @param bucket The S3 bucket name diff --git a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java index 2140ebc..8a7c663 100644 --- a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java +++ b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java @@ -2,9 +2,8 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; - +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java index 30ee761..0474d30 100644 --- a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java +++ b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java @@ -2,8 +2,8 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java index 244c8b4..1ecccb8 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java @@ -1,8 +1,12 @@ package software.amazon.payloadoffloading; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.nio.charset.StandardCharsets; import java.util.concurrent.CompletableFuture; @@ -19,7 +23,6 @@ import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; - public class S3AsyncDaoTest { private static String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id"; diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java index 1b5633d..1edc7f4 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java @@ -55,9 +55,7 @@ public void testStoreOriginalPayloadStreamOnSuccess() { @Test public void testStoreOriginalPayloadStreamWithCustomPartSize() { - // Since part size is now configured at construction time, this test verifies - // that the store operation works correctly with the configured DAO - String payload = generateString(12 * 1024); // Small payload for testing + String payload = generateString(12 * 1024); String s3Key = "custom-size-key"; java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8)); @@ -97,7 +95,6 @@ public void testStoreOriginalPayloadStreamHandlesNullFromDao() { java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join(); - // Should still return valid pointer even if DAO returns null/void assertNotNull(actualPayloadPointer); PayloadS3Pointer pointer = PayloadS3Pointer.fromJson(actualPayloadPointer); assertEquals(S3_BUCKET_NAME, pointer.getS3BucketName()); diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java index e9eee4e..4679778 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java @@ -49,9 +49,7 @@ public void testStoreOriginalPayloadStreamOnSuccess() { @Test public void testStoreOriginalPayloadStreamWithCustomPartSize() { - // Since part size is now configured at construction time, this test verifies - // that the store operation works correctly with the configured DAO - String payload = generateString(12 * 1024); // Small payload for testing + String payload = generateString(12 * 1024); String s3Key = "custom-size-key"; java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8)); From 67750bba589111323f52750b5642c8455be750d8 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 14:45:39 +0700 Subject: [PATCH 4/5] remove transfer manager from s3Dao --- .../PayloadStorageAsyncConfiguration.java | 20 --- .../PayloadStorageConfiguration.java | 78 ++++----- .../PayloadStorageConfigurationBase.java | 50 ------ .../amazon/payloadoffloading/S3Dao.java | 156 ++++++++++++------ .../PayloadStorageAsyncConfigurationTest.java | 24 +-- .../payloadoffloading/S3AsyncDaoTest.java | 24 +++ .../amazon/payloadoffloading/S3DaoTest.java | 150 +++++++++++++++-- 7 files changed, 309 insertions(+), 193 deletions(-) diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java index 5840971..832053c 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java @@ -160,24 +160,4 @@ public PayloadStorageAsyncConfiguration withStreamUploadEnabled(boolean enabled) setStreamUploadEnabled(enabled); return this; } - - /** - * Sets the stream upload threshold (in bytes). Only used when stream upload is enabled. - * @param threshold threshold in bytes (must be >0) otherwise default (5MB) is applied. - * @return updated configuration - */ - public PayloadStorageAsyncConfiguration withStreamUploadThreshold(int threshold) { - setStreamUploadThreshold(threshold); - return this; - } - - public PayloadStorageAsyncConfiguration withStreamUploadPartSize(int partSize) { - setStreamUploadPartSize(partSize); - return this; - } - - public PayloadStorageAsyncConfiguration withS3Region(String s3Region) { - setS3Region(s3Region); - return this; - } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java index cd7b782..06a3041 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java @@ -36,17 +36,18 @@ public class PayloadStorageConfiguration extends PayloadStorageConfigurationBase private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfiguration.class); private S3Client s3; - private S3AsyncClient s3Async; + private int streamUploadPartSize; + private int streamUploadThreshold; public PayloadStorageConfiguration() { s3 = null; - s3Async = null; } public PayloadStorageConfiguration(PayloadStorageConfiguration other) { super(other); this.s3 = other.getS3Client(); - this.s3Async = other.getS3AsyncClient(); + this.streamUploadThreshold = other.getStreamUploadThreshold(); + this.streamUploadPartSize = other.getStreamUploadPartSize(); } /** @@ -106,37 +107,6 @@ public S3Client getS3Client() { return s3; } - /** - * Sets the optional Amazon S3 async client to be used for TransferManager. - * This is useful for pre-configuring the async client with specific endpoint/credentials - * (e.g., for LocalStack testing). - * - * @param s3AsyncClient The S3AsyncClient to use for TransferManager operations. - */ - public void setS3AsyncClient(S3AsyncClient s3AsyncClient) { - this.s3Async = s3AsyncClient; - } - - /** - * Sets the optional Amazon S3 async client to be used for TransferManager. - * - * @param s3AsyncClient The S3AsyncClient to use for TransferManager operations. - * @return the updated PayloadStorageConfiguration object. - */ - public PayloadStorageConfiguration withS3AsyncClient(S3AsyncClient s3AsyncClient) { - setS3AsyncClient(s3AsyncClient); - return this; - } - - /** - * Gets the Amazon S3 async client which is being used for TransferManager operations. - * - * @return Reference to the Amazon S3 async client, or null if not configured. - */ - public S3AsyncClient getS3AsyncClient() { - return s3Async; - } - /** * Sets the payload size threshold for storing payloads in Amazon S3. * @@ -209,8 +179,42 @@ public PayloadStorageConfiguration withStreamUploadPartSize(int partSize) { return this; } - public PayloadStorageConfiguration withS3Region(String s3Region) { - setS3Region(s3Region); - return this; + + /** + * Gets the stream upload threshold in bytes. Default 5MB. + * + * @return threshold in bytes. + */ + public int getStreamUploadThreshold() { return streamUploadThreshold; } + + /** + * Sets the stream upload threshold in bytes. Values less than or equal to zero will reset to default (5MB). + * + * @param streamUploadThreshold threshold in bytes + */ + public void setStreamUploadThreshold(int streamUploadThreshold) { + int min = 5 * 1024 * 1024; + if (streamUploadThreshold <= min) { + this.streamUploadThreshold = min; + } else { + this.streamUploadThreshold = streamUploadThreshold; + } + } + + /** + * Gets the configured stream upload part size (bytes). Default 5MB. + */ + public int getStreamUploadPartSize() { return streamUploadPartSize; } + + /** + * Sets the stream upload part size (bytes). Values < 5MB are rounded up to 5MB. + */ + public void setStreamUploadPartSize(int partSize) { + int min = 5 * 1024 * 1024; + if (partSize < min) { + this.streamUploadPartSize = min; + } else { + this.streamUploadPartSize = partSize; + } } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java index 68947ae..bbb3980 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java @@ -18,7 +18,6 @@ public abstract class PayloadStorageConfigurationBase { private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfigurationBase.class); private String s3BucketName; - private String s3Region = "ap-southeast-1"; private int payloadSizeThreshold = 0; private boolean alwaysThroughS3 = false; private boolean payloadSupport = false; @@ -51,9 +50,6 @@ public PayloadStorageConfigurationBase(PayloadStorageConfigurationBase other) { this.serverSideEncryptionStrategy = other.getServerSideEncryptionStrategy(); this.objectCannedACL = other.getObjectCannedACL(); this.streamUploadEnabled = other.isStreamUploadEnabled(); - this.streamUploadThreshold = other.getStreamUploadThreshold(); - this.streamUploadPartSize = other.getStreamUploadPartSize(); - this.s3Region = other.getS3Region(); } /** @@ -200,50 +196,4 @@ public ObjectCannedACL getObjectCannedACL() { * @param streamUploadEnabled flag to enable/disable stream support. */ public void setStreamUploadEnabled(boolean streamUploadEnabled) { this.streamUploadEnabled = streamUploadEnabled; } - - /** - * Gets the stream upload threshold in bytes. Default 5MB. - * - * @return threshold in bytes. - */ - public long getStreamUploadThreshold() { return streamUploadThreshold; } - - /** - * Sets the stream upload threshold in bytes. Values less than or equal to zero will reset to default (5MB). - * - * @param streamUploadThreshold threshold in bytes - */ - public void setStreamUploadThreshold(long streamUploadThreshold) { - long min = 5 * 1024 * 1024L; - if (streamUploadThreshold <= min) { - this.streamUploadThreshold = min; - } else { - this.streamUploadThreshold = streamUploadThreshold; - } - } - - /** - * Gets the configured stream upload part size (bytes). Default 5MB. - */ - public long getStreamUploadPartSize() { return streamUploadPartSize; } - - /** - * Sets the stream upload part size (bytes). Values < 5MB are rounded up to 5MB. - */ - public void setStreamUploadPartSize(long partSize) { - long min = 5 * 1024 * 1024L; - if (partSize < min) { - this.streamUploadPartSize = min; - } else { - this.streamUploadPartSize = partSize; - } - } - - public void setS3Region(String s3Region) { - this.s3Region = s3Region; - } - - public String getS3Region() { - return this.s3Region; - } } \ No newline at end of file diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java index 5c9f331..d51fef8 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java @@ -6,22 +6,26 @@ import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.sync.RequestBody; -import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.transfer.s3.S3TransferManager; -import software.amazon.awssdk.transfer.s3.model.UploadRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.utils.IoUtils; import java.io.IOException; import java.io.InputStream; -import java.util.concurrent.Executors; -import java.util.concurrent.ExecutorService; +import java.util.ArrayList; +import java.util.List; /** * Dao layer to access S3. @@ -31,7 +35,8 @@ public class S3Dao { private final S3Client s3Client; private final ServerSideEncryptionStrategy serverSideEncryptionStrategy; private final ObjectCannedACL objectCannedACL; - private S3AsyncClient s3AsyncClient; + private int streamUploadPartSize = 5 * 1024 * 1024; // Default 5MB + private int streamUploadThreshold = 5 * 1024 * 1024; // Default 5MB public S3Dao(S3Client s3Client) { this.s3Client = s3Client; @@ -45,18 +50,12 @@ public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptio this.objectCannedACL = objectCannedACL; } - public S3Dao(S3Client s3Client, S3AsyncClient s3AsyncClient) { + public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL, int streamUploadThreshold, int streamUploadPartSize) { this.s3Client = s3Client; - this.s3AsyncClient = s3AsyncClient; - this.serverSideEncryptionStrategy = null; - this.objectCannedACL = null; - } - - public S3Dao(S3Client s3Client, S3AsyncClient s3AsyncClient, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL) { - this.s3Client = s3Client; - this.s3AsyncClient = s3AsyncClient; this.serverSideEncryptionStrategy = serverSideEncryptionStrategy; this.objectCannedACL = objectCannedACL; + this.streamUploadThreshold = streamUploadThreshold; + this.streamUploadPartSize = streamUploadPartSize; } public String getTextFromS3(String s3BucketName, String s3Key) { @@ -148,45 +147,110 @@ public void deletePayloadFromS3(String s3BucketName, String s3Key) { /** - * Stores a stream of data to S3 using TransferManager - * Requires S3AsyncClient to be provided in the constructor. + * Stores a stream of data to S3 using multipart upload. + * This method reads the stream in chunks and uploads each part separately, + * which is suitable for large payloads without requiring the entire content in memory. * * @param bucket The S3 bucket name * @param key The S3 object key * @param payloadStream The input stream to upload - * @throws IllegalStateException if S3AsyncClient was not provided in constructor + * @throws SdkException if the upload fails */ public void storeTextStreamInS3(String bucket, String key, InputStream payloadStream) { - if (s3AsyncClient == null) { - throw new IllegalStateException("S3AsyncClient must be provided in constructor for streaming uploads"); - } - - S3TransferManager transferManager = S3TransferManager.builder().s3Client(s3AsyncClient).build(); - ExecutorService executor = Executors.newSingleThreadExecutor(); - + String uploadId = null; + List completedParts = new ArrayList<>(); + try { - UploadRequest.Builder uploadBuilder = UploadRequest.builder() - .putObjectRequest(b -> { - b.bucket(bucket).key(key); - if (objectCannedACL != null) { - b.acl(objectCannedACL); - } - // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.htransferManagerl - if (serverSideEncryptionStrategy != null) { - serverSideEncryptionStrategy.decorate(b); - } - }) - .requestBody(AsyncRequestBody.fromInputStream(payloadStream, null, executor)); - - transferManager.upload(uploadBuilder.build()).completionFuture().join(); - LOG.info("S3 stream object created from InputStream, Bucket name: " + bucket + ", Object key: " + key + "."); - } catch (Exception e) { - String errorMessage = "Failed to store the message content from InputStream in an S3 stream object."; + CreateMultipartUploadRequest.Builder createRequestBuilder = CreateMultipartUploadRequest.builder() + .bucket(bucket) + .key(key); + + if (objectCannedACL != null) { + createRequestBuilder.acl(objectCannedACL); + } + + // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html + if (serverSideEncryptionStrategy != null) { + serverSideEncryptionStrategy.decorate(createRequestBuilder); + } + + CreateMultipartUploadResponse createResponse = s3Client.createMultipartUpload(createRequestBuilder.build()); + uploadId = createResponse.uploadId(); + + byte[] buffer = new byte[streamUploadPartSize]; + int partNumber = 1; + int bytesRead; + + while ((bytesRead = payloadStream.read(buffer)) > 0) { + RequestBody requestBody = RequestBody.fromBytes(bytesRead == buffer.length ? buffer : + java.util.Arrays.copyOf(buffer, bytesRead)); + + UploadPartRequest uploadPartRequest = UploadPartRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .partNumber(partNumber) + .build(); + + UploadPartResponse uploadPartResponse = s3Client.uploadPart(uploadPartRequest, requestBody); + + CompletedPart completedPart = CompletedPart.builder() + .partNumber(partNumber) + .eTag(uploadPartResponse.eTag()) + .build(); + + completedParts.add(completedPart); + partNumber++; + } + + CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder() + .parts(completedParts) + .build(); + + CompleteMultipartUploadRequest completeRequest = CompleteMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .multipartUpload(completedMultipartUpload) + .build(); + + s3Client.completeMultipartUpload(completeRequest); + LOG.info("S3 stream object created from InputStream using multipart upload, Bucket name: " + bucket + ", Object key: " + key + "."); + + } catch (IOException e) { + // Abort the multipart upload if it was initiated + if (uploadId != null) { + try { + AbortMultipartUploadRequest abortRequest = AbortMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .build(); + s3Client.abortMultipartUpload(abortRequest); + } catch (Exception abortException) { + LOG.warn("Failed to abort multipart upload", abortException); + } + } + String errorMessage = "Failed to read from InputStream during multipart upload."; + LOG.error(errorMessage, e); + throw SdkClientException.create(errorMessage, e); + } catch (SdkException e) { + // Abort the multipart upload if it was initiated + if (uploadId != null) { + try { + AbortMultipartUploadRequest abortRequest = AbortMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .build(); + s3Client.abortMultipartUpload(abortRequest); + } catch (Exception abortException) { + LOG.warn("Failed to abort multipart upload", abortException); + } + } + String errorMessage = "Failed to store the message content from InputStream in an S3 object using multipart upload."; LOG.error(errorMessage, e); throw SdkException.create(errorMessage, e); - } finally { - transferManager.close(); - executor.shutdownNow(); } } } diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java index 8aeeee7..af8aa96 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java @@ -111,39 +111,17 @@ public void testStreamUploadEnabled() { assertFalse(payloadStorageConfiguration.isStreamUploadEnabled()); } - @Test - public void testStreamUploadThreshold() { - PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); - - int customThreshold = 10 * 1024 * 1024; // 10MB - payloadStorageConfiguration.setStreamUploadThreshold(customThreshold); - assertEquals(customThreshold, payloadStorageConfiguration.getStreamUploadThreshold()); - } - - @Test - public void testStreamUploadPartSize() { - PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); - - int customPartSize = 10 * 1024 * 1024; // 10MB - payloadStorageConfiguration.setStreamUploadPartSize(customPartSize); - assertEquals(customPartSize, payloadStorageConfiguration.getStreamUploadPartSize()); - } - @Test public void testStreamConfigurationInCopyConstructor() { S3AsyncClient s3Async = mock(S3AsyncClient.class); PayloadStorageAsyncConfiguration original = new PayloadStorageAsyncConfiguration(); original.withPayloadSupportEnabled(s3Async, s3BucketName) - .withStreamUploadEnabled(true) - .withStreamUploadThreshold(10 * 1024 * 1024) - .setStreamUploadPartSize(8 * 1024 * 1024); + .withStreamUploadEnabled(true); PayloadStorageAsyncConfiguration copy = new PayloadStorageAsyncConfiguration(original); assertTrue(copy.isStreamUploadEnabled()); - assertEquals(10 * 1024 * 1024, copy.getStreamUploadThreshold()); - assertEquals(8 * 1024 * 1024, copy.getStreamUploadPartSize()); assertNotSame(copy, original); } } \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java index 1ecccb8..5d19b92 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java @@ -2,23 +2,28 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.ByteArrayInputStream; +import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; @@ -111,4 +116,23 @@ public void deleteTextTest() { verify(s3AsyncClient, times(1)).deleteObject(any(DeleteObjectRequest.class)); } + + @Test + public void getTextStreamFromS3Test() { + dao = new S3AsyncDao(s3AsyncClient); + + byte[] testData = ANY_PAYLOAD.getBytes(StandardCharsets.UTF_8); + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(testData); + GetObjectResponse getObjectResponse = GetObjectResponse.builder().build(); + ResponseInputStream mockResponseStream = + new ResponseInputStream<>(getObjectResponse, byteArrayInputStream); + + when(s3AsyncClient.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(CompletableFuture.completedFuture(mockResponseStream)); + + ResponseInputStream stream = dao.getTextStreamFromS3(S3_BUCKET_NAME, ANY_S3_KEY).join(); + + verify(s3AsyncClient, times(1)).getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class)); + assertNotNull(stream); + } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java index a99a3d1..e0ce2be 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java @@ -2,23 +2,20 @@ import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.S3AsyncClient; -import software.amazon.awssdk.services.s3.model.ObjectCannedACL; +import software.amazon.awssdk.services.s3.model.*; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.ServerSideEncryption; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -32,13 +29,11 @@ public class S3DaoTest { private final ServerSideEncryptionStrategy serverSideEncryptionStrategy = ServerSideEncryptionFactory.awsManagedCmk(); private final ObjectCannedACL objectCannedACL = ObjectCannedACL.PUBLIC_READ; private S3Client s3Client; - private S3AsyncClient s3AsyncClient; private S3Dao dao; @BeforeEach public void setup() { s3Client = mock(S3Client.class); - s3AsyncClient = mock(S3AsyncClient.class); } @Test @@ -84,26 +79,147 @@ public void storeTextInS3WithBothTest() { } @Test - public void storeTextStreamInS3_ThrowsIllegalStateWhenS3AsyncClientNotProvided() { - dao = new S3Dao(s3Client, null, null, null); + public void storeTextStreamInS3WithoutSSEOrCannedTest() { + dao = new S3Dao(s3Client); + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder() + .build(); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeResponse); + + byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8); + InputStream inputStream = new ByteArrayInputStream(testData); + + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + + ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture()); + + assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket()); + assertEquals(ANY_S3_KEY, createCaptor.getValue().key()); + assertNull(createCaptor.getValue().serverSideEncryption()); + assertNull(createCaptor.getValue().acl()); + + verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } + + @Test + public void storeTextStreamInS3WithSSETest() { + dao = new S3Dao(s3Client, serverSideEncryptionStrategy, null); + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder() + .build(); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeResponse); + + byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8); + InputStream inputStream = new ByteArrayInputStream(testData); + + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + + ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture()); + + assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket()); + assertEquals(ANY_S3_KEY, createCaptor.getValue().key()); + assertEquals(ServerSideEncryption.AWS_KMS, createCaptor.getValue().serverSideEncryption()); + assertNull(createCaptor.getValue().acl()); + + verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } + + @Test + public void storeTextStreamInS3WithBothSSEAndCannedACLTest() { + dao = new S3Dao(s3Client, serverSideEncryptionStrategy, objectCannedACL); + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder() + .build(); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeResponse); byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8); InputStream inputStream = new ByteArrayInputStream(testData); - assertThrows(IllegalStateException.class, () -> { - dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); - }); + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + + ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture()); + + assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket()); + assertEquals(ANY_S3_KEY, createCaptor.getValue().key()); + assertEquals(ServerSideEncryption.AWS_KMS, createCaptor.getValue().serverSideEncryption()); + assertEquals(objectCannedACL, createCaptor.getValue().acl()); + + verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); } @Test - public void storeTextStreamInS3_ThrowsExceptionWhenTransferManagerNotProvided() { + public void storeTextStreamInS3WithMultiplePartsTest() { dao = new S3Dao(s3Client); - byte[] testData = "Test data".getBytes(StandardCharsets.UTF_8); + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder() + .build(); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeResponse); + + byte[] testData = new byte[5 * 1024 * 1024 + 1]; InputStream inputStream = new ByteArrayInputStream(testData); - assertThrows(IllegalStateException.class, () -> { - dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); - }); + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + + verify(s3Client, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3Client, times(2)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); } } \ No newline at end of file From dc79c51c89a7073f5953195acdde82b998c41ca0 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Mon, 13 Oct 2025 09:45:48 +0700 Subject: [PATCH 5/5] cleanup --- .../payloadoffloading/PayloadStorageConfiguration.java | 5 ++--- .../payloadoffloading/PayloadStorageConfigurationBase.java | 4 ---- src/main/java/software/amazon/payloadoffloading/S3Dao.java | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java index 06a3041..afab682 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java @@ -4,7 +4,6 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.NotThreadSafe; import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; @@ -36,8 +35,8 @@ public class PayloadStorageConfiguration extends PayloadStorageConfigurationBase private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfiguration.class); private S3Client s3; - private int streamUploadPartSize; - private int streamUploadThreshold; + private int streamUploadPartSize = 5 * 1024 * 1024; + private int streamUploadThreshold = 5 * 1024 * 1024; public PayloadStorageConfiguration() { s3 = null; diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java index bbb3980..c0679ae 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java @@ -23,10 +23,6 @@ public abstract class PayloadStorageConfigurationBase { private boolean payloadSupport = false; // Enable stream upload support (opt-in, default false) private boolean streamUploadEnabled = false; - // Threshold (bytes) above which stream should be attempted when enabled (default 5MB) - private long streamUploadThreshold = 5 * 1024 * 1024L; - // Stream part size (bytes). Each part (except last) must be >=5MB. Default 5MB. - private long streamUploadPartSize = 5 * 1024 * 1024L; /** * This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS. */ diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java index d51fef8..2cfe158 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java @@ -35,8 +35,8 @@ public class S3Dao { private final S3Client s3Client; private final ServerSideEncryptionStrategy serverSideEncryptionStrategy; private final ObjectCannedACL objectCannedACL; - private int streamUploadPartSize = 5 * 1024 * 1024; // Default 5MB - private int streamUploadThreshold = 5 * 1024 * 1024; // Default 5MB + private int streamUploadPartSize = 5 * 1024 * 1024; + private int streamUploadThreshold = 5 * 1024 * 1024; public S3Dao(S3Client s3Client) { this.s3Client = s3Client;