diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncDownloadTests.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncDownloadTests.java index 38a96c2521a9..3094ac49c5b8 100644 --- a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncDownloadTests.java +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncDownloadTests.java @@ -8,7 +8,6 @@ import com.azure.core.test.utils.TestUtils; import com.azure.core.util.BinaryData; import com.azure.core.util.FluxUtil; -import com.azure.storage.blob.models.BlobRange; import com.azure.storage.blob.models.DownloadRetryOptions; import com.azure.storage.blob.options.BlobDownloadContentOptions; import com.azure.storage.blob.options.BlobDownloadStreamOptions; @@ -19,7 +18,6 @@ import com.azure.storage.common.implementation.contentvalidation.StorageCrc64Calculator; import com.azure.storage.common.test.shared.extensions.LiveOnly; import com.azure.storage.common.test.shared.policy.MockPartialResponsePolicy; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -46,11 +44,45 @@ */ public class BlobContentValidationAsyncDownloadTests extends BlobTestBase { private static final int TEN_MB = 10 * Constants.MB; + private static final int BLOCK_SIZE = 4 * Constants.MB; + private final List createdFiles = new ArrayList<>(); - @AfterEach - public void cleanup() { + private byte[] data; + private List recordedRequestHeaders; + private HttpHeaders recordedResponseHeaders; + private BlobAsyncClient blobClient; + private BlobAsyncClient downloadClient; + private File file; + private File outFile; + + @Override + public void beforeTest() { + super.beforeTest(); + data = null; + recordedRequestHeaders = new CopyOnWriteArrayList<>(); + recordedResponseHeaders = new HttpHeaders(); + blobClient = null; + downloadClient = null; + } + + @Override + protected void afterTest() { createdFiles.forEach(File::delete); + createdFiles.clear(); + data = null; + recordedRequestHeaders = new CopyOnWriteArrayList<>(); + recordedResponseHeaders = new HttpHeaders(); + blobClient = null; + downloadClient = null; + file = null; + outFile = null; + super.afterTest(); + } + + private void initializeBlobClient() { + blobClient = createBlobAsyncClientWithRequestSniffer(recordedRequestHeaders); + downloadClient = blobClient; } /** @@ -58,21 +90,18 @@ public void cleanup() { */ @Test public void downloadStreamWithResponseContentValidation() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); downloadClient.upload(BinaryData.fromBytes(data)).block(); BlobDownloadStreamOptions options = new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - StepVerifier - .create(downloadClient.downloadStreamWithResponse(options) - .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) - .assertNext(result -> TestUtils.assertArraysEqual(data, result)) - .verifyComplete(); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + StepVerifier.create(downloadClient.downloadStreamWithResponse(options).flatMap(r -> { + assertTrue(hasStructuredMessageDownloadResponseHeaders(r.getHeaders())); + return FluxUtil.collectBytesInByteBufferStream(r.getValue()); + })).assertNext(result -> TestUtils.assertArraysEqual(data, result)).verifyComplete(); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -80,19 +109,18 @@ public void downloadStreamWithResponseContentValidation() { */ @Test public void downloadContentWithResponseContentValidation() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); downloadClient.upload(BinaryData.fromBytes(data)).block(); BlobDownloadContentOptions options = new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - StepVerifier.create(downloadClient.downloadContentWithResponse(options)) - .assertNext(r -> TestUtils.assertArraysEqual(data, r.getValue().toBytes())) - .verifyComplete(); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + StepVerifier.create(downloadClient.downloadContentWithResponse(options)).assertNext(r -> { + assertTrue(hasStructuredMessageDownloadResponseHeaders(r.getHeaders())); + TestUtils.assertArraysEqual(data, r.getValue().toBytes()); + }).verifyComplete(); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -107,30 +135,30 @@ public void downloadContentWithResponseContentValidation() { 8 * 1026 * 1024 + 10, // medium file not aligned to block }) public void downloadToFileWithResponseContentValidation(int fileSize) throws IOException { - File file = getRandomFile(fileSize); + file = getRandomFile(fileSize); file.deleteOnExit(); createdFiles.add(file); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + initializeBlobClient(); downloadClient.uploadFromFile(file.toPath().toString(), true).block(); - File outFile = new File(prefix + ".txt"); + outFile = new File(prefix + ".txt"); createdFiles.add(outFile); outFile.deleteOnExit(); Files.deleteIfExists(outFile.toPath()); - ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong(4L * 1024 * 1024); + ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong((long) BLOCK_SIZE); BlobDownloadToFileOptions options = new BlobDownloadToFileOptions(outFile.toPath().toString()).setParallelTransferOptions(parallelOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - StepVerifier.create(downloadClient.downloadToFileWithResponse(options)) - .assertNext(r -> assertNotNull(r.getValue())) - .verifyComplete(); + StepVerifier.create(downloadClient.downloadToFileWithResponse(options)).assertNext(r -> { + assertTrue(hasStructuredMessageDownloadResponseHeaders(r.getHeaders())); + assertNotNull(r.getValue()); + }).verifyComplete(); assertTrue(compareFiles(file, outFile, 0, fileSize)); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -144,51 +172,30 @@ public void downloadToFileWithResponseContentValidation(int fileSize) throws IOE 50 * Constants.MB + 22 // large file not on MB boundary }) public void downloadToFileLargeWithResponseContentValidation(int fileSize) throws IOException { - File file = getRandomFile(fileSize); + file = getRandomFile(fileSize); file.deleteOnExit(); createdFiles.add(file); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + initializeBlobClient(); downloadClient.uploadFromFile(file.toPath().toString(), true).block(); - File outFile = new File(prefix + ".txt"); + outFile = new File(prefix + ".txt"); createdFiles.add(outFile); outFile.deleteOnExit(); Files.deleteIfExists(outFile.toPath()); - ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong(4L * 1024 * 1024); + ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong((long) BLOCK_SIZE); BlobDownloadToFileOptions options = new BlobDownloadToFileOptions(outFile.toPath().toString()).setParallelTransferOptions(parallelOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - StepVerifier.create(downloadClient.downloadToFileWithResponse(options)) - .assertNext(r -> assertNotNull(r.getValue())) - .verifyComplete(); + StepVerifier.create(downloadClient.downloadToFileWithResponse(options)).assertNext(r -> { + assertTrue(hasStructuredMessageDownloadResponseHeaders(r.getHeaders())); + assertNotNull(r.getValue()); + }).verifyComplete(); assertTrue(compareFiles(file, outFile, 0, fileSize)); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); - } - - /** - * Range download without content validation works correctly. - */ - @Test - public void downloadStreamWithResponseContentValidationRange() { - byte[] randomData = getRandomByteArray(4 * Constants.KB); - Flux input = Flux.just(ByteBuffer.wrap(randomData)); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); - - BlobRange range = new BlobRange(0, 512L); - - StepVerifier.create(downloadClient.upload(input, null, true) - .then(downloadClient.downloadStreamWithResponse(range, null, null, false)) - .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))).assertNext(r -> { - assertNotNull(r); - assertEquals(512, r.length); - }).verifyComplete(); - assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -196,9 +203,8 @@ public void downloadStreamWithResponseContentValidationRange() { */ @Test public void downloadStreamDefaultAlgorithmIsNone() { - byte[] data = getRandomByteArray(TEN_MB); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); downloadClient.upload(Flux.just(ByteBuffer.wrap(data)), null, true).block(); StepVerifier.create(downloadClient.downloadStreamWithResponse(new BlobDownloadStreamOptions()) @@ -206,7 +212,7 @@ public void downloadStreamDefaultAlgorithmIsNone() { assertNotNull(result); assertEquals(data.length, result.length); }).verifyComplete(); - assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertFalse(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -214,20 +220,18 @@ public void downloadStreamDefaultAlgorithmIsNone() { */ @Test public void downloadStreamWithAuto() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); downloadClient.upload(BinaryData.fromBytes(data)).block(); - StepVerifier - .create(downloadClient - .downloadStreamWithResponse( - new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO)) - .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) - .assertNext(result -> TestUtils.assertArraysEqual(data, result)) - .verifyComplete(); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + StepVerifier.create(downloadClient + .downloadStreamWithResponse( + new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO)) + .flatMap(r -> { + assertTrue(hasStructuredMessageDownloadResponseHeaders(r.getHeaders())); + return FluxUtil.collectBytesInByteBufferStream(r.getValue()); + })).assertNext(result -> TestUtils.assertArraysEqual(data, result)).verifyComplete(); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -235,9 +239,8 @@ public void downloadStreamWithAuto() { */ @Test public void downloadContentWithNone() { - byte[] data = getRandomByteArray(TEN_MB); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); downloadClient.upload(Flux.just(ByteBuffer.wrap(data)), null, true).block(); StepVerifier @@ -245,7 +248,7 @@ public void downloadContentWithNone() { new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.NONE))) .assertNext(r -> TestUtils.assertArraysEqual(data, r.getValue().toBytes())) .verifyComplete(); - assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertFalse(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -253,18 +256,19 @@ public void downloadContentWithNone() { */ @Test public void downloadContentWithAuto() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); downloadClient.upload(BinaryData.fromBytes(data)).block(); StepVerifier .create(downloadClient.downloadContentWithResponse( new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO))) - .assertNext(r -> TestUtils.assertArraysEqual(data, r.getValue().toBytes())) + .assertNext(r -> { + assertTrue(hasStructuredMessageDownloadResponseHeaders(r.getHeaders())); + TestUtils.assertArraysEqual(data, r.getValue().toBytes()); + }) .verifyComplete(); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -273,38 +277,35 @@ public void downloadContentWithAuto() { @Test public void interruptAndVerifyProperRewind() { final int segmentSize = Constants.KB; - byte[] randomData = getRandomByteArray(2 * segmentSize); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient blobClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(2 * segmentSize); + initializeBlobClient(); int interruptPos = segmentSize + (2 * (segmentSize / 4)) + 10; MockPartialResponsePolicy mockPolicy = new MockPartialResponsePolicy(1, interruptPos, blobClient.getBlobUrl()); - HttpPipelinePolicy sniffPolicy = (context, next) -> { - recorded.add(context.getHttpRequest().getHeaders()); - return next.process(); - }; + HttpPipelinePolicy sniffPolicy = getRequestAndResponseHeaderSniffer(blobClient.getBlobUrl(), + recordedRequestHeaders, recordedResponseHeaders); - blobClient.upload(Flux.just(ByteBuffer.wrap(randomData)), null, true).block(); + blobClient.upload(Flux.just(ByteBuffer.wrap(data)), null, true).block(); - BlobAsyncClient downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), - blobClient.getBlobUrl(), sniffPolicy, mockPolicy); + downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), blobClient.getBlobUrl(), + sniffPolicy, mockPolicy); DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(5); - StepVerifier - .create(downloadClient - .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) - .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) - .doFinally( - signalType -> assertTrue(mockPolicy.getHits() > 0, "Mock interruption policy was not invoked")) - .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) - .assertNext(result -> TestUtils.assertArraysEqual(randomData, result)) - .verifyComplete(); + StepVerifier.create(downloadClient + .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) + .doFinally(signalType -> assertTrue(mockPolicy.getHits() > 0, "Mock interruption policy was not invoked")) + .flatMap(r -> { + assertTrue(hasStructuredMessageDownloadResponseHeaders(r.getHeaders())); + return FluxUtil.collectBytesInByteBufferStream(r.getValue()); + })).assertNext(result -> TestUtils.assertArraysEqual(data, result)).verifyComplete(); assertEquals(0, mockPolicy.getTriesRemaining(), "Expected the configured interruption to be consumed"); assertTrue(mockPolicy.getRangeHeaders().size() >= 2, "Expected at least the initial request and one retry with a range header"); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadResponseHeaders(recordedResponseHeaders)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -315,33 +316,34 @@ public void interruptAndVerifyProperRewind() { public void interruptAndVerifyProperDecode(boolean multipleInterrupts) { final int segmentSize = 128 * Constants.KB; final int dataSize = 4 * Constants.KB; - byte[] randomData = getRandomByteArray(dataSize); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient blobClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(dataSize); + initializeBlobClient(); int interruptPos = segmentSize + (3 * (8 * Constants.KB)) + 10; MockPartialResponsePolicy mockPolicy = new MockPartialResponsePolicy(multipleInterrupts ? 2 : 1, interruptPos, blobClient.getBlobUrl()); - HttpPipelinePolicy sniffPolicy = (context, next) -> { - recorded.add(context.getHttpRequest().getHeaders()); - return next.process(); - }; + HttpPipelinePolicy sniffPolicy = getRequestAndResponseHeaderSniffer(blobClient.getBlobUrl(), + recordedRequestHeaders, recordedResponseHeaders); - blobClient.upload(Flux.just(ByteBuffer.wrap(randomData)), null, true).block(); + blobClient.upload(Flux.just(ByteBuffer.wrap(data)), null, true).block(); - BlobAsyncClient downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), - blobClient.getBlobUrl(), sniffPolicy, mockPolicy); + downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), blobClient.getBlobUrl(), + sniffPolicy, mockPolicy); DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(10); StepVerifier.create(downloadClient .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) - .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))).assertNext(result -> { + .flatMap(r -> { + assertTrue(hasStructuredMessageDownloadResponseHeaders(r.getHeaders())); + return FluxUtil.collectBytesInByteBufferStream(r.getValue()); + })).assertNext(result -> { assertEquals(dataSize, result.length, "Decoded data should have exactly " + dataSize + " bytes"); - TestUtils.assertArraysEqual(randomData, result); + TestUtils.assertArraysEqual(data, result); }).verifyComplete(); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadResponseHeaders(recordedResponseHeaders)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -349,10 +351,8 @@ public void interruptAndVerifyProperDecode(boolean multipleInterrupts) { */ @Test public void structuredMessageVerifiesDecodedCrc64DownloadStreaming() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); downloadClient.upload(BinaryData.fromBytes(data)).block(); long expectedCrc = StorageCrc64Calculator.compute(data, 0); @@ -368,7 +368,7 @@ public void structuredMessageVerifiesDecodedCrc64DownloadStreaming() { assertEquals(expectedCrc, actualCrc); }) .verifyComplete(); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -377,21 +377,18 @@ public void structuredMessageVerifiesDecodedCrc64DownloadStreaming() { @Test public void interruptWithDataIntact() { final int segmentSize = Constants.KB; - byte[] randomData = getRandomByteArray(4 * segmentSize); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient blobClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(4 * segmentSize); + initializeBlobClient(); int interruptPos = segmentSize + (3 * 128) + 10; MockPartialResponsePolicy mockPolicy = new MockPartialResponsePolicy(1, interruptPos, blobClient.getBlobUrl()); - HttpPipelinePolicy sniffPolicy = (context, next) -> { - recorded.add(context.getHttpRequest().getHeaders()); - return next.process(); - }; + HttpPipelinePolicy sniffPolicy = getRequestAndResponseHeaderSniffer(blobClient.getBlobUrl(), + recordedRequestHeaders, recordedResponseHeaders); - blobClient.upload(Flux.just(ByteBuffer.wrap(randomData)), null, true).block(); + blobClient.upload(Flux.just(ByteBuffer.wrap(data)), null, true).block(); - BlobAsyncClient downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), - blobClient.getBlobUrl(), sniffPolicy, mockPolicy); + downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), blobClient.getBlobUrl(), + sniffPolicy, mockPolicy); DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(5); @@ -400,9 +397,9 @@ public void interruptWithDataIntact() { .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) - .assertNext(result -> TestUtils.assertArraysEqual(randomData, result)) + .assertNext(result -> TestUtils.assertArraysEqual(data, result)) .verifyComplete(); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -411,21 +408,18 @@ public void interruptWithDataIntact() { @Test public void interruptMultipleTimesWithDataIntact() { final int segmentSize = Constants.KB; - byte[] randomData = getRandomByteArray(4 * segmentSize); - List recorded = new CopyOnWriteArrayList<>(); - BlobAsyncClient blobClient = createBlobAsyncClientWithRequestSniffer(recorded); + data = getRandomByteArray(4 * segmentSize); + initializeBlobClient(); int interruptPos = segmentSize + (3 * 128) + 10; MockPartialResponsePolicy mockPolicy = new MockPartialResponsePolicy(3, interruptPos, blobClient.getBlobUrl()); - HttpPipelinePolicy sniffPolicy = (context, next) -> { - recorded.add(context.getHttpRequest().getHeaders()); - return next.process(); - }; + HttpPipelinePolicy sniffPolicy = getRequestAndResponseHeaderSniffer(blobClient.getBlobUrl(), + recordedRequestHeaders, recordedResponseHeaders); - blobClient.upload(Flux.just(ByteBuffer.wrap(randomData)), null, true).block(); + blobClient.upload(Flux.just(ByteBuffer.wrap(data)), null, true).block(); - BlobAsyncClient downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), - blobClient.getBlobUrl(), sniffPolicy, mockPolicy); + downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), blobClient.getBlobUrl(), + sniffPolicy, mockPolicy); DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(10); @@ -434,9 +428,9 @@ public void interruptMultipleTimesWithDataIntact() { .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) - .assertNext(result -> TestUtils.assertArraysEqual(randomData, result)) + .assertNext(result -> TestUtils.assertArraysEqual(data, result)) .verifyComplete(); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } } diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationDownloadTests.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationDownloadTests.java index 86b7f116a60d..fd84ef02cbce 100644 --- a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationDownloadTests.java +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationDownloadTests.java @@ -4,10 +4,14 @@ package com.azure.storage.blob; import com.azure.core.http.HttpHeaders; +import com.azure.core.http.rest.Response; import com.azure.core.http.policy.HttpPipelinePolicy; import com.azure.core.test.utils.TestUtils; import com.azure.core.util.BinaryData; import com.azure.core.util.Context; +import com.azure.storage.blob.models.BlobDownloadContentResponse; +import com.azure.storage.blob.models.BlobDownloadResponse; +import com.azure.storage.blob.models.BlobProperties; import com.azure.storage.blob.models.BlobSeekableByteChannelReadResult; import com.azure.storage.blob.models.BlobRange; import com.azure.storage.blob.models.DownloadRetryOptions; @@ -22,7 +26,6 @@ import com.azure.storage.common.implementation.Constants; import com.azure.storage.common.test.shared.extensions.LiveOnly; import com.azure.storage.common.test.shared.policy.MockPartialResponsePolicy; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -51,11 +54,48 @@ */ public class BlobContentValidationDownloadTests extends BlobTestBase { private static final int TEN_MB = 10 * Constants.MB; + private static final int BLOCK_SIZE = 4 * Constants.MB; + private final List createdFiles = new ArrayList<>(); - @AfterEach - public void cleanup() { + private byte[] data; + private List recordedRequestHeaders; + private HttpHeaders recordedResponseHeaders; + private BlobClient blobClient; + private BlobClient downloadClient; + private File file; + private File outFile; + private ByteArrayOutputStream outputStream; + + @Override + public void beforeTest() { + super.beforeTest(); + data = null; + recordedRequestHeaders = new CopyOnWriteArrayList<>(); + recordedResponseHeaders = new HttpHeaders(); + blobClient = null; + downloadClient = null; + outputStream = null; + } + + @Override + protected void afterTest() { createdFiles.forEach(File::delete); + createdFiles.clear(); + data = null; + recordedRequestHeaders = new CopyOnWriteArrayList<>(); + recordedResponseHeaders = new HttpHeaders(); + blobClient = null; + downloadClient = null; + file = null; + outFile = null; + outputStream = null; + super.afterTest(); + } + + private void initializeBlobClient() { + blobClient = createBlobClientWithRequestSniffer(recordedRequestHeaders); + downloadClient = blobClient; } /** @@ -63,19 +103,18 @@ public void cleanup() { */ @Test public void downloadStreamWithResponseContentValidation() { - byte[] data = getRandomByteArray(TEN_MB); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); + blobClient.upload(BinaryData.fromBytes(data)); - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - client.downloadStreamWithResponse(outputStream, + outputStream = new ByteArrayOutputStream(); + BlobDownloadResponse response = downloadClient.downloadStreamWithResponse(outputStream, new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64), null, Context.NONE); + assertTrue(hasStructuredMessageDownloadResponseHeaders(response.getHeaders())); TestUtils.assertArraysEqual(data, outputStream.toByteArray()); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -83,22 +122,18 @@ public void downloadStreamWithResponseContentValidation() { */ @Test public void downloadContentWithResponseContentValidation() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); + blobClient.upload(BinaryData.fromBytes(data)); - byte[] result - = client - .downloadContentWithResponse( - new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64), - null, Context.NONE) - .getValue() - .toBytes(); + BlobDownloadContentResponse response = downloadClient.downloadContentWithResponse( + new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64), null, + Context.NONE); + byte[] result = response.getValue().toBytes(); + assertTrue(hasStructuredMessageDownloadResponseHeaders(response.getHeaders())); TestUtils.assertArraysEqual(data, result); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -113,27 +148,28 @@ public void downloadContentWithResponseContentValidation() { 8 * 1026 * 1024 + 10, // medium file not aligned to block }) public void downloadToFileWithResponseContentValidation(int fileSize) throws IOException { - File file = getRandomFile(fileSize); + file = getRandomFile(fileSize); file.deleteOnExit(); createdFiles.add(file); - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.uploadFromFile(file.toPath().toString(), true); + initializeBlobClient(); + blobClient.uploadFromFile(file.toPath().toString(), true); - File outFile = new File(prefix + ".txt"); + outFile = new File(prefix + ".txt"); createdFiles.add(outFile); outFile.deleteOnExit(); Files.deleteIfExists(outFile.toPath()); - ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong(4L * 1024 * 1024); + ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong((long) BLOCK_SIZE); BlobDownloadToFileOptions options = new BlobDownloadToFileOptions(outFile.toPath().toString()).setParallelTransferOptions(parallelOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - assertNotNull(client.downloadToFileWithResponse(options, null, Context.NONE).getValue()); + Response response = downloadClient.downloadToFileWithResponse(options, null, Context.NONE); + assertTrue(hasStructuredMessageDownloadResponseHeaders(response.getHeaders())); + assertNotNull(response.getValue()); assertTrue(compareFiles(file, outFile, 0, fileSize)); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -147,27 +183,28 @@ public void downloadToFileWithResponseContentValidation(int fileSize) throws IOE 50 * Constants.MB + 22 // large file not on MB boundary }) public void downloadToFileLargeWithResponseContentValidation(int fileSize) throws IOException { - File file = getRandomFile(fileSize); + file = getRandomFile(fileSize); file.deleteOnExit(); createdFiles.add(file); - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.uploadFromFile(file.toPath().toString(), true); + initializeBlobClient(); + blobClient.uploadFromFile(file.toPath().toString(), true); - File outFile = new File(prefix + ".txt"); + outFile = new File(prefix + ".txt"); createdFiles.add(outFile); outFile.deleteOnExit(); Files.deleteIfExists(outFile.toPath()); - ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong(4L * 1024 * 1024); + ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong((long) BLOCK_SIZE); BlobDownloadToFileOptions options = new BlobDownloadToFileOptions(outFile.toPath().toString()).setParallelTransferOptions(parallelOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - assertNotNull(client.downloadToFileWithResponse(options, null, Context.NONE).getValue()); + Response response = downloadClient.downloadToFileWithResponse(options, null, Context.NONE); + assertTrue(hasStructuredMessageDownloadResponseHeaders(response.getHeaders())); + assertNotNull(response.getValue()); assertTrue(compareFiles(file, outFile, 0, fileSize)); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -175,18 +212,16 @@ public void downloadToFileLargeWithResponseContentValidation(int fileSize) throw */ @Test public void downloadStreamWithResponseContentValidationRange() { - byte[] randomData = getRandomByteArray(4 * Constants.KB); + data = getRandomByteArray(4 * Constants.KB); + initializeBlobClient(); + blobClient.upload(BinaryData.fromBytes(data)); - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(randomData)); - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + outputStream = new ByteArrayOutputStream(); BlobDownloadStreamOptions options = new BlobDownloadStreamOptions().setRange(new BlobRange(0, 512L)); - client.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + downloadClient.downloadStreamWithResponse(outputStream, options, null, Context.NONE); assertEquals(512, outputStream.toByteArray().length); - assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertFalse(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -194,17 +229,15 @@ public void downloadStreamWithResponseContentValidationRange() { */ @Test public void downloadStreamDefaultAlgorithmIsNone() { - byte[] data = getRandomByteArray(TEN_MB); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); + blobClient.upload(BinaryData.fromBytes(data)); - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); - - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - client.downloadStreamWithResponse(outputStream, new BlobDownloadStreamOptions(), null, Context.NONE); + outputStream = new ByteArrayOutputStream(); + downloadClient.downloadStreamWithResponse(outputStream, new BlobDownloadStreamOptions(), null, Context.NONE); TestUtils.assertArraysEqual(data, outputStream.toByteArray()); - assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertFalse(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -212,19 +245,19 @@ public void downloadStreamDefaultAlgorithmIsNone() { */ @Test public void downloadStreamWithAuto() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); + blobClient.upload(BinaryData.fromBytes(data)); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + outputStream = new ByteArrayOutputStream(); BlobDownloadStreamOptions options = new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO); - client.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + BlobDownloadResponse response + = downloadClient.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + assertTrue(hasStructuredMessageDownloadResponseHeaders(response.getHeaders())); TestUtils.assertArraysEqual(data, outputStream.toByteArray()); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -232,14 +265,12 @@ public void downloadStreamWithAuto() { */ @Test public void downloadContentWithNone() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); + blobClient.upload(BinaryData.fromBytes(data)); byte[] result - = client + = downloadClient .downloadContentWithResponse( new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.NONE), null, Context.NONE) @@ -247,7 +278,7 @@ public void downloadContentWithNone() { .toBytes(); TestUtils.assertArraysEqual(data, result); - assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertFalse(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -255,22 +286,18 @@ public void downloadContentWithNone() { */ @Test public void downloadContentWithAuto() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); + blobClient.upload(BinaryData.fromBytes(data)); - byte[] result - = client - .downloadContentWithResponse( - new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO), - null, Context.NONE) - .getValue() - .toBytes(); + BlobDownloadContentResponse response = downloadClient.downloadContentWithResponse( + new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO), null, + Context.NONE); + byte[] result = response.getValue().toBytes(); + assertTrue(hasStructuredMessageDownloadResponseHeaders(response.getHeaders())); TestUtils.assertArraysEqual(data, result); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -279,34 +306,33 @@ public void downloadContentWithAuto() { @Test public void interruptAndVerifyProperRewind() { final int segmentSize = Constants.KB; - byte[] randomData = getRandomByteArray(2 * segmentSize); - List recorded = new CopyOnWriteArrayList<>(); + data = getRandomByteArray(2 * segmentSize); + initializeBlobClient(); - BlobClient uploadClient = createBlobClientWithRequestSniffer(recorded); - uploadClient.upload(BinaryData.fromBytes(randomData)); + blobClient.upload(BinaryData.fromBytes(data)); int interruptPos = segmentSize + (2 * (segmentSize / 4)) + 10; - MockPartialResponsePolicy mockPolicy - = new MockPartialResponsePolicy(1, interruptPos, uploadClient.getBlobUrl()); - HttpPipelinePolicy sniffPolicy = (context, next) -> { - recorded.add(context.getHttpRequest().getHeaders()); - return next.process(); - }; - - BlobClient downloadClient = getBlobClient(ENVIRONMENT.getPrimaryAccount().getCredential(), - uploadClient.getBlobUrl(), sniffPolicy, mockPolicy); + MockPartialResponsePolicy mockPolicy = new MockPartialResponsePolicy(1, interruptPos, blobClient.getBlobUrl()); + HttpPipelinePolicy sniffPolicy = getRequestAndResponseHeaderSniffer(blobClient.getBlobUrl(), + recordedRequestHeaders, recordedResponseHeaders); + + downloadClient = getBlobClient(ENVIRONMENT.getPrimaryAccount().getCredential(), blobClient.getBlobUrl(), + sniffPolicy, mockPolicy); DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(5); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + outputStream = new ByteArrayOutputStream(); BlobDownloadStreamOptions options = new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - downloadClient.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + BlobDownloadResponse response + = downloadClient.downloadStreamWithResponse(outputStream, options, null, Context.NONE); - TestUtils.assertArraysEqual(randomData, outputStream.toByteArray()); + assertTrue(hasStructuredMessageDownloadResponseHeaders(response.getHeaders())); + TestUtils.assertArraysEqual(data, outputStream.toByteArray()); assertEquals(0, mockPolicy.getTriesRemaining(), "Expected the configured interruption to be consumed"); assertTrue(mockPolicy.getRangeHeaders().size() >= 2, "Expected at least the initial request and one retry with a range header"); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadResponseHeaders(recordedResponseHeaders)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -317,75 +343,72 @@ public void interruptAndVerifyProperRewind() { public void interruptAndVerifyProperDecode(boolean multipleInterrupts) { final int segmentSize = 128 * Constants.KB; final int dataSize = 4 * Constants.KB; - byte[] randomData = getRandomByteArray(dataSize); - List recorded = new CopyOnWriteArrayList<>(); + data = getRandomByteArray(dataSize); + initializeBlobClient(); - BlobClient uploadClient = createBlobClientWithRequestSniffer(recorded); - uploadClient.upload(BinaryData.fromBytes(randomData)); + blobClient.upload(BinaryData.fromBytes(data)); int interruptPos = segmentSize + (3 * (8 * Constants.KB)) + 10; MockPartialResponsePolicy mockPolicy - = new MockPartialResponsePolicy(multipleInterrupts ? 2 : 1, interruptPos, uploadClient.getBlobUrl()); - HttpPipelinePolicy sniffPolicy = (context, next) -> { - recorded.add(context.getHttpRequest().getHeaders()); - return next.process(); - }; - - BlobClient downloadClient = getBlobClient(ENVIRONMENT.getPrimaryAccount().getCredential(), - uploadClient.getBlobUrl(), sniffPolicy, mockPolicy); + = new MockPartialResponsePolicy(multipleInterrupts ? 2 : 1, interruptPos, blobClient.getBlobUrl()); + HttpPipelinePolicy sniffPolicy = getRequestAndResponseHeaderSniffer(blobClient.getBlobUrl(), + recordedRequestHeaders, recordedResponseHeaders); + + downloadClient = getBlobClient(ENVIRONMENT.getPrimaryAccount().getCredential(), blobClient.getBlobUrl(), + sniffPolicy, mockPolicy); DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(10); - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + outputStream = new ByteArrayOutputStream(); BlobDownloadStreamOptions options = new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - downloadClient.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + BlobDownloadResponse response + = downloadClient.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + assertTrue(hasStructuredMessageDownloadResponseHeaders(response.getHeaders())); byte[] result = outputStream.toByteArray(); assertEquals(dataSize, result.length, "Decoded data should have exactly " + dataSize + " bytes"); - TestUtils.assertArraysEqual(randomData, result); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + TestUtils.assertArraysEqual(data, result); + assertTrue(hasStructuredMessageDownloadResponseHeaders(recordedResponseHeaders)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } // Only run this test in live mode as BlobOutputStream dynamically assigns blocks @LiveOnly @Test public void openInputStreamContentValidation() { - byte[] data = getRandomByteArray(TEN_MB); - - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); + blobClient.upload(BinaryData.fromBytes(data)); BlobInputStreamOptions options = new BlobInputStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); - BlobInputStream inputStream = client.openInputStream(options, Context.NONE); + BlobInputStream inputStream = downloadClient.openInputStream(options, Context.NONE); TestUtils.assertArraysEqual(data, convertInputStreamToByteArray(inputStream)); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } // Only run this test in live mode as BlobOutputStream dynamically assigns blocks @LiveOnly @Test public void openInputStreamRangeContentValidation() { - byte[] data = getRandomByteArray(TEN_MB); + data = getRandomByteArray(TEN_MB); + initializeBlobClient(); int start = Constants.MB; int count = 3 * Constants.MB + 257; - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); + blobClient.upload(BinaryData.fromBytes(data)); BlobInputStreamOptions options = new BlobInputStreamOptions().setRange(new BlobRange(start, (long) count)) .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64) .setBlockSize(Constants.MB); - BlobInputStream inputStream = client.openInputStream(options, Context.NONE); + BlobInputStream inputStream = downloadClient.openInputStream(options, Context.NONE); byte[] downloadedRange = convertInputStreamToByteArray(inputStream); assertEquals(count, downloadedRange.length); TestUtils.assertArraysEqual(data, start, downloadedRange, 0, count); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } /** @@ -395,17 +418,16 @@ public void openInputStreamRangeContentValidation() { @MethodSource("channelReadDataSupplier") public void openSeekableByteChannelReadContentValidation(Integer streamBufferSize, int copyBufferSize, int dataLength) throws IOException { - byte[] data = getRandomByteArray(dataLength); + data = getRandomByteArray(dataLength); + initializeBlobClient(); - List recorded = new CopyOnWriteArrayList<>(); - BlobClient client = createBlobClientWithRequestSniffer(recorded); - client.upload(BinaryData.fromBytes(data)); + blobClient.upload(BinaryData.fromBytes(data)); // when: "Channel initialized" BlobSeekableByteChannelReadOptions options = new BlobSeekableByteChannelReadOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64) .setReadSizeInBytes(streamBufferSize); - BlobSeekableByteChannelReadResult result = client.openSeekableByteChannelRead(options, Context.NONE); + BlobSeekableByteChannelReadResult result = downloadClient.openSeekableByteChannelRead(options, Context.NONE); SeekableByteChannel channel = result.getChannel(); // then: "Channel initialized to position zero" @@ -423,7 +445,7 @@ public void openSeekableByteChannelReadContentValidation(Integer streamBufferSiz // and: "expected data downloaded" TestUtils.assertArraysEqual(data, downloadedData.toByteArray()); - assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + assertTrue(hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, false)); } static Stream channelReadDataSupplier() { diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobTestBase.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobTestBase.java index 514ff455fb90..8c8a9af0c746 100644 --- a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobTestBase.java +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobTestBase.java @@ -12,6 +12,9 @@ import com.azure.core.http.HttpMethod; import com.azure.core.http.HttpPipeline; import com.azure.core.http.HttpPipelineBuilder; +import com.azure.core.http.HttpPipelineCallContext; +import com.azure.core.http.HttpPipelineNextPolicy; +import com.azure.core.http.HttpPipelinePosition; import com.azure.core.http.HttpRequest; import com.azure.core.http.HttpResponse; import com.azure.core.http.policy.AddDatePolicy; @@ -97,6 +100,7 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.Callable; +import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -1382,14 +1386,62 @@ public static HttpPipelinePolicy getAddHeadersAndQueryPolicy(Map } protected static boolean hasOnlyStructuredMessageHeaders(List recordedRequestHeaders) { - return hasStructuredMessageRequestHeaders(recordedRequestHeaders, true); + return hasStructuredMessageDownloadRequestHeaders(recordedRequestHeaders, true); } - protected static boolean hasOnlyStructuredMessageDownloadHeaders(List recordedRequestHeaders) { - return hasStructuredMessageRequestHeaders(recordedRequestHeaders, false); + protected static boolean hasStructuredMessageDownloadRequestHeaders(HttpHeaders recordedRequestHeaders) { + if (recordedRequestHeaders == null || recordedRequestHeaders.getSize() == 0) { + return false; + } + return hasStructuredMessageDownloadRequestHeaders(Collections.singletonList(recordedRequestHeaders), false); + } + + protected static boolean hasStructuredMessageDownloadResponseHeaders(HttpHeaders headers) { + return validateBasicHeaders(headers) + && StructuredMessageConstants.STRUCTURED_BODY_TYPE_VALUE + .equalsIgnoreCase(headers.getValue(Constants.HeaderConstants.STRUCTURED_BODY_TYPE_HEADER_NAME)); + } + + protected static HttpPipelinePolicy getRequestAndResponseHeaderSniffer(String targetUrlPrefix, + HttpHeaders recordedRequestHeaders, HttpHeaders recordedResponseHeaders) { + return getRequestAndResponseHeaderSniffer(targetUrlPrefix, headers -> { + synchronized (recordedRequestHeaders) { + recordedRequestHeaders.setAllHttpHeaders(headers); + } + }, recordedResponseHeaders); + } + + protected static HttpPipelinePolicy getRequestAndResponseHeaderSniffer(String targetUrlPrefix, + List recordedRequestHeaders, HttpHeaders recordedResponseHeaders) { + return getRequestAndResponseHeaderSniffer(targetUrlPrefix, recordedRequestHeaders::add, + recordedResponseHeaders); + } + + private static HttpPipelinePolicy getRequestAndResponseHeaderSniffer(String targetUrlPrefix, + Consumer requestRecorder, HttpHeaders recordedResponseHeaders) { + return new HttpPipelinePolicy() { + @Override + public HttpPipelinePosition getPipelinePosition() { + return HttpPipelinePosition.PER_RETRY; + } + + @Override + public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { + requestRecorder.accept(context.getHttpRequest().getHeaders()); + return next.process().map(response -> { + if (response.getRequest().getHttpMethod() == HttpMethod.GET + && response.getRequest().getUrl().toString().startsWith(targetUrlPrefix)) { + synchronized (recordedResponseHeaders) { + recordedResponseHeaders.setAllHttpHeaders(response.getHeaders()); + } + } + return response; + }); + } + }; } - private static boolean hasStructuredMessageRequestHeaders(List recordedRequestHeaders, + protected static boolean hasStructuredMessageDownloadRequestHeaders(List recordedRequestHeaders, boolean requireStructuredContentLength) { if (recordedRequestHeaders == null || recordedRequestHeaders.isEmpty()) { return false; @@ -1473,9 +1525,9 @@ protected static boolean hasNoContentValidationHeaders(List recorde } /** - * Creates a BlobClient that records all outgoing request headers into the supplied list. - * Each test should use its own list so tests can run concurrently. - */ + * Creates a BlobClient that records all outgoing request headers into the supplied list. + * Each test should use its own list so tests can run concurrently. + */ protected BlobClient createBlobClientWithRequestSniffer(List recordedRequestHeaders) { HttpPipelinePolicy sniffPolicy = (context, next) -> { recordedRequestHeaders.add(context.getHttpRequest().getHeaders()); @@ -1487,8 +1539,8 @@ protected BlobClient createBlobClientWithRequestSniffer(List record } /** - * Creates a BlobAsyncClient that records all outgoing request headers into the supplied list. - */ + * Creates a BlobAsyncClient that records all outgoing request headers into the supplied list. + */ protected BlobAsyncClient createBlobAsyncClientWithRequestSniffer(List recordedRequestHeaders) { HttpPipelinePolicy sniffPolicy = (context, next) -> { recordedRequestHeaders.add(context.getHttpRequest().getHeaders());