diff --git a/pom.xml b/pom.xml
index 0cea9a4..15052fb 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
software.amazon.payloadoffloading
payloadoffloading-common
- 2.2.0
+ 2.3.0
jar
Payload offloading common library for AWS
Common library between extended Amazon AWS clients to save payloads up to 2GB on Amazon S3.
@@ -36,13 +36,13 @@
- 2.20.130
+ 2.27.14
software.amazon.awssdk
- s3
+ s3-transfer-manager
${aws-java-sdk.version}
diff --git a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java
index ae291f4..eaaa49f 100644
--- a/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java
+++ b/src/main/java/software/amazon/payloadoffloading/AwsManagedCmk.java
@@ -1,6 +1,7 @@
package software.amazon.payloadoffloading;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
public class AwsManagedCmk implements ServerSideEncryptionStrategy {
@@ -8,4 +9,9 @@ public class AwsManagedCmk implements ServerSideEncryptionStrategy {
public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) {
putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS);
}
+
+ @Override
+ public void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder) {
+ createStreamUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS);
+ }
}
diff --git a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java
index 7f62d49..9a2db93 100644
--- a/src/main/java/software/amazon/payloadoffloading/CustomerKey.java
+++ b/src/main/java/software/amazon/payloadoffloading/CustomerKey.java
@@ -1,6 +1,7 @@
package software.amazon.payloadoffloading;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
public class CustomerKey implements ServerSideEncryptionStrategy {
@@ -15,4 +16,10 @@ public void decorate(PutObjectRequest.Builder putObjectRequestBuilder) {
putObjectRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS);
putObjectRequestBuilder.ssekmsKeyId(awsKmsKeyId);
}
+
+ @Override
+ public void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder) {
+ createStreamUploadRequestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS);
+ createStreamUploadRequestBuilder.ssekmsKeyId(awsKmsKeyId);
+ }
}
diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java
index 3bd8d08..832053c 100644
--- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java
+++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfiguration.java
@@ -150,4 +150,14 @@ public PayloadStorageAsyncConfiguration withObjectCannedACL(ObjectCannedACL obje
setObjectCannedACL(objectCannedACL);
return this;
}
+
+ /**
+ * Enables or disables stream upload support.
+ * @param enabled true to enable stream uploads when threshold exceeded.
+ * @return updated configuration
+ */
+ public PayloadStorageAsyncConfiguration withStreamUploadEnabled(boolean enabled) {
+ setStreamUploadEnabled(enabled);
+ return this;
+ }
}
diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java
index 9ab3c10..afab682 100644
--- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java
+++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfiguration.java
@@ -35,6 +35,8 @@ public class PayloadStorageConfiguration extends PayloadStorageConfigurationBase
private static final Logger LOG = LoggerFactory.getLogger(PayloadStorageConfiguration.class);
private S3Client s3;
+ private int streamUploadPartSize = 5 * 1024 * 1024;
+ private int streamUploadThreshold = 5 * 1024 * 1024;
public PayloadStorageConfiguration() {
s3 = null;
@@ -43,6 +45,8 @@ public PayloadStorageConfiguration() {
public PayloadStorageConfiguration(PayloadStorageConfiguration other) {
super(other);
this.s3 = other.getS3Client();
+ this.streamUploadThreshold = other.getStreamUploadThreshold();
+ this.streamUploadPartSize = other.getStreamUploadPartSize();
}
/**
@@ -148,4 +152,68 @@ public PayloadStorageConfiguration withObjectCannedACL(ObjectCannedACL objectCan
setObjectCannedACL(objectCannedACL);
return this;
}
+
+ /**
+ * Enables or disables stream upload support.
+ * @param enabled true to enable stream uploads when threshold exceeded.
+ * @return updated configuration
+ */
+ public PayloadStorageConfiguration withStreamUploadEnabled(boolean enabled) {
+ setStreamUploadEnabled(enabled);
+ return this;
+ }
+
+ /**
+ * Sets the stream upload threshold (in bytes). Only used when stream upload is enabled.
+ * @param threshold threshold in bytes (must be >0) otherwise default (5MB) is applied.
+ * @return updated configuration
+ */
+ public PayloadStorageConfiguration withStreamUploadThreshold(int threshold) {
+ setStreamUploadThreshold(threshold);
+ return this;
+ }
+
+ public PayloadStorageConfiguration withStreamUploadPartSize(int partSize) {
+ setStreamUploadPartSize(partSize);
+ return this;
+ }
+
+
+ /**
+ * Gets the stream upload threshold in bytes. Default 5MB.
+ *
+ * @return threshold in bytes.
+ */
+ public int getStreamUploadThreshold() { return streamUploadThreshold; }
+
+ /**
+ * Sets the stream upload threshold in bytes. Values less than or equal to zero will reset to default (5MB).
+ *
+ * @param streamUploadThreshold threshold in bytes
+ */
+ public void setStreamUploadThreshold(int streamUploadThreshold) {
+ int min = 5 * 1024 * 1024;
+ if (streamUploadThreshold <= min) {
+ this.streamUploadThreshold = min;
+ } else {
+ this.streamUploadThreshold = streamUploadThreshold;
+ }
+ }
+
+ /**
+ * Gets the configured stream upload part size (bytes). Default 5MB.
+ */
+ public int getStreamUploadPartSize() { return streamUploadPartSize; }
+
+ /**
+ * Sets the stream upload part size (bytes). Values < 5MB are rounded up to 5MB.
+ */
+ public void setStreamUploadPartSize(int partSize) {
+ int min = 5 * 1024 * 1024;
+ if (partSize < min) {
+ this.streamUploadPartSize = min;
+ } else {
+ this.streamUploadPartSize = partSize;
+ }
+ }
}
diff --git a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java
index 7d08746..c0679ae 100644
--- a/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java
+++ b/src/main/java/software/amazon/payloadoffloading/PayloadStorageConfigurationBase.java
@@ -4,7 +4,6 @@
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.core.exception.SdkClientException;
-import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
/**
@@ -22,6 +21,8 @@ public abstract class PayloadStorageConfigurationBase {
private int payloadSizeThreshold = 0;
private boolean alwaysThroughS3 = false;
private boolean payloadSupport = false;
+ // Enable stream upload support (opt-in, default false)
+ private boolean streamUploadEnabled = false;
/**
* This field is optional, it is set only when we want to configure S3 Server Side Encryption with KMS.
*/
@@ -44,6 +45,7 @@ public PayloadStorageConfigurationBase(PayloadStorageConfigurationBase other) {
this.payloadSizeThreshold = other.getPayloadSizeThreshold();
this.serverSideEncryptionStrategy = other.getServerSideEncryptionStrategy();
this.objectCannedACL = other.getObjectCannedACL();
+ this.streamUploadEnabled = other.isStreamUploadEnabled();
}
/**
@@ -175,4 +177,19 @@ public boolean isObjectCannedACLDefined() {
public ObjectCannedACL getObjectCannedACL() {
return objectCannedACL;
}
+
+ /**
+ * Checks whether stream upload support is enabled. Default: false.
+ *
+ * @return true if stream upload support is enabled.
+ */
+ public boolean isStreamUploadEnabled() { return streamUploadEnabled; }
+
+ /**
+ * Enable or disable stream upload support. When enabling, callers should ensure they provide a PayloadStore
+ * implementation capable of stream operations; otherwise normal single PUT behavior will be used.
+ *
+ * @param streamUploadEnabled flag to enable/disable stream support.
+ */
+ public void setStreamUploadEnabled(boolean streamUploadEnabled) { this.streamUploadEnabled = streamUploadEnabled; }
}
\ No newline at end of file
diff --git a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java
index a0dc868..ea7b5e8 100644
--- a/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java
+++ b/src/main/java/software/amazon/payloadoffloading/S3AsyncDao.java
@@ -1,11 +1,13 @@
package software.amazon.payloadoffloading;
import java.io.UncheckedIOException;
+import java.io.InputStream;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.ResponseBytes;
+import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.core.exception.SdkClientException;
@@ -13,8 +15,14 @@
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.transfer.s3.S3TransferManager;
+import software.amazon.awssdk.transfer.s3.model.UploadRequest;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
/**
* Dao layer to access S3.
@@ -65,6 +73,28 @@ public CompletableFuture getTextFromS3(String s3BucketName, String s3Key
});
}
+ public CompletableFuture> getTextStreamFromS3(String s3BucketName, String s3Key) {
+ GetObjectRequest getObjectRequest = GetObjectRequest.builder()
+ .bucket(s3BucketName)
+ .key(s3Key)
+ .build();
+
+ return s3Client.getObject(getObjectRequest, AsyncResponseTransformer.toBlockingInputStream())
+ .handle((v, tIn) -> {
+ if (tIn != null) {
+ Throwable t = Util.unwrapFutureException(tIn);
+ if (t instanceof SdkException) {
+ String errorMessage = "Failed to get the S3 object stream which contains the payload.";
+ LOG.error(errorMessage, t);
+ throw SdkException.create(errorMessage, t);
+ }
+ throw new CompletionException(t);
+ }
+ LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key);
+ return v;
+ });
+ }
+
public CompletableFuture storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr) {
PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder()
.bucket(s3BucketName)
@@ -115,4 +145,54 @@ public CompletableFuture deletePayloadFromS3(String s3BucketName, String s
return null;
});
}
+
+
+ /**
+ * Stores a stream of data to S3 using TransferManager for efficient multipart uploads.
+ *
+ * @param bucket The S3 bucket name
+ * @param key The S3 object key
+ * @param payloadStream The input stream to upload
+ * @return CompletableFuture that completes when upload is finished
+ */
+ public CompletableFuture storeTextStreamInS3(String bucket, String key, InputStream payloadStream) {
+ S3TransferManager transferManager = S3TransferManager.create();
+ ExecutorService executor = Executors.newSingleThreadExecutor();
+
+ try {
+ UploadRequest.Builder uploadBuilder = UploadRequest.builder()
+ .putObjectRequest(b -> {
+ b.bucket(bucket).key(key);
+ if (objectCannedACL != null) {
+ b.acl(objectCannedACL);
+ }
+ // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html
+ if (serverSideEncryptionStrategy != null) {
+ serverSideEncryptionStrategy.decorate(b);
+ }
+ })
+ .requestBody(AsyncRequestBody.fromInputStream(payloadStream, null, executor));
+
+ return transferManager.upload(uploadBuilder.build()).completionFuture()
+ .thenApply(completedUpload -> {
+ LOG.info("S3 stream object created from InputStream, Bucket name: " + bucket + ", Object key: " + key + ".");
+ return (Void) null;
+ })
+ .exceptionally(t -> {
+ Throwable cause = Util.unwrapFutureException(t);
+ if (cause instanceof SdkException) {
+ String errorMessage = "Failed to store the message content from InputStream in an S3 stream object.";
+ LOG.error(errorMessage, cause);
+ throw SdkException.create(errorMessage, cause);
+ }
+ throw new CompletionException(cause);
+ });
+ } catch (UnsupportedOperationException e) {
+ LOG.warn("TransferManager creation disabled, cannot perform streaming upload: " + e.getMessage());
+ throw new CompletionException(e);
+ } finally {
+ transferManager.close();
+ executor.shutdownNow();
+ }
+ }
}
diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java
new file mode 100644
index 0000000..cd3bc09
--- /dev/null
+++ b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStore.java
@@ -0,0 +1,37 @@
+package software.amazon.payloadoffloading;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import software.amazon.awssdk.core.ResponseInputStream;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
+import java.io.InputStream;
+
+public class S3BackedStreamPayloadStore extends S3BackedPayloadStore implements StreamPayloadStore {
+ private static final Logger LOG = LoggerFactory.getLogger(S3BackedStreamPayloadStore.class);
+ private final S3Dao s3Dao;
+ private final String s3BucketName;
+
+ public S3BackedStreamPayloadStore(S3Dao s3Dao, String s3BucketName) {
+ super(s3Dao, s3BucketName);
+ this.s3Dao = s3Dao;
+ this.s3BucketName = s3BucketName;
+ }
+
+ @Override
+ public String storeOriginalPayloadStream(InputStream payloadStream, String s3Key) {
+ s3Dao.storeTextStreamInS3(s3BucketName, s3Key, payloadStream);
+ LOG.info("S3 stream object created from InputStream, Bucket name: " + s3BucketName + ", Object key: " + s3Key + ".");
+ return new PayloadS3Pointer(s3BucketName, s3Key).toJson();
+ }
+
+ @Override
+ public ResponseInputStream getOriginalPayloadStreamStream(String payloadPointer) {
+ PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(payloadPointer);
+ String s3BucketName = s3Pointer.getS3BucketName();
+ String s3Key = s3Pointer.getS3Key();
+
+ ResponseInputStream stream = s3Dao.getTextStreamFromS3(s3BucketName, s3Key);
+ LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key);
+ return stream;
+ }
+}
diff --git a/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java
new file mode 100644
index 0000000..7a75e5b
--- /dev/null
+++ b/src/main/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsync.java
@@ -0,0 +1,45 @@
+package software.amazon.payloadoffloading;
+
+import java.io.InputStream;
+import java.util.concurrent.CompletableFuture;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import software.amazon.awssdk.core.ResponseInputStream;
+import software.amazon.awssdk.services.s3.S3AsyncClient;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
+import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
+
+public class S3BackedStreamPayloadStoreAsync extends S3BackedPayloadStoreAsync implements StreamPayloadStoreAsync {
+ private static final Logger LOG = LoggerFactory.getLogger(S3BackedStreamPayloadStoreAsync.class);
+ private final S3AsyncDao streamDao;
+ private final String s3BucketName;
+
+ public S3BackedStreamPayloadStoreAsync(S3AsyncDao s3Dao, String s3BucketName) {
+ super(s3Dao, s3BucketName);
+ this.streamDao = s3Dao;
+ this.s3BucketName = s3BucketName;
+ }
+
+ @Override
+ public CompletableFuture storeOriginalPayloadStream(InputStream payloadStream, String s3Key) {
+ return streamDao.storeTextStreamInS3(s3BucketName, s3Key, payloadStream)
+ .thenApply(v -> {
+ LOG.info("S3 stream object created from InputStream, Bucket name: " + s3BucketName + ", Object key: " + s3Key + ".");
+ return new PayloadS3Pointer(s3BucketName, s3Key).toJson();
+ });
+ }
+
+ @Override
+ public CompletableFuture> getOriginalPayloadStreamStream(String payloadPointer) {
+ PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(payloadPointer);
+ String s3BucketName = s3Pointer.getS3BucketName();
+ String s3Key = s3Pointer.getS3Key();
+
+ return streamDao.getTextStreamFromS3(s3BucketName, s3Key)
+ .thenApply(stream -> {
+ LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key);
+ return stream;
+ });
+ }
+}
diff --git a/src/main/java/software/amazon/payloadoffloading/S3Dao.java b/src/main/java/software/amazon/payloadoffloading/S3Dao.java
index 2b03dd5..2cfe158 100644
--- a/src/main/java/software/amazon/payloadoffloading/S3Dao.java
+++ b/src/main/java/software/amazon/payloadoffloading/S3Dao.java
@@ -12,9 +12,20 @@
import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest;
+import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload;
+import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
+import software.amazon.awssdk.services.s3.model.CompletedPart;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse;
+import software.amazon.awssdk.services.s3.model.UploadPartRequest;
+import software.amazon.awssdk.services.s3.model.UploadPartResponse;
import software.amazon.awssdk.utils.IoUtils;
import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.List;
/**
* Dao layer to access S3.
@@ -24,9 +35,13 @@ public class S3Dao {
private final S3Client s3Client;
private final ServerSideEncryptionStrategy serverSideEncryptionStrategy;
private final ObjectCannedACL objectCannedACL;
+ private int streamUploadPartSize = 5 * 1024 * 1024;
+ private int streamUploadThreshold = 5 * 1024 * 1024;
public S3Dao(S3Client s3Client) {
- this(s3Client, null, null);
+ this.s3Client = s3Client;
+ this.serverSideEncryptionStrategy = null;
+ this.objectCannedACL = null;
}
public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL) {
@@ -35,6 +50,14 @@ public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptio
this.objectCannedACL = objectCannedACL;
}
+ public S3Dao(S3Client s3Client, ServerSideEncryptionStrategy serverSideEncryptionStrategy, ObjectCannedACL objectCannedACL, int streamUploadThreshold, int streamUploadPartSize) {
+ this.s3Client = s3Client;
+ this.serverSideEncryptionStrategy = serverSideEncryptionStrategy;
+ this.objectCannedACL = objectCannedACL;
+ this.streamUploadThreshold = streamUploadThreshold;
+ this.streamUploadPartSize = streamUploadPartSize;
+ }
+
public String getTextFromS3(String s3BucketName, String s3Key) {
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
.bucket(s3BucketName)
@@ -65,6 +88,23 @@ public String getTextFromS3(String s3BucketName, String s3Key) {
return embeddedText;
}
+ public ResponseInputStream getTextStreamFromS3(String s3BucketName, String s3Key) {
+ GetObjectRequest getObjectRequest = GetObjectRequest.builder()
+ .bucket(s3BucketName)
+ .key(s3Key)
+ .build();
+
+ try {
+ ResponseInputStream object = s3Client.getObject(getObjectRequest);
+ LOG.info("S3 object stream retrieved, Bucket name: " + s3BucketName + ", Object key: " + s3Key);
+ return object;
+ } catch (SdkException e) {
+ String errorMessage = "Failed to get the S3 object stream which contains the payload.";
+ LOG.error(errorMessage, e);
+ throw SdkException.create(errorMessage, e);
+ }
+ }
+
public void storeTextInS3(String s3BucketName, String s3Key, String payloadContentStr) {
PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder()
.bucket(s3BucketName)
@@ -104,4 +144,113 @@ public void deletePayloadFromS3(String s3BucketName, String s3Key) {
LOG.info("S3 object deleted, Bucket name: " + s3BucketName + ", Object key: " + s3Key + ".");
}
+
+
+ /**
+ * Stores a stream of data to S3 using multipart upload.
+ * This method reads the stream in chunks and uploads each part separately,
+ * which is suitable for large payloads without requiring the entire content in memory.
+ *
+ * @param bucket The S3 bucket name
+ * @param key The S3 object key
+ * @param payloadStream The input stream to upload
+ * @throws SdkException if the upload fails
+ */
+ public void storeTextStreamInS3(String bucket, String key, InputStream payloadStream) {
+ String uploadId = null;
+ List completedParts = new ArrayList<>();
+
+ try {
+ CreateMultipartUploadRequest.Builder createRequestBuilder = CreateMultipartUploadRequest.builder()
+ .bucket(bucket)
+ .key(key);
+
+ if (objectCannedACL != null) {
+ createRequestBuilder.acl(objectCannedACL);
+ }
+
+ // https://docs.aws.amazon.com/AmazonS3/latest/dev/kms-using-sdks.html
+ if (serverSideEncryptionStrategy != null) {
+ serverSideEncryptionStrategy.decorate(createRequestBuilder);
+ }
+
+ CreateMultipartUploadResponse createResponse = s3Client.createMultipartUpload(createRequestBuilder.build());
+ uploadId = createResponse.uploadId();
+
+ byte[] buffer = new byte[streamUploadPartSize];
+ int partNumber = 1;
+ int bytesRead;
+
+ while ((bytesRead = payloadStream.read(buffer)) > 0) {
+ RequestBody requestBody = RequestBody.fromBytes(bytesRead == buffer.length ? buffer :
+ java.util.Arrays.copyOf(buffer, bytesRead));
+
+ UploadPartRequest uploadPartRequest = UploadPartRequest.builder()
+ .bucket(bucket)
+ .key(key)
+ .uploadId(uploadId)
+ .partNumber(partNumber)
+ .build();
+
+ UploadPartResponse uploadPartResponse = s3Client.uploadPart(uploadPartRequest, requestBody);
+
+ CompletedPart completedPart = CompletedPart.builder()
+ .partNumber(partNumber)
+ .eTag(uploadPartResponse.eTag())
+ .build();
+
+ completedParts.add(completedPart);
+ partNumber++;
+ }
+
+ CompletedMultipartUpload completedMultipartUpload = CompletedMultipartUpload.builder()
+ .parts(completedParts)
+ .build();
+
+ CompleteMultipartUploadRequest completeRequest = CompleteMultipartUploadRequest.builder()
+ .bucket(bucket)
+ .key(key)
+ .uploadId(uploadId)
+ .multipartUpload(completedMultipartUpload)
+ .build();
+
+ s3Client.completeMultipartUpload(completeRequest);
+ LOG.info("S3 stream object created from InputStream using multipart upload, Bucket name: " + bucket + ", Object key: " + key + ".");
+
+ } catch (IOException e) {
+ // Abort the multipart upload if it was initiated
+ if (uploadId != null) {
+ try {
+ AbortMultipartUploadRequest abortRequest = AbortMultipartUploadRequest.builder()
+ .bucket(bucket)
+ .key(key)
+ .uploadId(uploadId)
+ .build();
+ s3Client.abortMultipartUpload(abortRequest);
+ } catch (Exception abortException) {
+ LOG.warn("Failed to abort multipart upload", abortException);
+ }
+ }
+ String errorMessage = "Failed to read from InputStream during multipart upload.";
+ LOG.error(errorMessage, e);
+ throw SdkClientException.create(errorMessage, e);
+ } catch (SdkException e) {
+ // Abort the multipart upload if it was initiated
+ if (uploadId != null) {
+ try {
+ AbortMultipartUploadRequest abortRequest = AbortMultipartUploadRequest.builder()
+ .bucket(bucket)
+ .key(key)
+ .uploadId(uploadId)
+ .build();
+ s3Client.abortMultipartUpload(abortRequest);
+ } catch (Exception abortException) {
+ LOG.warn("Failed to abort multipart upload", abortException);
+ }
+ }
+ String errorMessage = "Failed to store the message content from InputStream in an S3 object using multipart upload.";
+ LOG.error(errorMessage, e);
+ throw SdkException.create(errorMessage, e);
+ }
+ }
}
diff --git a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java
index f385ce6..43aced1 100644
--- a/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java
+++ b/src/main/java/software/amazon/payloadoffloading/ServerSideEncryptionStrategy.java
@@ -1,7 +1,9 @@
package software.amazon.payloadoffloading;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
public interface ServerSideEncryptionStrategy {
void decorate(PutObjectRequest.Builder putObjectRequestBuilder);
+ void decorate(CreateMultipartUploadRequest.Builder createStreamUploadRequestBuilder);
}
diff --git a/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java
new file mode 100644
index 0000000..ba3ac0d
--- /dev/null
+++ b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStore.java
@@ -0,0 +1,33 @@
+package software.amazon.payloadoffloading;
+
+import software.amazon.awssdk.core.ResponseInputStream;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
+import java.io.InputStream;
+
+/**
+ * Optional extension interface for {@link PayloadStore} implementations that can perform stream
+ * upload and streaming retrieval for very large payloads to reduce memory pressure.
+ */
+public interface StreamPayloadStore extends PayloadStore {
+
+ /**
+ * Store a payload from an InputStream using streaming-aware logic to avoid loading large payloads into memory.
+ * @param payloadStream InputStream containing UTF-8 text payload
+ * @param s3Key pre-generated S3 key (or equivalent object key) to use
+ * @return pointer string (JSON) referencing stored payload
+ * @throws UnsupportedOperationException if streaming upload from InputStream is not implemented
+ */
+ default String storeOriginalPayloadStream(InputStream payloadStream, String s3Key) {
+ throw new UnsupportedOperationException("Streaming upload from InputStream not implemented");
+ }
+
+ /**
+ * Retrieve a payload as a stream to avoid loading large payloads into memory.
+ * @param payloadPointer pointer JSON string
+ * @return stream containing the original payload content
+ * @throws UnsupportedOperationException if streaming retrieval is not implemented
+ */
+ default ResponseInputStream getOriginalPayloadStreamStream(String payloadPointer) {
+ throw new UnsupportedOperationException("Streaming retrieval not implemented");
+ }
+}
diff --git a/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java
new file mode 100644
index 0000000..0333eec
--- /dev/null
+++ b/src/main/java/software/amazon/payloadoffloading/StreamPayloadStoreAsync.java
@@ -0,0 +1,34 @@
+package software.amazon.payloadoffloading;
+
+import java.io.InputStream;
+import java.util.concurrent.CompletableFuture;
+import software.amazon.awssdk.core.ResponseInputStream;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
+
+/**
+ * Optional extension interface for {@link PayloadStoreAsync} implementations that can perform stream
+ * upload and streaming retrieval for very large payloads to reduce memory pressure.
+ */
+public interface StreamPayloadStoreAsync extends PayloadStoreAsync {
+
+ /**
+ * Store payload content from an InputStream using stream semantics to avoid loading large payloads into memory.
+ * @param payloadStream InputStream containing UTF-8 text payload
+ * @param s3Key object key to use
+ * @return future pointer string (JSON) referencing stored payload
+ * @throws UnsupportedOperationException if streaming upload from InputStream is not implemented
+ */
+ default CompletableFuture storeOriginalPayloadStream(InputStream payloadStream, String s3Key) {
+ throw new UnsupportedOperationException("Streaming upload from InputStream not implemented");
+ }
+
+ /**
+ * Retrieve payload as a stream to avoid loading large payloads into memory.
+ * @param payloadPointer pointer JSON string
+ * @return future stream containing the original payload content
+ * @throws UnsupportedOperationException if streaming retrieval is not implemented
+ */
+ default CompletableFuture> getOriginalPayloadStreamStream(String payloadPointer) {
+ throw new UnsupportedOperationException("Streaming retrieval not implemented");
+ }
+}
diff --git a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java
index 723678d..8a7c663 100644
--- a/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java
+++ b/src/test/java/software/amazon/payloadoffloading/AwsManagedCmkTest.java
@@ -3,6 +3,7 @@
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -19,4 +20,16 @@ public void testAwsManagedCmkStrategySetsCorrectEncryptionValues() {
assertEquals(putObjectRequest.serverSideEncryption(), (ServerSideEncryption.AWS_KMS));
}
+
+
+ @Test
+ public void testAwsManagedCmkStrategyStreamSetsCorrectEncryptionValues() {
+ AwsManagedCmk awsManagedCmk = new AwsManagedCmk();
+
+ CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder();
+ awsManagedCmk.decorate(createMultipartUploadRequestBuilder);
+ CreateMultipartUploadRequest createMultipartUploadRequest = createMultipartUploadRequestBuilder.build();
+
+ assertEquals(createMultipartUploadRequest.serverSideEncryption(), (ServerSideEncryption.AWS_KMS));
+ }
}
\ No newline at end of file
diff --git a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java
index cc55b70..0474d30 100644
--- a/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java
+++ b/src/test/java/software/amazon/payloadoffloading/CustomerKeyTest.java
@@ -3,6 +3,7 @@
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
+import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -21,4 +22,16 @@ public void testCustomerKeyStrategySetsCorrectEncryptionValues() {
assertEquals(putObjectRequest.serverSideEncryption(), ServerSideEncryption.AWS_KMS);
assertEquals(putObjectRequest.ssekmsKeyId(), AWS_KMS_KEY_ID);
}
+
+ @Test
+ public void testCustomerKeyStrategySetsStreamUploadEncryptionValues() {
+ CustomerKey customerKey = new CustomerKey(AWS_KMS_KEY_ID);
+
+ CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder();
+ customerKey.decorate(createMultipartUploadRequestBuilder);
+ CreateMultipartUploadRequest createMultipartUploadRequest = createMultipartUploadRequestBuilder.build();
+
+ assertEquals(createMultipartUploadRequest.serverSideEncryption(), ServerSideEncryption.AWS_KMS);
+ assertEquals(createMultipartUploadRequest.ssekmsKeyId(), AWS_KMS_KEY_ID);
+ }
}
\ No newline at end of file
diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java
index 30cf188..af8aa96 100644
--- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java
+++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageAsyncConfigurationTest.java
@@ -99,4 +99,29 @@ public void testCannedAccessControlList() {
assertTrue(payloadStorageConfiguration.isObjectCannedACLDefined());
assertEquals(objectCannelACL, payloadStorageConfiguration.getObjectCannedACL());
}
+
+ @Test
+ public void testStreamUploadEnabled() {
+ PayloadStorageAsyncConfiguration payloadStorageConfiguration = new PayloadStorageAsyncConfiguration();
+
+ payloadStorageConfiguration.setStreamUploadEnabled(true);
+ assertTrue(payloadStorageConfiguration.isStreamUploadEnabled());
+
+ payloadStorageConfiguration.setStreamUploadEnabled(false);
+ assertFalse(payloadStorageConfiguration.isStreamUploadEnabled());
+ }
+
+ @Test
+ public void testStreamConfigurationInCopyConstructor() {
+ S3AsyncClient s3Async = mock(S3AsyncClient.class);
+
+ PayloadStorageAsyncConfiguration original = new PayloadStorageAsyncConfiguration();
+ original.withPayloadSupportEnabled(s3Async, s3BucketName)
+ .withStreamUploadEnabled(true);
+
+ PayloadStorageAsyncConfiguration copy = new PayloadStorageAsyncConfiguration(original);
+
+ assertTrue(copy.isStreamUploadEnabled());
+ assertNotSame(copy, original);
+ }
}
\ No newline at end of file
diff --git a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java
index f4adf77..d233826 100644
--- a/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java
+++ b/src/test/java/software/amazon/payloadoffloading/PayloadStorageConfigurationTest.java
@@ -99,4 +99,51 @@ public void testCannedAccessControlList() {
assertTrue(payloadStorageConfiguration.isObjectCannedACLDefined());
assertEquals(objectCannelACL, payloadStorageConfiguration.getObjectCannedACL());
}
+
+ @Test
+ public void testStreamUploadEnabled() {
+ PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();
+
+ payloadStorageConfiguration.setStreamUploadEnabled(true);
+ assertTrue(payloadStorageConfiguration.isStreamUploadEnabled());
+
+ payloadStorageConfiguration.setStreamUploadEnabled(false);
+ assertFalse(payloadStorageConfiguration.isStreamUploadEnabled());
+ }
+
+ @Test
+ public void testStreamUploadThreshold() {
+ PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();
+
+ int customThreshold = 10 * 1024 * 1024; // 10MB
+ payloadStorageConfiguration.setStreamUploadThreshold(customThreshold);
+ assertEquals(customThreshold, payloadStorageConfiguration.getStreamUploadThreshold());
+ }
+
+ @Test
+ public void testStreamUploadPartSize() {
+ PayloadStorageConfiguration payloadStorageConfiguration = new PayloadStorageConfiguration();
+
+ int customPartSize = 10 * 1024 * 1024; // 10MB
+ payloadStorageConfiguration.setStreamUploadPartSize(customPartSize);
+ assertEquals(customPartSize, payloadStorageConfiguration.getStreamUploadPartSize());
+ }
+
+ @Test
+ public void testStreamConfigurationInCopyConstructor() {
+ S3Client s3 = mock(S3Client.class);
+
+ PayloadStorageConfiguration original = new PayloadStorageConfiguration();
+ original.withPayloadSupportEnabled(s3, s3BucketName)
+ .withStreamUploadEnabled(true)
+ .withStreamUploadThreshold(10 * 1024 * 1024)
+ .setStreamUploadPartSize(8 * 1024 * 1024);
+
+ PayloadStorageConfiguration copy = new PayloadStorageConfiguration(original);
+
+ assertTrue(copy.isStreamUploadEnabled());
+ assertEquals(10 * 1024 * 1024, copy.getStreamUploadThreshold());
+ assertEquals(8 * 1024 * 1024, copy.getStreamUploadPartSize());
+ assertNotSame(copy, original);
+ }
}
diff --git a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java
index 1ecccb8..5d19b92 100644
--- a/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java
+++ b/src/test/java/software/amazon/payloadoffloading/S3AsyncDaoTest.java
@@ -2,23 +2,28 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import software.amazon.awssdk.core.ResponseBytes;
+import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
import software.amazon.awssdk.services.s3.model.GetObjectRequest;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
@@ -111,4 +116,23 @@ public void deleteTextTest() {
verify(s3AsyncClient, times(1)).deleteObject(any(DeleteObjectRequest.class));
}
+
+ @Test
+ public void getTextStreamFromS3Test() {
+ dao = new S3AsyncDao(s3AsyncClient);
+
+ byte[] testData = ANY_PAYLOAD.getBytes(StandardCharsets.UTF_8);
+ ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(testData);
+ GetObjectResponse getObjectResponse = GetObjectResponse.builder().build();
+ ResponseInputStream mockResponseStream =
+ new ResponseInputStream<>(getObjectResponse, byteArrayInputStream);
+
+ when(s3AsyncClient.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class)))
+ .thenReturn(CompletableFuture.completedFuture(mockResponseStream));
+
+ ResponseInputStream stream = dao.getTextStreamFromS3(S3_BUCKET_NAME, ANY_S3_KEY).join();
+
+ verify(s3AsyncClient, times(1)).getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class));
+ assertNotNull(stream);
+ }
}
diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java
new file mode 100644
index 0000000..1edc7f4
--- /dev/null
+++ b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreAsyncTest.java
@@ -0,0 +1,111 @@
+package software.amazon.payloadoffloading;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import software.amazon.awssdk.core.exception.SdkException;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.*;
+
+public class S3BackedStreamPayloadStoreAsyncTest {
+ private static final String S3_BUCKET_NAME = "test-bucket-name";
+ private static final String ANY_PAYLOAD = "AnyPayload";
+ private static final String ANY_S3_KEY = "AnyS3key";
+
+ private S3BackedStreamPayloadStoreAsync streamPayloadStore;
+ private S3AsyncDao s3AsyncDao;
+
+ @BeforeEach
+ public void setup() {
+ s3AsyncDao = mock(S3AsyncDao.class);
+ streamPayloadStore = new S3BackedStreamPayloadStoreAsync(s3AsyncDao, S3_BUCKET_NAME);
+ }
+
+ @Test
+ public void testStoreOriginalPayloadStreamOnSuccess() {
+ when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class)))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+ String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join();
+
+ ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class);
+ ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class);
+ ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(java.io.InputStream.class);
+
+ verify(s3AsyncDao, times(1)).storeTextStreamInS3(
+ bucketCaptor.capture(),
+ keyCaptor.capture(),
+ payloadCaptor.capture()
+ );
+
+ assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue());
+ assertEquals(ANY_S3_KEY, keyCaptor.getValue());
+ assertNotNull(payloadCaptor.getValue());
+
+ PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY);
+ assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer);
+ }
+
+ @Test
+ public void testStoreOriginalPayloadStreamWithCustomPartSize() {
+ String payload = generateString(12 * 1024);
+ String s3Key = "custom-size-key";
+ java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+
+ when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class)))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ streamPayloadStore.storeOriginalPayloadStream(payloadStream, s3Key).join();
+
+ verify(s3AsyncDao, times(1)).storeTextStreamInS3(
+ eq(S3_BUCKET_NAME),
+ eq(s3Key),
+ any(java.io.InputStream.class)
+ );
+ }
+
+ @Test
+ public void testStoreOriginalPayloadStreamOnS3Failure() {
+ CompletableFuture failedFuture = new CompletableFuture<>();
+ failedFuture.completeExceptionally(SdkException.create("S3 Exception", new Throwable()));
+
+ when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class)))
+ .thenReturn(failedFuture);
+
+ CompletionException exception = assertThrows(CompletionException.class, () -> {
+ java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+ streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join();
+ });
+
+ assertTrue(exception.getMessage().contains("S3 Exception"));
+ }
+
+ @Test
+ public void testStoreOriginalPayloadStreamHandlesNullFromDao() {
+ when(s3AsyncDao.storeTextStreamInS3(any(String.class), any(String.class), any(java.io.InputStream.class)))
+ .thenReturn(CompletableFuture.completedFuture(null));
+
+ java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+ String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY).join();
+
+ assertNotNull(actualPayloadPointer);
+ PayloadS3Pointer pointer = PayloadS3Pointer.fromJson(actualPayloadPointer);
+ assertEquals(S3_BUCKET_NAME, pointer.getS3BucketName());
+ assertEquals(ANY_S3_KEY, pointer.getS3Key());
+ }
+
+ private String generateString(int length) {
+ StringBuilder sb = new StringBuilder(length);
+ for (int i = 0; i < length; i++) {
+ sb.append('a');
+ }
+ return sb.toString();
+ }
+}
diff --git a/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java
new file mode 100644
index 0000000..4679778
--- /dev/null
+++ b/src/test/java/software/amazon/payloadoffloading/S3BackedStreamPayloadStoreTest.java
@@ -0,0 +1,88 @@
+package software.amazon.payloadoffloading;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import software.amazon.awssdk.core.exception.SdkException;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.*;
+
+public class S3BackedStreamPayloadStoreTest {
+ private static final String S3_BUCKET_NAME = "test-bucket-name";
+ private static final String ANY_PAYLOAD = "AnyPayload";
+ private static final String ANY_S3_KEY = "AnyS3key";
+
+ private S3BackedStreamPayloadStore streamPayloadStore;
+ private S3Dao s3Dao;
+
+ @BeforeEach
+ public void setup() {
+ s3Dao = mock(S3Dao.class);
+ streamPayloadStore = new S3BackedStreamPayloadStore(s3Dao, S3_BUCKET_NAME);
+ }
+
+ @Test
+ public void testStoreOriginalPayloadStreamOnSuccess() {
+ java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+ String actualPayloadPointer = streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY);
+
+ ArgumentCaptor bucketCaptor = ArgumentCaptor.forClass(String.class);
+ ArgumentCaptor keyCaptor = ArgumentCaptor.forClass(String.class);
+ ArgumentCaptor payloadCaptor = ArgumentCaptor.forClass(java.io.InputStream.class);
+
+ verify(s3Dao, times(1)).storeTextStreamInS3(
+ bucketCaptor.capture(),
+ keyCaptor.capture(),
+ payloadCaptor.capture()
+ );
+
+ assertEquals(S3_BUCKET_NAME, bucketCaptor.getValue());
+ assertEquals(ANY_S3_KEY, keyCaptor.getValue());
+ assertNotNull(payloadCaptor.getValue());
+
+ PayloadS3Pointer expectedPayloadPointer = new PayloadS3Pointer(S3_BUCKET_NAME, ANY_S3_KEY);
+ assertEquals(expectedPayloadPointer.toJson(), actualPayloadPointer);
+ }
+
+ @Test
+ public void testStoreOriginalPayloadStreamWithCustomPartSize() {
+ String payload = generateString(12 * 1024);
+ String s3Key = "custom-size-key";
+ java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+
+ streamPayloadStore.storeOriginalPayloadStream(payloadStream, s3Key);
+
+ verify(s3Dao, times(1)).storeTextStreamInS3(
+ eq(S3_BUCKET_NAME),
+ eq(s3Key),
+ any(java.io.InputStream.class)
+ );
+ }
+
+ @Test
+ public void testStoreOriginalPayloadStreamOnS3Failure() {
+ doThrow(SdkException.create("S3 Exception", new Throwable()))
+ .when(s3Dao)
+ .storeTextStreamInS3(
+ any(String.class),
+ any(String.class),
+ any(java.io.InputStream.class)
+ );
+
+ assertThrows(SdkException.class, () -> {
+ java.io.InputStream payloadStream = new java.io.ByteArrayInputStream(ANY_PAYLOAD.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+ streamPayloadStore.storeOriginalPayloadStream(payloadStream, ANY_S3_KEY);
+ }, "S3 Exception");
+ }
+
+ private String generateString(int length) {
+ StringBuilder sb = new StringBuilder(length);
+ for (int i = 0; i < length; i++) {
+ sb.append('a');
+ }
+ return sb.toString();
+ }
+}
diff --git a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java
index b0a8f25..e0ce2be 100644
--- a/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java
+++ b/src/test/java/software/amazon/payloadoffloading/S3DaoTest.java
@@ -2,14 +2,12 @@
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
-import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
+import software.amazon.awssdk.services.s3.model.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
-import software.amazon.awssdk.services.s3.model.PutObjectRequest;
-import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
@@ -17,10 +15,14 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
public class S3DaoTest {
- private static final String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id";
private static final String S3_BUCKET_NAME = "test-bucket-name";
private static final String ANY_PAYLOAD = "AnyPayload";
private static final String ANY_S3_KEY = "AnyS3key";
@@ -75,4 +77,149 @@ public void storeTextInS3WithBothTest() {
assertEquals(objectCannedACL, argument.getValue().acl());
assertEquals(S3_BUCKET_NAME, argument.getValue().bucket());
}
+
+ @Test
+ public void storeTextStreamInS3WithoutSSEOrCannedTest() {
+ dao = new S3Dao(s3Client);
+
+ CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder()
+ .uploadId("test-upload-id")
+ .build();
+ when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class)))
+ .thenReturn(createResponse);
+
+ UploadPartResponse uploadPartResponse = UploadPartResponse.builder()
+ .eTag("test-etag")
+ .build();
+ when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class)))
+ .thenReturn(uploadPartResponse);
+
+ CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder()
+ .build();
+ when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class)))
+ .thenReturn(completeResponse);
+
+ byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8);
+ InputStream inputStream = new ByteArrayInputStream(testData);
+
+ dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream);
+
+ ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class);
+ verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture());
+
+ assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket());
+ assertEquals(ANY_S3_KEY, createCaptor.getValue().key());
+ assertNull(createCaptor.getValue().serverSideEncryption());
+ assertNull(createCaptor.getValue().acl());
+
+ verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class));
+ verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
+ }
+
+ @Test
+ public void storeTextStreamInS3WithSSETest() {
+ dao = new S3Dao(s3Client, serverSideEncryptionStrategy, null);
+
+ CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder()
+ .uploadId("test-upload-id")
+ .build();
+ when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class)))
+ .thenReturn(createResponse);
+
+ UploadPartResponse uploadPartResponse = UploadPartResponse.builder()
+ .eTag("test-etag")
+ .build();
+ when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class)))
+ .thenReturn(uploadPartResponse);
+
+ CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder()
+ .build();
+ when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class)))
+ .thenReturn(completeResponse);
+
+ byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8);
+ InputStream inputStream = new ByteArrayInputStream(testData);
+
+ dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream);
+
+ ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class);
+ verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture());
+
+ assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket());
+ assertEquals(ANY_S3_KEY, createCaptor.getValue().key());
+ assertEquals(ServerSideEncryption.AWS_KMS, createCaptor.getValue().serverSideEncryption());
+ assertNull(createCaptor.getValue().acl());
+
+ verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class));
+ verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
+ }
+
+ @Test
+ public void storeTextStreamInS3WithBothSSEAndCannedACLTest() {
+ dao = new S3Dao(s3Client, serverSideEncryptionStrategy, objectCannedACL);
+
+ CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder()
+ .uploadId("test-upload-id")
+ .build();
+ when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class)))
+ .thenReturn(createResponse);
+
+ UploadPartResponse uploadPartResponse = UploadPartResponse.builder()
+ .eTag("test-etag")
+ .build();
+ when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class)))
+ .thenReturn(uploadPartResponse);
+
+ CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder()
+ .build();
+ when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class)))
+ .thenReturn(completeResponse);
+
+ byte[] testData = "Test streaming data".getBytes(StandardCharsets.UTF_8);
+ InputStream inputStream = new ByteArrayInputStream(testData);
+
+ dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream);
+
+ ArgumentCaptor createCaptor = ArgumentCaptor.forClass(CreateMultipartUploadRequest.class);
+ verify(s3Client, times(1)).createMultipartUpload(createCaptor.capture());
+
+ assertEquals(S3_BUCKET_NAME, createCaptor.getValue().bucket());
+ assertEquals(ANY_S3_KEY, createCaptor.getValue().key());
+ assertEquals(ServerSideEncryption.AWS_KMS, createCaptor.getValue().serverSideEncryption());
+ assertEquals(objectCannedACL, createCaptor.getValue().acl());
+
+ verify(s3Client, times(1)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class));
+ verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
+ }
+
+ @Test
+ public void storeTextStreamInS3WithMultiplePartsTest() {
+ dao = new S3Dao(s3Client);
+
+ CreateMultipartUploadResponse createResponse = CreateMultipartUploadResponse.builder()
+ .uploadId("test-upload-id")
+ .build();
+ when(s3Client.createMultipartUpload(any(CreateMultipartUploadRequest.class)))
+ .thenReturn(createResponse);
+
+ UploadPartResponse uploadPartResponse = UploadPartResponse.builder()
+ .eTag("test-etag")
+ .build();
+ when(s3Client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class)))
+ .thenReturn(uploadPartResponse);
+
+ CompleteMultipartUploadResponse completeResponse = CompleteMultipartUploadResponse.builder()
+ .build();
+ when(s3Client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class)))
+ .thenReturn(completeResponse);
+
+ byte[] testData = new byte[5 * 1024 * 1024 + 1];
+ InputStream inputStream = new ByteArrayInputStream(testData);
+
+ dao.storeTextStreamInS3(S3_BUCKET_NAME, ANY_S3_KEY, inputStream);
+
+ verify(s3Client, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class));
+ verify(s3Client, times(2)).uploadPart(any(UploadPartRequest.class), any(RequestBody.class));
+ verify(s3Client, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
+ }
}
\ No newline at end of file