diff --git a/pom.xml b/pom.xml index 0cea9a4..15052fb 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ software.amazon.payloadoffloading payloadoffloading-common - 2.2.0 + 2.3.0 jar Payload offloading common library for AWS Common library between extended Amazon AWS clients to save payloads up to 2GB on Amazon S3. @@ -36,13 +36,13 @@ - 2.20.130 + 2.27.14 software.amazon.awssdk - s3 + s3-transfer-manager ${aws-java-sdk.version} diff --git a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java index ae291f4..eaaa49f 100644 --- a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java +++ b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java @@ -1,6 +1,7 @@ package software.amazon.payloadoffloading; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; public class AwsManagedCmk implements ServerSideEncryptionStrategy { @@ -8,4 +9,9 @@ public class AwsManagedCmk implements ServerSideEncryptionStrategy { public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) { putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); } + + @Override + public void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder) { + createStreamUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + } } diff --git a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java index 7f62d49..9a2db93 100644 --- a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java +++ b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java @@ -1,6 +1,7 @@ package software.amazon.payloadoffloading; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; public class CustomerKey implements ServerSideEncryptionStrategy { @@ -15,4 +16,10 @@ public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) { putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); putObjectRequestBuilder.ssekmsKeyId(awsKmsKeyId); } + + @Override + public void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder) { + createStreamUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + createStreamUploadRequestBuilder.ssekmsKeyId(awsKmsKeyId); + } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java index 3bd8d08..832053c 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java @@ -150,4 +150,14 @@ public PayloadStorageAsyncConfiguration withObjectCannedACL(ObjectCannedACL obje setObjectCannedACL(objectCannedACL); return this; } + + /** + * Enables or disables stream upload support. + * @param enabled true to enable stream uploads when threshold exceeded. + * @return updated configuration + */ + public PayloadStorageAsyncConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); + return this; + } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java index 9ab3c10..afab682 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java @@ -35,6 +35,8 @@ public class PayloadStorageConfiguration extends PayloadStorageConfigurationBase private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfiguration.class); private S3Client s3; + private int streamUploadPartSize = 5 * 1024 * 1024; + private int streamUploadThreshold = 5 * 1024 * 1024; public PayloadStorageConfiguration() { s3 = null; @@ -43,6 +45,8 @@ public PayloadStorageConfiguration() { public PayloadStorageConfiguration(PayloadStorageConfiguration other) { super(other); this.s3 = other.getS3Client(); + this.streamUploadThreshold = other.getStreamUploadThreshold(); + this.streamUploadPartSize = other.getStreamUploadPartSize(); } /** @@ -148,4 +152,68 @@ public PayloadStorageConfiguration withObjectCannedACL(ObjectCannedACL objectCan setObjectCannedACL(objectCannedACL); return this; } + + /** + * Enables or disables stream upload support. + * @param enabled true to enable stream uploads when threshold exceeded. + * @return updated configuration + */ + public PayloadStorageConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); + return this; + } + + /** + * Sets the stream upload threshold (in bytes). Only used when stream upload is enabled. + * @param threshold threshold in bytes (must be >0) otherwise default (5MB) is applied. + * @return updated configuration + */ + public PayloadStorageConfiguration withStreamUploadThreshold(int threshold) { + setStreamUploadThreshold(threshold); + return this; + } + + public PayloadStorageConfiguration withStreamUploadPartSize(int partSize) { + setStreamUploadPartSize(partSize); + return this; + } + + + /** + * Gets the stream upload threshold in bytes. Default 5MB. + * + * @return threshold in bytes. + */ + public int getStreamUploadThreshold() { return streamUploadThreshold; } + + /** + * Sets the stream upload threshold in bytes. Values less than or equal to zero will reset to default (5MB). + * + * @param streamUploadThreshold threshold in bytes + */ + public void setStreamUploadThreshold(int streamUploadThreshold) { + int min = 5 * 1024 * 1024; + if (streamUploadThreshold <= min) { + this.streamUploadThreshold = min; + } else { + this.streamUploadThreshold = streamUploadThreshold; + } + } + + /** + * Gets the configured stream upload part size (bytes). Default 5MB. + */ + public int getStreamUploadPartSize() { return streamUploadPartSize; } + + /** + * Sets the stream upload part size (bytes). Values < 5MB are rounded up to 5MB. + */ + public void setStreamUploadPartSize(int partSize) { + int min = 5 * 1024 * 1024; + if (partSize < min) { + this.streamUploadPartSize = min; + } else { + this.streamUploadPartSize = partSize; + } + } } diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java index 7d08746..c0679ae 100644 --- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java +++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java @@ -4,7 +4,6 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.annotations.NotThreadSafe; import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; /** @@ -22,6 +21,8 @@ public abstract class PayloadStorageConfigurationBase { private int payloadSizeThreshold = 0; private boolean alwaysThroughS3 = false; private boolean payloadSupport = false; + // Enable stream upload support (opt-in, default false) + private boolean streamUploadEnabled = false; /** * This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS. */ @@ -44,6 +45,7 @@ public PayloadStorageConfigurationBase(PayloadStorageConfigurationBase other) { this.payloadSizeThreshold = other.getPayloadSizeThreshold(); this.serverSideEncryptionStrategy = other.getServerSideEncryptionStrategy(); this.objectCannedACL = other.getObjectCannedACL(); + this.streamUploadEnabled = other.isStreamUploadEnabled(); } /** @@ -175,4 +177,19 @@ public boolean isObjectCannedACLDefined() { public ObjectCannedACL getObjectCannedACL() { return objectCannedACL; } + + /** + * Checks whether stream upload support is enabled. Default: false. + * + * @return true if stream upload support is enabled. + */ + public boolean isStreamUploadEnabled() { return streamUploadEnabled; } + + /** + * Enable or disable stream upload support. When enabling, callers should ensure they provide a PayloadStore + * implementation capable of stream operations; otherwise normal single PUT behavior will be used. + * + * @param streamUploadEnabled flag to enable/disable stream support. + */ + public void setStreamUploadEnabled(boolean streamUploadEnabled) { this.streamUploadEnabled = streamUploadEnabled; } } \ No newline at end of file diff --git a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java index a0dc868..ea7b5e8 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java @@ -1,11 +1,13 @@ package software.amazon.payloadoffloading; import java.io.UncheckedIOException; +import java.io.InputStream; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.exception.SdkClientException; @@ -13,8 +15,14 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.transfer.s3.S3TransferManager; +import software.amazon.awssdk.transfer.s3.model.UploadRequest; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Dao layer to access S3. @@ -65,6 +73,28 @@ public CompletableFuture getTextFromS3(String s3BucketName, String s3Key }); } + public CompletableFuture> getTextStreamFromS3(String s3BucketName, String s3Key) { + GetObjectRequest getObjectRequest = GetObjectRequest.builder() + .bucket(s3BucketName) + .key(s3Key) + .build(); + + return s3Client.getObject(getObjectRequest, AsyncResponseTransformer.toBlockingInputStream()) + .handle((v, tIn) -> { + if (tIn != null) { + Throwable t = Util.unwrapFutureException(tIn); + if (t instanceof SdkException) { + String errorMessage = "Failed to get the S3 object stream which contains the payload."; + LOG.error(errorMessage, t); + throw SdkException.create(errorMessage, t); + } + throw new CompletionException(t); + } + LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key); + return v; + }); + } + public CompletableFuture storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr) { PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() .bucket(s3BucketName) @@ -115,4 +145,54 @@ public CompletableFuture deletePayloadFromS3(String s3BucketName, String s return null; }); } + + + /** + * Stores a stream of data to S3 using TransferManager for efficient multipart uploads. + * + * @param bucket The S3 bucket name + * @param key The S3 object key + * @param payloadStream The input stream to upload + * @return CompletableFuture that completes when upload is finished + */ + public CompletableFuture storeTextStreamInS3(String bucket, String key, InputStream payloadStream) { + S3TransferManager transferManager = S3TransferManager.create(); + ExecutorService executor = Executors.newSingleThreadExecutor(); + + try { + UploadRequest.Builder uploadBuilder = UploadRequest.builder() + .putObjectRequest(b -> { + b.bucket(bucket).key(key); + if (objectCannedACL != null) { + b.acl(objectCannedACL); + } + // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html + if (serverSideEncryptionStrategy != null) { + serverSideEncryptionStrategy.decorate(b); + } + }) + .requestBody(AsyncRequestBody.fromInputStream(payloadStream, null, executor)); + + return transferManager.upload(uploadBuilder.build()).completionFuture() + .thenApply(completedUpload -> { + LOG.info("S3 stream object created from InputStream, Bucket name: " + bucket + ", Object key: " + key + "."); + return (Void) null; + }) + .exceptionally(t -> { + Throwable cause = Util.unwrapFutureException(t); + if (cause instanceof SdkException) { + String errorMessage = "Failed to store the message content from InputStream in an S3 stream object."; + LOG.error(errorMessage, cause); + throw SdkException.create(errorMessage, cause); + } + throw new CompletionException(cause); + }); + } catch (UnsupportedOperationException e) { + LOG.warn("TransferManager creation disabled, cannot perform streaming upload: " + e.getMessage()); + throw new CompletionException(e); + } finally { + transferManager.close(); + executor.shutdownNow(); + } + } } diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java new file mode 100644 index 0000000..cd3bc09 --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java @@ -0,0 +1,37 @@ +package software.amazon.payloadoffloading; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import java.io.InputStream; + +public class S3BackedStreamPayloadStore extends S3BackedPayloadStore implements StreamPayloadStore { + private static final Logger LOG = LoggerFactory.getLogger(S3BackedStreamPayloadStore.class); + private final S3Dao s3Dao; + private final String s3BucketName; + + public S3BackedStreamPayloadStore(S3Dao s3Dao, String s3BucketName) { + super(s3Dao, s3BucketName); + this.s3Dao = s3Dao; + this.s3BucketName = s3BucketName; + } + + @Override + public String storeOriginalPayloadStream(InputStream payloadStream, String s3Key) { + s3Dao.storeTextStreamInS3(s3BucketName, s3Key, payloadStream); + LOG.info("S3 stream object created from InputStream, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); + return new PayloadS3Pointer(s3BucketName, s3Key).toJson(); + } + + @Override + public ResponseInputStream getOriginalPayloadStreamStream(String payloadPointer) { + PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(payloadPointer); + String s3BucketName = s3Pointer.getS3BucketName(); + String s3Key = s3Pointer.getS3Key(); + + ResponseInputStream stream = s3Dao.getTextStreamFromS3(s3BucketName, s3Key); + LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key); + return stream; + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java new file mode 100644 index 0000000..7a75e5b --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java @@ -0,0 +1,45 @@ +package software.amazon.payloadoffloading; + +import java.io.InputStream; +import java.util.concurrent.CompletableFuture; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ObjectCannedACL; + +public class S3BackedStreamPayloadStoreAsync extends S3BackedPayloadStoreAsync implements StreamPayloadStoreAsync { + private static final Logger LOG = LoggerFactory.getLogger(S3BackedStreamPayloadStoreAsync.class); + private final S3AsyncDao streamDao; + private final String s3BucketName; + + public S3BackedStreamPayloadStoreAsync(S3AsyncDao s3Dao, String s3BucketName) { + super(s3Dao, s3BucketName); + this.streamDao = s3Dao; + this.s3BucketName = s3BucketName; + } + + @Override + public CompletableFuture storeOriginalPayloadStream(InputStream payloadStream, String s3Key) { + return streamDao.storeTextStreamInS3(s3BucketName, s3Key, payloadStream) + .thenApply(v -> { + LOG.info("S3 stream object created from InputStream, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); + return new PayloadS3Pointer(s3BucketName, s3Key).toJson(); + }); + } + + @Override + public CompletableFuture> getOriginalPayloadStreamStream(String payloadPointer) { + PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(payloadPointer); + String s3BucketName = s3Pointer.getS3BucketName(); + String s3Key = s3Pointer.getS3Key(); + + return streamDao.getTextStreamFromS3(s3BucketName, s3Key) + .thenApply(stream -> { + LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key); + return stream; + }); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java index 2b03dd5..2cfe158 100644 --- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java +++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java @@ -12,9 +12,20 @@ import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.utils.IoUtils; import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; /** * Dao layer to access S3. @@ -24,9 +35,13 @@ public class S3Dao { private final S3Client s3Client; private final ServerSideEncryptionStrategy serverSideEncryptionStrategy; private final ObjectCannedACL objectCannedACL; + private int streamUploadPartSize = 5 * 1024 * 1024; + private int streamUploadThreshold = 5 * 1024 * 1024; public S3Dao(S3Client s3Client) { - this(s3Client, null, null); + this.s3Client = s3Client; + this.serverSideEncryptionStrategy = null; + this.objectCannedACL = null; } public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL) { @@ -35,6 +50,14 @@ public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptio this.objectCannedACL = objectCannedACL; } + public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL, int streamUploadThreshold, int streamUploadPartSize) { + this.s3Client = s3Client; + this.serverSideEncryptionStrategy = serverSideEncryptionStrategy; + this.objectCannedACL = objectCannedACL; + this.streamUploadThreshold = streamUploadThreshold; + this.streamUploadPartSize = streamUploadPartSize; + } + public String getTextFromS3(String s3BucketName, String s3Key) { GetObjectRequest getObjectRequest = GetObjectRequest.builder() .bucket(s3BucketName) @@ -65,6 +88,23 @@ public String getTextFromS3(String s3BucketName, String s3Key) { return embeddedText; } + public ResponseInputStream getTextStreamFromS3(String s3BucketName, String s3Key) { + GetObjectRequest getObjectRequest = GetObjectRequest.builder() + .bucket(s3BucketName) + .key(s3Key) + .build(); + + try { + ResponseInputStream object = s3Client.getObject(getObjectRequest); + LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key); + return object; + } catch (SdkException e) { + String errorMessage = "Failed to get the S3 object stream which contains the payload."; + LOG.error(errorMessage, e); + throw SdkException.create(errorMessage, e); + } + } + public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr) { PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() .bucket(s3BucketName) @@ -104,4 +144,113 @@ public void deletePayloadFromS3(String s3BucketName, String s3Key) { LOG.info("S3 object deleted, Bucket name: " + s3BucketName + ", Object key: " + s3Key + "."); } + + + /** + * Stores a stream of data to S3 using multipart upload. + * This method reads the stream in chunks and uploads each part separately, + * which is suitable for large payloads without requiring the entire content in memory. + * + * @param bucket The S3 bucket name + * @param key The S3 object key + * @param payloadStream The input stream to upload + * @throws SdkException if the upload fails + */ + public void storeTextStreamInS3(String bucket, String key, InputStream payloadStream) { + String uploadId = null; + List completedParts = new ArrayList<>(); + + try { + CreateMultipartUploadRequest.Builder createRequestBuilder = CreateMultipartUploadRequest.builder() + .bucket(bucket) + .key(key); + + if (objectCannedACL != null) { + createRequestBuilder.acl(objectCannedACL); + } + + // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html + if (serverSideEncryptionStrategy != null) { + serverSideEncryptionStrategy.decorate(createRequestBuilder); + } + + CreateMultipartUploadResponse createResponse = s3Client.createMultipartUpload(createRequestBuilder.build()); + uploadId = createResponse.uploadId(); + + byte[] buffer = new byte[streamUploadPartSize]; + int partNumber = 1; + int bytesRead; + + while ((bytesRead = payloadStream.read(buffer)) > 0) { + RequestBody requestBody = RequestBody.fromBytes(bytesRead == buffer.length ? buffer : + java.util.Arrays.copyOf(buffer, bytesRead)); + + UploadPartRequest uploadPartRequest = UploadPartRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .partNumber(partNumber) + .build(); + + UploadPartResponse uploadPartResponse = s3Client.uploadPart(uploadPartRequest, requestBody); + + CompletedPart completedPart = CompletedPart.builder() + .partNumber(partNumber) + .eTag(uploadPartResponse.eTag()) + .build(); + + completedParts.add(completedPart); + partNumber++; + } + + CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder() + .parts(completedParts) + .build(); + + CompleteMultipartUploadRequest completeRequest = CompleteMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .multipartUpload(completedMultipartUpload) + .build(); + + s3Client.completeMultipartUpload(completeRequest); + LOG.info("S3 stream object created from InputStream using multipart upload, Bucket name: " + bucket + ", Object key: " + key + "."); + + } catch (IOException e) { + // Abort the multipart upload if it was initiated + if (uploadId != null) { + try { + AbortMultipartUploadRequest abortRequest = AbortMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .build(); + s3Client.abortMultipartUpload(abortRequest); + } catch (Exception abortException) { + LOG.warn("Failed to abort multipart upload", abortException); + } + } + String errorMessage = "Failed to read from InputStream during multipart upload."; + LOG.error(errorMessage, e); + throw SdkClientException.create(errorMessage, e); + } catch (SdkException e) { + // Abort the multipart upload if it was initiated + if (uploadId != null) { + try { + AbortMultipartUploadRequest abortRequest = AbortMultipartUploadRequest.builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .build(); + s3Client.abortMultipartUpload(abortRequest); + } catch (Exception abortException) { + LOG.warn("Failed to abort multipart upload", abortException); + } + } + String errorMessage = "Failed to store the message content from InputStream in an S3 object using multipart upload."; + LOG.error(errorMessage, e); + throw SdkException.create(errorMessage, e); + } + } } diff --git a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java index f385ce6..43aced1 100644 --- a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java +++ b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java @@ -1,7 +1,9 @@ package software.amazon.payloadoffloading; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; public interface ServerSideEncryptionStrategy { void decorate(PutObjectRequest.Builder putObjectRequestBuilder); + void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder); } diff --git a/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java new file mode 100644 index 0000000..ba3ac0d --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java @@ -0,0 +1,33 @@ +package software.amazon.payloadoffloading; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import java.io.InputStream; + +/** + * Optional extension interface for {@link PayloadStore} implementations that can perform stream + * upload and streaming retrieval for very large payloads to reduce memory pressure. + */ +public interface StreamPayloadStore extends PayloadStore { + + /** + * Store a payload from an InputStream using streaming-aware logic to avoid loading large payloads into memory. + * @param payloadStream InputStream containing UTF-8 text payload + * @param s3Key pre-generated S3 key (or equivalent object key) to use + * @return pointer string (JSON) referencing stored payload + * @throws UnsupportedOperationException if streaming upload from InputStream is not implemented + */ + default String storeOriginalPayloadStream(InputStream payloadStream, String s3Key) { + throw new UnsupportedOperationException("Streaming upload from InputStream not implemented"); + } + + /** + * Retrieve a payload as a stream to avoid loading large payloads into memory. + * @param payloadPointer pointer JSON string + * @return stream containing the original payload content + * @throws UnsupportedOperationException if streaming retrieval is not implemented + */ + default ResponseInputStream getOriginalPayloadStreamStream(String payloadPointer) { + throw new UnsupportedOperationException("Streaming retrieval not implemented"); + } +} diff --git a/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java new file mode 100644 index 0000000..0333eec --- /dev/null +++ b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java @@ -0,0 +1,34 @@ +package software.amazon.payloadoffloading; + +import java.io.InputStream; +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +/** + * Optional extension interface for {@link PayloadStoreAsync} implementations that can perform stream + * upload and streaming retrieval for very large payloads to reduce memory pressure. + */ +public interface StreamPayloadStoreAsync extends PayloadStoreAsync { + + /** + * Store payload content from an InputStream using stream semantics to avoid loading large payloads into memory. + * @param payloadStream InputStream containing UTF-8 text payload + * @param s3Key object key to use + * @return future pointer string (JSON) referencing stored payload + * @throws UnsupportedOperationException if streaming upload from InputStream is not implemented + */ + default CompletableFuture storeOriginalPayloadStream(InputStream payloadStream, String s3Key) { + throw new UnsupportedOperationException("Streaming upload from InputStream not implemented"); + } + + /** + * Retrieve payload as a stream to avoid loading large payloads into memory. + * @param payloadPointer pointer JSON string + * @return future stream containing the original payload content + * @throws UnsupportedOperationException if streaming retrieval is not implemented + */ + default CompletableFuture> getOriginalPayloadStreamStream(String payloadPointer) { + throw new UnsupportedOperationException("Streaming retrieval not implemented"); + } +} diff --git a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java index 723678d..8a7c663 100644 --- a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java +++ b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java @@ -3,6 +3,7 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -19,4 +20,16 @@ public void testAwsManagedCmkStrategySetsCorrectEncryptionValues() { assertEquals(putObjectRequest.serverSideEncryption(), (ServerSideEncryption.AWS_KMS)); } + + + @Test + public void testAwsManagedCmkStrategyStreamSetsCorrectEncryptionValues() { + AwsManagedCmk awsManagedCmk = new AwsManagedCmk(); + + CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder(); + awsManagedCmk.decorate(createMultipartUploadRequestBuilder); + CreateMultipartUploadRequest createMultipartUploadRequest = createMultipartUploadRequestBuilder.build(); + + assertEquals(createMultipartUploadRequest.serverSideEncryption(), (ServerSideEncryption.AWS_KMS)); + } } \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java index cc55b70..0474d30 100644 --- a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java +++ b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java @@ -3,6 +3,7 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -21,4 +22,16 @@ public void testCustomerKeyStrategySetsCorrectEncryptionValues() { assertEquals(putObjectRequest.serverSideEncryption(), ServerSideEncryption.AWS_KMS); assertEquals(putObjectRequest.ssekmsKeyId(), AWS_KMS_KEY_ID); } + + @Test + public void testCustomerKeyStrategySetsStreamUploadEncryptionValues() { + CustomerKey customerKey = new CustomerKey(AWS_KMS_KEY_ID); + + CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder(); + customerKey.decorate(createMultipartUploadRequestBuilder); + CreateMultipartUploadRequest createMultipartUploadRequest = createMultipartUploadRequestBuilder.build(); + + assertEquals(createMultipartUploadRequest.serverSideEncryption(), ServerSideEncryption.AWS_KMS); + assertEquals(createMultipartUploadRequest.ssekmsKeyId(), AWS_KMS_KEY_ID); + } } \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java index 30cf188..af8aa96 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java @@ -99,4 +99,29 @@ public void testCannedAccessControlList() { assertTrue(payloadStorageConfiguration.isObjectCannedACLDefined()); assertEquals(objectCannelACL, payloadStorageConfiguration.getObjectCannedACL()); } + + @Test + public void testStreamUploadEnabled() { + PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration(); + + payloadStorageConfiguration.setStreamUploadEnabled(true); + assertTrue(payloadStorageConfiguration.isStreamUploadEnabled()); + + payloadStorageConfiguration.setStreamUploadEnabled(false); + assertFalse(payloadStorageConfiguration.isStreamUploadEnabled()); + } + + @Test + public void testStreamConfigurationInCopyConstructor() { + S3AsyncClient s3Async = mock(S3AsyncClient.class); + + PayloadStorageAsyncConfiguration original = new PayloadStorageAsyncConfiguration(); + original.withPayloadSupportEnabled(s3Async, s3BucketName) + .withStreamUploadEnabled(true); + + PayloadStorageAsyncConfiguration copy = new PayloadStorageAsyncConfiguration(original); + + assertTrue(copy.isStreamUploadEnabled()); + assertNotSame(copy, original); + } } \ No newline at end of file diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java index f4adf77..d233826 100644 --- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java +++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java @@ -99,4 +99,51 @@ public void testCannedAccessControlList() { assertTrue(payloadStorageConfiguration.isObjectCannedACLDefined()); assertEquals(objectCannelACL, payloadStorageConfiguration.getObjectCannedACL()); } + + @Test + public void testStreamUploadEnabled() { + PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); + + payloadStorageConfiguration.setStreamUploadEnabled(true); + assertTrue(payloadStorageConfiguration.isStreamUploadEnabled()); + + payloadStorageConfiguration.setStreamUploadEnabled(false); + assertFalse(payloadStorageConfiguration.isStreamUploadEnabled()); + } + + @Test + public void testStreamUploadThreshold() { + PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); + + int customThreshold = 10 * 1024 * 1024; // 10MB + payloadStorageConfiguration.setStreamUploadThreshold(customThreshold); + assertEquals(customThreshold, payloadStorageConfiguration.getStreamUploadThreshold()); + } + + @Test + public void testStreamUploadPartSize() { + PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration(); + + int customPartSize = 10 * 1024 * 1024; // 10MB + payloadStorageConfiguration.setStreamUploadPartSize(customPartSize); + assertEquals(customPartSize, payloadStorageConfiguration.getStreamUploadPartSize()); + } + + @Test + public void testStreamConfigurationInCopyConstructor() { + S3Client s3 = mock(S3Client.class); + + PayloadStorageConfiguration original = new PayloadStorageConfiguration(); + original.withPayloadSupportEnabled(s3, s3BucketName) + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(10 * 1024 * 1024) + .setStreamUploadPartSize(8 * 1024 * 1024); + + PayloadStorageConfiguration copy = new PayloadStorageConfiguration(original); + + assertTrue(copy.isStreamUploadEnabled()); + assertEquals(10 * 1024 * 1024, copy.getStreamUploadThreshold()); + assertEquals(8 * 1024 * 1024, copy.getStreamUploadPartSize()); + assertNotSame(copy, original); + } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java index 1ecccb8..5d19b92 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java @@ -2,23 +2,28 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.io.ByteArrayInputStream; +import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectCannedACL; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; @@ -111,4 +116,23 @@ public void deleteTextTest() { verify(s3AsyncClient, times(1)).deleteObject(any(DeleteObjectRequest.class)); } + + @Test + public void getTextStreamFromS3Test() { + dao = new S3AsyncDao(s3AsyncClient); + + byte[] testData = ANY_PAYLOAD.getBytes(StandardCharsets.UTF_8); + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(testData); + GetObjectResponse getObjectResponse = GetObjectResponse.builder().build(); + ResponseInputStream mockResponseStream = + new ResponseInputStream<>(getObjectResponse, byteArrayInputStream); + + when(s3AsyncClient.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(CompletableFuture.completedFuture(mockResponseStream)); + + ResponseInputStream stream = dao.getTextStreamFromS3(S3_BUCKET_NAME, ANY_S3_KEY).join(); + + verify(s3AsyncClient, times(1)).getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class)); + assertNotNull(stream); + } } diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java new file mode 100644 index 0000000..1edc7f4 --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java @@ -0,0 +1,111 @@ +package software.amazon.payloadoffloading; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.exception.SdkException; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +public class S3BackedStreamPayloadStoreAsyncTest { + private static final String S3_BUCKET_NAME = "test-bucket-name"; + private static final String ANY_PAYLOAD = "AnyPayload"; + private static final String ANY_S3_KEY = "AnyS3key"; + + private S3BackedStreamPayloadStoreAsync streamPayloadStore; + private S3AsyncDao s3AsyncDao; + + @BeforeEach + public void setup() { + s3AsyncDao = mock(S3AsyncDao.class); + streamPayloadStore = new S3BackedStreamPayloadStoreAsync(s3AsyncDao, S3_BUCKET_NAME); + } + + @Test + public void testStoreOriginalPayloadStreamOnSuccess() { + when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join(); + + ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(java.io.InputStream.class); + + verify(s3AsyncDao, times(1)).storeTextStreamInS3( + bucketCaptor.capture(), + keyCaptor.capture(), + payloadCaptor.capture() + ); + + assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue()); + assertEquals(ANY_S3_KEY, keyCaptor.getValue()); + assertNotNull(payloadCaptor.getValue()); + + PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); + assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); + } + + @Test + public void testStoreOriginalPayloadStreamWithCustomPartSize() { + String payload = generateString(12 * 1024); + String s3Key = "custom-size-key"; + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + streamPayloadStore.storeOriginalPayloadStream(payloadStream, s3Key).join(); + + verify(s3AsyncDao, times(1)).storeTextStreamInS3( + eq(S3_BUCKET_NAME), + eq(s3Key), + any(java.io.InputStream.class) + ); + } + + @Test + public void testStoreOriginalPayloadStreamOnS3Failure() { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(SdkException.create("S3 Exception", new Throwable())); + + when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class))) + .thenReturn(failedFuture); + + CompletionException exception = assertThrows(CompletionException.class, () -> { + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join(); + }); + + assertTrue(exception.getMessage().contains("S3 Exception")); + } + + @Test + public void testStoreOriginalPayloadStreamHandlesNullFromDao() { + when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class))) + .thenReturn(CompletableFuture.completedFuture(null)); + + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join(); + + assertNotNull(actualPayloadPointer); + PayloadS3Pointer pointer = PayloadS3Pointer.fromJson(actualPayloadPointer); + assertEquals(S3_BUCKET_NAME, pointer.getS3BucketName()); + assertEquals(ANY_S3_KEY, pointer.getS3Key()); + } + + private String generateString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('a'); + } + return sb.toString(); + } +} diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java new file mode 100644 index 0000000..4679778 --- /dev/null +++ b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java @@ -0,0 +1,88 @@ +package software.amazon.payloadoffloading; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.exception.SdkException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +public class S3BackedStreamPayloadStoreTest { + private static final String S3_BUCKET_NAME = "test-bucket-name"; + private static final String ANY_PAYLOAD = "AnyPayload"; + private static final String ANY_S3_KEY = "AnyS3key"; + + private S3BackedStreamPayloadStore streamPayloadStore; + private S3Dao s3Dao; + + @BeforeEach + public void setup() { + s3Dao = mock(S3Dao.class); + streamPayloadStore = new S3BackedStreamPayloadStore(s3Dao, S3_BUCKET_NAME); + } + + @Test + public void testStoreOriginalPayloadStreamOnSuccess() { + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY); + + ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(java.io.InputStream.class); + + verify(s3Dao, times(1)).storeTextStreamInS3( + bucketCaptor.capture(), + keyCaptor.capture(), + payloadCaptor.capture() + ); + + assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue()); + assertEquals(ANY_S3_KEY, keyCaptor.getValue()); + assertNotNull(payloadCaptor.getValue()); + + PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY); + assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer); + } + + @Test + public void testStoreOriginalPayloadStreamWithCustomPartSize() { + String payload = generateString(12 * 1024); + String s3Key = "custom-size-key"; + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + streamPayloadStore.storeOriginalPayloadStream(payloadStream, s3Key); + + verify(s3Dao, times(1)).storeTextStreamInS3( + eq(S3_BUCKET_NAME), + eq(s3Key), + any(java.io.InputStream.class) + ); + } + + @Test + public void testStoreOriginalPayloadStreamOnS3Failure() { + doThrow(SdkException.create("S3 Exception", new Throwable())) + .when(s3Dao) + .storeTextStreamInS3( + any(String.class), + any(String.class), + any(java.io.InputStream.class) + ); + + assertThrows(SdkException.class, () -> { + java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY); + }, "S3 Exception"); + } + + private String generateString(int length) { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) { + sb.append('a'); + } + return sb.toString(); + } +} diff --git a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java index b0a8f25..e0ce2be 100644 --- a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java +++ b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java @@ -2,14 +2,12 @@ import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.ObjectCannedACL; +import software.amazon.awssdk.services.s3.model.*; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.ServerSideEncryption; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; @@ -17,10 +15,14 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; public class S3DaoTest { - private static final String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id"; private static final String S3_BUCKET_NAME = "test-bucket-name"; private static final String ANY_PAYLOAD = "AnyPayload"; private static final String ANY_S3_KEY = "AnyS3key"; @@ -75,4 +77,149 @@ public void storeTextInS3WithBothTest() { assertEquals(objectCannedACL, argument.getValue().acl()); assertEquals(S3_BUCKET_NAME, argument.getValue().bucket()); } + + @Test + public void storeTextStreamInS3WithoutSSEOrCannedTest() { + dao = new S3Dao(s3Client); + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder() + .build(); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeResponse); + + byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8); + InputStream inputStream = new ByteArrayInputStream(testData); + + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + + ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture()); + + assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket()); + assertEquals(ANY_S3_KEY, createCaptor.getValue().key()); + assertNull(createCaptor.getValue().serverSideEncryption()); + assertNull(createCaptor.getValue().acl()); + + verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } + + @Test + public void storeTextStreamInS3WithSSETest() { + dao = new S3Dao(s3Client, serverSideEncryptionStrategy, null); + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder() + .build(); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeResponse); + + byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8); + InputStream inputStream = new ByteArrayInputStream(testData); + + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + + ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture()); + + assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket()); + assertEquals(ANY_S3_KEY, createCaptor.getValue().key()); + assertEquals(ServerSideEncryption.AWS_KMS, createCaptor.getValue().serverSideEncryption()); + assertNull(createCaptor.getValue().acl()); + + verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } + + @Test + public void storeTextStreamInS3WithBothSSEAndCannedACLTest() { + dao = new S3Dao(s3Client, serverSideEncryptionStrategy, objectCannedACL); + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder() + .build(); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeResponse); + + byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8); + InputStream inputStream = new ByteArrayInputStream(testData); + + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + + ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class); + verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture()); + + assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket()); + assertEquals(ANY_S3_KEY, createCaptor.getValue().key()); + assertEquals(ServerSideEncryption.AWS_KMS, createCaptor.getValue().serverSideEncryption()); + assertEquals(objectCannedACL, createCaptor.getValue().acl()); + + verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } + + @Test + public void storeTextStreamInS3WithMultiplePartsTest() { + dao = new S3Dao(s3Client); + + CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder() + .uploadId("test-upload-id") + .build(); + when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class))) + .thenReturn(createResponse); + + UploadPartResponse uploadPartResponse = UploadPartResponse.builder() + .eTag("test-etag") + .build(); + when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))) + .thenReturn(uploadPartResponse); + + CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder() + .build(); + when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))) + .thenReturn(completeResponse); + + byte[] testData = new byte[5 * 1024 * 1024 + 1]; + InputStream inputStream = new ByteArrayInputStream(testData); + + dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream); + + verify(s3Client, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3Client, times(2)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class)); + verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + } } \ No newline at end of file