diff --git a/sdk/storage/azure-storage-blob/assets.json b/sdk/storage/azure-storage-blob/assets.json index eea1f647e8e6..0c3832771777 100644 --- a/sdk/storage/azure-storage-blob/assets.json +++ b/sdk/storage/azure-storage-blob/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "java", "TagPrefix": "java/storage/azure-storage-blob", - "Tag": "java/storage/azure-storage-blob_4e6c4fe966" + "Tag": "java/storage/azure-storage-blob_1f689f90f0" } diff --git a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/implementation/util/BuilderHelper.java b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/implementation/util/BuilderHelper.java index 2e51e40dc4b6..53fcb67447dc 100644 --- a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/implementation/util/BuilderHelper.java +++ b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/implementation/util/BuilderHelper.java @@ -39,6 +39,7 @@ import com.azure.storage.common.policy.ResponseValidationPolicyBuilder; import com.azure.storage.common.policy.ScrubEtagPolicy; import com.azure.storage.common.policy.StorageBearerTokenChallengeAuthorizationPolicy; +import com.azure.storage.common.policy.StorageContentValidationDecoderPolicy; import com.azure.storage.common.policy.StorageContentValidationPolicy; import com.azure.storage.common.policy.StorageSharedKeyCredentialPolicy; @@ -117,6 +118,7 @@ public static HttpPipeline buildPipeline(StorageSharedKeyCredential storageShare policies.add(new MetadataValidationPolicy()); policies.add(new StorageContentValidationPolicy()); + policies.add(new StorageContentValidationDecoderPolicy()); if (storageSharedKeyCredential != null) { policies.add(new StorageSharedKeyCredentialPolicy(storageSharedKeyCredential)); diff --git a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java index 49c9c625d2b9..b462df20b7b3 100644 --- a/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java +++ b/sdk/storage/azure-storage-blob/src/main/java/com/azure/storage/blob/specialized/BlobAsyncClientBase.java @@ -1261,16 +1261,20 @@ Mono downloadStreamWithResponseInternal(BlobRange ran BlobRequestConditions requestConditions, boolean getRangeContentMd5, ContentValidationAlgorithm contentValidationAlgorithm, Context context) { BlobRange finalRange = range == null ? new BlobRange(0) : range; + Boolean getMD5 = getRangeContentMd5 ? getRangeContentMd5 : null; BlobRequestConditions finalRequestConditions = requestConditions == null ? new BlobRequestConditions() : requestConditions; DownloadRetryOptions finalOptions = (options == null) ? new DownloadRetryOptions() : options; + context + = ContentValidationModeResolver.addStructuredMessageDecodingToContext(context, contentValidationAlgorithm); + // The first range should eagerly convert headers as they'll be used to create response types. Context firstRangeContext = context == null ? new Context("azure-eagerly-convert-headers", true) : context.addData("azure-eagerly-convert-headers", true); - + Context nextRangeContext = context; return downloadRange(finalRange, finalRequestConditions, finalRequestConditions.getIfMatch(), getMD5, firstRangeContext).map(response -> { BlobsDownloadHeaders blobsDownloadHeaders = new BlobsDownloadHeaders(response.getHeaders()); @@ -1315,7 +1319,7 @@ Mono downloadStreamWithResponseInternal(BlobRange ran try { return downloadRange(new BlobRange(initialOffset + offset, newCount), finalRequestConditions, - eTag, getMD5, context); + eTag, getMD5, nextRangeContext); } catch (Exception e) { return Mono.error(e); } diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncDownloadTests.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncDownloadTests.java new file mode 100644 index 000000000000..38a96c2521a9 --- /dev/null +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncDownloadTests.java @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.blob; + +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.policy.HttpPipelinePolicy; +import com.azure.core.test.utils.TestUtils; +import com.azure.core.util.BinaryData; +import com.azure.core.util.FluxUtil; +import com.azure.storage.blob.models.BlobRange; +import com.azure.storage.blob.models.DownloadRetryOptions; +import com.azure.storage.blob.options.BlobDownloadContentOptions; +import com.azure.storage.blob.options.BlobDownloadStreamOptions; +import com.azure.storage.blob.options.BlobDownloadToFileOptions; +import com.azure.storage.common.ParallelTransferOptions; +import com.azure.storage.common.ContentValidationAlgorithm; +import com.azure.storage.common.implementation.Constants; +import com.azure.storage.common.implementation.contentvalidation.StorageCrc64Calculator; +import com.azure.storage.common.test.shared.extensions.LiveOnly; +import com.azure.storage.common.test.shared.policy.MockPartialResponsePolicy; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; +import reactor.util.function.Tuples; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Async tests for structured message decoding during blob downloads using StorageContentValidationDecoderPolicy. + * These tests verify that the pipeline policy correctly decodes structured messages when content validation is enabled. + */ +public class BlobContentValidationAsyncDownloadTests extends BlobTestBase { + private static final int TEN_MB = 10 * Constants.MB; + private final List createdFiles = new ArrayList<>(); + + @AfterEach + public void cleanup() { + createdFiles.forEach(File::delete); + } + + /** + * downloadStreamWithResponse with CRC64 content validation. + */ + @Test + public void downloadStreamWithResponseContentValidation() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.upload(BinaryData.fromBytes(data)).block(); + + BlobDownloadStreamOptions options + = new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + + StepVerifier + .create(downloadClient.downloadStreamWithResponse(options) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(result -> TestUtils.assertArraysEqual(data, result)) + .verifyComplete(); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadContentWithResponse with CRC64 content validation. + */ + @Test + public void downloadContentWithResponseContentValidation() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.upload(BinaryData.fromBytes(data)).block(); + + BlobDownloadContentOptions options + = new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + + StepVerifier.create(downloadClient.downloadContentWithResponse(options)) + .assertNext(r -> TestUtils.assertArraysEqual(data, r.getValue().toBytes())) + .verifyComplete(); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadToFileWithResponse with CRC64 content validation. + */ + @ParameterizedTest + @ValueSource( + ints = { + 0, // empty file + 20, // small file + 16 * 1024 * 1024, // medium file in several chunks + 8 * 1026 * 1024 + 10, // medium file not aligned to block + }) + public void downloadToFileWithResponseContentValidation(int fileSize) throws IOException { + File file = getRandomFile(fileSize); + file.deleteOnExit(); + createdFiles.add(file); + + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.uploadFromFile(file.toPath().toString(), true).block(); + + File outFile = new File(prefix + ".txt"); + createdFiles.add(outFile); + outFile.deleteOnExit(); + Files.deleteIfExists(outFile.toPath()); + + ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong(4L * 1024 * 1024); + BlobDownloadToFileOptions options + = new BlobDownloadToFileOptions(outFile.toPath().toString()).setParallelTransferOptions(parallelOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + + StepVerifier.create(downloadClient.downloadToFileWithResponse(options)) + .assertNext(r -> assertNotNull(r.getValue())) + .verifyComplete(); + + assertTrue(compareFiles(file, outFile, 0, fileSize)); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadToFileWithResponse with CRC64 content validation (parallel, multiple block sizes). + */ + @LiveOnly + @ParameterizedTest + @ValueSource( + ints = { + 50 * Constants.MB, //large file requiring multiple requests + 50 * Constants.MB + 22 // large file not on MB boundary + }) + public void downloadToFileLargeWithResponseContentValidation(int fileSize) throws IOException { + File file = getRandomFile(fileSize); + file.deleteOnExit(); + createdFiles.add(file); + + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.uploadFromFile(file.toPath().toString(), true).block(); + + File outFile = new File(prefix + ".txt"); + createdFiles.add(outFile); + outFile.deleteOnExit(); + Files.deleteIfExists(outFile.toPath()); + + ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong(4L * 1024 * 1024); + BlobDownloadToFileOptions options + = new BlobDownloadToFileOptions(outFile.toPath().toString()).setParallelTransferOptions(parallelOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + + StepVerifier.create(downloadClient.downloadToFileWithResponse(options)) + .assertNext(r -> assertNotNull(r.getValue())) + .verifyComplete(); + + assertTrue(compareFiles(file, outFile, 0, fileSize)); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Range download without content validation works correctly. + */ + @Test + public void downloadStreamWithResponseContentValidationRange() { + byte[] randomData = getRandomByteArray(4 * Constants.KB); + Flux input = Flux.just(ByteBuffer.wrap(randomData)); + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + + BlobRange range = new BlobRange(0, 512L); + + StepVerifier.create(downloadClient.upload(input, null, true) + .then(downloadClient.downloadStreamWithResponse(range, null, null, false)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))).assertNext(r -> { + assertNotNull(r); + assertEquals(512, r.length); + }).verifyComplete(); + assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Default behavior: when no algorithm is specified, default is NONE (no validation). + */ + @Test + public void downloadStreamDefaultAlgorithmIsNone() { + byte[] data = getRandomByteArray(TEN_MB); + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.upload(Flux.just(ByteBuffer.wrap(data)), null, true).block(); + + StepVerifier.create(downloadClient.downloadStreamWithResponse(new BlobDownloadStreamOptions()) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))).assertNext(result -> { + assertNotNull(result); + assertEquals(data.length, result.length); + }).verifyComplete(); + assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * AUTO on downloadStream resolves to CRC64 behavior. + */ + @Test + public void downloadStreamWithAuto() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.upload(BinaryData.fromBytes(data)).block(); + + StepVerifier + .create(downloadClient + .downloadStreamWithResponse( + new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(result -> TestUtils.assertArraysEqual(data, result)) + .verifyComplete(); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadContentWithResponse with NONE: no validation triggered. + */ + @Test + public void downloadContentWithNone() { + byte[] data = getRandomByteArray(TEN_MB); + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.upload(Flux.just(ByteBuffer.wrap(data)), null, true).block(); + + StepVerifier + .create(downloadClient.downloadContentWithResponse( + new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.NONE))) + .assertNext(r -> TestUtils.assertArraysEqual(data, r.getValue().toBytes())) + .verifyComplete(); + assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadContentWithResponse with AUTO resolves to CRC64 behavior. + */ + @Test + public void downloadContentWithAuto() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.upload(BinaryData.fromBytes(data)).block(); + + StepVerifier + .create(downloadClient.downloadContentWithResponse( + new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO))) + .assertNext(r -> TestUtils.assertArraysEqual(data, r.getValue().toBytes())) + .verifyComplete(); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Interrupt with proper rewind to segment boundary; verifies retry range headers. + */ + @Test + public void interruptAndVerifyProperRewind() { + final int segmentSize = Constants.KB; + byte[] randomData = getRandomByteArray(2 * segmentSize); + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient blobClient = createBlobAsyncClientWithRequestSniffer(recorded); + + int interruptPos = segmentSize + (2 * (segmentSize / 4)) + 10; + MockPartialResponsePolicy mockPolicy = new MockPartialResponsePolicy(1, interruptPos, blobClient.getBlobUrl()); + HttpPipelinePolicy sniffPolicy = (context, next) -> { + recorded.add(context.getHttpRequest().getHeaders()); + return next.process(); + }; + + blobClient.upload(Flux.just(ByteBuffer.wrap(randomData)), null, true).block(); + + BlobAsyncClient downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), + blobClient.getBlobUrl(), sniffPolicy, mockPolicy); + + DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(5); + + StepVerifier + .create(downloadClient + .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) + .doFinally( + signalType -> assertTrue(mockPolicy.getHits() > 0, "Mock interruption policy was not invoked")) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(result -> TestUtils.assertArraysEqual(randomData, result)) + .verifyComplete(); + + assertEquals(0, mockPolicy.getTriesRemaining(), "Expected the configured interruption to be consumed"); + assertTrue(mockPolicy.getRangeHeaders().size() >= 2, + "Expected at least the initial request and one retry with a range header"); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Proper decode across retries (single and multiple interrupts). + */ + @ParameterizedTest + @ValueSource(booleans = { false, true }) + public void interruptAndVerifyProperDecode(boolean multipleInterrupts) { + final int segmentSize = 128 * Constants.KB; + final int dataSize = 4 * Constants.KB; + byte[] randomData = getRandomByteArray(dataSize); + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient blobClient = createBlobAsyncClientWithRequestSniffer(recorded); + + int interruptPos = segmentSize + (3 * (8 * Constants.KB)) + 10; + MockPartialResponsePolicy mockPolicy + = new MockPartialResponsePolicy(multipleInterrupts ? 2 : 1, interruptPos, blobClient.getBlobUrl()); + HttpPipelinePolicy sniffPolicy = (context, next) -> { + recorded.add(context.getHttpRequest().getHeaders()); + return next.process(); + }; + + blobClient.upload(Flux.just(ByteBuffer.wrap(randomData)), null, true).block(); + + BlobAsyncClient downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), + blobClient.getBlobUrl(), sniffPolicy, mockPolicy); + + DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(10); + + StepVerifier.create(downloadClient + .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))).assertNext(result -> { + assertEquals(dataSize, result.length, "Decoded data should have exactly " + dataSize + " bytes"); + TestUtils.assertArraysEqual(randomData, result); + }).verifyComplete(); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * After consuming the response stream with CRC64 validation, decoded payload preserves the expected CRC64. + */ + @Test + public void structuredMessageVerifiesDecodedCrc64DownloadStreaming() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient downloadClient = createBlobAsyncClientWithRequestSniffer(recorded); + downloadClient.upload(BinaryData.fromBytes(data)).block(); + + long expectedCrc = StorageCrc64Calculator.compute(data, 0); + + StepVerifier + .create(downloadClient + .downloadStreamWithResponse( + new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()).map(bytes -> Tuples.of(r, bytes)))) + .assertNext(tuple -> { + TestUtils.assertArraysEqual(data, tuple.getT2()); + long actualCrc = StorageCrc64Calculator.compute(tuple.getT2(), 0); + assertEquals(expectedCrc, actualCrc); + }) + .verifyComplete(); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Single interrupt with data intact: fault policy + decoder; structured message retry recovers. + */ + @Test + public void interruptWithDataIntact() { + final int segmentSize = Constants.KB; + byte[] randomData = getRandomByteArray(4 * segmentSize); + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient blobClient = createBlobAsyncClientWithRequestSniffer(recorded); + + int interruptPos = segmentSize + (3 * 128) + 10; + MockPartialResponsePolicy mockPolicy = new MockPartialResponsePolicy(1, interruptPos, blobClient.getBlobUrl()); + HttpPipelinePolicy sniffPolicy = (context, next) -> { + recorded.add(context.getHttpRequest().getHeaders()); + return next.process(); + }; + + blobClient.upload(Flux.just(ByteBuffer.wrap(randomData)), null, true).block(); + + BlobAsyncClient downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), + blobClient.getBlobUrl(), sniffPolicy, mockPolicy); + + DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(5); + + StepVerifier + .create(downloadClient + .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(result -> TestUtils.assertArraysEqual(randomData, result)) + .verifyComplete(); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Multiple interrupts with data intact: fault policy + decoder; structured message retry recovers. + */ + @Test + public void interruptMultipleTimesWithDataIntact() { + final int segmentSize = Constants.KB; + byte[] randomData = getRandomByteArray(4 * segmentSize); + List recorded = new CopyOnWriteArrayList<>(); + BlobAsyncClient blobClient = createBlobAsyncClientWithRequestSniffer(recorded); + + int interruptPos = segmentSize + (3 * 128) + 10; + MockPartialResponsePolicy mockPolicy = new MockPartialResponsePolicy(3, interruptPos, blobClient.getBlobUrl()); + HttpPipelinePolicy sniffPolicy = (context, next) -> { + recorded.add(context.getHttpRequest().getHeaders()); + return next.process(); + }; + + blobClient.upload(Flux.just(ByteBuffer.wrap(randomData)), null, true).block(); + + BlobAsyncClient downloadClient = getBlobAsyncClient(ENVIRONMENT.getPrimaryAccount().getCredential(), + blobClient.getBlobUrl(), sniffPolicy, mockPolicy); + + DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(10); + + StepVerifier + .create(downloadClient + .downloadStreamWithResponse(new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64)) + .flatMap(r -> FluxUtil.collectBytesInByteBufferStream(r.getValue()))) + .assertNext(result -> TestUtils.assertArraysEqual(randomData, result)) + .verifyComplete(); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + +} diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncUploadTests.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncUploadTests.java index dc23983506b8..0d9b8e0e45e8 100644 --- a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncUploadTests.java +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationAsyncUploadTests.java @@ -55,7 +55,7 @@ public class BlobContentValidationAsyncUploadTests extends BlobTestBase { private static final int UNDER_4MB = 2 * Constants.MB; private static final long LARGE_UPLOAD_MIN_BYTES = 500L * Constants.MB; - private static final long LARGE_UPLOAD_MAX_BYTES = 1L * Constants.GB; + private static final long LARGE_UPLOAD_MAX_BYTES = Constants.GB; private static final long LARGE_UPLOAD_BLOCK_SIZE_BYTES = 8L * Constants.MB; private static final int LARGE_UPLOAD_MAX_CONCURRENCY = 8; @@ -163,7 +163,6 @@ public void uploadWithoutContentValidation() { /** * Blob parallel upload rejects using both computeMd5 (SDK-computed MD5) and CRC64 (transfer validation checksum algorithm) at once. */ - @SuppressWarnings("deprecation") @Test public void uploadWithComputeMd5AndCrc64Throws() { BlobAsyncClient client = createBlobAsyncClientWithRequestSniffer(new CopyOnWriteArrayList<>()); diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationDownloadTests.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationDownloadTests.java new file mode 100644 index 000000000000..86b7f116a60d --- /dev/null +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobContentValidationDownloadTests.java @@ -0,0 +1,433 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.blob; + +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.policy.HttpPipelinePolicy; +import com.azure.core.test.utils.TestUtils; +import com.azure.core.util.BinaryData; +import com.azure.core.util.Context; +import com.azure.storage.blob.models.BlobSeekableByteChannelReadResult; +import com.azure.storage.blob.models.BlobRange; +import com.azure.storage.blob.models.DownloadRetryOptions; +import com.azure.storage.blob.options.BlobDownloadContentOptions; +import com.azure.storage.blob.options.BlobDownloadStreamOptions; +import com.azure.storage.blob.options.BlobDownloadToFileOptions; +import com.azure.storage.blob.options.BlobInputStreamOptions; +import com.azure.storage.blob.options.BlobSeekableByteChannelReadOptions; +import com.azure.storage.blob.specialized.BlobInputStream; +import com.azure.storage.common.ParallelTransferOptions; +import com.azure.storage.common.ContentValidationAlgorithm; +import com.azure.storage.common.implementation.Constants; +import com.azure.storage.common.test.shared.extensions.LiveOnly; +import com.azure.storage.common.test.shared.policy.MockPartialResponsePolicy; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.nio.channels.SeekableByteChannel; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.stream.Stream; + +import static com.azure.storage.blob.specialized.BlobSeekableByteChannelTests.copy; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Sync tests for structured message decoding during blob downloads using StorageContentValidationDecoderPolicy. + * These tests verify that the pipeline policy correctly decodes structured messages when content validation is enabled. + */ +public class BlobContentValidationDownloadTests extends BlobTestBase { + private static final int TEN_MB = 10 * Constants.MB; + private final List createdFiles = new ArrayList<>(); + + @AfterEach + public void cleanup() { + createdFiles.forEach(File::delete); + } + + /** + * downloadStreamWithResponse with CRC64 content validation. + */ + @Test + public void downloadStreamWithResponseContentValidation() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + client.downloadStreamWithResponse(outputStream, + new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64), null, + Context.NONE); + + TestUtils.assertArraysEqual(data, outputStream.toByteArray()); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadContentWithResponse with CRC64 content validation. + */ + @Test + public void downloadContentWithResponseContentValidation() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + byte[] result + = client + .downloadContentWithResponse( + new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64), + null, Context.NONE) + .getValue() + .toBytes(); + + TestUtils.assertArraysEqual(data, result); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadToFileWithResponse with CRC64 content validation (parallel, multiple block sizes). + */ + @ParameterizedTest + @ValueSource( + ints = { + 0, // empty file + 20, // small file + 16 * 1024 * 1024, // medium file in several chunks + 8 * 1026 * 1024 + 10, // medium file not aligned to block + }) + public void downloadToFileWithResponseContentValidation(int fileSize) throws IOException { + File file = getRandomFile(fileSize); + file.deleteOnExit(); + createdFiles.add(file); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.uploadFromFile(file.toPath().toString(), true); + + File outFile = new File(prefix + ".txt"); + createdFiles.add(outFile); + outFile.deleteOnExit(); + Files.deleteIfExists(outFile.toPath()); + + ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong(4L * 1024 * 1024); + BlobDownloadToFileOptions options + = new BlobDownloadToFileOptions(outFile.toPath().toString()).setParallelTransferOptions(parallelOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + + assertNotNull(client.downloadToFileWithResponse(options, null, Context.NONE).getValue()); + assertTrue(compareFiles(file, outFile, 0, fileSize)); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadToFileWithResponse with CRC64 content validation (parallel, multiple block sizes). + */ + @LiveOnly + @ParameterizedTest + @ValueSource( + ints = { + 50 * Constants.MB, //large file requiring multiple requests + 50 * Constants.MB + 22 // large file not on MB boundary + }) + public void downloadToFileLargeWithResponseContentValidation(int fileSize) throws IOException { + File file = getRandomFile(fileSize); + file.deleteOnExit(); + createdFiles.add(file); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.uploadFromFile(file.toPath().toString(), true); + + File outFile = new File(prefix + ".txt"); + createdFiles.add(outFile); + outFile.deleteOnExit(); + Files.deleteIfExists(outFile.toPath()); + + ParallelTransferOptions parallelOptions = new ParallelTransferOptions().setBlockSizeLong(4L * 1024 * 1024); + BlobDownloadToFileOptions options + = new BlobDownloadToFileOptions(outFile.toPath().toString()).setParallelTransferOptions(parallelOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + + assertNotNull(client.downloadToFileWithResponse(options, null, Context.NONE).getValue()); + assertTrue(compareFiles(file, outFile, 0, fileSize)); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Range download without content validation works correctly. + */ + @Test + public void downloadStreamWithResponseContentValidationRange() { + byte[] randomData = getRandomByteArray(4 * Constants.KB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(randomData)); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + BlobDownloadStreamOptions options = new BlobDownloadStreamOptions().setRange(new BlobRange(0, 512L)); + client.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + + assertEquals(512, outputStream.toByteArray().length); + assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Default behavior: when no algorithm is specified, default is NONE (no validation). + */ + @Test + public void downloadStreamDefaultAlgorithmIsNone() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + client.downloadStreamWithResponse(outputStream, new BlobDownloadStreamOptions(), null, Context.NONE); + + TestUtils.assertArraysEqual(data, outputStream.toByteArray()); + assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * AUTO on downloadStream resolves to CRC64 behavior. + */ + @Test + public void downloadStreamWithAuto() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + BlobDownloadStreamOptions options + = new BlobDownloadStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO); + client.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + + TestUtils.assertArraysEqual(data, outputStream.toByteArray()); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadContentWithResponse with NONE: no validation triggered. + */ + @Test + public void downloadContentWithNone() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + byte[] result + = client + .downloadContentWithResponse( + new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.NONE), + null, Context.NONE) + .getValue() + .toBytes(); + + TestUtils.assertArraysEqual(data, result); + assertFalse(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * downloadContentWithResponse with AUTO resolves to CRC64 behavior. + */ + @Test + public void downloadContentWithAuto() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + byte[] result + = client + .downloadContentWithResponse( + new BlobDownloadContentOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.AUTO), + null, Context.NONE) + .getValue() + .toBytes(); + + TestUtils.assertArraysEqual(data, result); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Interrupt with proper rewind to segment boundary; verifies retry range headers. + */ + @Test + public void interruptAndVerifyProperRewind() { + final int segmentSize = Constants.KB; + byte[] randomData = getRandomByteArray(2 * segmentSize); + List recorded = new CopyOnWriteArrayList<>(); + + BlobClient uploadClient = createBlobClientWithRequestSniffer(recorded); + uploadClient.upload(BinaryData.fromBytes(randomData)); + + int interruptPos = segmentSize + (2 * (segmentSize / 4)) + 10; + MockPartialResponsePolicy mockPolicy + = new MockPartialResponsePolicy(1, interruptPos, uploadClient.getBlobUrl()); + HttpPipelinePolicy sniffPolicy = (context, next) -> { + recorded.add(context.getHttpRequest().getHeaders()); + return next.process(); + }; + + BlobClient downloadClient = getBlobClient(ENVIRONMENT.getPrimaryAccount().getCredential(), + uploadClient.getBlobUrl(), sniffPolicy, mockPolicy); + DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(5); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + BlobDownloadStreamOptions options = new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + downloadClient.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + + TestUtils.assertArraysEqual(randomData, outputStream.toByteArray()); + assertEquals(0, mockPolicy.getTriesRemaining(), "Expected the configured interruption to be consumed"); + assertTrue(mockPolicy.getRangeHeaders().size() >= 2, + "Expected at least the initial request and one retry with a range header"); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * Proper decode across retries (single and multiple interrupts). + */ + @ParameterizedTest + @ValueSource(booleans = { false, true }) + public void interruptAndVerifyProperDecode(boolean multipleInterrupts) { + final int segmentSize = 128 * Constants.KB; + final int dataSize = 4 * Constants.KB; + byte[] randomData = getRandomByteArray(dataSize); + List recorded = new CopyOnWriteArrayList<>(); + + BlobClient uploadClient = createBlobClientWithRequestSniffer(recorded); + uploadClient.upload(BinaryData.fromBytes(randomData)); + + int interruptPos = segmentSize + (3 * (8 * Constants.KB)) + 10; + MockPartialResponsePolicy mockPolicy + = new MockPartialResponsePolicy(multipleInterrupts ? 2 : 1, interruptPos, uploadClient.getBlobUrl()); + HttpPipelinePolicy sniffPolicy = (context, next) -> { + recorded.add(context.getHttpRequest().getHeaders()); + return next.process(); + }; + + BlobClient downloadClient = getBlobClient(ENVIRONMENT.getPrimaryAccount().getCredential(), + uploadClient.getBlobUrl(), sniffPolicy, mockPolicy); + DownloadRetryOptions retryOptions = new DownloadRetryOptions().setMaxRetryRequests(10); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + BlobDownloadStreamOptions options = new BlobDownloadStreamOptions().setDownloadRetryOptions(retryOptions) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + downloadClient.downloadStreamWithResponse(outputStream, options, null, Context.NONE); + + byte[] result = outputStream.toByteArray(); + assertEquals(dataSize, result.length, "Decoded data should have exactly " + dataSize + " bytes"); + TestUtils.assertArraysEqual(randomData, result); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + // Only run this test in live mode as BlobOutputStream dynamically assigns blocks + @LiveOnly + @Test + public void openInputStreamContentValidation() { + byte[] data = getRandomByteArray(TEN_MB); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + BlobInputStreamOptions options + = new BlobInputStreamOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64); + BlobInputStream inputStream = client.openInputStream(options, Context.NONE); + + TestUtils.assertArraysEqual(data, convertInputStreamToByteArray(inputStream)); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + // Only run this test in live mode as BlobOutputStream dynamically assigns blocks + @LiveOnly + @Test + public void openInputStreamRangeContentValidation() { + byte[] data = getRandomByteArray(TEN_MB); + + int start = Constants.MB; + int count = 3 * Constants.MB + 257; + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + BlobInputStreamOptions options = new BlobInputStreamOptions().setRange(new BlobRange(start, (long) count)) + .setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64) + .setBlockSize(Constants.MB); + BlobInputStream inputStream = client.openInputStream(options, Context.NONE); + + byte[] downloadedRange = convertInputStreamToByteArray(inputStream); + assertEquals(count, downloadedRange.length); + TestUtils.assertArraysEqual(data, start, downloadedRange, 0, count); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + /** + * openSeekableByteChannelRead with CRC64 content validation. + */ + @ParameterizedTest + @MethodSource("channelReadDataSupplier") + public void openSeekableByteChannelReadContentValidation(Integer streamBufferSize, int copyBufferSize, + int dataLength) throws IOException { + byte[] data = getRandomByteArray(dataLength); + + List recorded = new CopyOnWriteArrayList<>(); + BlobClient client = createBlobClientWithRequestSniffer(recorded); + client.upload(BinaryData.fromBytes(data)); + + // when: "Channel initialized" + BlobSeekableByteChannelReadOptions options + = new BlobSeekableByteChannelReadOptions().setContentValidationAlgorithm(ContentValidationAlgorithm.CRC64) + .setReadSizeInBytes(streamBufferSize); + BlobSeekableByteChannelReadResult result = client.openSeekableByteChannelRead(options, Context.NONE); + SeekableByteChannel channel = result.getChannel(); + + // then: "Channel initialized to position zero" + assertEquals(0, channel.position()); + assertNotNull(result.getProperties()); + assertEquals(data.length, result.getProperties().getBlobSize()); + + // when: "read from channel" + ByteArrayOutputStream downloadedData = new ByteArrayOutputStream(); + int copied = copy(channel, downloadedData, copyBufferSize); + + // then: "channel position updated accordingly" + assertEquals(dataLength, copied); + assertEquals(dataLength, channel.position()); + + // and: "expected data downloaded" + TestUtils.assertArraysEqual(data, downloadedData.toByteArray()); + assertTrue(hasOnlyStructuredMessageDownloadHeaders(recorded)); + } + + static Stream channelReadDataSupplier() { + return Stream.of(Arguments.of(50, 40, Constants.KB), Arguments.of(Constants.KB + 50, 40, Constants.KB), + Arguments.of(null, Constants.MB, TEN_MB)); + } +} diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobTestBase.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobTestBase.java index e4e49ff383d9..514ff455fb90 100644 --- a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobTestBase.java +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/BlobTestBase.java @@ -1382,6 +1382,15 @@ public static HttpPipelinePolicy getAddHeadersAndQueryPolicy(Map } protected static boolean hasOnlyStructuredMessageHeaders(List recordedRequestHeaders) { + return hasStructuredMessageRequestHeaders(recordedRequestHeaders, true); + } + + protected static boolean hasOnlyStructuredMessageDownloadHeaders(List recordedRequestHeaders) { + return hasStructuredMessageRequestHeaders(recordedRequestHeaders, false); + } + + private static boolean hasStructuredMessageRequestHeaders(List recordedRequestHeaders, + boolean requireStructuredContentLength) { if (recordedRequestHeaders == null || recordedRequestHeaders.isEmpty()) { return false; } @@ -1404,6 +1413,9 @@ protected static boolean hasOnlyStructuredMessageHeaders(List recor if (!StructuredMessageConstants.STRUCTURED_BODY_TYPE_VALUE.equals(bodyType) || contentCrc64 != null) { return false; } + if (!requireStructuredContentLength) { + return true; + } // Require non-blank content length that parses as non-negative long (same format as policy uses). // Rejects empty string, whitespace, or non-numeric values so we never return true when // structured message was not actually applied. diff --git a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/specialized/BlobSeekableByteChannelTests.java b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/specialized/BlobSeekableByteChannelTests.java index 4ff8054894a6..24a9ca7e781a 100644 --- a/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/specialized/BlobSeekableByteChannelTests.java +++ b/sdk/storage/azure-storage-blob/src/test/java/com/azure/storage/blob/specialized/BlobSeekableByteChannelTests.java @@ -106,7 +106,7 @@ static Stream channelReadDataSupplier() { * @param copySize Size of array to copy contents with. * @return Total number of bytes read from src. */ - private static int copy(SeekableByteChannel src, OutputStream dst, int copySize) throws IOException { + public static int copy(SeekableByteChannel src, OutputStream dst, int copySize) throws IOException { int read; int totalRead = 0; byte[] temp = new byte[copySize]; diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/ContentValidationModeResolver.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/ContentValidationModeResolver.java index 99ad8ef42e84..3c9f958e3ddb 100644 --- a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/ContentValidationModeResolver.java +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/ContentValidationModeResolver.java @@ -7,6 +7,7 @@ import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.MAXIMUM_SINGLE_SHOT_UPLOAD_SIZE_TO_USE_CRC64_HEADER; import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.USE_CRC64_CHECKSUM_HEADER_CONTEXT; import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.USE_STRUCTURED_MESSAGE_CONTEXT; +import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.STRUCTURED_MESSAGE_DECODING_CONTEXT_KEY; import com.azure.core.util.Context; import com.azure.core.util.FluxUtil; @@ -80,7 +81,7 @@ public static Mono addContentValidationMode(Mono mono, ContentValidati * message. */ private static String getModeForSingleShotUpload(ContentValidationAlgorithm algorithm, long length) { - if (algorithm == ContentValidationAlgorithm.CRC64 || algorithm == ContentValidationAlgorithm.AUTO) { + if (isCrc64OrAuto(algorithm)) { return length < MAXIMUM_SINGLE_SHOT_UPLOAD_SIZE_TO_USE_CRC64_HEADER ? USE_CRC64_CHECKSUM_HEADER_CONTEXT : USE_STRUCTURED_MESSAGE_CONTEXT; @@ -92,7 +93,7 @@ private static String getModeForSingleShotUpload(ContentValidationAlgorithm algo * Mode for a chunked (multi-shot) upload. Always use structured message. */ private static String getModeForChunkedUpload(ContentValidationAlgorithm algorithm) { - if (algorithm == ContentValidationAlgorithm.CRC64 || algorithm == ContentValidationAlgorithm.AUTO) { + if (isCrc64OrAuto(algorithm)) { return USE_STRUCTURED_MESSAGE_CONTEXT; } return null; @@ -139,12 +140,41 @@ public static boolean isContentValidationAlgorithmPresent(ContentValidationAlgor return contentValidationAlgorithm != null && contentValidationAlgorithm != ContentValidationAlgorithm.NONE; } + /** + * @return {@code true} when {@code algorithm} is {@link ContentValidationAlgorithm#CRC64} or + * {@link ContentValidationAlgorithm#AUTO}. Upload and download structured-message validation use this rule. + */ + public static boolean isCrc64OrAuto(ContentValidationAlgorithm algorithm) { + return algorithm == ContentValidationAlgorithm.CRC64 || algorithm == ContentValidationAlgorithm.AUTO; + } + + /** + * When the transfer validation mode is {@link ContentValidationAlgorithm#CRC64} or + * {@link ContentValidationAlgorithm#AUTO}, adds + * {@link StructuredMessageConstants#STRUCTURED_MESSAGE_DECODING_CONTEXT_KEY} so the HTTP + * pipeline can decode/validate the structured message response. For {@code null} or + * {@link ContentValidationAlgorithm#NONE}, returns the context unchanged (no key added), matching "no + * structured-message validation" for that download. + * + * @param context The base {@link Context}; null is treated as {@link Context#NONE}. + * @param contentValidationAlgorithm The algorithm from download options, or null. + * @return The same context, or a copy with the decoding key set when applicable. + */ + public static Context addStructuredMessageDecodingToContext(Context context, + ContentValidationAlgorithm contentValidationAlgorithm) { + Context base = context == null ? Context.NONE : context; + if (!isCrc64OrAuto(contentValidationAlgorithm)) { + return base; + } + return base.addData(STRUCTURED_MESSAGE_DECODING_CONTEXT_KEY, true); + } + /** * Validates that parallel transfer progress reporting is not combined with CRC64/AUTO content validation. * * @param parallelTransferOptions May be {@code null}. * @param contentValidationAlgorithm Transfer validation algorithm from options. - * @throws IllegalArgumentException if a progress listener is set and {@link #isCrc64OrAutoContentValidation} is true. + * @throws IllegalArgumentException if a progress listener is set and {@link #isContentValidationAlgorithmPresent} is true. */ public static void validateProgressWithContentValidation(ParallelTransferOptions parallelTransferOptions, ContentValidationAlgorithm contentValidationAlgorithm) { diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StorageCrc64Calculator.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StorageCrc64Calculator.java index 047b438f1a84..ed21ef153823 100644 --- a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StorageCrc64Calculator.java +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StorageCrc64Calculator.java @@ -2413,7 +2413,8 @@ public static long compute(byte[] src, long uCrc) { /** * Computes the CRC64 checksum for a region of a byte array. Avoids copying when the caller has - * a view (e.g. ByteBuffer.array() with arrayOffset + position). + * an array view (e.g. {@link ByteBuffer#array()} with array offset and position) or when combined + * with {@link #compute(byte[], long)} for whole-array input. * * @param src the byte array. * @param offset start index (inclusive). diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageConstants.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageConstants.java index caddb8104739..8e0ed1ff6e86 100644 --- a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageConstants.java +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageConstants.java @@ -52,4 +52,6 @@ public final class StructuredMessageConstants { public static final String USE_CRC64_CHECKSUM_HEADER_CONTEXT = "crc64ChecksumHeaderContext"; public static final String USE_STRUCTURED_MESSAGE_CONTEXT = "structuredMessageChecksumAlgorithm"; + + public static final String STRUCTURED_MESSAGE_DECODING_CONTEXT_KEY = "azure-storage-structured-message-decoding"; } diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageDecoder.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageDecoder.java new file mode 100644 index 000000000000..7930067a1c5d --- /dev/null +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageDecoder.java @@ -0,0 +1,523 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.common.implementation.contentvalidation; + +import com.azure.core.util.logging.ClientLogger; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.CRC64_LENGTH; +import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.DEFAULT_MESSAGE_VERSION; +import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.V1_HEADER_LENGTH; +import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH; + +/** + * Streaming decoder for the storage structured message format used to validate downloaded blob/file/datalake + * content with CRC64 checksums. + * + *

This class owns the actual parsing and CRC validation. The pipeline policy hands it raw {@link ByteBuffer}s as + * they arrive on the wire (via {@link #decodeChunk(ByteBuffer)}); the decoder returns only the payload bytes that + * have already been CRC-validated and tells the policy when the entire message has been consumed + * (via {@link #isComplete()}). Any malformed input or CRC mismatch surfaces as an + * {@link IllegalArgumentException} thrown from {@code decodeChunk} so the policy can translate it into a stream + * error.

+ * + *

Wire format (V1)

+ * + *

The encoded body has the following layout (all integers little-endian):

+ *
+ *   |-- message header (13 B) ----------------------------------------|
+ *   |  version (1)  |  total message length (8)  |  flags (2)  |  numSegments (2)  |
+ *
+ *   for each segment in 1..numSegments:
+ *     |-- segment header (10 B) -|
+ *     |  segNum (2)  |  segContentLen (8)  |
+ *     |-- segment payload (segContentLen B) --|
+ *     |-- segment CRC64 footer (8 B; only if STORAGE_CRC64) --|
+ *
+ *   |-- message CRC64 footer (8 B; only if STORAGE_CRC64) --|
+ * 
+ * + *

Emission guarantee

+ * + * Payload bytes for a segment are never emitted to the caller until that segment's CRC64 footer + * has been validated. This matches the emission semantics used by {@code BlobDecryptionPolicy}/{@code DecryptorV2} + * (which only emits a decrypted region after its GCM tag is verified) and ensures that no unvalidated bytes are + * exposed to consumers, even if the connection is later torn down or the download is retried. + * + *

Thread-safety

+ * + *

This class is not thread-safe. A new instance is created for every HTTP response, and the + * reactive operators in the policy ({@code concatMap}) serialize access to the single instance. Retries produce new + * HTTP responses and therefore new decoder instances, so a CRC failure on one attempt cannot pollute another.

+ */ +public class StructuredMessageDecoder { + private static final ClientLogger LOGGER = new ClientLogger(StructuredMessageDecoder.class); + + private long messageLength = -1; + private StructuredMessageFlags flags; + private int numSegments = -1; + private final long expectedEncodedMessageLength; + // Number of encoded bytes consumed so far (headers + payloads + footers). + private long messageOffset = 0; + private int currentSegmentNumber = 0; + private long currentSegmentContentLength = 0; + private long currentSegmentContentOffset = 0; + private boolean segmentHeaderRead = false; + // Running CRC64 over all payload bytes seen so far (across every segment). + private long messageCrc64 = 0; + // Running CRC64 over only the current segment's payload bytes. + private long segmentCrc64 = 0; + // Holds bytes left over from a previous decodeChunk() call when the current chunk did not contain a full + // header or footer. + private final ByteArrayOutputStream pendingBytes = new ByteArrayOutputStream(); + // Holds the payload bytes of the segment that is currently being decoded. These bytes are intentionally NOT + // emitted to the caller until the segment's CRC footer has been validated. + private final ByteArrayOutputStream currentSegmentBuffer = new ByteArrayOutputStream(); + + /** + * Constructs a new StructuredMessageDecoder. + * + * @param expectedEncodedMessageLength The expected encoded structured-message length (typically HTTP + * {@code Content-Length}). + */ + public StructuredMessageDecoder(long expectedEncodedMessageLength) { + this.expectedEncodedMessageLength = expectedEncodedMessageLength; + } + + /** + * Reads the 13-byte message header (version + total length + flags + numSegments) the first time the decoder + * sees enough bytes, and validates each field. Subsequent calls are no-ops. + * + * @param buffer The buffer to read from. + * @return true if the header was successfully read (or had already been read on a previous pass); false if more + * bytes are still needed. + */ + private boolean tryReadMessageHeader(ByteBuffer buffer) { + if (messageLength != -1) { + // Header already parsed on a previous chunk; nothing to do. + return true; + } + + if (getAvailableBytes(buffer) < V1_HEADER_LENGTH) { + // Not enough bytes for the full header yet; carry over what we have. + appendToPending(buffer); + return false; + } + + ByteBuffer combined = getCombinedBuffer(buffer); + + // Byte 0: protocol version. + int messageVersion = Byte.toUnsignedInt(combined.get()); + if (messageVersion != DEFAULT_MESSAGE_VERSION) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + enrichExceptionMessage("Unsupported structured message version: " + messageVersion))); + } + + // Bytes 1-8: total encoded message length. Must be at least the header itself, and must agree with what the + // HTTP layer told us via Content-Length – any disagreement implies a truncated/extended response. + long msgLen = combined.getLong(); + if (msgLen < V1_HEADER_LENGTH) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException(enrichExceptionMessage("Message length too small: " + msgLen))); + } + if (msgLen != expectedEncodedMessageLength) { + throw LOGGER + .logExceptionAsError(new IllegalArgumentException(enrichExceptionMessage("Structured message length " + + msgLen + " did not match content length " + expectedEncodedMessageLength))); + } + + // Bytes 9-10: flags (NONE or STORAGE_CRC64). Bytes 11-12: number of segments. + flags = StructuredMessageFlags.fromValue(Short.toUnsignedInt(combined.getShort())); + numSegments = Short.toUnsignedInt(combined.getShort()); + if (numSegments < 1) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + enrichExceptionMessage("Structured message must have at least one segment, got: " + numSegments))); + } + + // Commit: drop the 13 bytes we just parsed from pending/buffer and record the message length. + consumeBytes(V1_HEADER_LENGTH, buffer); + messageOffset += V1_HEADER_LENGTH; + messageLength = msgLen; + + return true; + } + + /** + * Reads the 10-byte header for the next segment (segment number + segment payload length) and resets + * per-segment state so {@link #tryReadSegmentContent(ByteBuffer)} can begin filling + * {@link #currentSegmentBuffer}. + * + *

Validates that segments arrive in order and that the declared segment size leaves enough room in the + * remaining message for any subsequent segment headers, payloads, footers, and the trailing message footer – + * this catches malformed/oversized segment lengths up front instead of waiting until we run off the end of the + * stream.

+ * + * @param buffer The buffer to read from. + * @return true if the segment header was read; false if more bytes are needed. + */ + private boolean tryReadSegmentHeader(ByteBuffer buffer) { + if (getAvailableBytes(buffer) < V1_SEGMENT_HEADER_LENGTH) { + appendToPending(buffer); + return false; + } + + ByteBuffer combined = getCombinedBuffer(buffer); + + // Bytes 0-1: segment number. Bytes 2-9: declared payload length of this segment. + int segmentNum = Short.toUnsignedInt(combined.getShort()); + long segmentSize = combined.getLong(); + + // Segments must arrive strictly in order so the running CRC and "segment N follows segment N-1" assumption + // hold. Anything else implies a malformed/reordered response. + if (segmentNum != currentSegmentNumber + 1) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException(enrichExceptionMessage( + "Unexpected segment number. Expected: " + (currentSegmentNumber + 1) + ", got: " + segmentNum))); + } + + // Compute an upper bound on the legal segment size: whatever is left in the message, minus the bytes that + // MUST still appear after this segment's payload (this segment's footer, the headers/payloads/footers of all + // remaining segments, and the trailing message footer). + long footerSize = flags == StructuredMessageFlags.STORAGE_CRC64 ? CRC64_LENGTH : 0; + long remainingSegmentsAfterThis = (long) numSegments - segmentNum; + long reservedBytes + = footerSize + remainingSegmentsAfterThis * (V1_SEGMENT_HEADER_LENGTH + footerSize) + footerSize; + long maxSegmentSize = messageLength - messageOffset - V1_SEGMENT_HEADER_LENGTH - reservedBytes; + if (segmentSize < 0 || segmentSize > maxSegmentSize) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException(enrichExceptionMessage( + "Invalid segment size detected: " + segmentSize + " (max=" + maxSegmentSize + ")"))); + } + + // Commit: drop the 10 header bytes and set up per-segment state so payload accumulation can start fresh. + consumeBytes(V1_SEGMENT_HEADER_LENGTH, buffer); + messageOffset += V1_SEGMENT_HEADER_LENGTH; + currentSegmentNumber = segmentNum; + currentSegmentContentLength = segmentSize; + currentSegmentContentOffset = 0; + currentSegmentBuffer.reset(); + + if (flags == StructuredMessageFlags.STORAGE_CRC64) { + // Reset only the per-segment running CRC; the message-wide running CRC keeps accumulating across all + // segments so the final message footer covers the entire payload. + segmentCrc64 = 0; + } + + return true; + } + + /** + * Pulls as many payload bytes as possible (bounded by what is still owed for the current segment) from the + * pending+buffer view into {@link #currentSegmentBuffer}, updating the running per-segment and per-message + * CRC64 values along the way. + * + *

Bytes accumulated here are not yet emitted to the caller. They are released only after + * {@link #tryReadSegmentFooter(ByteBuffer)} validates this segment's CRC. This is the mechanism that enforces + * "no unvalidated bytes ever leave the decoder".

+ * + * @param buffer The buffer to read from. + * @return The number of payload bytes read in this call (0 means we either had no bytes available or the + * current segment's payload was already complete). + */ + private int tryReadSegmentContent(ByteBuffer buffer) { + long remaining = currentSegmentContentLength - currentSegmentContentOffset; + if (remaining == 0) { + // Segment payload is already complete; nothing to do here. The caller will move on to read the footer. + return 0; + } + + int available = getAvailableBytes(buffer); + if (available == 0) { + return 0; + } + + // Read the minimum of "what's available right now" and "what's still owed for this segment" so we never + // accidentally consume the segment footer here. + int toRead = (int) Math.min(available, remaining); + ByteBuffer combined = getCombinedBuffer(buffer); + + // Materialize the bytes into a fresh array so we can both feed the CRC64 calculator and stash them in the + // per-segment buffer in one pass. + byte[] content = new byte[toRead]; + combined.get(content); + currentSegmentBuffer.write(content, 0, toRead); + + if (flags == StructuredMessageFlags.STORAGE_CRC64) { + // Update both CRCs incrementally: the segment CRC will be checked at the segment footer, and the + // message CRC accumulates across every segment to be checked at the message footer. + segmentCrc64 = StorageCrc64Calculator.compute(content, segmentCrc64); + messageCrc64 = StorageCrc64Calculator.compute(content, messageCrc64); + } + + consumeBytes(toRead, buffer); + messageOffset += toRead; + currentSegmentContentOffset += toRead; + + return toRead; + } + + /** + * Validates the 8-byte segment CRC64 footer for the segment that has just finished accumulating. Pre-condition: + * {@code currentSegmentContentOffset == currentSegmentContentLength}. + * + *

This step is intentionally separate from reading the message footer: when the CRC matches, we want to be + * able to flush the buffered segment payload to the caller right away – even if the trailing message footer is + * not yet available in the current chunk.

+ * + * @param buffer The buffer to read from. + * @return true if the footer was successfully read (or no footer is required for this message); false if more + * bytes are still needed. + */ + private boolean tryReadSegmentFooter(ByteBuffer buffer) { + if (currentSegmentContentOffset != currentSegmentContentLength) { + // Segment payload is not complete yet; wait for more content. + return true; + } + + if (flags == StructuredMessageFlags.STORAGE_CRC64) { + return tryConsumeCrc64Footer(buffer, segmentCrc64, " in segment " + currentSegmentNumber); + } + + // No CRC was negotiated, so there is no footer to read; the caller can release the buffered payload. + return true; + } + + /** + * Validates the 8-byte message CRC64 footer that follows the last segment. + * + * @param buffer The buffer to read from. + * @return true if the footer was successfully read (or none is required); false if more bytes are still needed. + */ + private boolean tryReadMessageFooter(ByteBuffer buffer) { + if (flags == StructuredMessageFlags.STORAGE_CRC64) { + return tryConsumeCrc64Footer(buffer, messageCrc64, " in message footer"); + } + return true; + } + + /** + * Decodes as much as possible from the given buffer and returns any fully validated + * payload bytes that are now safe to emit downstream. + * + *

The returned buffer will only ever contain bytes from segments whose CRC (when + * enabled) has already been verified. If no segments have been fully validated by + * this invocation the method returns {@code null}. Callers distinguish "more bytes + * needed" from "stream complete" via {@link #isComplete()}.

+ * + * @param buffer The buffer containing encoded data. + * @return Validated payload bytes ready to emit, or {@code null} if none are ready. + * @throws IllegalArgumentException if the input is malformed or a CRC64 check fails. + */ + public ByteBuffer decodeChunk(ByteBuffer buffer) { + // Decoder always reads little-endian; force the order on the caller's buffer so all our get() calls match + // the wire format regardless of how the buffer was constructed. + buffer.order(ByteOrder.LITTLE_ENDIAN); + + // Output collected during this single invocation. Each segment whose CRC validates in this call is appended + // here and ultimately returned to the policy as one ByteBuffer. + ByteArrayOutputStream validatedOutput = new ByteArrayOutputStream(); + + // Step 1: parse the message header on the first chunk that has enough bytes for it. If this chunk doesn't, + // bail out early. + if (!tryReadMessageHeader(buffer)) { + return emptyOrNull(validatedOutput); + } + + // Step 2: walk forward through the message until we either hit the end (messageOffset == messageLength) or + // we run out of bytes for the current structural element and have to wait for the next chunk. + while (messageOffset < messageLength) { + if (!segmentHeaderRead) { + // We are between segments. If every segment has been processed, only the trailing message footer + // can still appear in the stream – read it (or wait for it) and exit. + if (currentSegmentNumber == numSegments) { + if (!tryReadMessageFooter(buffer)) { + break; + } + break; + } + // Otherwise, parse the next segment's header. May return false if it is split across chunks. + if (!tryReadSegmentHeader(buffer)) { + break; + } + segmentHeaderRead = true; + } + + // Drain as many payload bytes as are available into the per-segment buffer. + int payloadRead = tryReadSegmentContent(buffer); + + if (currentSegmentContentOffset == currentSegmentContentLength) { + // Segment payload fully buffered. Validate the CRC footer (if any). When the footer isn't fully + // available yet, break and resume on the next chunk – currentSegmentBuffer keeps its contents so + // we can still emit them on the call where the footer arrives. + if (!tryReadSegmentFooter(buffer)) { + break; + } + // Segment passed validation: it is now safe to release the buffered payload to the caller. + try { + currentSegmentBuffer.writeTo(validatedOutput); + } catch (java.io.IOException e) { + // ByteArrayOutputStream.writeTo(ByteArrayOutputStream) does not actually throw, but the + // signature forces us to handle it. + throw LOGGER.logExceptionAsError(new IllegalStateException(e)); + } + currentSegmentBuffer.reset(); + segmentHeaderRead = false; + // Loop continues: either consume the next segment's header or the message footer. + } else if (payloadRead == 0 && getAvailableBytes(buffer) == 0) { + // Nothing left to read this pass and the segment is not complete – wait for the next chunk. + break; + } + } + + return emptyOrNull(validatedOutput); + } + + /** + * @return the total number of bytes the decoder can currently see across the carry-over {@link #pendingBytes} + * plus the unread tail of the supplied buffer. Used to decide whether a structural element (header / + * footer) can be parsed in this pass or whether we must defer to the next chunk. + */ + private int getAvailableBytes(ByteBuffer buffer) { + return pendingBytes.size() + buffer.remaining(); + } + + /** + * Returns a single read-only view that logically concatenates {@link #pendingBytes} with the unread tail of + * a buffer. + * + *

The position of the supplied buffer is intentionally not advanced here – reads happen on the + * combined view, and the original buffer's position is moved later by {@link #consumeBytes(int, ByteBuffer)} + * once we know the parse succeeded.

+ * + *

When pendingBytes is empty we avoid the allocation and just return a duplicate of the buffer; + * otherwise we materialize a fresh array of size {@code pending + buffer.remaining()}.

+ */ + private ByteBuffer getCombinedBuffer(ByteBuffer buffer) { + if (pendingBytes.size() == 0) { + ByteBuffer dup = buffer.duplicate(); + dup.order(ByteOrder.LITTLE_ENDIAN); + return dup; + } + + byte[] pending = pendingBytes.toByteArray(); + ByteBuffer combined = ByteBuffer.allocate(pending.length + buffer.remaining()); + combined.order(ByteOrder.LITTLE_ENDIAN); + combined.put(pending); + combined.put(buffer.duplicate()); + combined.flip(); + return combined; + } + + /** + * Consumes the next 8 bytes as a little-endian CRC64 footer, validates it against expectedCrc64, and + * advances {@link #messageOffset}. Used for both segment and message footers. + * + *

If fewer than 8 bytes are available, the remaining buffer bytes are stashed in {@link #pendingBytes} and + * the method returns false so the caller can break out of the decode loop and wait for the next + * chunk. On a CRC mismatch, an {@link IllegalArgumentException} is thrown (the decoder is then discarded by + * the enclosing policy).

+ */ + private boolean tryConsumeCrc64Footer(ByteBuffer buffer, long expectedCrc64, String mismatchDetail) { + if (getAvailableBytes(buffer) < CRC64_LENGTH) { + // Not enough bytes yet for the footer; carry whatever we have over to the next call. + appendToPending(buffer); + return false; + } + ByteBuffer combined = getCombinedBuffer(buffer); + long reportedCrc = combined.getLong(); + if (expectedCrc64 != reportedCrc) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException(enrichExceptionMessage( + "CRC64 mismatch" + mismatchDetail + ". Expected: " + expectedCrc64 + ", got: " + reportedCrc))); + } + consumeBytes(CRC64_LENGTH, buffer); + messageOffset += CRC64_LENGTH; + return true; + } + + /** + * Drains {@code bytesToConsume} bytes from the logical pending+buffer stream that + * {@link #getCombinedBuffer(ByteBuffer)} produced. + * + *

Bytes are taken from {@link #pendingBytes} first, then from the live buffer. The pending stream is + * reset whenever it is fully drained, and any leftover (when {@code bytesToConsume} was less than what was in + * pending) is rewritten so the carry-over stays compact.

+ */ + private void consumeBytes(int bytesToConsume, ByteBuffer buffer) { + int pendingSize = pendingBytes.size(); + if (bytesToConsume <= pendingSize) { + // The entire consume fits in pending: rewrite whatever survives back into pending after a reset. + byte[] remaining = pendingBytes.toByteArray(); + pendingBytes.reset(); + if (bytesToConsume < pendingSize) { + pendingBytes.write(remaining, bytesToConsume, pendingSize - bytesToConsume); + } + } else { + // Pending is fully drained and the remainder comes from the live buffer; advance its position directly. + int bytesFromBuffer = bytesToConsume - pendingSize; + pendingBytes.reset(); + buffer.position(buffer.position() + bytesFromBuffer); + } + } + + /** + * Stashes everything still unread in the buffer into {@link #pendingBytes} so it can be combined with the + * next chunk on the next call to {@link #decodeChunk(ByteBuffer)}. + * + *

This is only called when the current chunk does not contain enough bytes for the next structural element, + * so the carry-over is always small (bounded by the largest header size, currently 13 bytes).

+ */ + private void appendToPending(ByteBuffer buffer) { + while (buffer.hasRemaining()) { + pendingBytes.write(buffer.get()); + } + } + + /** + * Wraps {@code output} as a {@link ByteBuffer}, or returns {@code null} when no bytes were emitted in this + * pass. The {@code null} return distinguishes "no validated bytes ready in this chunk" (still need more input) + * from "stream complete" (which the caller checks via {@link #isComplete()}). + */ + private static ByteBuffer emptyOrNull(ByteArrayOutputStream output) { + if (output.size() == 0) { + return null; + } + return ByteBuffer.wrap(output.toByteArray()); + } + + /** + * Reports whether the decoder has finished consuming the entire structured message and validated everything it + * was supposed to validate. Used by the pipeline policy to distinguish "stream ended cleanly" from "stream was + * truncated". + * + *

The check requires all of:

+ *
    + *
  • The message header has been parsed ({@code messageLength != -1}).
  • + *
  • Every byte of the declared message has been consumed.
  • + *
  • No carry-over bytes remain in pending.
  • + *
  • No segment is currently in progress (no segment header without a matching footer).
  • + *
  • The current segment's payload accumulation is itself complete.
  • + *
+ * + * @return true if all expected bytes have been decoded and validated; false otherwise. + */ + public boolean isComplete() { + return messageLength != -1 + && messageOffset == messageLength + && pendingBytes.size() == 0 + && !segmentHeaderRead + && currentSegmentContentOffset == currentSegmentContentLength; + } + + /** + * Appends the current decoder offset to an exception message so failures can be traced back to a specific + * point in the encoded stream. + * + * @param message The original exception message. + * @return The original message with {@code [decoderOffset=N]} appended. + */ + private String enrichExceptionMessage(String message) { + return String.format("%s [decoderOffset=%d]", message, messageOffset); + } +} diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/DecodedResponse.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/DecodedResponse.java new file mode 100644 index 000000000000..df9c0bc88c10 --- /dev/null +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/DecodedResponse.java @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.common.policy; + +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.HttpResponse; +import com.azure.core.util.FluxUtil; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.nio.ByteBuffer; +import java.nio.charset.Charset; + +/** + * {@link HttpResponse} wrapper that exposes a decoded body stream while preserving the request, status code, and + * headers of the original response. + * + *

The policy hands this class a Flux that already represents validated, framing-stripped bytes (produced by the + * decoder pipeline). This class's only job is to make that Flux look like the body of the original + * {@link HttpResponse}. Status code, headers, and request remain identical to the underlying response so callers + * cannot distinguish a validated download from a normal one – the validation is transparent.

+ */ +class DecodedResponse extends HttpResponse { + private final HttpResponse originalResponse; + private final Flux decodedBody; + private final HttpHeaders httpHeaders; + private final int statusCode; + + /** + * Wraps {@code httpResponse} with a body backed by {@code decodedBody}. + * + * @param httpResponse The original response from the storage service. Its request, status code, and headers + * are preserved verbatim. + * @param decodedBody The Flux of CRC-validated, framing-stripped payload bytes produced by the decoder + * pipeline. + */ + DecodedResponse(HttpResponse httpResponse, Flux decodedBody) { + // Preserve the original request so retry policies, response models, and logging keep their reference chain + // intact. + super(httpResponse.getRequest()); + this.originalResponse = httpResponse; + this.decodedBody = decodedBody; + this.statusCode = httpResponse.getStatusCode(); + this.httpHeaders = httpResponse.getHeaders(); + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getHeaderValue(String name) { + return httpHeaders.getValue(name); + } + + @Override + public HttpHeaders getHeaders() { + return httpHeaders; + } + + @Override + public Flux getBody() { + return decodedBody; + } + + @Override + public Mono getBodyAsByteArray() { + return FluxUtil.collectBytesInByteBufferStream(decodedBody); + } + + @Override + public Mono getBodyAsString() { + return FluxUtil.collectBytesInByteBufferStream(decodedBody).map(String::new); + } + + @Override + public Mono getBodyAsString(Charset charset) { + return FluxUtil.collectBytesInByteBufferStream(decodedBody).map(b -> new String(b, charset)); + } +} diff --git a/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java new file mode 100644 index 000000000000..647a95f71e7e --- /dev/null +++ b/sdk/storage/azure-storage-common/src/main/java/com/azure/storage/common/policy/StorageContentValidationDecoderPolicy.java @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.common.policy; + +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.HttpMethod; +import com.azure.core.http.HttpPipelineCallContext; +import com.azure.core.http.HttpPipelineNextPolicy; +import com.azure.core.http.HttpResponse; +import com.azure.core.http.policy.HttpPipelinePolicy; +import com.azure.core.util.logging.ClientLogger; +import com.azure.storage.common.implementation.Constants; +import com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants; +import com.azure.storage.common.implementation.contentvalidation.StructuredMessageDecoder; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * HTTP pipeline policy that decodes the storage structured message body returned for downloads when CRC64 + * content validation is active. + * + *

The policy decides when to opt in (via the context key), tells the service to + * encode the response (via the request header), constructs the decoder and the wrapper response, and + * translates decoder-level failures (malformed framing, CRC mismatch, premature end-of-stream) into reactive + * {@link IOException} errors.

+ * + *

This policy uses {@link com.azure.core.http.HttpPipelinePosition#PER_RETRY PER_RETRY} semantics by default, so + * each retry produces a fresh response that this policy wraps with a fresh decoder. A CRC failure on one attempt + * cannot pollute another, and the storage download retry logic ({@code BlobAsyncClientBase.downloadStream...}) can + * resume by reissuing range requests; each new range response is validated end-to-end on its own.

+ * + *

Because the wrapped {@link StructuredMessageDecoder} only releases payload bytes after the corresponding + * segment's CRC has been verified, the {@link DecodedResponse}'s body Flux is guaranteed to contain only validated + * bytes – callers never see a byte that could later fail validation, even when retries are involved.

+ */ +public class StorageContentValidationDecoderPolicy implements HttpPipelinePolicy { + private static final ClientLogger LOGGER = new ClientLogger(StorageContentValidationDecoderPolicy.class); + + /** + * Creates a new instance of {@link StorageContentValidationDecoderPolicy}. + */ + public StorageContentValidationDecoderPolicy() { + } + + @Override + public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { + // Check if the decoding should be applied. + if (!shouldApplyDecoding(context)) { + return next.process(); + } + + // Tell the service we want a structured-message body. + context.getHttpRequest() + .getHeaders() + .set(Constants.HeaderConstants.STRUCTURED_BODY_TYPE_HEADER_NAME, + StructuredMessageConstants.STRUCTURED_BODY_TYPE_VALUE); + + return next.process().map(httpResponse -> { + // The HTTP Content-Length is the size of the encoded structured message body. We hand it to the + // decoder which cross-checks it against the message header. + Long contentLength = getContentLength(httpResponse.getHeaders()); + + // Only 2xx GET responses with a positive content length carry a body that we can decode. + if (!isEligibleDownload(httpResponse, contentLength)) { + return httpResponse; + } + + // Confirm the service actually honored our structured-body request before we hand the body to the decoder. + validateStructuredMessageHeaders(httpResponse); + + // Fresh decoder per response so retries each get a clean state machine. + StructuredMessageDecoder decoder = new StructuredMessageDecoder(contentLength); + + Flux decodedStream = decodeStream(httpResponse.getBody(), decoder); + return new DecodedResponse(httpResponse, decodedStream); + }); + } + + /** + * @return true when the request carries the boolean opt-in flag set + * by {@code ContentValidationModeResolver.addStructuredMessageDecodingToContext}. + */ + private boolean shouldApplyDecoding(HttpPipelineCallContext context) { + return context.getData(StructuredMessageConstants.STRUCTURED_MESSAGE_DECODING_CONTEXT_KEY) + .map(value -> value instanceof Boolean && (Boolean) value) + .orElse(false); + } + + /** + * Verifies the response acknowledges the structured-body request: presence of the + * {@code x-ms-structured-body} header and the {@code x-ms-structured-content-length} + * header. If either is missing, the service is sending us a normal body and we must not run the decoder over it. + */ + private void validateStructuredMessageHeaders(HttpResponse httpResponse) { + String structuredBody + = httpResponse.getHeaders().getValue(Constants.HeaderConstants.STRUCTURED_BODY_TYPE_HEADER_NAME); + String structuredContentLength + = httpResponse.getHeaders().getValue(Constants.HeaderConstants.STRUCTURED_CONTENT_LENGTH_HEADER_NAME); + if (structuredBody == null || structuredContentLength == null) { + throw LOGGER.logExceptionAsError( + new IllegalStateException("Structured message was requested but the response did not acknowledge it.")); + } + } + + /** + * Reads {@code Content-Length} as a {@code long}, returning {@code null} when the header is missing or + * not parseable so callers can simply skip decoding for non-bodied responses. + */ + private static Long getContentLength(HttpHeaders headers) { + String value = headers.getValue(HttpHeaderName.CONTENT_LENGTH); + if (value != null) { + try { + return Long.parseLong(value); + } catch (NumberFormatException e) { + // Header invalid; treat as not eligible. + } + } + return null; + } + + /** + * @return true for a 2xx response to a GET request, the only response shape that carries a body we + * can decode. 206 (Partial Content) on retried range downloads is included. + */ + private static boolean isDownloadResponse(HttpResponse response) { + return response.getRequest().getHttpMethod() == HttpMethod.GET && response.getStatusCode() / 100 == 2; + } + + /** + * @return true when the response is one we should decode: a 2xx GET with a positive, parseable + * {@code Content-Length}. + */ + private static boolean isEligibleDownload(HttpResponse response, Long contentLength) { + return isDownloadResponse(response) && contentLength != null && contentLength > 0; + } + + /** + * Builds the body-decoding Flux: each upstream {@link ByteBuffer} is fed to the decoder in order + * ({@code concatMap} preserves order and serializes access), and a deferred stream-completion check is + * appended so a truncated body raises an error instead of completing silently. + */ + private Flux decodeStream(Flux encodedFlux, StructuredMessageDecoder decoder) { + return encodedFlux.concatMap(buffer -> decodeBuffer(buffer, decoder)) + .concatWith(Mono.defer(() -> handleStreamCompletion(decoder))); + } + + /** + * Feeds a single inbound chunk to the decoder and translates its outputs into reactive emissions: + * If the decoder reports validated bytes, emit them downstream. + * If the decoder threw because the input is malformed or a CRC mismatch was detected, surface that as + * an {@link IOException}. + * If the decoder is already complete (e.g., extra trailing bytes after the message footer), drop the + * chunk silently. + */ + private Flux decodeBuffer(ByteBuffer buffer, StructuredMessageDecoder decoder) { + if (decoder.isComplete()) { + // Decoding finished on a previous chunk; ignore any trailing bytes the transport might still emit. + return Flux.empty(); + } + + if (buffer == null || !buffer.hasRemaining()) { + return Flux.empty(); + } + + try { + ByteBuffer validated = decoder.decodeChunk(buffer); + return emitDecodedPayload(validated); + } catch (IllegalArgumentException e) { + return Flux.error(new IOException("Failed to decode structured message: " + e.getMessage(), e)); + } catch (Exception e) { + // Anything not foreseen by the decoder, log it. + LOGGER.error("Failed to decode structured message chunk: " + e.getMessage(), e); + return Flux.error(new IOException("Failed to decode structured message chunk: " + e.getMessage(), e)); + } + } + + /** + * Run after the upstream Flux completes. If the decoder is not in a complete state, the response body ended + * before all expected bytes arrived – surface this as an {@link IOException} so callers don't accept a + * truncated payload. + */ + private Mono handleStreamCompletion(StructuredMessageDecoder decoder) { + if (!decoder.isComplete()) { + return Mono.error(new IOException("Stream ended prematurely before structured message decoding completed")); + } + return Mono.empty(); + } + + /** + * Wraps decoder output in a Flux. + */ + private static Flux emitDecodedPayload(ByteBuffer decodedPayload) { + if (decodedPayload == null || !decodedPayload.hasRemaining()) { + return Flux.empty(); + } + return Flux.just(decodedPayload); + } +} diff --git a/sdk/storage/azure-storage-common/src/test-shared/java/com/azure/storage/common/test/shared/policy/MockDownloadHttpResponse.java b/sdk/storage/azure-storage-common/src/test-shared/java/com/azure/storage/common/test/shared/policy/MockDownloadHttpResponse.java index 5e84dd31947d..ddd488ff0887 100644 --- a/sdk/storage/azure-storage-common/src/test-shared/java/com/azure/storage/common/test/shared/policy/MockDownloadHttpResponse.java +++ b/sdk/storage/azure-storage-common/src/test-shared/java/com/azure/storage/common/test/shared/policy/MockDownloadHttpResponse.java @@ -6,10 +6,10 @@ import com.azure.core.http.HttpHeaderName; import com.azure.core.http.HttpHeaders; import com.azure.core.http.HttpResponse; +import com.azure.core.util.FluxUtil; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.Charset; @@ -19,14 +19,21 @@ with than was worth it. Because this type is just for BlobDownload, we don't need to accept a header type. */ public class MockDownloadHttpResponse extends HttpResponse { + private final HttpResponse originalResponse; private final int statusCode; private final HttpHeaders headers; private final Flux body; public MockDownloadHttpResponse(HttpResponse response, int statusCode, Flux body) { + this(response, statusCode, response.getHeaders(), body); + } + + public MockDownloadHttpResponse(HttpResponse response, int statusCode, HttpHeaders headers, + Flux body) { super(response.getRequest()); + this.originalResponse = response; this.statusCode = statusCode; - this.headers = response.getHeaders(); + this.headers = headers; this.body = body; } @@ -52,21 +59,26 @@ public HttpHeaders getHeaders() { @Override public Flux getBody() { - return body; + return Flux.using(() -> originalResponse, ignored -> body, HttpResponse::close); } @Override public Mono getBodyAsByteArray() { - return Mono.error(new IOException()); + return FluxUtil.collectBytesInByteBufferStream(getBody()); } @Override public Mono getBodyAsString() { - return Mono.error(new IOException()); + return getBodyAsByteArray().map(bytes -> new String(bytes, Charset.defaultCharset())); } @Override public Mono getBodyAsString(Charset charset) { - return Mono.error(new IOException()); + return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); + } + + @Override + public void close() { + originalResponse.close(); } } diff --git a/sdk/storage/azure-storage-common/src/test-shared/java/com/azure/storage/common/test/shared/policy/MockPartialResponsePolicy.java b/sdk/storage/azure-storage-common/src/test-shared/java/com/azure/storage/common/test/shared/policy/MockPartialResponsePolicy.java index 1f1109d8b38b..347d3ac11a59 100644 --- a/sdk/storage/azure-storage-common/src/test-shared/java/com/azure/storage/common/test/shared/policy/MockPartialResponsePolicy.java +++ b/sdk/storage/azure-storage-common/src/test-shared/java/com/azure/storage/common/test/shared/policy/MockPartialResponsePolicy.java @@ -8,6 +8,7 @@ import com.azure.core.http.HttpMethod; import com.azure.core.http.HttpPipelineCallContext; import com.azure.core.http.HttpPipelineNextPolicy; +import com.azure.core.http.HttpPipelinePosition; import com.azure.core.http.HttpResponse; import com.azure.core.http.policy.HttpPipelinePolicy; import reactor.core.publisher.Flux; @@ -16,50 +17,137 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; public class MockPartialResponsePolicy implements HttpPipelinePolicy { - static final HttpHeaderName RANGE_HEADER = HttpHeaderName.fromString("x-ms-range"); - private int tries; - private final List rangeHeaders = new ArrayList<>(); + static final HttpHeaderName X_MS_RANGE_HEADER = HttpHeaderName.fromString("x-ms-range"); + static final HttpHeaderName RANGE_HEADER = HttpHeaderName.RANGE; + private final AtomicInteger tries; + private final List rangeHeaders = Collections.synchronizedList(new ArrayList<>()); + private final int maxBytesPerResponse; + private final AtomicInteger hits = new AtomicInteger(); + private final String targetUrlPrefix; + /** + * Creates a MockPartialResponsePolicy that simulates network interruptions. + * + * @param tries Number of times to simulate interruptions (0 = no interruptions) + */ public MockPartialResponsePolicy(int tries) { - this.tries = tries; + this(tries, 200, null); + } + + /** + * Creates a MockPartialResponsePolicy with configurable interruption behavior. + * + * @param tries Number of times to simulate interruptions (0 = no interruptions) + * @param maxBytesPerResponse Maximum bytes to return in each interrupted response + */ + public MockPartialResponsePolicy(int tries, int maxBytesPerResponse) { + this(tries, maxBytesPerResponse, null); + } + + /** + * Creates a MockPartialResponsePolicy with configurable interruption behavior and an optional URL filter. + * + * @param tries Number of times to simulate interruptions (0 = no interruptions) + * @param maxBytesPerResponse Maximum bytes to return in each interrupted response + * @param targetUrlPrefix If non-null, only requests whose URL starts with this prefix will be interrupted. + */ + public MockPartialResponsePolicy(int tries, int maxBytesPerResponse, String targetUrlPrefix) { + this.tries = new AtomicInteger(tries); + this.maxBytesPerResponse = maxBytesPerResponse; + this.targetUrlPrefix = targetUrlPrefix; + } + + @Override + public HttpPipelinePosition getPipelinePosition() { + return HttpPipelinePosition.PER_RETRY; } @Override public Mono process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) { return next.process().flatMap(response -> { HttpHeader rangeHttpHeader = response.getRequest().getHeaders().get(RANGE_HEADER); - String rangeHeader = rangeHttpHeader == null ? null : rangeHttpHeader.getValue(); + HttpHeader xMsRangeHttpHeader = response.getRequest().getHeaders().get(X_MS_RANGE_HEADER); - if (rangeHeader != null && rangeHeader.startsWith("bytes=")) { - rangeHeaders.add(rangeHeader); + if (response.getRequest().getHttpMethod() == HttpMethod.GET) { + String recordedRange = null; + if (rangeHttpHeader != null && rangeHttpHeader.getValue().startsWith("bytes=")) { + recordedRange = rangeHttpHeader.getValue(); + } else if (xMsRangeHttpHeader != null && xMsRangeHttpHeader.getValue().startsWith("bytes=")) { + recordedRange = xMsRangeHttpHeader.getValue(); + } + rangeHeaders.add(recordedRange == null ? "" : recordedRange); } - if ((response.getRequest().getHttpMethod() != HttpMethod.GET) || this.tries == 0) { + boolean urlMatches = targetUrlPrefix == null + || response.getRequest().getUrl().toString().startsWith(targetUrlPrefix); + + if ((response.getRequest().getHttpMethod() != HttpMethod.GET) || !urlMatches) { return Mono.just(response); } else { - this.tries -= 1; - return response.getBody().collectList().flatMap(bodyBuffers -> { - ByteBuffer firstBuffer = bodyBuffers.get(0); - byte firstByte = firstBuffer.get(); - - // Simulate partial response by returning the first byte only from the requested range and timeout - return Mono.just(new MockDownloadHttpResponse(response, 206, - Flux.just(ByteBuffer.wrap(new byte[] { firstByte })) - .concatWith(Flux.error(new IOException("Simulated timeout"))) - )); - }); + int remainingTries = this.tries.getAndUpdate(value -> value > 0 ? value - 1 : value); + if (remainingTries <= 0) { + return Mono.just(response); + } + hits.incrementAndGet(); + + Flux limitedBody = limitStreamToBytes(response.getBody(), maxBytesPerResponse); + return Mono.just( + new MockDownloadHttpResponse(response, response.getStatusCode(), response.getHeaders(), + limitedBody)); } }); } + private Flux limitStreamToBytes(Flux body, int maxBytes) { + return Flux.defer(() -> { + final long[] bytesEmitted = new long[] { 0 }; + return body.concatMap(buffer -> { + if (buffer == null || !buffer.hasRemaining()) { + return Flux.just(buffer); + } + + long remaining = maxBytes - bytesEmitted[0]; + if (remaining <= 0) { + return Flux.error(new IOException("Simulated timeout")); + } + + int bufferSize = buffer.remaining(); + if (bufferSize <= remaining) { + bytesEmitted[0] += bufferSize; + if (bytesEmitted[0] >= maxBytes) { + return Flux.just(buffer).concatWith(Flux.error(new IOException("Simulated timeout"))); + } + return Flux.just(buffer); + } else { + int bytesToEmit = (int) remaining; + ByteBuffer slice = buffer.duplicate(); + slice.limit(slice.position() + bytesToEmit); + + ByteBuffer limited = ByteBuffer.allocate(bytesToEmit); + limited.put(slice); + limited.flip(); + + bytesEmitted[0] += bytesToEmit; + return Flux.just(limited).concatWith(Flux.error(new IOException("Simulated timeout"))); + } + }); + }); + } + public int getTriesRemaining() { - return tries; + return tries.get(); } public List getRangeHeaders() { return rangeHeaders; } + + public int getHits() { + return hits.get(); + } } diff --git a/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/ContentValidationModeResolverTests.java b/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/ContentValidationModeResolverTests.java index fb2bbb4bef12..6cba23c3d9ae 100644 --- a/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/ContentValidationModeResolverTests.java +++ b/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/ContentValidationModeResolverTests.java @@ -22,7 +22,9 @@ import static com.azure.storage.common.implementation.contentvalidation.StructuredMessageConstants.USE_STRUCTURED_MESSAGE_CONTEXT; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; public class ContentValidationModeResolverTests { @@ -271,4 +273,12 @@ public void validateProgressWithContentValidationParallelOptionsDelegatesToListe assertThrows(IllegalArgumentException.class, () -> ContentValidationModeResolver .validateProgressWithContentValidation(opts, ContentValidationAlgorithm.CRC64)); } + + @Test + public void isCrc64OrAutoReflectsCrc64AndAutoOnly() { + assertTrue(ContentValidationModeResolver.isCrc64OrAuto(ContentValidationAlgorithm.CRC64)); + assertTrue(ContentValidationModeResolver.isCrc64OrAuto(ContentValidationAlgorithm.AUTO)); + assertFalse(ContentValidationModeResolver.isCrc64OrAuto(ContentValidationAlgorithm.NONE)); + assertFalse(ContentValidationModeResolver.isCrc64OrAuto(null)); + } } diff --git a/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/StorageCrc64CalculatorTests.java b/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/StorageCrc64CalculatorTests.java index a637287f3ce4..a07cef1f2013 100644 --- a/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/StorageCrc64CalculatorTests.java +++ b/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/StorageCrc64CalculatorTests.java @@ -211,4 +211,12 @@ private static Stream testConcatWithInitialsSupplier() { Arguments.of("889000539881195835", "2971048229276949174", "5346315327374690144", "307387", "1407121768110541356", "10535852615249992663", "741189", "3634018251978804152")); } + + @Test + void testComputeSliceMatchesFullArray() { + byte[] data = "Hello World!".getBytes(); + long expected = StorageCrc64Calculator.compute(data, 0); + long actual = StorageCrc64Calculator.compute(data, 0, data.length, 0); + assertEquals(expected, actual); + } } diff --git a/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageDecoderTests.java b/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageDecoderTests.java new file mode 100644 index 000000000000..faa3d69cde46 --- /dev/null +++ b/sdk/storage/azure-storage-common/src/test/java/com/azure/storage/common/implementation/contentvalidation/StructuredMessageDecoderTests.java @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.storage.common.implementation.contentvalidation; + +import com.azure.core.util.FluxUtil; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests for StructuredMessageDecoder with focus on the validated-emission guarantee: + * payload bytes for a segment are only returned after the segment's CRC has been verified. + */ +public class StructuredMessageDecoderTests { + private static final int MESSAGE_HEADER_LENGTH = 13; + private static final int SEGMENT_HEADER_LENGTH = 10; + private static final int CRC64_LENGTH = 8; + + private static ByteBuffer collectFlux(Flux flux) { + return ByteBuffer.wrap(FluxUtil.collectBytesInByteBufferStream(flux).block()).order(ByteOrder.LITTLE_ENDIAN); + } + + private static byte[] encode(byte[] originalData, int segmentLength, StructuredMessageFlags flags) + throws IOException { + StructuredMessageEncoder encoder = new StructuredMessageEncoder(originalData.length, segmentLength, flags); + ByteBuffer encoded = collectFlux(encoder.encode(ByteBuffer.wrap(originalData))); + byte[] encodedBytes = new byte[encoded.remaining()]; + encoded.get(encodedBytes); + return encodedBytes; + } + + @Test + public void readsCompleteMessageInSingleChunk() throws IOException { + byte[] originalData = new byte[1024]; + ThreadLocalRandom.current().nextBytes(originalData); + + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(originalData.length, 512, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = collectFlux(encoder.encode(ByteBuffer.wrap(originalData))); + int encodedLength = encodedData.remaining(); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedLength); + ByteBuffer result = decoder.decodeChunk(encodedData); + + assertTrue(decoder.isComplete()); + assertNotNull(result); + byte[] decodedData = new byte[result.remaining()]; + result.get(decodedData); + assertArrayEquals(originalData, decodedData); + } + + @Test + public void readsMessageSplitHeaderAcrossChunks() throws IOException { + byte[] originalData = new byte[256]; + ThreadLocalRandom.current().nextBytes(originalData); + + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(originalData.length, 128, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = collectFlux(encoder.encode(ByteBuffer.wrap(originalData))); + int encodedLength = encodedData.remaining(); + byte[] encodedBytes = new byte[encodedLength]; + encodedData.get(encodedBytes); + + // Split at byte 7 (mid-header, header is 13 bytes) + ByteBuffer chunk1 = ByteBuffer.wrap(encodedBytes, 0, 7); + ByteBuffer chunk2 = ByteBuffer.wrap(encodedBytes, 7, encodedLength - 7); + chunk1.order(ByteOrder.LITTLE_ENDIAN); + chunk2.order(ByteOrder.LITTLE_ENDIAN); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedLength); + + ByteBuffer result1 = decoder.decodeChunk(chunk1); + assertNull(result1); + assertFalse(decoder.isComplete()); + + ByteBuffer result2 = decoder.decodeChunk(chunk2); + assertNotNull(result2); + assertTrue(decoder.isComplete()); + } + + @Test + public void readsSegmentHeaderSplitAcrossChunks() throws IOException { + byte[] originalData = new byte[512]; + ThreadLocalRandom.current().nextBytes(originalData); + + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(originalData.length, 256, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = collectFlux(encoder.encode(ByteBuffer.wrap(originalData))); + int encodedLength = encodedData.remaining(); + byte[] encodedBytes = new byte[encodedLength]; + encodedData.get(encodedBytes); + + // Split after message header (13 bytes) + 5 bytes into first segment header. + int splitPoint = 18; + ByteBuffer chunk1 = ByteBuffer.wrap(encodedBytes, 0, splitPoint); + ByteBuffer chunk2 = ByteBuffer.wrap(encodedBytes, splitPoint, encodedLength - splitPoint); + chunk1.order(ByteOrder.LITTLE_ENDIAN); + chunk2.order(ByteOrder.LITTLE_ENDIAN); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedLength); + + ByteBuffer result1 = decoder.decodeChunk(chunk1); + // Only the message header is consumed; segment header is incomplete so nothing validated yet. + assertNull(result1); + assertFalse(decoder.isComplete()); + + ByteBuffer result2 = decoder.decodeChunk(chunk2); + assertNotNull(result2); + assertTrue(decoder.isComplete()); + } + + @Test + public void handlesZeroLengthSegment() throws IOException { + byte[] minimalData = new byte[1]; + ThreadLocalRandom.current().nextBytes(minimalData); + + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(minimalData.length, 1024, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = collectFlux(encoder.encode(ByteBuffer.wrap(minimalData))); + int encodedLength = encodedData.remaining(); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedLength); + ByteBuffer result = decoder.decodeChunk(encodedData); + + assertTrue(decoder.isComplete()); + assertNotNull(result); + assertEquals(1, result.remaining()); + } + + @Test + public void multipleChunksDecode() throws IOException { + byte[] originalData = new byte[256]; + ThreadLocalRandom.current().nextBytes(originalData); + + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(originalData.length, 128, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = collectFlux(encoder.encode(ByteBuffer.wrap(originalData))); + int encodedLength = encodedData.remaining(); + byte[] encodedBytes = new byte[encodedLength]; + encodedData.get(encodedBytes); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedLength); + + int chunkSize = 32; + ByteArrayOutputStream output = new ByteArrayOutputStream(); + + for (int offset = 0; offset < encodedLength; offset += chunkSize) { + int len = Math.min(chunkSize, encodedLength - offset); + ByteBuffer chunk = ByteBuffer.wrap(encodedBytes, offset, len); + chunk.order(ByteOrder.LITTLE_ENDIAN); + + ByteBuffer result = decoder.decodeChunk(chunk); + if (result != null && result.hasRemaining()) { + byte[] decoded = new byte[result.remaining()]; + result.get(decoded); + output.write(decoded, 0, decoded.length); + } + if (decoder.isComplete()) { + break; + } + } + + assertTrue(decoder.isComplete()); + assertArrayEquals(originalData, output.toByteArray()); + } + + @Test + public void decodeWithNoCrc() throws IOException { + byte[] originalData = new byte[256]; + ThreadLocalRandom.current().nextBytes(originalData); + + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(originalData.length, 128, StructuredMessageFlags.NONE); + ByteBuffer encodedData = collectFlux(encoder.encode(ByteBuffer.wrap(originalData))); + int encodedLength = encodedData.remaining(); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedLength); + ByteBuffer result = decoder.decodeChunk(encodedData); + + assertTrue(decoder.isComplete()); + assertNotNull(result); + byte[] decodedData = new byte[result.remaining()]; + result.get(decodedData); + assertArrayEquals(originalData, decodedData); + } + + @Test + public void handlesZeroLengthBuffer() throws IOException { + byte[] originalData = new byte[256]; + ThreadLocalRandom.current().nextBytes(originalData); + + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(originalData.length, 128, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = collectFlux(encoder.encode(ByteBuffer.wrap(originalData))); + int encodedLength = encodedData.remaining(); + byte[] encodedBytes = new byte[encodedLength]; + encodedData.get(encodedBytes); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedLength); + + ByteBuffer emptyBuffer = ByteBuffer.allocate(0); + ByteBuffer result1 = decoder.decodeChunk(emptyBuffer); + assertNull(result1); + + ByteBuffer dataBuffer = ByteBuffer.wrap(encodedBytes); + dataBuffer.order(ByteOrder.LITTLE_ENDIAN); + ByteBuffer result2 = decoder.decodeChunk(dataBuffer); + assertNotNull(result2); + assertTrue(decoder.isComplete()); + } + + /** + * Verifies Kyle's emission guarantee (r3120267493): payload bytes for a segment are + * not emitted until the segment's CRC footer is read and validated. When the decoder + * has received the full segment payload but the CRC footer is still incomplete, + * {@code decodeChunk} must return {@code null}, never the in-progress payload bytes. + */ + @Test + public void withholdsPayloadUntilSegmentFooterValidated() throws IOException { + byte[] originalData = new byte[1024]; + ThreadLocalRandom.current().nextBytes(originalData); + + StructuredMessageEncoder encoder + = new StructuredMessageEncoder(originalData.length, 1024, StructuredMessageFlags.STORAGE_CRC64); + ByteBuffer encodedData = collectFlux(encoder.encode(ByteBuffer.wrap(originalData))); + int encodedLength = encodedData.remaining(); + byte[] encodedBytes = new byte[encodedLength]; + encodedData.get(encodedBytes); + + // Layout: msgHeader(13) + segHeader(10) + payload(1024) + segCrc(8) + msgCrc(8) = 1063. + // Feed the full payload but stop 1 byte short of completing the SEGMENT CRC footer. + int segCrcAllButLast = 13 + 10 + 1024 + 7; + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedLength); + + ByteBuffer chunk1 = ByteBuffer.wrap(encodedBytes, 0, segCrcAllButLast); + chunk1.order(ByteOrder.LITTLE_ENDIAN); + ByteBuffer partial = decoder.decodeChunk(chunk1); + assertNull(partial, "Decoder must not emit payload before segment CRC is validated"); + assertFalse(decoder.isComplete()); + + // Feed the remainder; segment CRC completes, payload is released, and message CRC completes. + ByteBuffer chunk2 = ByteBuffer.wrap(encodedBytes, segCrcAllButLast, encodedLength - segCrcAllButLast); + chunk2.order(ByteOrder.LITTLE_ENDIAN); + ByteBuffer emitted = decoder.decodeChunk(chunk2); + assertNotNull(emitted); + assertTrue(decoder.isComplete()); + + byte[] decodedData = new byte[emitted.remaining()]; + emitted.get(decodedData); + assertArrayEquals(originalData, decodedData); + } + + @Test + public void throwsOnUnsupportedStructuredMessageVersion() throws IOException { + byte[] data = new byte[64]; + ThreadLocalRandom.current().nextBytes(data); + byte[] encodedBytes = encode(data, 64, StructuredMessageFlags.STORAGE_CRC64); + + // Corrupt message version (byte 0 of the message header). + encodedBytes[0] = (byte) (StructuredMessageConstants.DEFAULT_MESSAGE_VERSION + 1); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedBytes.length); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> decoder.decodeChunk(ByteBuffer.wrap(encodedBytes).order(ByteOrder.LITTLE_ENDIAN))); + assertTrue(exception.getMessage().contains("Unsupported structured message version")); + } + + @Test + public void throwsOnMessageLengthMismatch() throws IOException { + byte[] data = new byte[128]; + ThreadLocalRandom.current().nextBytes(data); + byte[] encodedBytes = encode(data, 128, StructuredMessageFlags.STORAGE_CRC64); + + // Construct decoder with wrong expected encoded length. + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedBytes.length + 1); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> decoder.decodeChunk(ByteBuffer.wrap(encodedBytes).order(ByteOrder.LITTLE_ENDIAN))); + assertTrue(exception.getMessage().contains("did not match content length")); + } + + @Test + public void throwsOnUnexpectedSegmentNumber() throws IOException { + byte[] data = new byte[300]; + ThreadLocalRandom.current().nextBytes(data); + byte[] encodedBytes = encode(data, 128, StructuredMessageFlags.STORAGE_CRC64); + + // Corrupt first segment number from 1 to 2 (offset 13 in v1 format). + ByteBuffer.wrap(encodedBytes).order(ByteOrder.LITTLE_ENDIAN).putShort(MESSAGE_HEADER_LENGTH, (short) 2); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedBytes.length); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> decoder.decodeChunk(ByteBuffer.wrap(encodedBytes).order(ByteOrder.LITTLE_ENDIAN))); + assertTrue(exception.getMessage().contains("Unexpected segment number")); + } + + @Test + public void throwsOnInvalidSegmentSize() throws IOException { + byte[] data = new byte[256]; + ThreadLocalRandom.current().nextBytes(data); + byte[] encodedBytes = encode(data, 128, StructuredMessageFlags.STORAGE_CRC64); + + // Corrupt first segment size to an impossible value (offsets 15..22 in v1 format). + ByteBuffer.wrap(encodedBytes).order(ByteOrder.LITTLE_ENDIAN).putLong(MESSAGE_HEADER_LENGTH + 2, Long.MAX_VALUE); + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedBytes.length); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> decoder.decodeChunk(ByteBuffer.wrap(encodedBytes).order(ByteOrder.LITTLE_ENDIAN))); + assertTrue(exception.getMessage().contains("Invalid segment size detected")); + } + + @Test + public void throwsOnSegmentCrcMismatch() throws IOException { + byte[] data = new byte[512]; + ThreadLocalRandom.current().nextBytes(data); + byte[] encodedBytes = encode(data, 512, StructuredMessageFlags.STORAGE_CRC64); + + // Layout for one-segment message: + // messageHeader(13) + segmentHeader(10) + payload(512) + segmentCrc(8) + messageCrc(8) + int segmentCrcOffset = MESSAGE_HEADER_LENGTH + SEGMENT_HEADER_LENGTH + data.length; + encodedBytes[segmentCrcOffset] ^= 0x01; + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(encodedBytes.length); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> decoder.decodeChunk(ByteBuffer.wrap(encodedBytes).order(ByteOrder.LITTLE_ENDIAN))); + assertTrue(exception.getMessage().contains("CRC64 mismatch in segment")); + } + + @Test + public void throwsOnMessageCrcMismatch() throws IOException { + byte[] data = new byte[512]; + ThreadLocalRandom.current().nextBytes(data); + byte[] encodedBytes = encode(data, 512, StructuredMessageFlags.STORAGE_CRC64); + + int messageCrcOffset = encodedBytes.length - CRC64_LENGTH; + byte[] corrupted = Arrays.copyOf(encodedBytes, encodedBytes.length); + corrupted[messageCrcOffset] ^= 0x01; + + StructuredMessageDecoder decoder = new StructuredMessageDecoder(corrupted.length); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, + () -> decoder.decodeChunk(ByteBuffer.wrap(corrupted).order(ByteOrder.LITTLE_ENDIAN))); + assertTrue(exception.getMessage().contains("CRC64 mismatch in message footer")); + } + +}