From 0a7ba427010e8d2835649ae0a09985f5b605ff7f Mon Sep 17 00:00:00 2001 From: Peter Wu <162184229+weirongw23-msft@users.noreply.github.com> Date: Tue, 3 Feb 2026 11:16:23 -0500 Subject: [PATCH 01/14] [Storage] [STG 102] Create File with Data (#44901) --- sdk/storage/azure-storage-file-share/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdk/storage/azure-storage-file-share/CHANGELOG.md b/sdk/storage/azure-storage-file-share/CHANGELOG.md index 162222511349..374a1ba89946 100644 --- a/sdk/storage/azure-storage-file-share/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-share/CHANGELOG.md @@ -3,6 +3,10 @@ ## 12.27.0b1 (Unreleased) ### Features Added +- Added support for the keyword `file_property_semantics` in `ShareClient`'s `create_directory` and `DirectoryClient`'s +`create_directory` APIs, which specifies permissions to be configured upon directory creation. +- Added support for the keyword `data` to `FileClient`'s `create_file` API, which specifies the +optional initial data to be uploaded (up to 4MB). ## 12.25.0 (2026-05-14) From 01f9bc24b966dbe746da062ea8e897a3cf49eadc Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Thu, 19 Feb 2026 14:58:16 -0800 Subject: [PATCH 02/14] [Storage][102] CRC64 content validation - part 1 (#45096) --- .../devtools_testutils/storage/__init__.py | 3 + .../storage/aio/__init__.py | 7 +- .../storage/aio/asyncdecorators.py | 18 + .../devtools_testutils/storage/decorators.py | 18 + .../azure/storage/blob/_shared/policies.py | 69 +- .../storage/blob/_shared/policies_async.py | 9 +- .../azure/storage/blob/_shared/streams.py | 572 ++++++++++++++++ .../azure/storage/blob/_shared/validation.py | 33 + .../azure/storage/blob/_upload_helpers.py | 14 +- .../azure/storage/blob/aio/_upload_helpers.py | 14 +- .../tests/test_content_validation.py | 357 ++++++++++ .../tests/test_content_validation_async.py | 342 ++++++++++ .../azure-storage-blob/tests/test_streams.py | 638 ++++++++++++++++++ 13 files changed, 2063 insertions(+), 31 deletions(-) create mode 100644 eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py create mode 100644 eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py create mode 100644 sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py create mode 100644 sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py create mode 100644 sdk/storage/azure-storage-blob/tests/test_content_validation.py create mode 100644 sdk/storage/azure-storage-blob/tests/test_content_validation_async.py create mode 100644 sdk/storage/azure-storage-blob/tests/test_streams.py diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py index 1482a6448440..4098c53f93dd 100644 --- a/eng/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py @@ -1,9 +1,12 @@ from .api_version_policy import ApiVersionAssertPolicy +from .decorators import GenericTestProxyParametrize1, GenericTestProxyParametrize2 from .service_versions import service_version_map, ServiceVersion, is_version_before from .testcase import StorageRecordedTestCase, LogCaptured __all__ = [ "ApiVersionAssertPolicy", + "GenericTestProxyParametrize1", + "GenericTestProxyParametrize2", "service_version_map", "StorageRecordedTestCase", "ServiceVersion", diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py index bc57c1b559e8..f1c9a35dcb20 100644 --- a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py @@ -1,3 +1,8 @@ from .asynctestcase import AsyncStorageRecordedTestCase +from .asyncdecorators import GenericTestProxyParametrize1, GenericTestProxyParametrize2 -__all__ = ["AsyncStorageRecordedTestCase"] +__all__ = [ + "AsyncStorageRecordedTestCase", + "GenericTestProxyParametrize1", + "GenericTestProxyParametrize2" +] \ No newline at end of file diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py new file mode 100644 index 000000000000..cc2455a69bbe --- /dev/null +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py @@ -0,0 +1,18 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +class GenericTestProxyParametrize1: + def __call__(self, fn): + async def _wrapper(test_class, a, **kwargs): + await fn(test_class, a, **kwargs) + return _wrapper + + +class GenericTestProxyParametrize2: + def __call__(self, fn): + async def _wrapper(test_class, a, b, **kwargs): + await fn(test_class, a, b, **kwargs) + return _wrapper diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py new file mode 100644 index 000000000000..20d5c2dfba10 --- /dev/null +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py @@ -0,0 +1,18 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +class GenericTestProxyParametrize1: + def __call__(self, fn): + def _wrapper(test_class, a, **kwargs): + fn(test_class, a, **kwargs) + return _wrapper + + +class GenericTestProxyParametrize2: + def __call__(self, fn): + def _wrapper(test_class, a, b, **kwargs): + fn(test_class, a, b, **kwargs) + return _wrapper diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 61a4fdb15bdd..568b7b3fdc9d 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -12,7 +12,7 @@ import uuid from io import SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -34,6 +34,8 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE from .models import LocationMode, StorageErrorCode +from .streams import StructuredMessageEncodeStream, StructuredMessageProperties +from .validation import calculate_crc64_bytes, ChecksumAlgorithm if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +46,16 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -101,9 +110,13 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( StorageContentValidation.get_content_md5(response.http_response.body()) ) @@ -358,17 +371,11 @@ class StorageContentValidation(SansIOHTTPPolicy): This will overwrite any headers already defined in the request. """ - - header_name = "Content-MD5" - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super(StorageContentValidation, self).__init__() @staticmethod def get_content_md5(data): - # Since HTTP does not differentiate between no content and empty content, - # we have to perform a None check. - data = data or b"" md5 = hashlib.md5() # nosec if isinstance(data, bytes): md5.update(data) @@ -383,22 +390,48 @@ def get_content_md5(data): try: data.seek(pos, SEEK_SET) except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc + raise ValueError(CV_TYPE_ERROR_MSG) from exc else: - raise ValueError("Data should be bytes or a seekable file-like object.") + raise ValueError(CV_TYPE_ERROR_MSG) return md5.digest() def on_request(self, request: "PipelineRequest") -> None: validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 + if not validate_content: + return + + if request.http_request.method != "GET": + # Since HTTP does not differentiate between no content and empty content, + # we have to perform a None check. + data = request.http_request.data or b"" + if validate_content is True or validate_content == ChecksumAlgorithm.MD5: + computed_md5 = encode_base64(StorageContentValidation.get_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 + + elif validate_content == ChecksumAlgorithm.CRC64: + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) + elif hasattr(data, "read"): + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) + request.context["validate_content"] = validate_content def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if (validate_content is True or validate_content == ChecksumAlgorithm.MD5) and response.http_response.headers.get("content-md5"): computed_md5 = request.context.get("validate_content_md5") or encode_base64( StorageContentValidation.get_content_md5(response.http_response.body()) ) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 4cb32f23248b..e20e5db84860 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -16,6 +16,7 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .validation import ChecksumAlgorithm if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -37,8 +38,12 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py new file mode 100644 index 000000000000..f1790a2882b8 --- /dev/null +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -0,0 +1,572 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import IO, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: + return (version.to_bytes(1, 'little') + + size.to_bytes(8, 'little') + + flags.to_bytes(2, 'little') + + num_segments.to_bytes(2, 'little')) + + +def generate_segment_header(number: int, size: int) -> bytes: + return (number.to_bytes(2, 'little') + + size.to_bytes(8, 'little')) + + +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _checksum_offset: int + """Tracks the offset the checksum has been calculated up to for seeking purposes""" + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._checksum_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - \ + ((self._current_segment_number - 1) * self._segment_size) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + def close(self) -> None: + self._inner_stream.close() + super().close() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return self._inner_stream.seekable() and self._initial_content_position is not None + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return (self._message_header_length + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return (self._message_header_length + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return (self._message_header_length + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return (self._message_header_length + self._content_offset + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence == SEEK_SET: + position = offset + elif whence == SEEK_CUR: + position = self.tell() + offset + elif whence == SEEK_END: + position = self.message_length + offset + else: + raise ValueError(f"Invalid value for whence: {whence}") + + if position < 0: + raise ValueError(f"Cannot seek to negative position: {position}") + if position > self.tell(): + raise UnsupportedOperation("This stream only supports seeking backwards.") + + # MESSAGE_HEADER + if position < self._message_header_length: + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_offset = position + self._content_offset = 0 + self._current_segment_number = 0 + # MESSAGE_FOOTER + elif position >= self.message_length - self._message_footer_length: + self._current_region = SMRegion.MESSAGE_FOOTER + self._current_region_offset = position - (self.message_length - self._message_footer_length) + self._content_offset = self.content_length + self._current_segment_number = self._num_segments + else: + # The size of a "full" segment. Fine to use for calculating new segment number and pos + full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length + new_segment_num = 1 + (position - self._message_header_length) // full_segment_size + segment_pos = (position - self._message_header_length) % full_segment_size + previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size + + # We need the size of the segment we are seeking to for some of the calculations below + new_segment_size = self._segment_size + if new_segment_num == self._num_segments: + # The last segment size is the remaining content length + new_segment_size = self.content_length - previous_segments_total_content_size + + # SEGMENT_HEADER + if segment_pos < self._segment_header_length: + self._current_region = SMRegion.SEGMENT_HEADER + self._current_region_offset = segment_pos + self._content_offset = previous_segments_total_content_size + # SEGMENT_CONTENT + elif segment_pos < self._segment_header_length + new_segment_size: + self._current_region = SMRegion.SEGMENT_CONTENT + self._current_region_offset = segment_pos - self._segment_header_length + self._content_offset = previous_segments_total_content_size + self._current_region_offset + # SEGMENT_FOOTER + else: + self._current_region = SMRegion.SEGMENT_FOOTER + self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size + self._content_offset = previous_segments_total_content_size + new_segment_size + + self._current_segment_number = new_segment_num + + self._update_current_region_length() + self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) + return position + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b'' + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER): + count += self._read_metadata_region(self._current_region, remaining, output) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min(self._segment_size, self.content_length - self._content_offset) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, 'little') + return b'' + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little') + return b'' + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[self._current_region_offset: self._current_region_offset + read_size] + output.write(content) + + self._current_region_offset += read_size + if (self._current_region_offset == self._current_region_length and + self._current_region != SMRegion.MESSAGE_FOOTER): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + # Will be non-zero if there is data to read that does not need to have checksum calculated. + # Will always be positive as stream can only seek backwards. + checksum_offset = self._checksum_offset - self._content_offset + + read_size = min(size, self._current_region_length - self._current_region_offset) + if checksum_offset != 0: + # Only read up to checksum offset this iteration + read_size = min(read_size, checksum_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + if checksum_offset == 0: + self._segment_crc64s[self._current_segment_number] = \ + calculate_crc64(content, self._segment_crc64s[self._current_segment_number]) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + # Only update the checksum offset if we've read new data + if self._content_offset > self._checksum_offset: + self._checksum_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecodeStream(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_stream: IO[bytes] + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + + def __init__(self, inner_stream: IO[bytes], content_length: int) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_stream = inner_stream + + # Validate that inner stream is positioned at the start of the structured message + try: + initial_position = self._inner_stream.tell() + if initial_position != 0: + raise ValueError( + f"Inner stream must be positioned at the start of the structured message. " + f"Current position is {initial_position}, expected 0." + ) + except (AttributeError, UnsupportedOperation, OSError): + # Stream doesn't support tell(), assume it's at the correct position + pass + + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + super().__init__() + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def close(self) -> None: + self._inner_stream.close() + super().close() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b'' + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + self._read_message_footer() + return b'' + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + data = self._inner_stream.read(size) + if len(data) != size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + return data + + def _read_message_header(self) -> None: + # The first byte should always be the message version + self.message_version = int.from_bytes(self._read_from_inner(1), 'little') + + if self.message_version == 1: + message_length = int.from_bytes(self._read_from_inner(8), 'little') + if message_length != self.message_length: + raise ValueError(f"Structured message length {message_length} " + f"did not match content length {self.message_length}") + + self.flags = StructuredMessageProperties(int.from_bytes(self._read_from_inner(2), 'little')) + self.num_segments = int.from_bytes(self._read_from_inner(2), 'little') + + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + else: + raise ValueError(f"The structured message version is not supported: {self.message_version}") + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, 'little'): + raise ValueError("CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid.") + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + segment_number = int.from_bytes(self._read_from_inner(2), 'little') + if segment_number != self._segment_number + 1: + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") + self._segment_number = segment_number + self._segment_content_length = int.from_bytes(self._read_from_inner(8), 'little') + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, 'little'): + raise ValueError(f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid.") + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py new file mode 100644 index 000000000000..117aee73353b --- /dev/null +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py @@ -0,0 +1,33 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +from enum import Enum +from typing import cast + +from azure.core import CaseInsensitiveEnumMeta + +CRC64_LENGTH = 8 + + +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + AUTO = "auto" + MD5 = "md5" + CRC64 = "crc64" + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, 'little')) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py index 2ce55f7ab237..64b0432da803 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py @@ -123,11 +123,15 @@ def upload_block_blob( # pylint: disable=too-many-locals, too-many-statements return cast(Dict[str, Any], response) - use_original_upload_path = blob_settings.use_byte_buffer or \ - validate_content or encryption_options.get('required') or \ - blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold or \ - hasattr(stream, 'seekable') and not stream.seekable() or \ - not hasattr(stream, 'seek') or not hasattr(stream, 'tell') + use_original_upload_path = ( + blob_settings.use_byte_buffer + or validate_content is not None + or encryption_options.get('required') + or blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold + or hasattr(stream, 'seekable') and not stream.seekable() + or not hasattr(stream, 'seek') + or not hasattr(stream, 'tell') + ) if use_original_upload_path: total_size = length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py index 794beee36e3b..dc7b35b04307 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py @@ -103,11 +103,15 @@ async def upload_block_blob( # pylint: disable=too-many-locals, too-many-statem return response - use_original_upload_path = blob_settings.use_byte_buffer or \ - validate_content or encryption_options.get('required') or \ - blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold or \ - hasattr(stream, 'seekable') and not stream.seekable() or \ - not hasattr(stream, 'seek') or not hasattr(stream, 'tell') + use_original_upload_path = ( + blob_settings.use_byte_buffer + or validate_content is not None + or encryption_options.get('required') + or blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold + or hasattr(stream, 'seekable') and not stream.seekable() + or not hasattr(stream, 'seek') + or not hasattr(stream, 'tell') + ) if use_original_upload_path: total_size = length diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py new file mode 100644 index 000000000000..13990403b666 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -0,0 +1,357 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from azure.storage.blob import ( + BlobBlock, + BlobServiceClient, + BlobType, + ContainerClient +) +from devtools_testutils import recorded_by_proxy +from devtools_testutils.storage import ( + GenericTestProxyParametrize1, + GenericTestProxyParametrize2, + StorageRecordedTestCase +) + +from settings.testcase import BlobPreparer + + +def assert_content_md5(request): + if request.http_request.query.get('comp') in ('block', 'page') or request.http_request.headers.get('x-ms-blob-type') == 'BlockBlob': + assert request.http_request.headers.get('Content-MD5') is not None + + +def assert_content_md5_get(response): + assert response.http_request.headers.get('x-ms-range-get-content-md5') == 'true' + assert response.http_response.headers.get('Content-MD5') is not None + + +def assert_content_crc64(request): + if request.http_request.query.get('comp') in ('block', 'page') or request.http_request.headers.get('x-ms-blob-type') == 'BlockBlob': + assert request.http_request.headers.get('x-ms-content-crc64') is not None + + +def assert_structured_message(request): + if request.http_request.query.get('comp') in ('block', 'page') or request.http_request.headers.get('x-ms-blob-type') == 'BlockBlob': + assert request.http_request.headers.get('x-ms-structured-body') is not None + + +class TestIter: + def __init__(self, data, *, chunk_size=100): + self.data = data + self.chunk_size = chunk_size + self.length = len(data) + self.offset = 0 + + def __len__(self): + return self.length + + def __iter__(self): + return self + + def __next__(self): + if self.offset >= self.length: + raise StopIteration + + result = self.data[self.offset: self.offset + self.chunk_size] + self.offset += len(result) + return result + + +class TestStorageContentValidation(StorageRecordedTestCase): + bsc: BlobServiceClient + container: ContainerClient + + def _setup(self, account_name): + token_credential = self.get_credential(BlobServiceClient) + self.bsc = BlobServiceClient(self.account_url(account_name, "blob"), token_credential, logging_enable=True) + self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + self.container.create_container() + + def teardown_method(self, _): + if self.container: + try: + self.container.delete_container() + except: + pass + + def _get_blob_reference(self): + return self.get_resource_name('blob') + + # TODO: This test coming later + # @BlobPreparer() + # def test_encryption_blocked_crc64(self, **kwargs): + # storage_account_name = kwargs.pop("storage_account_name") + # storage_account_key = kwargs.pop("storage_account_key") + + # kek = KeyWrapper('key1') + # blob = BlobClient( + # self.account_url(storage_account_name, "blob"), + # "testing", + # "testing", + # credential=storage_account_key, + # require_encryption=True, + # encryption_version='2.0', + # key_encryption_key=kek) + + # with pytest.raises(ValueError): + # blob.upload_blob(b'123', validate_content='crc64') + + @BlobPreparer() + @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @GenericTestProxyParametrize2() + @recorded_by_proxy + def test_upload_blob(self, a, b, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + + # Test supported data types + byte_data = b'abc' * 512 + str_data = "你好世界abcd" * 32 + str_data_encoded = str_data.encode('utf-8') + byte_stream = BytesIO(byte_data) + byte_iter = TestIter(byte_data) + str_iter = TestIter(str_data) + + blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob(str_data, blob_type=a, encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == str_data_encoded + blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == str_data_encoded + + @BlobPreparer() + @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @GenericTestProxyParametrize2() + @recorded_by_proxy + def test_upload_blob_chunks(self, a, b, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_put_size = 512 + self.container._config.max_block_size = 512 + self.container._config.max_page_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + + # Test supported data types + byte_data = b'abc' * 512 + str_data = "你好世界abcd" * 32 + str_data_encoded = str_data.encode('utf-8') + byte_stream = BytesIO(byte_data) + byte_iter = TestIter(byte_data) + str_iter = TestIter(str_data) + + blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob(str_data, blob_type=a, encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == str_data_encoded + blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == str_data_encoded + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_blob_substream(self, a, **kwargs): + # Substream is disabled when using content validation so this will behave like regular upload (buffer) + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_put_size = 512 + self.container._config.max_block_size = 512 + self.container._config.min_large_block_upload_threshold = 1 # Set less than block size to enable substream + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + data = b'abc' * 512 + b'abcde' + io = BytesIO(data) + + # Act + blob.upload_blob(io, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = blob.download_blob() + assert content.read() == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_stage_block(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b'abc' * 512 + data2 = '你好世界' * 10 + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i: i + 500] + + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + blob.stage_block('1', data1, validate_content=a, raw_request_hook=assert_method) + blob.stage_block('2', data2, encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) + blob.stage_block('3', generator(), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock('1'), BlobBlock('2'), BlobBlock('3')]) + + # Assert + content = blob.download_blob() + assert content.read() == data1 + data2.encode('utf-8-sig') + data1 + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_stage_block_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b'abcde' * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + blob.stage_block('1', BytesIO(content), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock('1')]) + + result = blob.download_blob() + assert result.read() == content + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.live_test_only + def test_stage_block_streaming_large(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + data1 = b'abcde' * 1024 * 1024 # 5 MiB + data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 + data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + blob.stage_block('1', BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + blob.stage_block('2', BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + blob.stage_block('3', BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock('1'), BlobBlock('2'), BlobBlock('3')]) + + result = blob.download_blob() + assert result.read() == data1 + data2 + data3 + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_append_block(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b'abc' * 512 + data2 = '你好世界' * 10 + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i: i + 500] + + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + blob.create_append_blob() + blob.append_block(data1, validate_content=a, raw_request_hook=assert_method) + blob.append_block(data2, encoding='utf-16', validate_content=a, raw_request_hook=assert_method) + blob.append_block(generator(), validate_content=a, raw_request_hook=assert_method) + + content = blob.download_blob() + assert content.read() == data1 + data2.encode('utf-16') + data1 + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_append_block_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b'abcde' * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + blob.create_append_blob() + blob.append_block(BytesIO(content), validate_content=a, raw_request_hook=assert_method) + + result = blob.download_blob() + assert result.read() == content + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.live_test_only + def test_append_block_streaming_large(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + data1 = b'abcde' * 1024 * 1024 # 5 MiB + data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 + data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + blob.create_append_blob() + blob.append_block(BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + blob.append_block(BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + blob.append_block(BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + + result = blob.download_blob() + assert result.read() == data1 + data2 + data3 + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_page(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b'abc' * 512 + data2 = "你好世界abcd" * 32 + data2_encoded = data2.encode('utf-8') + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + # Act + blob.create_page_blob(5 * 1024) + blob.upload_page(data1, offset=0, length=len(data1), validate_content=a, raw_request_hook=assert_method) + blob.upload_page(data2, offset=len(data1), length=len(data2_encoded), encoding='utf-8', validate_content=a, raw_request_hook=assert_method) + + # Assert + content = blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) + assert content.read() == data1 + data2_encoded diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py new file mode 100644 index 000000000000..dc2c885adcef --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -0,0 +1,342 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from azure.core.exceptions import ResourceExistsError +from azure.storage.blob import BlobBlock, BlobType +from azure.storage.blob.aio import ( + BlobServiceClient, + ContainerClient +) +from devtools_testutils.aio import recorded_by_proxy_async +from devtools_testutils.storage.aio import ( + AsyncStorageRecordedTestCase, + GenericTestProxyParametrize1, + GenericTestProxyParametrize2 +) +from settings.testcase import BlobPreparer + +from test_content_validation import ( + assert_content_crc64, + assert_content_md5, + assert_structured_message, + TestIter +) + + +class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): + bsc: BlobServiceClient + container: ContainerClient + + async def _setup(self, account_name): + token_credential = self.get_credential(BlobServiceClient, is_async=True) + self.bsc = BlobServiceClient(self.account_url(account_name, "blob"), token_credential, logging_enable=True) + self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + try: + await self.container.create_container() + except ResourceExistsError: + pass + + # TODO: Figure out how to get this to run automatically + async def _teardown(self): + if self.container: + try: + await self.container.delete_container() + except: + pass + + def _get_blob_reference(self): + return self.get_resource_name('blob') + + # TODO: This test coming later + # @BlobPreparer() + # async def test_encryption_blocked_crc64(self, **kwargs): + # storage_account_name = kwargs.pop("storage_account_name") + # storage_account_key = kwargs.pop("storage_account_key") + + # kek = KeyWrapper('key1') + # blob = BlobClient( + # self.account_url(storage_account_name, "blob"), + # "testing", + # "testing", + # credential=storage_account_key, + # require_encryption=True, + # encryption_version='2.0', + # key_encryption_key=kek) + + # with pytest.raises(ValueError): + # await blob.upload_blob(b'123', validate_content='crc64') + + @BlobPreparer() + @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @GenericTestProxyParametrize2() + @recorded_by_proxy_async + async def test_upload_blob(self, a, b, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + + # Test supported data types + byte_data = b'abc' * 512 + str_data = "你好世界abcd" * 32 + str_data_encoded = str_data.encode('utf-8') + byte_stream = BytesIO(byte_data) + byte_iter = TestIter(byte_data) + str_iter = TestIter(str_data) + + # Act / Assert + await blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob(str_data, blob_type=a, encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == str_data_encoded + await blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == str_data_encoded + + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @GenericTestProxyParametrize2() + @recorded_by_proxy_async + async def test_upload_blob_chunks(self, a, b, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_put_size = 512 + self.container._config.max_block_size = 512 + self.container._config.max_page_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + + # Test supported data types + byte_data = b'abc' * 512 + str_data = "你好世界abcd" * 32 + str_data_encoded = str_data.encode('utf-8') + byte_stream = BytesIO(byte_data) + byte_iter = TestIter(byte_data) + str_iter = TestIter(str_data) + + # Act / Assert + await blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob(str_data, blob_type=a, encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == str_data_encoded + await blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert await (await blob.download_blob()).read() == str_data_encoded + + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_blob_substream(self, a, **kwargs): + # Substream is disabled when using content validation so this will behave like regular upload (buffer) + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_put_size = 512 + self.container._config.max_block_size = 512 + self.container._config.min_large_block_upload_threshold = 1 # Set less than block size to enable substream + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + data = b'abc' * 512 + b'abcde' + io = BytesIO(data) + + # Act + await blob.upload_blob(io, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await blob.download_blob() + assert await content.read() == data + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_stage_block(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b'abc' * 512 + data2 = '你好世界' * 10 + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i: i + 500] + + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + # Act + await blob.stage_block('1', data1, validate_content=a, raw_request_hook=assert_method) + await blob.stage_block('2', data2, encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) + await blob.stage_block('3', generator(), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock('1'), BlobBlock('2'), BlobBlock('3')]) + + # Assert + content = await blob.download_blob() + assert await content.read() == data1 + data2.encode('utf-8-sig') + data1 + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_stage_block_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b'abcde' * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + await blob.stage_block('1', BytesIO(content), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock('1')]) + + # Assert + result = await blob.download_blob() + assert await result.read() == content + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.live_test_only + async def test_stage_block_streaming_large(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + data1 = b'abcde' * 1024 * 1024 # 5 MiB + data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 + data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + await blob.stage_block('1', BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + await blob.stage_block('2', BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + await blob.stage_block('3', BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock('1'), BlobBlock('2'), BlobBlock('3')]) + + result = await blob.download_blob() + assert await result.read() == data1 + data2 + data3 + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_append_block(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b'abc' * 512 + data2 = '你好世界' * 10 + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i: i + 500] + + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + # Act + await blob.create_append_blob() + await blob.append_block(data1, validate_content=a, raw_request_hook=assert_method) + await blob.append_block(data2, encoding='utf-16', validate_content=a, raw_request_hook=assert_method) + await blob.append_block(generator(), validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await blob.download_blob() + assert await content.readall() == data1 + data2.encode('utf-16') + data1 + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_append_block_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b'abcde' * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + await blob.create_append_blob() + await blob.append_block(BytesIO(content), validate_content=a, raw_request_hook=assert_method) + + result = await blob.download_blob() + assert await result.read() == content + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.live_test_only + async def test_append_block_streaming_large(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + data1 = b'abcde' * 1024 * 1024 # 5 MiB + data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 + data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + await blob.create_append_blob() + await blob.append_block(BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + await blob.append_block(BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + await blob.append_block(BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + + result = await blob.download_blob() + assert await result.read() == data1 + data2 + data3 + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_page(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b'abc' * 512 + data2 = "你好世界abcd" * 32 + data2_encoded = data2.encode('utf-8') + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + # Act + await blob.create_page_blob(5 * 1024) + await blob.upload_page(data1, offset=0, length=len(data1), validate_content=a, raw_request_hook=assert_method) + await blob.upload_page(data2, offset=len(data1), length=len(data2_encoded), encoding='utf-8', validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) + assert await content.read() == data1 + data2_encoded + await self._teardown() diff --git a/sdk/storage/azure-storage-blob/tests/test_streams.py b/sdk/storage/azure-storage-blob/tests/test_streams.py new file mode 100644 index 000000000000..3665886aaa71 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_streams.py @@ -0,0 +1,638 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import os +import random +from io import BytesIO, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import List, Optional, Tuple, Union + +import pytest +from azure.storage.blob._shared.streams import ( + StructuredMessageConstants, + StructuredMessageDecodeStream, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from azure.storage.extensions import crc64 + +from test_helpers import NonSeekableStream + + +def _write_segment( + number: int, + data: bytes, + data_crc: Optional[int], + stream: BytesIO, +) -> None: + stream.write(number.to_bytes(2, 'little')) # Segment number + stream.write(len(data).to_bytes(8, 'little')) # Segment length + stream.write(data) # Segment content + if data_crc is not None: + stream.write(data_crc.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little')) + + +def _build_structured_message( + data: bytes, + segment_size: Union[int, List[int]], + flags: StructuredMessageProperties, + invalidate_crc_segment: Optional[int] = None, +) -> Tuple[BytesIO, int]: + if isinstance(segment_size, list): + segment_count = len(segment_size) + else: + segment_count = math.ceil(len(data) / segment_size) or 1 + segment_footer_length = StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0 + + message_length = ( + StructuredMessageConstants.V1_HEADER_LENGTH + + ((StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + segment_footer_length) * segment_count) + + len(data) + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) + + message = BytesIO() + message_crc = 0 + + # Message Header + message.write(b'\x01') # Version + message.write(message_length.to_bytes(8, 'little')) # Message length + message.write(int(flags).to_bytes(2, 'little')) # Flags + message.write(segment_count.to_bytes(2, 'little')) # Num segments + + # Special case for 0 length content + if len(data) == 0: + crc = 0 if StructuredMessageProperties.CRC64 in flags else None + _write_segment(1, data, crc, message) + else: + # Segments + segment_sizes = segment_size if isinstance(segment_size, list) else [segment_size] * segment_count + offset = 0 + for i in range(1, segment_count + 1): + size = segment_sizes[i - 1] + segment_data = data[offset: offset + size] + offset += size + + segment_crc = None + if StructuredMessageProperties.CRC64 in flags: + segment_crc = crc64.compute(segment_data, 0) + if i == invalidate_crc_segment: + segment_crc += 5 + _write_segment(i, segment_data, segment_crc, message) + + message_crc = crc64.compute(segment_data, message_crc) + + # Message footer + if StructuredMessageProperties.CRC64 in flags: + if invalidate_crc_segment == -1: + message_crc += 5 + message.write(message_crc.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little')) + + message.seek(0, 0) + return message, message_length + + +class TestStructuredMessageEncodeStream: + def test_close(self): + inner = BytesIO() + stream = StructuredMessageEncodeStream(inner, 0, StructuredMessageProperties.NONE) + assert not stream.closed + assert not inner.closed + + stream.close() + assert stream.closed + assert inner.closed + + with pytest.raises(ValueError): + stream.read(0) + + def test_read_past_end(self): + data = os.urandom(10) + inner_stream = BytesIO(data) + + stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + expected = _build_structured_message(data, len(data), StructuredMessageProperties.CRC64)[0].getvalue() + + result = stream.read(100) + assert result == expected + + result = stream.read(100) + assert result == b'' + + @pytest.mark.parametrize("size, segment_size, flags", [ + (0, 1, StructuredMessageProperties.NONE), + (0, 1, StructuredMessageProperties.CRC64), + (10, 1, StructuredMessageProperties.NONE), + (10, 1, StructuredMessageProperties.CRC64), + (1024, 1024, StructuredMessageProperties.NONE), + (1024, 1024, StructuredMessageProperties.CRC64), + (1024, 512, StructuredMessageProperties.NONE), + (1024, 512, StructuredMessageProperties.CRC64), + (1024, 200, StructuredMessageProperties.NONE), + (1024, 200, StructuredMessageProperties.CRC64), + (123456, 1234, StructuredMessageProperties.NONE), + (123456, 1234, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_read_all(self, size, segment_size, flags): + data = os.urandom(size) + inner_stream = BytesIO(data) + + stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) + actual = stream.read() + + expected = _build_structured_message(data, segment_size, flags)[0].getvalue() + assert actual == expected + + @pytest.mark.parametrize("size, segment_size, chunk_size, flags", [ + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (12345, 678, 90, StructuredMessageProperties.NONE), + (12345, 678, 90, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_read_chunks(self, size, segment_size, chunk_size, flags): + data = os.urandom(size) + inner_stream = BytesIO(data) + + stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) + + read = 0 + content = b'' + while read < len(stream): + chunk = stream.read(chunk_size) + assert len(chunk) == min(chunk_size, len(stream) - read) + + content += chunk + read += chunk_size + + expected = _build_structured_message(data, segment_size, flags)[0].getvalue() + assert content == expected + + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + def test_random_reads(self, data_size): + data = os.urandom(data_size) + inner_stream = BytesIO(data) + + segment_size = data_size // 3 + stream = StructuredMessageEncodeStream( + inner_stream, + len(data), + StructuredMessageProperties.CRC64, + segment_size=segment_size) + + count = 0 + content = b'' + stream_size = len(stream) + while count < stream_size: + read_size = random.randint(10, stream_size // 3) + read_size = min(read_size, stream_size - count) + + content += stream.read(read_size) + count += read_size + + expected = _build_structured_message(data, segment_size, StructuredMessageProperties.CRC64)[0].getvalue() + assert content == expected + + def test_seekable(self): + data = os.urandom(10) + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + + assert sm_stream.seekable() + + def test_not_seekable(self): + data = os.urandom(10) + inner_stream = NonSeekableStream(BytesIO(data)) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + + assert not sm_stream.seekable() + with pytest.raises(UnsupportedOperation): + sm_stream.seek(0) + + def test_seek_whence(self): + data = os.urandom(10) + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + # Read so we can seek backwards + sm_stream.read(25) + + pos = sm_stream.seek(10, SEEK_SET) + assert pos == 10 + pos = sm_stream.seek(-len(sm_stream) + 9, SEEK_END) + assert pos == 9 + pos = sm_stream.seek(-5, SEEK_CUR) + assert pos == 4 + + def test_seek_forward(self): + data = os.urandom(10) + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + + sm_stream.read(5) + with pytest.raises(UnsupportedOperation): + sm_stream.seek(10) + + @pytest.mark.parametrize("initial_read, segment_size, flags", [ + # Single segment + (5000, 2048, StructuredMessageProperties.NONE), # End -> Beginning + (5000, 2048, StructuredMessageProperties.CRC64), # End -> Beginning + (5, 2048, StructuredMessageProperties.NONE), # Message header + (5, 2048, StructuredMessageProperties.CRC64), + (20, 2048, StructuredMessageProperties.NONE), # Segment header + (20, 2048, StructuredMessageProperties.CRC64), + (100, 2048, StructuredMessageProperties.NONE), # First segment content + (100, 2048, StructuredMessageProperties.CRC64), + (1000, 2048, StructuredMessageProperties.NONE), # Second segment content + (1000, 2048, StructuredMessageProperties.CRC64), + (525, 2048, StructuredMessageProperties.CRC64), # Segment footer + (1092, 2048, StructuredMessageProperties.CRC64), # Message footer + # Multiple segments + (5000, 500, StructuredMessageProperties.NONE), # End -> Beginning + (5000, 500, StructuredMessageProperties.CRC64), # End -> Beginning + (5, 500, StructuredMessageProperties.NONE), # Message header + (5, 500, StructuredMessageProperties.CRC64), + (20, 500, StructuredMessageProperties.NONE), # Segment header + (20, 500, StructuredMessageProperties.CRC64), + (100, 500, StructuredMessageProperties.NONE), # First segment content + (100, 500, StructuredMessageProperties.CRC64), + (1000, 500, StructuredMessageProperties.NONE), # Second segment content + (1000, 500, StructuredMessageProperties.CRC64), + (525, 500, StructuredMessageProperties.CRC64), # Segment footer + (1092, 500, StructuredMessageProperties.CRC64), # Message footer + ]) + def test_seek_reverse_beginning(self, initial_read, segment_size, flags): + data = os.urandom(1024) + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) + expected = _build_structured_message(data, segment_size, flags)[0].getvalue() + + initial = sm_stream.read(initial_read) + assert initial == expected[:initial_read] + + sm_stream.seek(0) + result = sm_stream.read() + assert result == expected + + @pytest.mark.parametrize("initial_read, seek_offset, segment_size, flags", [ + # Single segment + (10, 5, 2048, StructuredMessageProperties.NONE), # Message header -> Message header + (10, 5, 2048, StructuredMessageProperties.CRC64), + (20, 15, 2048, StructuredMessageProperties.NONE), # Segment header -> Segment header + (20, 15, 2048, StructuredMessageProperties.CRC64), + (100, 50, 2048, StructuredMessageProperties.NONE), # First segment content -> First segment content + (100, 50, 2048, StructuredMessageProperties.CRC64), + (1000, 900, 2048, StructuredMessageProperties.NONE), # Second segment content -> Second segment content + (1000, 900, 2048, StructuredMessageProperties.CRC64), + (530, 525, 2048, StructuredMessageProperties.CRC64), # Segment footer -> Segment footer + (1060, 1050, 2048, StructuredMessageProperties.CRC64), # Message footer -> Segment footer + (1000, 100, 2048, StructuredMessageProperties.NONE), # Second segment content -> First segment content + (1000, 100, 2048, StructuredMessageProperties.CRC64), + (1000, 20, 2048, StructuredMessageProperties.NONE), # Second segment content -> First segment header + (1000, 20, 2048, StructuredMessageProperties.CRC64), + (1000, 530, 2048, StructuredMessageProperties.CRC64), # Second segment content -> First segment footer + (1097, 100, 2048, StructuredMessageProperties.CRC64), # Message footer -> First segment content + # Multiple segments + (10, 5, 500, StructuredMessageProperties.NONE), # Message header -> Message header + (10, 5, 500, StructuredMessageProperties.CRC64), + (20, 15, 500, StructuredMessageProperties.NONE), # Segment header -> Segment header + (20, 15, 500, StructuredMessageProperties.CRC64), + (100, 50, 500, StructuredMessageProperties.NONE), # First segment content -> First segment content + (100, 50, 500, StructuredMessageProperties.CRC64), + (1000, 900, 500, StructuredMessageProperties.NONE), # Second segment content -> Second segment content + (1000, 900, 500, StructuredMessageProperties.CRC64), + (530, 525, 500, StructuredMessageProperties.CRC64), # Segment footer -> Segment footer + (1097, 1090, 500, StructuredMessageProperties.CRC64), # Message footer -> Segment footer + (1000, 100, 500, StructuredMessageProperties.NONE), # Second segment content -> First segment content + (1000, 100, 500, StructuredMessageProperties.CRC64), + (1000, 20, 500, StructuredMessageProperties.NONE), # Second segment content -> First segment header + (1000, 20, 500, StructuredMessageProperties.CRC64), + (1000, 530, 500, StructuredMessageProperties.CRC64), # Second segment content -> First segment footer + (1097, 100, 500, StructuredMessageProperties.CRC64), # Message footer -> First segment content + ]) + def test_seek_reverse_middle(self, initial_read, seek_offset, segment_size, flags): + data = os.urandom(1024) + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) + expected = _build_structured_message(data, segment_size, flags)[0].getvalue() + + initial = sm_stream.read(initial_read) + assert initial == expected[:initial_read] + + sm_stream.seek(seek_offset) + result = sm_stream.read() + assert result == expected[seek_offset:] + + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_seek_reverse_random(self, flags): + data = os.urandom(1024) + expected = _build_structured_message(data, 500, flags)[0].getvalue() + + for _ in range(10): + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=500) + + initial_read = random.randint(5, len(data)) + seek_offset = random.randint(0, initial_read) + + initial = sm_stream.read(initial_read) + assert initial == expected[:initial_read] + + sm_stream.seek(seek_offset) + result = sm_stream.read() + assert result == expected[seek_offset:] + + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_partial_stream_read(self, flags): + data = os.urandom(1024) + partial_read = 100 + + inner_stream = BytesIO(data) + inner_stream.seek(partial_read) + expected = _build_structured_message(data[partial_read:], 500, flags)[0].getvalue() + + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data) - partial_read, flags, segment_size=500) + result = sm_stream.read() + assert result == expected + + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_partial_stream_seek_beginning(self, flags): + data = os.urandom(1024) + partial_read = 100 + + inner_stream = BytesIO(data) + inner_stream.seek(partial_read) + expected = _build_structured_message(data[partial_read:], 500, flags)[0].getvalue() + + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data) - partial_read, flags, segment_size=500) + initial = sm_stream.read(101) + assert initial == expected[:101] + + sm_stream.seek(0) + assert inner_stream.tell() == partial_read + + result = sm_stream.read() + assert result == expected + + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_partial_stream_seek_middle(self, flags): + data = os.urandom(1024) + partial_read = 100 + + inner_stream = BytesIO(data) + inner_stream.seek(partial_read) + expected = _build_structured_message(data[partial_read:], 500, flags)[0].getvalue() + + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data) - partial_read, flags, segment_size=500) + initial = sm_stream.read(501) + assert initial == expected[:501] + + sm_stream.seek(100) + assert inner_stream.tell() == partial_read + (100 - + StructuredMessageConstants.V1_HEADER_LENGTH - + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH) + + result = sm_stream.read() + assert result == expected[100:] + + +class TestStructuredMessageDecodeStream: + + def test_empty_inner_stream(self): + with pytest.raises(ValueError): + StructuredMessageDecodeStream(BytesIO(), 0) + + def test_read_past_end(self): + data = os.urandom(10) + message_stream, length = _build_structured_message(data, len(data), StructuredMessageProperties.CRC64) + + stream = StructuredMessageDecodeStream(message_stream, length) + result = stream.read(100) + assert result == data + + result = stream.read(100) + assert result == b'' + + @pytest.mark.parametrize("size, segment_size, flags", [ + (0, 1, StructuredMessageProperties.NONE), + (0, 1, StructuredMessageProperties.CRC64), + (10, 1, StructuredMessageProperties.NONE), + (10, 1, StructuredMessageProperties.CRC64), + (1024, 1024, StructuredMessageProperties.NONE), + (1024, 1024, StructuredMessageProperties.CRC64), + (1024, 512, StructuredMessageProperties.NONE), + (1024, 512, StructuredMessageProperties.CRC64), + (1024, 200, StructuredMessageProperties.NONE), + (1024, 200, StructuredMessageProperties.CRC64), + (123456, 1234, StructuredMessageProperties.NONE), + (123456, 1234, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_read_all(self, size, segment_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + stream = StructuredMessageDecodeStream(message_stream, length) + content = stream.read() + + assert content == data + + @pytest.mark.parametrize("size, segment_size, chunk_size, flags", [ + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_read_chunks(self, size, segment_size, chunk_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + stream = StructuredMessageDecodeStream(message_stream, length) + read = 0 + content = b'' + while read < len(data): + chunk = stream.read(chunk_size) + content += chunk + read += chunk_size + + assert content == data + + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + def test_random_reads(self, data_size): + data = os.urandom(data_size) + message_stream, length = _build_structured_message(data, data_size // 3, StructuredMessageProperties.CRC64) + + stream = StructuredMessageDecodeStream(message_stream, length) + + count = 0 + content = b'' + while count < data_size: + read_size = random.randint(10, data_size // 3) + read_size = min(read_size, data_size - count) + + content += stream.read(read_size) + count += read_size + + assert content == data + + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + def test_random_segment_sizes(self, data_size): + segment_sizes = [] + count = 0 + while count < data_size: + size = random.randint(10, data_size // 3) + size = min(size, data_size - count) + segment_sizes.append(size) + count += size + + data = os.urandom(data_size) + message_stream, length = _build_structured_message(data, segment_sizes, StructuredMessageProperties.CRC64) + + stream = StructuredMessageDecodeStream(message_stream, length) + content = stream.read() + + assert content == data + + @pytest.mark.parametrize("invalid_segment", [1, 2, 3, -1]) + def test_crc64_mismatch_read_all(self, invalid_segment): + data = os.urandom(3 * 1024) + message_stream, length = _build_structured_message( + data, + 1024, + StructuredMessageProperties.CRC64, + invalid_segment) + + stream = StructuredMessageDecodeStream(message_stream, length) + with pytest.raises(ValueError) as e: + stream.read() + assert 'CRC64 mismatch' in str(e.value) + + @pytest.mark.parametrize("invalid_segment", [1, 2, 3, -1]) + def test_crc64_mismatch_read_chunks(self, invalid_segment): + data = os.urandom(3 * 1024) + message_stream, length = _build_structured_message( + data, + 1024, + StructuredMessageProperties.CRC64, + invalid_segment) + + stream = StructuredMessageDecodeStream(message_stream, length) + # Since we only check CRC on segment borders, some reads will succeed, but we test + # to ensure eventually the stream reading will error out. + with pytest.raises(ValueError) as e: + read = 0 + while read < len(data): + stream.read(512) + + assert 'CRC64 mismatch' in str(e.value) + + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_invalid_message_version(self, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 512, flags) + + # Stream already set to front + message_stream.write(b'\xFF') + message_stream.seek(0) + + stream = StructuredMessageDecodeStream(message_stream, length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("message_length", [100, 1234567]) # Correct value: 1057 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_message_length(self, message_length, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 512, flags) + + message_stream.seek(1) + message_stream.write(int.to_bytes(message_length, 8, 'little')) + message_stream.seek(0) + + stream = StructuredMessageDecodeStream(message_stream, length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("segment_count", [2, 123]) # Correct value: 4 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_segment_count(self, segment_count, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + message_stream.seek(11) + message_stream.write(int.to_bytes(segment_count, 2, 'little')) + message_stream.seek(0) + + stream = StructuredMessageDecodeStream(message_stream, length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("segment_number", [1, 123]) # Correct value: 2 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_segment_number(self, segment_number, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + # Change the second segment to be the incorrect number + position = (StructuredMessageConstants.V1_HEADER_LENGTH + + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + 256 + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) + message_stream.seek(position) + message_stream.write(int.to_bytes(segment_number, 2, 'little')) + message_stream.seek(0) + + stream = StructuredMessageDecodeStream(message_stream, length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("segment_size", [123, 345]) # Correct value: 256 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_segment_size(self, segment_size, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + # Change the second segment to be the incorrect size + position = (StructuredMessageConstants.V1_HEADER_LENGTH + + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + 256 + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0) + + 2) + message_stream.seek(position) + message_stream.write(int.to_bytes(segment_size, 2, 'little')) + message_stream.seek(0) + + stream = StructuredMessageDecodeStream(message_stream, length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("segment_size", [123, 345]) # Correct value: 256 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_segment_size_single_segment(self, segment_size, flags): + data = os.urandom(256) + message_stream, length = _build_structured_message(data, 256, flags) + + message_stream.seek(15) + message_stream.write(int.to_bytes(segment_size, 2, 'little')) + message_stream.seek(0) + + stream = StructuredMessageDecodeStream(message_stream, length) + with pytest.raises(ValueError): + stream.read() From a2625fb4a96b1eebc4d7a7292dc0440a7a2ad47a Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Thu, 12 Mar 2026 11:39:32 -0700 Subject: [PATCH 03/14] [Storage][102] CRC64 content validation - part 2 (#45567) --- .../azure/storage/blob/_download.py | 16 +- .../azure/storage/blob/_shared/policies.py | 46 +++++- .../storage/blob/_shared/request_handlers.py | 2 +- .../azure/storage/blob/_shared/streams.py | 57 ++++--- .../azure/storage/blob/_shared/validation.py | 8 +- .../tests/test_content_validation.py | 147 ++++++++++++++++++ .../azure-storage-blob/tests/test_streams.py | 101 +++++++----- 7 files changed, 303 insertions(+), 74 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py index a2f50ebc91ec..47b49365a6b1 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py @@ -21,6 +21,7 @@ from ._shared.request_handlers import validate_and_format_range_headers from ._shared.response_handlers import parse_length_from_content_range, process_storage_error from ._shared.constants import DEFAULT_MAX_CONCURRENCY +from ._shared.validation import is_md5_validation from ._deserialize import deserialize_blob_properties, get_page_ranges_result from ._encryption import ( adjust_blob_size_for_encryption, @@ -214,7 +215,7 @@ def _download_chunk(self, chunk_start: int, chunk_end: int) -> Tuple[bytes, int] range_header, range_validation = validate_and_format_range_headers( download_range[0], download_range[1], - check_content_md5=self.validate_content + check_content_md5=is_md5_validation(self.validate_content) ) retry_active = True @@ -358,6 +359,7 @@ def __init__( self._file_size = 0 self._non_empty_ranges = None self._encryption_data: Optional["_EncryptionData"] = None + self._is_structured_message = False # The content download offset, after any processing (decryption), in bytes self._download_offset = 0 @@ -382,11 +384,10 @@ def __init__( self._get_encryption_data_request() # The service only provides transactional MD5s for chunks under 4MB. - # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. - first_get_size = ( - self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size - ) + first_get_size = self._config.max_single_get_size if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size + initial_request_start = self._download_start if self._end_range is not None and self._end_range - initial_request_start < first_get_size: initial_request_end = self._end_range @@ -445,7 +446,7 @@ def _get_encryption_data_request(self) -> None: @property def _download_complete(self): - if is_encryption_v2(self._encryption_data): + if is_encryption_v2(self._encryption_data) or self._is_structured_message: return self._download_offset >= self.size return self._raw_download_offset >= self.size @@ -455,7 +456,7 @@ def _initial_request(self): self._initial_range[1], start_range_required=False, end_range_required=False, - check_content_md5=self._validate_content + check_content_md5=is_md5_validation(self._validate_content) ) retry_active = True @@ -543,6 +544,7 @@ def _initial_request(self): except HttpResponseError: pass + self._is_structured_message = response.response.headers.get("x-ms-structured-body") is not None if not self._download_complete and self._request_options.get("modified_access_conditions"): self._request_options["modified_access_conditions"].if_match = response.properties.etag diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 568b7b3fdc9d..f3568df83140 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -32,10 +32,10 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode -from .streams import StructuredMessageEncodeStream, StructuredMessageProperties -from .validation import calculate_crc64_bytes, ChecksumAlgorithm +from .streams import StructuredMessageDecoder, StructuredMessageEncodeStream, StructuredMessageProperties +from .validation import calculate_crc64_bytes, ChecksumAlgorithm, is_md5_validation if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -401,11 +401,17 @@ def on_request(self, request: "PipelineRequest") -> None: if not validate_content: return - if request.http_request.method != "GET": + # Download + if request.http_request.method == "GET": + if validate_content == ChecksumAlgorithm.CRC64: + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. data = request.http_request.data or b"" - if validate_content is True or validate_content == ChecksumAlgorithm.MD5: + if is_md5_validation(validate_content): computed_md5 = encode_base64(StorageContentValidation.get_content_md5(data)) request.http_request.headers[MD5_HEADER] = computed_md5 request.context["validate_content_md5"] = computed_md5 @@ -431,7 +437,7 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") if not validate_content: return - if (validate_content is True or validate_content == ChecksumAlgorithm.MD5) and response.http_response.headers.get("content-md5"): + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): computed_md5 = request.context.get("validate_content_md5") or encode_base64( StorageContentValidation.get_content_md5(response.http_response.body()) ) @@ -444,6 +450,34 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") response=response.http_response, ) + elif validate_content == ChecksumAlgorithm.CRC64: + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}", + ), + response=response.http_response, + ) + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = StructuredMessageDecoder(iterator, content_length, block_size=DATA_BLOCK_SIZE) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + class StorageRetryPolicy(HTTPPolicy): """ diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py index b23f65859690..9f106333e1fb 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py @@ -139,7 +139,7 @@ def validate_and_format_range_headers( raise ValueError("Both start and end range required for MD5 content validation.") if end_range - start_range > 4 * 1024 * 1024: raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.") - range_validation = "true" + range_validation = True return range_header, range_validation diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py index f1790a2882b8..40fa93c2dac5 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -8,7 +8,7 @@ import sys from enum import auto, Enum, IntFlag from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET -from typing import IO, Optional +from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -381,7 +381,7 @@ def _increment_current_segment(self): self._segment_crc64s.setdefault(self._current_segment_number, 0) -class StructuredMessageDecodeStream(IOBase): # pylint: disable=too-many-instance-attributes +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes message_version: int """The version of the structured message.""" @@ -392,34 +392,24 @@ class StructuredMessageDecodeStream(IOBase): # pylint: disable=too-many-instanc num_segments: int """The number of message segments.""" - _inner_stream: IO[bytes] + _inner_iterator: Iterator[bytes] + _buffer: bytes _message_offset: int _message_crc64: int _segment_number: int _segment_crc64: int _segment_content_length: int _segment_content_offset: int + _block_size: int - def __init__(self, inner_stream: IO[bytes], content_length: int) -> None: + def __init__(self, inner_iterator: Iterator[bytes], content_length: int, *, block_size: int = 4096) -> None: self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: raise ValueError("Content not long enough to contain a valid message header.") - self._inner_stream = inner_stream - - # Validate that inner stream is positioned at the start of the structured message - try: - initial_position = self._inner_stream.tell() - if initial_position != 0: - raise ValueError( - f"Inner stream must be positioned at the start of the structured message. " - f"Current position is {initial_position}, expected 0." - ) - except (AttributeError, UnsupportedOperation, OSError): - # Stream doesn't support tell(), assume it's at the correct position - pass - + self._inner_iterator = inner_iterator + self._buffer = b'' self._message_offset = 0 self._message_crc64 = 0 @@ -427,8 +417,13 @@ def __init__(self, inner_stream: IO[bytes], content_length: int) -> None: self._segment_crc64 = 0 self._segment_content_length = 0 self._segment_content_offset = 0 + self._block_size = block_size super().__init__() + @property + def content_length(self) -> int: + return self.message_length + @property def _segment_header_length(self) -> int: return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH @@ -445,16 +440,21 @@ def _message_footer_length(self) -> int: def _end_of_segment_content(self) -> bool: return self._segment_content_offset == self._segment_content_length - def close(self) -> None: - self._inner_stream.close() - super().close() - def readable(self) -> bool: return True def seekable(self) -> bool: return False + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + def read(self, size: int = -1) -> bytes: if self.closed: raise ValueError("Stream is closed") @@ -512,9 +512,18 @@ def read(self, size: int = -1) -> bytes: return content.getvalue() def _read_from_inner(self, size: int) -> bytes: - data = self._inner_stream.read(size) - if len(data) != size: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] return data def _read_message_header(self) -> None: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py index 117aee73353b..3960cbfc5f22 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py @@ -6,7 +6,7 @@ # pylint: disable=c-extension-no-member from enum import Enum -from typing import cast +from typing import cast, Literal, Union from azure.core import CaseInsensitiveEnumMeta @@ -19,6 +19,12 @@ class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): CRC64 = "crc64" +def is_md5_validation(validate_content: Union[bool, Literal["md5", "crc64"]]) -> bool: + if isinstance(validate_content, bool): + return validate_content + return validate_content == ChecksumAlgorithm.MD5 + + def calculate_crc64(data: bytes, initial_crc: int) -> int: # Locally import to avoid error if not installed. from azure.storage.extensions import crc64 diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py index 13990403b666..4d24bd015f4f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -43,6 +43,11 @@ def assert_structured_message(request): assert request.http_request.headers.get('x-ms-structured-body') is not None +def assert_structured_message_get(response): + assert response.http_request.headers.get('x-ms-structured-body') is not None + assert response.http_response.headers.get('x-ms-structured-body') is not None + + class TestIter: def __init__(self, data, *, chunk_size=100): self.data = data @@ -355,3 +360,145 @@ def test_upload_page(self, a, **kwargs): # Assert content = blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) assert content.read() == data1 + data2_encoded + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_blob(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b'abc' * 512 + blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_blob_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b'abc' * 512 + b'abcde' + blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + read_content = b'' + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + for _ in range(len(data) // 100 + 1): + read_content += downloader.read(100) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_blob_chunks_partial(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b'abc' * 512 + b'abcde' + blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[10:1010] + + @BlobPreparer() + @pytest.mark.live_test_only + def test_download_blob_large_chunks(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.container._config.max_chunk_get_size = 10 * 1024 * 1024 + data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + blob.upload_blob(data, overwrite=True, max_concurrency=5) + + # Act + downloader = blob.download_blob(validate_content='crc64', max_concurrency=3) + content = downloader.read() + + downloader = blob.download_blob(offset=5 * 1024 * 1024, length=25 *1024 * 1024) + partial = downloader.read() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024: 30 * 1024 * 1024] + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_blob_chars(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + + data = '你好世界' * 256 # 3 KiB + blob = self.container.get_blob_client(self._get_blob_reference()) + blob.upload_blob(data, encoding='utf-8', overwrite=True) + + stream = blob.download_blob(encoding='utf-8', validate_content=a) + assert stream.read() == data + + stream = blob.download_blob(encoding='utf-8', validate_content=a) + assert stream.read(chars=100000) == data + + result = '' + stream = blob.download_blob(encoding='utf-8', validate_content=a) + for _ in range(4): + chunk = stream.read(chars=100) + result += chunk + assert len(chunk) == 100 + + result += stream.readall() + assert result == data diff --git a/sdk/storage/azure-storage-blob/tests/test_streams.py b/sdk/storage/azure-storage-blob/tests/test_streams.py index 3665886aaa71..874c8e4a912f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_streams.py +++ b/sdk/storage/azure-storage-blob/tests/test_streams.py @@ -8,12 +8,12 @@ import os import random from io import BytesIO, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET -from typing import List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Tuple, Union import pytest from azure.storage.blob._shared.streams import ( StructuredMessageConstants, - StructuredMessageDecodeStream, + StructuredMessageDecoder, StructuredMessageEncodeStream, StructuredMessageProperties, ) @@ -22,6 +22,12 @@ from test_helpers import NonSeekableStream +def _iter_bytes(data: bytes, chunk_size: int = 1024) -> Iterator[bytes]: + """Convert bytes to an Iterator[bytes] with the given chunk size.""" + for i in range(0, len(data), chunk_size): + yield data[i:i + chunk_size] + + def _write_segment( number: int, data: bytes, @@ -406,17 +412,17 @@ def test_partial_stream_seek_middle(self, flags): assert result == expected[100:] -class TestStructuredMessageDecodeStream: +class TestStructuredMessageDecoder: def test_empty_inner_stream(self): with pytest.raises(ValueError): - StructuredMessageDecodeStream(BytesIO(), 0) + StructuredMessageDecoder(iter([b'']), 0) def test_read_past_end(self): data = os.urandom(10) message_stream, length = _build_structured_message(data, len(data), StructuredMessageProperties.CRC64) - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) result = stream.read(100) assert result == data @@ -443,7 +449,7 @@ def test_read_all(self, size, segment_size, flags): data = os.urandom(size) message_stream, length = _build_structured_message(data, segment_size, flags) - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) content = stream.read() assert content == data @@ -464,7 +470,7 @@ def test_read_chunks(self, size, segment_size, chunk_size, flags): data = os.urandom(size) message_stream, length = _build_structured_message(data, segment_size, flags) - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) read = 0 content = b'' while read < len(data): @@ -479,7 +485,7 @@ def test_random_reads(self, data_size): data = os.urandom(data_size) message_stream, length = _build_structured_message(data, data_size // 3, StructuredMessageProperties.CRC64) - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) count = 0 content = b'' @@ -492,6 +498,36 @@ def test_random_reads(self, data_size): assert content == data + @pytest.mark.parametrize("size, segment_size, block_size, flags", [ + (0, 1, 1, StructuredMessageProperties.NONE), + (0, 1, 1, StructuredMessageProperties.CRC64), + (1, 1, 1, StructuredMessageProperties.NONE), + (1, 1, 1, StructuredMessageProperties.CRC64), + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (1024, 200, 1024, StructuredMessageProperties.NONE), + (1024, 200, 1024, StructuredMessageProperties.CRC64), + (1024, 200, 50, StructuredMessageProperties.NONE), + (1024, 200, 50, StructuredMessageProperties.CRC64), + (100 * 1024, 4 * 1024, 7 * 1024, StructuredMessageProperties.NONE), + (100 * 1024, 4 * 1024, 7 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_iterate(self, size, segment_size, block_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + decoder = StructuredMessageDecoder( + _iter_bytes(message_stream.getvalue()), length, block_size=block_size) + content = b''.join(decoder) + + assert content == data + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) def test_random_segment_sizes(self, data_size): segment_sizes = [] @@ -505,7 +541,7 @@ def test_random_segment_sizes(self, data_size): data = os.urandom(data_size) message_stream, length = _build_structured_message(data, segment_sizes, StructuredMessageProperties.CRC64) - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) content = stream.read() assert content == data @@ -519,7 +555,7 @@ def test_crc64_mismatch_read_all(self, invalid_segment): StructuredMessageProperties.CRC64, invalid_segment) - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) with pytest.raises(ValueError) as e: stream.read() assert 'CRC64 mismatch' in str(e.value) @@ -533,7 +569,7 @@ def test_crc64_mismatch_read_chunks(self, invalid_segment): StructuredMessageProperties.CRC64, invalid_segment) - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) # Since we only check CRC on segment borders, some reads will succeed, but we test # to ensure eventually the stream reading will error out. with pytest.raises(ValueError) as e: @@ -548,11 +584,11 @@ def test_invalid_message_version(self, flags): data = os.urandom(1024) message_stream, length = _build_structured_message(data, 512, flags) - # Stream already set to front - message_stream.write(b'\xFF') - message_stream.seek(0) + # Corrupt the version byte + raw = bytearray(message_stream.getvalue()) + raw[0] = 0xFF - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) with pytest.raises(ValueError): stream.read() @@ -562,11 +598,10 @@ def test_incorrect_message_length(self, message_length, flags): data = os.urandom(1024) message_stream, length = _build_structured_message(data, 512, flags) - message_stream.seek(1) - message_stream.write(int.to_bytes(message_length, 8, 'little')) - message_stream.seek(0) + raw = bytearray(message_stream.getvalue()) + raw[1:9] = int.to_bytes(message_length, 8, 'little') - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) with pytest.raises(ValueError): stream.read() @@ -576,11 +611,10 @@ def test_incorrect_segment_count(self, segment_count, flags): data = os.urandom(1024) message_stream, length = _build_structured_message(data, 256, flags) - message_stream.seek(11) - message_stream.write(int.to_bytes(segment_count, 2, 'little')) - message_stream.seek(0) + raw = bytearray(message_stream.getvalue()) + raw[11:13] = int.to_bytes(segment_count, 2, 'little') - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) with pytest.raises(ValueError): stream.read() @@ -595,11 +629,10 @@ def test_incorrect_segment_number(self, segment_number, flags): StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + 256 + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) - message_stream.seek(position) - message_stream.write(int.to_bytes(segment_number, 2, 'little')) - message_stream.seek(0) + raw = bytearray(message_stream.getvalue()) + raw[position:position + 2] = int.to_bytes(segment_number, 2, 'little') - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) with pytest.raises(ValueError): stream.read() @@ -615,11 +648,10 @@ def test_incorrect_segment_size(self, segment_size, flags): 256 + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0) + 2) - message_stream.seek(position) - message_stream.write(int.to_bytes(segment_size, 2, 'little')) - message_stream.seek(0) + raw = bytearray(message_stream.getvalue()) + raw[position:position + 8] = int.to_bytes(segment_size, 8, 'little') - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) with pytest.raises(ValueError): stream.read() @@ -629,10 +661,9 @@ def test_incorrect_segment_size_single_segment(self, segment_size, flags): data = os.urandom(256) message_stream, length = _build_structured_message(data, 256, flags) - message_stream.seek(15) - message_stream.write(int.to_bytes(segment_size, 2, 'little')) - message_stream.seek(0) + raw = bytearray(message_stream.getvalue()) + raw[15:23] = int.to_bytes(segment_size, 8, 'little') - stream = StructuredMessageDecodeStream(message_stream, length) + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) with pytest.raises(ValueError): stream.read() From c9bc4ed3e726c456fcc68de477e66b79f8de6c8e Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Tue, 24 Mar 2026 15:25:34 -0700 Subject: [PATCH 04/14] [Storage][102] CRC64 content validation - part 3 (#45861) --- .../storage/blob/_shared/base_client_async.py | 5 +- .../azure/storage/blob/_shared/policies.py | 212 +++++----- .../storage/blob/_shared/policies_async.py | 42 +- .../azure/storage/blob/_shared/streams.py | 51 ++- .../storage/blob/_shared/streams_async.py | 198 ++++++++++ .../azure/storage/blob/_shared/validation.py | 27 +- .../azure/storage/blob/aio/_download_async.py | 13 +- .../tests/test_content_validation_async.py | 149 +++++++ .../tests/test_streams_async.py | 372 ++++++++++++++++++ 9 files changed, 932 insertions(+), 137 deletions(-) create mode 100644 sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py create mode 100644 sdk/storage/azure-storage-blob/tests/test_streams_async.py diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 16aba3116029..2c917610eade 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -36,12 +36,11 @@ from .parser import DEVSTORE_ACCOUNT_KEY, _get_development_storage_endpoint from .policies import ( QueueMessagePolicy, - StorageContentValidation, StorageHeadersPolicy, StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook +from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncContentValidationPolicy, AsyncStorageResponseHook from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken @@ -130,7 +129,7 @@ def _create_pipeline( QueueMessagePolicy(), config.proxy_policy, config.user_agent_policy, - StorageContentValidation(), + AsyncContentValidationPolicy(), ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=hosts, **kwargs), diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index f3568df83140..b8d7ea6b9702 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -35,7 +35,13 @@ from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode from .streams import StructuredMessageDecoder, StructuredMessageEncodeStream, StructuredMessageProperties -from .validation import calculate_crc64_bytes, ChecksumAlgorithm, is_md5_validation +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_md5_validation, + ChecksumAlgorithm, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -52,7 +58,6 @@ SM_HEADER = "x-ms-structured-body" SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" SM_LENGTH_HEADER = "x-ms-structured-content-length" -CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." def encode_base64(data: Union[bytes, str]) -> str: @@ -118,7 +123,7 @@ def is_checksum_retry(response) -> bool: # Legacy code - evaluate retry only on validate_content=True if validate_content is True and response.http_response.headers.get("content-md5"): computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True @@ -365,118 +370,117 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. +def _prepare_content_validation(request: "PipelineRequest") -> None: + """Shared request-side logic for content validation. - This will overwrite any headers already defined in the request. + Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 + validation, and stores the validation mode in the request context. """ - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() - - @staticmethod - def get_content_md5(data): - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError(CV_TYPE_ERROR_MSG) from exc - else: - raise ValueError(CV_TYPE_ERROR_MSG) + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return + + # Download + if request.http_request.method == "GET": + if validate_content == ChecksumAlgorithm.CRC64: + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + + # Upload + else: + # Since HTTP does not differentiate between no content and empty content, + # we have to perform a None check. + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif validate_content == ChecksumAlgorithm.CRC64: + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) + elif hasattr(data, "read"): + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) - def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if not validate_content: - return + request.context["validate_content"] = validate_content - # Download - if request.http_request.method == "GET": - if validate_content == ChecksumAlgorithm.CRC64: - request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - # Upload - else: - # Since HTTP does not differentiate between no content and empty content, - # we have to perform a None check. - data = request.http_request.data or b"" - if is_md5_validation(validate_content): - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(data)) - request.http_request.headers[MD5_HEADER] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - - elif validate_content == ChecksumAlgorithm.CRC64: - if isinstance(data, bytes): - request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) - elif hasattr(data, "read"): - content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) - # Wrap data in structured message stream and adjust HTTP request - sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) - request.http_request.data = sm_stream - request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) - request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) - request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - else: - raise ValueError(CV_TYPE_ERROR_MSG) +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + """Shared response-side logic for content validation. - request.context["validate_content"] = validate_content + Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches + ``stream_download`` to wrap the iterator in the given *decoder_cls*. + """ + validate_content = response.context.get("validate_content", False) + if not validate_content: + return - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - validate_content = response.context.get("validate_content", False) - if not validate_content: - return + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) - if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + elif validate_content == ChecksumAlgorithm.CRC64: + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}", + ), + response=response.http_response, ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, - ) - elif validate_content == ChecksumAlgorithm.CRC64: - # For upload and download verify structured message header present in response if provided in request. - sm_request = request.http_request.headers.get(SM_HEADER) - sm_response = response.http_response.headers.get(SM_HEADER) - if sm_request != sm_response: - raise AzureError( - ( - f"Expected structured message header in response does not match request. " - f"Request: {sm_request}, Response: {sm_response}", - ), - response=response.http_response, - ) - - if response.http_request.method == "GET": - # Raises exception if missing - content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) - - # Patch response to return response iterator wrapped in structured message decoder - original_stream_download = response.http_response.stream_download - def wrapped_stream_download(*args, **kwargs): - iterator = original_stream_download(*args, **kwargs) - decoder = StructuredMessageDecoder(iterator, content_length, block_size=DATA_BLOCK_SIZE) - decoder.request = iterator.request # type: ignore - decoder.response = iterator.response # type: ignore - return decoder - - response.http_response.stream_download = wrapped_stream_download + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + def on_request(self, request: "PipelineRequest") -> None: + _prepare_content_validation(request) + + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index e20e5db84860..860f10e93089 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -15,8 +15,18 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy -from .validation import ChecksumAlgorithm +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -50,13 +60,39 @@ async def is_checksum_retry(response): except (StreamClosedError, StreamConsumedError): pass computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py index 40fa93c2dac5..e04d666eab5e 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -47,6 +47,29 @@ def generate_segment_header(number: int, size: int) -> bytes: size.to_bytes(8, 'little')) +def parse_message_header( + data: bytes, expected_message_length: int +) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], 'little') + if message_length != expected_message_length: + raise ValueError(f"Structured message length {message_length} " + f"did not match content length {expected_message_length}") + flags = StructuredMessageProperties(int.from_bytes(data[9:11], 'little')) + num_segments = int.from_bytes(data[11:13], 'little') + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], 'little') + if segment_number != expected_segment_number: + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") + segment_content_length = int.from_bytes(data[2:10], 'little') + return segment_number, segment_content_length + + class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes message_version: int content_length: int @@ -527,22 +550,10 @@ def _read_from_inner(self, size: int) -> bytes: return data def _read_message_header(self) -> None: - # The first byte should always be the message version - self.message_version = int.from_bytes(self._read_from_inner(1), 'little') - - if self.message_version == 1: - message_length = int.from_bytes(self._read_from_inner(8), 'little') - if message_length != self.message_length: - raise ValueError(f"Structured message length {message_length} " - f"did not match content length {self.message_length}") - - self.flags = StructuredMessageProperties(int.from_bytes(self._read_from_inner(2), 'little')) - self.num_segments = int.from_bytes(self._read_from_inner(2), 'little') - - self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH - - else: - raise ValueError(f"The structured message version is not supported: {self.message_version}") + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH def _read_message_footer(self) -> None: # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. @@ -560,11 +571,9 @@ def _read_message_footer(self) -> None: self._message_offset += self._message_footer_length def _read_segment_header(self) -> None: - segment_number = int.from_bytes(self._read_from_inner(2), 'little') - if segment_number != self._segment_number + 1: - raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") - self._segment_number = segment_number - self._segment_content_length = int.from_bytes(self._read_from_inner(8), 'little') + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py new file mode 100644 index 000000000000..0bd608d02379 --- /dev/null +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py @@ -0,0 +1,198 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import StructuredMessageConstants, StructuredMessageProperties, parse_message_header, parse_segment_header +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__(self, inner_iterator: AsyncIterator[bytes], content_length: int, *, block_size: int = 4096) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b'' + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b'' + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + await self._read_message_footer() + return b'' + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, 'little'): + raise ValueError("CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid.") + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, 'little'): + raise ValueError(f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid.") + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py index 3960cbfc5f22..329ef7517d9b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py @@ -5,12 +5,15 @@ # -------------------------------------------------------------------------- # pylint: disable=c-extension-no-member +import hashlib from enum import Enum -from typing import cast, Literal, Union +from io import SEEK_SET +from typing import IO, cast, Literal, Union from azure.core import CaseInsensitiveEnumMeta CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): @@ -25,6 +28,28 @@ def is_md5_validation(validate_content: Union[bool, Literal["md5", "crc64"]]) -> return validate_content == ChecksumAlgorithm.MD5 +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + def calculate_crc64(data: bytes, initial_crc: int) -> int: # Locally import to avoid error if not installed. from azure.storage.extensions import crc64 diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index 30cbb0c68fbf..ed65a88f78e2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -24,6 +24,7 @@ from .._shared.request_handlers import validate_and_format_range_headers from .._shared.response_handlers import parse_length_from_content_range, process_storage_error from .._shared.constants import DEFAULT_MAX_CONCURRENCY +from .._shared.validation import is_md5_validation from .._deserialize import deserialize_blob_properties, get_page_ranges_result from .._download import process_range_and_offset, _ChunkDownloader from .._encryption import ( @@ -124,7 +125,7 @@ async def _download_chunk(self, chunk_start: int, chunk_end: int) -> Tuple[bytes range_header, range_validation = validate_and_format_range_headers( download_range[0], download_range[1], - check_content_md5=self.validate_content + check_content_md5=is_md5_validation(self.validate_content) ) retry_active = True @@ -267,6 +268,7 @@ def __init__( self._file_size = 0 self._non_empty_ranges = None self._encryption_data: Optional["_EncryptionData"] = None + self._is_structured_message = False # The content download offset, after any processing (decryption), in bytes self._download_offset = 0 @@ -318,10 +320,10 @@ async def _setup(self) -> None: await self._get_encryption_data_request() # The service only provides transactional MD5s for chunks under 4MB. - # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. first_get_size = ( - self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size + self._config.max_single_get_size if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size ) initial_request_start = self._start_range if self._start_range is not None else 0 if self._end_range is not None and self._end_range - initial_request_start < first_get_size: @@ -356,7 +358,7 @@ async def _setup(self) -> None: @property def _download_complete(self): - if is_encryption_v2(self._encryption_data): + if is_encryption_v2(self._encryption_data) or self._is_structured_message: return self._download_offset >= self.size return self._raw_download_offset >= self.size @@ -366,7 +368,7 @@ async def _initial_request(self): self._initial_range[1], start_range_required=False, end_range_required=False, - check_content_md5=self._validate_content + check_content_md5=is_md5_validation(self._validate_content) ) retry_active = True @@ -449,6 +451,7 @@ async def _initial_request(self): except HttpResponseError: pass + self._is_structured_message = response.response.headers.get("x-ms-structured-body") is not None if not self._download_complete and self._request_options.get("modified_access_conditions"): self._request_options["modified_access_conditions"].if_match = response.properties.etag diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py index dc2c885adcef..a0fdcbee01d8 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -24,7 +24,9 @@ from test_content_validation import ( assert_content_crc64, assert_content_md5, + assert_content_md5_get, assert_structured_message, + assert_structured_message_get, TestIter ) @@ -340,3 +342,150 @@ async def test_upload_page(self, a, **kwargs): content = await blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) assert await content.read() == data1 + data2_encoded await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_blob(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b'abc' * 512 + await blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_blob_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b'abc' * 512 + b'abcde' + await blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + read_content = b'' + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + for _ in range(len(data) // 100 + 1): + read_content += await downloader.read(100) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_blob_chunks_partial(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b'abc' * 512 + b'abcde' + await blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = await blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[10:1010] + await self._teardown() + + @BlobPreparer() + @pytest.mark.live_test_only + async def test_download_blob_large_chunks(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.container._config.max_chunk_get_size = 10 * 1024 * 1024 + data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + await blob.upload_blob(data, overwrite=True, max_concurrency=5) + + # Act + downloader = await blob.download_blob(validate_content='crc64', max_concurrency=3) + content = await downloader.read() + + downloader = await blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024) + partial = await downloader.read() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024: 30 * 1024 * 1024] + await self._teardown() + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_blob_chars(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + + data = '你好世界' * 256 # 3 KiB + blob = self.container.get_blob_client(self._get_blob_reference()) + await blob.upload_blob(data, encoding='utf-8', overwrite=True) + + stream = await blob.download_blob(encoding='utf-8', validate_content=a) + assert await stream.read() == data + + stream = await blob.download_blob(encoding='utf-8', validate_content=a) + assert await stream.read(chars=100000) == data + + result = '' + stream = await blob.download_blob(encoding='utf-8', validate_content=a) + for _ in range(4): + chunk = await stream.read(chars=100) + result += chunk + assert len(chunk) == 100 + + result += await stream.readall() + assert result == data + await self._teardown() diff --git a/sdk/storage/azure-storage-blob/tests/test_streams_async.py b/sdk/storage/azure-storage-blob/tests/test_streams_async.py new file mode 100644 index 000000000000..e49cc3becd72 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_streams_async.py @@ -0,0 +1,372 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import os +import random +from io import BytesIO +from typing import AsyncIterator, List, Optional, Tuple, Union + +import pytest +import pytest_asyncio +from azure.storage.blob._shared.streams import ( + StructuredMessageConstants, + StructuredMessageProperties, +) +from azure.storage.blob._shared.streams_async import AsyncStructuredMessageDecoder +from azure.storage.extensions import crc64 + + +async def _async_iter_bytes(data: bytes, chunk_size: int = 1024) -> AsyncIterator[bytes]: + """Convert bytes to an AsyncIterator[bytes] with the given chunk size.""" + for i in range(0, len(data), chunk_size): + yield data[i:i + chunk_size] + + +def _write_segment( + number: int, + data: bytes, + data_crc: Optional[int], + stream: BytesIO, +) -> None: + stream.write(number.to_bytes(2, 'little')) # Segment number + stream.write(len(data).to_bytes(8, 'little')) # Segment length + stream.write(data) # Segment content + if data_crc is not None: + stream.write(data_crc.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little')) + + +def _build_structured_message( + data: bytes, + segment_size: Union[int, List[int]], + flags: StructuredMessageProperties, + invalidate_crc_segment: Optional[int] = None, +) -> Tuple[BytesIO, int]: + if isinstance(segment_size, list): + segment_count = len(segment_size) + else: + segment_count = math.ceil(len(data) / segment_size) or 1 + segment_footer_length = StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0 + + message_length = ( + StructuredMessageConstants.V1_HEADER_LENGTH + + ((StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + segment_footer_length) * segment_count) + + len(data) + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) + + message = BytesIO() + message_crc = 0 + + # Message Header + message.write(b'\x01') # Version + message.write(message_length.to_bytes(8, 'little')) # Message length + message.write(int(flags).to_bytes(2, 'little')) # Flags + message.write(segment_count.to_bytes(2, 'little')) # Num segments + + # Special case for 0 length content + if len(data) == 0: + crc = 0 if StructuredMessageProperties.CRC64 in flags else None + _write_segment(1, data, crc, message) + else: + # Segments + segment_sizes = segment_size if isinstance(segment_size, list) else [segment_size] * segment_count + offset = 0 + for i in range(1, segment_count + 1): + size = segment_sizes[i - 1] + segment_data = data[offset: offset + size] + offset += size + + segment_crc = None + if StructuredMessageProperties.CRC64 in flags: + segment_crc = crc64.compute(segment_data, 0) + if i == invalidate_crc_segment: + segment_crc += 5 + _write_segment(i, segment_data, segment_crc, message) + + message_crc = crc64.compute(segment_data, message_crc) + + # Message footer + if StructuredMessageProperties.CRC64 in flags: + if invalidate_crc_segment == -1: + message_crc += 5 + message.write(message_crc.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little')) + + message.seek(0, 0) + return message, message_length + + +class TestAsyncStructuredMessageDecoder: + + @pytest.mark.asyncio + async def test_empty_inner_stream(self): + with pytest.raises(ValueError): + AsyncStructuredMessageDecoder(_async_iter_bytes(b''), 0) + + @pytest.mark.asyncio + async def test_read_past_end(self): + data = os.urandom(10) + message_stream, length = _build_structured_message(data, len(data), StructuredMessageProperties.CRC64) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + result = await stream.read(100) + assert result == data + + result = await stream.read(100) + assert result == b'' + + @pytest.mark.asyncio + @pytest.mark.parametrize("size, segment_size, flags", [ + (0, 1, StructuredMessageProperties.NONE), + (0, 1, StructuredMessageProperties.CRC64), + (10, 1, StructuredMessageProperties.NONE), + (10, 1, StructuredMessageProperties.CRC64), + (1024, 1024, StructuredMessageProperties.NONE), + (1024, 1024, StructuredMessageProperties.CRC64), + (1024, 512, StructuredMessageProperties.NONE), + (1024, 512, StructuredMessageProperties.CRC64), + (1024, 200, StructuredMessageProperties.NONE), + (1024, 200, StructuredMessageProperties.CRC64), + (123456, 1234, StructuredMessageProperties.NONE), + (123456, 1234, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + async def test_read_all(self, size, segment_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + content = await stream.read() + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("size, segment_size, chunk_size, flags", [ + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + async def test_read_chunks(self, size, segment_size, chunk_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + read = 0 + content = b'' + while read < len(data): + chunk = await stream.read(chunk_size) + content += chunk + read += chunk_size + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + async def test_random_reads(self, data_size): + data = os.urandom(data_size) + message_stream, length = _build_structured_message(data, data_size // 3, StructuredMessageProperties.CRC64) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + + count = 0 + content = b'' + while count < data_size: + read_size = random.randint(10, data_size // 3) + read_size = min(read_size, data_size - count) + + content += await stream.read(read_size) + count += read_size + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("size, segment_size, block_size, flags", [ + (0, 1, 1, StructuredMessageProperties.NONE), + (0, 1, 1, StructuredMessageProperties.CRC64), + (1, 1, 1, StructuredMessageProperties.NONE), + (1, 1, 1, StructuredMessageProperties.CRC64), + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (1024, 200, 1024, StructuredMessageProperties.NONE), + (1024, 200, 1024, StructuredMessageProperties.CRC64), + (1024, 200, 50, StructuredMessageProperties.NONE), + (1024, 200, 50, StructuredMessageProperties.CRC64), + (100 * 1024, 4 * 1024, 7 * 1024, StructuredMessageProperties.NONE), + (100 * 1024, 4 * 1024, 7 * 1024, StructuredMessageProperties.CRC64), + ]) + async def test_iterate(self, size, segment_size, block_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + decoder = AsyncStructuredMessageDecoder( + _async_iter_bytes(message_stream.getvalue()), length, block_size=block_size) + content = b'' + async for chunk in decoder: + content += chunk + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + async def test_random_segment_sizes(self, data_size): + segment_sizes = [] + count = 0 + while count < data_size: + size = random.randint(10, data_size // 3) + size = min(size, data_size - count) + segment_sizes.append(size) + count += size + + data = os.urandom(data_size) + message_stream, length = _build_structured_message(data, segment_sizes, StructuredMessageProperties.CRC64) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + content = await stream.read() + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("invalid_segment", [1, 2, 3, -1]) + async def test_crc64_mismatch_read_all(self, invalid_segment): + data = os.urandom(3 * 1024) + message_stream, length = _build_structured_message( + data, + 1024, + StructuredMessageProperties.CRC64, + invalid_segment) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + with pytest.raises(ValueError) as e: + await stream.read() + assert 'CRC64 mismatch' in str(e.value) + + @pytest.mark.asyncio + @pytest.mark.parametrize("invalid_segment", [1, 2, 3, -1]) + async def test_crc64_mismatch_read_chunks(self, invalid_segment): + data = os.urandom(3 * 1024) + message_stream, length = _build_structured_message( + data, + 1024, + StructuredMessageProperties.CRC64, + invalid_segment) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + # Since we only check CRC on segment borders, some reads will succeed, but we test + # to ensure eventually the stream reading will error out. + with pytest.raises(ValueError) as e: + read = 0 + while read < len(data): + await stream.read(512) + + assert 'CRC64 mismatch' in str(e.value) + + @pytest.mark.asyncio + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_invalid_message_version(self, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 512, flags) + + # Corrupt the version byte + raw = bytearray(message_stream.getvalue()) + raw[0] = 0xFF + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("message_length", [100, 1234567]) # Correct value: 1057 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_message_length(self, message_length, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 512, flags) + + raw = bytearray(message_stream.getvalue()) + raw[1:9] = int.to_bytes(message_length, 8, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("segment_count", [2, 123]) # Correct value: 4 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_segment_count(self, segment_count, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + raw = bytearray(message_stream.getvalue()) + raw[11:13] = int.to_bytes(segment_count, 2, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("segment_number", [1, 123]) # Correct value: 2 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_segment_number(self, segment_number, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + # Change the second segment to be the incorrect number + position = (StructuredMessageConstants.V1_HEADER_LENGTH + + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + 256 + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) + raw = bytearray(message_stream.getvalue()) + raw[position:position + 2] = int.to_bytes(segment_number, 2, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("segment_size", [123, 345]) # Correct value: 256 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_segment_size(self, segment_size, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + # Change the second segment to be the incorrect size + position = (StructuredMessageConstants.V1_HEADER_LENGTH + + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + 256 + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0) + + 2) + raw = bytearray(message_stream.getvalue()) + raw[position:position + 8] = int.to_bytes(segment_size, 8, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("segment_size", [123, 345]) # Correct value: 256 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_segment_size_single_segment(self, segment_size, flags): + data = os.urandom(256) + message_stream, length = _build_structured_message(data, 256, flags) + + raw = bytearray(message_stream.getvalue()) + raw[15:23] = int.to_bytes(segment_size, 8, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() From fdd9b13e74e92e7f1a8825bdebd5f0e163054fd3 Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Fri, 27 Mar 2026 12:10:57 -0700 Subject: [PATCH 05/14] [Storage][102] CRC64 content validation - part 4 (#45949) --- .../azure/storage/blob/_blob_client.py | 75 +- .../azure/storage/blob/_blob_client.pyi | 14 +- .../storage/blob/_blob_client_helpers.py | 16 +- .../azure/storage/blob/_container_client.py | 28 +- .../azure/storage/blob/_container_client.pyi | 9 +- .../azure/storage/blob/_download.py | 6 +- .../azure/storage/blob/_shared/policies.py | 208 ++++-- .../storage/blob/_shared/policies_async.py | 118 ++- .../azure/storage/blob/_shared/streams.py | 285 ++++--- .../storage/blob/_shared/streams_async.py | 106 ++- .../azure/storage/blob/_shared/validation.py | 47 +- .../azure/storage/blob/_upload_helpers.py | 10 +- .../storage/blob/aio/_blob_client_async.py | 75 +- .../storage/blob/aio/_blob_client_async.pyi | 14 +- .../blob/aio/_container_client_async.py | 28 +- .../blob/aio/_container_client_async.pyi | 9 +- .../azure/storage/blob/aio/_download_async.py | 4 +- .../azure/storage/blob/aio/_upload_helpers.py | 10 +- .../azure-storage-blob/dev_requirements.txt | 1 + .../tests/test_content_validation.py | 59 +- .../tests/test_content_validation_async.py | 56 +- .../storage/filedatalake/_shared/policies.py | 369 ++++++--- .../filedatalake/_shared/policies_async.py | 165 +++- .../storage/filedatalake/_shared/streams.py | 703 ++++++++++++++++++ .../filedatalake/_shared/streams_async.py | 248 ++++++ .../filedatalake/_shared/validation.py | 105 +++ .../storage/fileshare/_shared/policies.py | 369 ++++++--- .../fileshare/_shared/policies_async.py | 165 +++- .../storage/fileshare/_shared/streams.py | 703 ++++++++++++++++++ .../fileshare/_shared/streams_async.py | 248 ++++++ .../storage/fileshare/_shared/validation.py | 105 +++ .../azure/storage/queue/_shared/policies.py | 325 +++++--- .../storage/queue/_shared/policies_async.py | 139 +++- .../azure/storage/queue/_shared/streams.py | 703 ++++++++++++++++++ .../storage/queue/_shared/streams_async.py | 248 ++++++ .../azure/storage/queue/_shared/validation.py | 105 +++ 36 files changed, 5089 insertions(+), 789 deletions(-) create mode 100644 sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py create mode 100644 sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py create mode 100644 sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py create mode 100644 sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py create mode 100644 sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py create mode 100644 sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py create mode 100644 sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py create mode 100644 sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py create mode 100644 sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index b3dafa603afd..89b49dac7a94 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -65,6 +65,7 @@ from ._quick_query_helper import BlobQueryReader from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin, TransportWrapper from ._shared.response_handlers import process_storage_error, return_response_headers +from ._shared.validation import ChecksumAlgorithm, parse_validation_option from ._serialize import ( get_access_conditions, get_api_version, @@ -505,15 +506,11 @@ def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, upload_blob only succeeds if the blob's lease is active and matches this ID. Value can be a BlobLeaseClient object @@ -616,6 +613,9 @@ def upload_blob( raise ValueError("Encryption required but no key was provided.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _upload_blob_options( data=data, blob_type=blob_type, @@ -627,6 +627,7 @@ def upload_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -683,15 +684,11 @@ def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a @@ -765,6 +762,9 @@ def download_blob( raise ValueError("Offset value must not be None if length is set.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _download_blob_options( blob_name=self.blob_name, container_name=self.container_name, @@ -778,6 +778,7 @@ def download_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -2009,15 +2010,11 @@ def stage_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -2850,13 +2847,11 @@ def upload_page( Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.BlobLeaseClient or str - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int if_sequence_number_lte: If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. @@ -3157,13 +3152,11 @@ def append_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int maxsize_condition: Optional conditional header. The max length in bytes permitted for the append blob. If the Append Block operation would cause the blob diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi index fee0f4757f79..f5679a595ad8 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi @@ -173,7 +173,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): tags: Optional[Dict[str, str]] = None, overwrite: bool = False, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[BlobLeaseClient] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -200,7 +200,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -222,7 +222,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -244,7 +244,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -486,7 +486,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, encoding: Optional[str] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, @@ -671,7 +671,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: int, *, lease: Optional[Union[BlobLeaseClient, str]] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, if_sequence_number_lte: Optional[int] = None, if_sequence_number_lt: Optional[int] = None, if_sequence_number_eq: Optional[int] = None, @@ -741,7 +741,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, maxsize_condition: Optional[int] = None, appendpos_condition: Optional[int] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py index 33d5dfc7f0b2..1bbfb33901d3 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py @@ -8,7 +8,7 @@ from io import BytesIO from typing import ( Any, AnyStr, AsyncGenerator, AsyncIterable, cast, - Dict, IO, Iterable, List, Optional, Tuple, Union, + Dict, IO, Iterable, List, Literal, Optional, Tuple, Union, TYPE_CHECKING ) from urllib.parse import quote, unquote, urlparse @@ -58,6 +58,7 @@ from ._shared.response_handlers import return_headers_and_deserialized, return_response_headers from ._shared.uploads import IterStreamer from ._shared.uploads_async import AsyncIterStreamer +from ._shared.validation import parse_validation_option from ._upload_helpers import _any_conditions if TYPE_CHECKING: @@ -110,6 +111,7 @@ def _upload_blob_options( # pylint:disable=too-many-statements length: Optional[int], metadata: Optional[Dict[str, str]], encryption_options: Dict[str, Any], + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]], config: "StorageConfiguration", sdk_moniker: str, client: "AzureBlobStorage", @@ -135,7 +137,6 @@ def _upload_blob_options( # pylint:disable=too-many-statements else: raise TypeError(f"Unsupported data type: {type(data)}") - validate_content = kwargs.pop('validate_content', False) content_settings = kwargs.pop('content_settings', None) overwrite = kwargs.pop('overwrite', False) max_concurrency = kwargs.pop('max_concurrency', None) @@ -258,6 +259,7 @@ def _download_blob_options( length: Optional[int], encoding: Optional[str], encryption_options: Dict[str, Any], + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]], config: "StorageConfiguration", sdk_moniker: str, client: "AzureBlobStorage", @@ -279,6 +281,8 @@ def _download_blob_options( Encoding to decode the downloaded bytes. Default is None, i.e. no decoding. :param Dict[str, Any] encryption_options: The options for encryption, if enabled. + :param validate_content: + Enables checksum validation for the transfer. Already parsed via parse_validation_option. :param StorageConfiguration config: The Storage configuration options. :param str sdk_moniker: @@ -292,8 +296,6 @@ def _download_blob_options( if offset is None: raise ValueError("Offset must be provided if length is provided.") length = offset + length - 1 # Service actually uses an end-range inclusive index - - validate_content = kwargs.pop('validate_content', False) access_conditions = get_access_conditions(kwargs.pop('lease', None)) mod_conditions = get_modify_conditions(kwargs) @@ -721,7 +723,7 @@ def _stage_block_options( if isinstance(data, bytes): data = data[:length] - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) cpk_scope_info = get_cpk_scope_info(kwargs) cpk = kwargs.pop('cpk', None) cpk_info = None @@ -1004,7 +1006,7 @@ def _upload_page_options( ) mod_conditions = get_modify_conditions(kwargs) cpk_scope_info = get_cpk_scope_info(kwargs) - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) cpk = kwargs.pop('cpk', None) cpk_info = None if cpk: @@ -1149,7 +1151,7 @@ def _append_block_options( appendpos_condition = kwargs.pop('appendpos_condition', None) maxsize_condition = kwargs.pop('maxsize_condition', None) - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) append_conditions = None if maxsize_condition or appendpos_condition is not None: append_conditions = AppendPositionAccessConditions( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index 6d1a96d4bb88..dfba8f54eb63 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -1010,15 +1010,11 @@ def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the container has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -1255,15 +1251,11 @@ def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi index 9ce6d9b7acdb..2ad0514290d3 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi @@ -13,6 +13,7 @@ from typing import ( Callable, Dict, List, + Literal, IO, Iterable, Iterator, @@ -246,7 +247,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): *, overwrite: Optional[bool] = None, content_settings: Optional[ContentSettings] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -288,7 +289,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -309,7 +310,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -331,7 +332,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py index 47b49365a6b1..ba7018772413 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py @@ -11,7 +11,7 @@ from io import BytesIO, StringIO from typing import ( Any, Callable, cast, Dict, Generator, - Generic, IO, Iterator, List, Optional, + Generic, IO, Iterator, List, Literal, Optional, overload, Tuple, TypeVar, Union, TYPE_CHECKING ) @@ -92,7 +92,7 @@ def __init__( current_progress: int, start_range: int, end_range: int, - validate_content: bool, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], encryption_options: Dict[str, Any], encryption_data: Optional["_EncryptionData"] = None, stream: Any = None, @@ -330,7 +330,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, encryption_options: Dict[str, Any] = None, # type: ignore [assignment] max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index b8d7ea6b9702..63a517a4e305 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -34,7 +34,11 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode -from .streams import StructuredMessageDecoder, StructuredMessageEncodeStream, StructuredMessageProperties +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) from .validation import ( CV_TYPE_ERROR_MSG, calculate_content_md5, @@ -69,7 +73,12 @@ def encode_base64(data: Union[bytes, str]) -> str: # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -78,7 +87,9 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs + ) # Is this method/status code retryable? (Based on allowlists and control @@ -98,7 +109,9 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get("x-ms-copy-source-error-code") + error_code = response.http_response.headers.get( + "x-ms-copy-source-error-code" + ) if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -122,9 +135,9 @@ def is_checksum_retry(response) -> bool: # Legacy code - evaluate retry only on validate_content=True if validate_content is True and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - calculate_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -153,7 +166,9 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str( + uuid.uuid1() + ) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -193,7 +208,9 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError(f"Attempting to use undefined host location {use_location}") + raise ValueError( + f"Attempting to use undefined host location {use_location}" + ) if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -211,7 +228,9 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + super(StorageLoggingPolicy, self).__init__( + logging_enable=logging_enable, **kwargs + ) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -255,7 +274,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -286,7 +314,9 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get("content-type", "") + resp_content_type = response.http_response.headers.get( + "content-type", "" + ) if header and pattern.match(header): filename = header.partition("=")[2] @@ -315,7 +345,9 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop("raw_request_hook", self._request_callback) + request_callback = request.context.options.pop( + "raw_request_hook", self._request_callback + ) if request_callback: request_callback(request) @@ -333,36 +365,50 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -397,13 +443,21 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: elif validate_content == ChecksumAlgorithm.CRC64: if isinstance(data, bytes): - request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) + request.http_request.headers[CRC64_HEADER] = encode_base64( + calculate_crc64_bytes(data) + ) elif hasattr(data, "read"): - content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) + content_length = int( + request.http_request.headers.get(CONTENT_LENGTH_HEADER) + ) # Wrap data in structured message stream and adjust HTTP request - sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) + sm_stream = StructuredMessageEncodeStream( + data, content_length, StructuredMessageProperties.CRC64 + ) request.http_request.data = sm_stream - request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) + request.http_request.headers[CONTENT_LENGTH_HEADER] = str( + len(sm_stream) + ) request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 else: @@ -426,7 +480,9 @@ def _validate_content_response( if not validate_content: return - if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): + if is_md5_validation(validate_content) and response.http_response.headers.get( + "content-md5" + ): computed_md5 = request.context.get("validate_content_md5") or encode_base64( calculate_content_md5(response.http_response.body()) ) @@ -458,9 +514,12 @@ def _validate_content_response( # Patch response to return response iterator wrapped in structured message decoder original_stream_download = response.http_response.stream_download + def wrapped_stream_download(*args, **kwargs): iterator = original_stream_download(*args, **kwargs) - decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) + decoder = decoder_cls( + iterator, content_length, block_size=DATA_BLOCK_SIZE + ) decoder.request = iterator.request # type: ignore decoder.response = iterator.response # type: ignore return decoder @@ -473,13 +532,16 @@ class StorageContentValidation(SansIOHTTPPolicy): This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. """ + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super().__init__() def on_request(self, request: "PipelineRequest") -> None: _prepare_content_validation(request) - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: _validate_content_response(request, response, StructuredMessageDecoder) @@ -507,7 +569,9 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + def _set_next_host_location( + self, settings: Dict[str, Any], request: "PipelineRequest" + ) -> None: """ A function which sets the next host location on the request, if applicable. @@ -527,7 +591,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -546,7 +610,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "retry_secondary": options.pop( + "retry_to_secondary", self.retry_to_secondary + ), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -555,7 +621,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + def get_backoff_time( + self, settings: Dict[str, Any] + ) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -567,7 +635,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -618,7 +686,9 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append(RequestHistory(request, http_response=response)) + settings["history"].append( + RequestHistory(request, http_response=response) + ) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -641,7 +711,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -653,13 +723,20 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry( + response + ): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -667,9 +744,16 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -722,7 +806,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -735,8 +821,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -774,7 +866,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -789,7 +883,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -797,16 +895,22 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "TokenCredential", audience: str, **kwargs: Any + ) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 860f10e93089..14ce070e47ff 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -11,7 +11,10 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE @@ -42,9 +45,17 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + await settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) else: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) async def is_checksum_retry(response): @@ -59,9 +70,9 @@ async def is_checksum_retry(response): await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - calculate_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -72,6 +83,7 @@ class AsyncContentValidationPolicy(AsyncHTTPPolicy): This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. """ + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument super().__init__() @@ -106,36 +118,50 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -164,13 +190,20 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): + if is_retry( + response, retry_settings["mode"] + ) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -178,9 +211,16 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err @@ -235,7 +275,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -248,8 +290,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -287,7 +335,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -302,7 +352,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -310,10 +364,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any + ) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + async def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py index e04d666eab5e..712f4e90af69 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -35,16 +35,19 @@ class SMRegion(Enum): MESSAGE_FOOTER = 5 -def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: - return (version.to_bytes(1, 'little') + - size.to_bytes(8, 'little') + - flags.to_bytes(2, 'little') + - num_segments.to_bytes(2, 'little')) +def generate_message_header( + version: int, size: int, flags: StructuredMessageProperties, num_segments: int +) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) def generate_segment_header(number: int, size: int) -> bytes: - return (number.to_bytes(2, 'little') + - size.to_bytes(8, 'little')) + return number.to_bytes(2, "little") + size.to_bytes(8, "little") def parse_message_header( @@ -53,24 +56,30 @@ def parse_message_header( version = data[0] if version != 1: raise ValueError(f"The structured message version is not supported: {version}") - message_length = int.from_bytes(data[1:9], 'little') + message_length = int.from_bytes(data[1:9], "little") if message_length != expected_message_length: - raise ValueError(f"Structured message length {message_length} " - f"did not match content length {expected_message_length}") - flags = StructuredMessageProperties(int.from_bytes(data[9:11], 'little')) - num_segments = int.from_bytes(data[11:13], 'little') + raise ValueError( + f"Structured message length {message_length} " + f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") return version, flags, num_segments def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: - segment_number = int.from_bytes(data[0:2], 'little') + segment_number = int.from_bytes(data[0:2], "little") if segment_number != expected_segment_number: - raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") - segment_content_length = int.from_bytes(data[2:10], 'little') + raise ValueError( + f"Structured message segment number invalid or out of order {segment_number}" + ) + segment_content_length = int.from_bytes(data[2:10], "little") return segment_number, segment_content_length -class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes +class StructuredMessageEncodeStream( + IOBase +): # pylint: disable=too-many-instance-attributes message_version: int content_length: int message_length: int @@ -95,11 +104,12 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _segment_crc64s: dict[int, int] def __init__( - self, inner_stream: IO[bytes], + self, + inner_stream: IO[bytes], content_length: int, flags: StructuredMessageProperties, *, - segment_size: int = DEFAULT_SEGMENT_SIZE + segment_size: int = DEFAULT_SEGMENT_SIZE, ) -> None: if segment_size < 1: raise ValueError("Segment size must be greater than 0.") @@ -141,11 +151,19 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _message_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) def _update_current_region_length(self) -> None: if self._current_region == SMRegion.MESSAGE_HEADER: @@ -155,8 +173,9 @@ def _update_current_region_length(self) -> None: elif self._current_region == SMRegion.SEGMENT_CONTENT: # Last segment size is remaining content if self._current_segment_number == self._num_segments: - self._current_region_length = self.content_length - \ - ((self._current_segment_number - 1) * self._segment_size) + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) else: self._current_region_length = self._segment_size elif self._current_region == SMRegion.SEGMENT_FOOTER: @@ -179,7 +198,10 @@ def readable(self) -> bool: def seekable(self) -> bool: try: # Only seekable if the inner stream is and we could get its initial position - return self._inner_stream.seekable() and self._initial_content_position is not None + return ( + self._inner_stream.seekable() + and self._initial_content_position is not None + ) except (AttributeError, UnsupportedOperation, OSError): return False @@ -187,22 +209,38 @@ def tell(self) -> int: if self._current_region == SMRegion.MESSAGE_HEADER: return self._current_region_offset if self._current_region == SMRegion.SEGMENT_HEADER: - return (self._message_header_length + self._content_offset + - (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + - self._current_region_offset) + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) if self._current_region == SMRegion.SEGMENT_CONTENT: - return (self._message_header_length + self._content_offset + - (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + - self._segment_header_length) + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) if self._current_region == SMRegion.SEGMENT_FOOTER: - return (self._message_header_length + self._content_offset + - (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + - self._segment_header_length + - self._current_region_offset) + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) if self._current_region == SMRegion.MESSAGE_FOOTER: - return (self._message_header_length + self._content_offset + - self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + - self._current_region_offset) + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) raise ValueError(f"Invalid SMRegion {self._current_region}") @@ -233,21 +271,33 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: # MESSAGE_FOOTER elif position >= self.message_length - self._message_footer_length: self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) + self._current_region_offset = position - ( + self.message_length - self._message_footer_length + ) self._content_offset = self.content_length self._current_segment_number = self._num_segments else: # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size + full_segment_size = ( + self._segment_header_length + + self._segment_size + + self._segment_footer_length + ) + new_segment_num = ( + 1 + (position - self._message_header_length) // full_segment_size + ) segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size + previous_segments_total_content_size = ( + new_segment_num - 1 + ) * self._segment_size # We need the size of the segment we are seeking to for some of the calculations below new_segment_size = self._segment_size if new_segment_num == self._num_segments: # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size + new_segment_size = ( + self.content_length - previous_segments_total_content_size + ) # SEGMENT_HEADER if segment_pos < self._segment_header_length: @@ -258,17 +308,25 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: elif segment_pos < self._segment_header_length + new_segment_size: self._current_region = SMRegion.SEGMENT_CONTENT self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset + self._content_offset = ( + previous_segments_total_content_size + self._current_region_offset + ) # SEGMENT_FOOTER else: self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + self._current_region_offset = ( + segment_pos - self._segment_header_length - new_segment_size + ) + self._content_offset = ( + previous_segments_total_content_size + new_segment_size + ) self._current_segment_number = new_segment_num self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) + self._inner_stream.seek( + (self._initial_content_position or 0) + self._content_offset + ) return position def read(self, size: int = -1) -> bytes: @@ -276,7 +334,7 @@ def read(self, size: int = -1) -> bytes: raise ValueError("Stream is closed") if size == 0: - return b'' + return b"" if size < 0: size = sys.maxsize @@ -286,11 +344,14 @@ def read(self, size: int = -1) -> bytes: while count < size and self.tell() < self.message_length: remaining = size - count if self._current_region in ( - SMRegion.MESSAGE_HEADER, - SMRegion.SEGMENT_HEADER, - SMRegion.SEGMENT_FOOTER, - SMRegion.MESSAGE_FOOTER): - count += self._read_metadata_region(self._current_region, remaining, output) + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region( + self._current_region, remaining, output + ) elif self._current_region == SMRegion.SEGMENT_CONTENT: count += self._read_content(remaining, output) else: @@ -300,7 +361,9 @@ def read(self, size: int = -1) -> bytes: def _calculate_message_length(self) -> int: length = self._message_header_length - length += (self._segment_header_length + self._segment_footer_length) * self._num_segments + length += ( + self._segment_header_length + self._segment_footer_length + ) * self._num_segments length += self.content_length length += self._message_footer_length return length @@ -311,22 +374,28 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: self.message_version, self.message_length, self.flags, - self._num_segments) + self._num_segments, + ) if region == SMRegion.SEGMENT_HEADER: - segment_size = min(self._segment_size, self.content_length - self._content_offset) + segment_size = min( + self._segment_size, self.content_length - self._content_offset + ) return generate_segment_header(self._current_segment_number, segment_size) if region == SMRegion.SEGMENT_FOOTER: if StructuredMessageProperties.CRC64 in self.flags: return self._segment_crc64s[self._current_segment_number].to_bytes( - StructuredMessageConstants.CRC64_LENGTH, 'little') - return b'' + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" if region == SMRegion.MESSAGE_FOOTER: if StructuredMessageProperties.CRC64 in self.flags: - return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little') - return b'' + return self._message_crc64.to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" raise ValueError(f"Invalid metadata SMRegion {self._current_region}") @@ -352,16 +421,22 @@ def _advance_region(self, current: SMRegion): self._update_current_region_length() - def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: + def _read_metadata_region( + self, region: SMRegion, size: int, output: BytesIO + ) -> int: metadata = self._get_metadata_region(region) read_size = min(size, self._current_region_length - self._current_region_offset) - content = metadata[self._current_region_offset: self._current_region_offset + read_size] + content = metadata[ + self._current_region_offset : self._current_region_offset + read_size + ] output.write(content) self._current_region_offset += read_size - if (self._current_region_offset == self._current_region_length and - self._current_region != SMRegion.MESSAGE_FOOTER): + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): self._advance_region(region) return read_size @@ -383,8 +458,9 @@ def _read_content(self, size: int, output: BytesIO) -> int: if StructuredMessageProperties.CRC64 in self.flags: if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = \ - calculate_crc64(content, self._segment_crc64s[self._current_segment_number]) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size @@ -425,14 +501,22 @@ class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-att _segment_content_offset: int _block_size: int - def __init__(self, inner_iterator: Iterator[bytes], content_length: int, *, block_size: int = 4096) -> None: + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError("Content not long enough to contain a valid message header.") + raise ValueError( + "Content not long enough to contain a valid message header." + ) self._inner_iterator = inner_iterator - self._buffer = b'' + self._buffer = b"" self._message_offset = 0 self._message_crc64 = 0 @@ -453,11 +537,19 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _message_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _end_of_segment_content(self) -> bool: @@ -483,7 +575,7 @@ def read(self, size: int = -1) -> bytes: raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: - return b'' + return b"" if size < 0: size = sys.maxsize @@ -496,17 +588,23 @@ def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: self._read_segment_footer() if self.num_segments > 1: - raise ValueError("First message segment was empty but more segments were detected.") + raise ValueError( + "First message segment was empty but more segments were detected." + ) self._read_message_footer() - return b'' + return b"" count = 0 content = BytesIO() - while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): if self._end_of_segment_content: self._read_segment_header() - segment_remaining = self._segment_content_length - self._segment_content_offset + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) read_size = min(segment_remaining, size - count) segment_content = self._read_from_inner(read_size) @@ -514,8 +612,12 @@ def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) - self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) self._segment_content_offset += read_size self._message_offset += read_size @@ -529,7 +631,10 @@ def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if self._message_offset == self.message_length and self._segment_number != self.num_segments: + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -543,7 +648,9 @@ def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError("Invalid structured message data detected. Stream content incomplete.") + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) data = self._buffer[:size] self._buffer = self._buffer[size:] @@ -552,7 +659,8 @@ def _read_from_inner(self, size: int) -> bytes: def _read_message_header(self) -> None: header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length) + header_data, self.message_length + ) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH def _read_message_footer(self) -> None: @@ -564,16 +672,19 @@ def _read_message_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) - if self._message_crc64 != int.from_bytes(message_crc, 'little'): - raise ValueError("CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid.") + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) self._message_offset += self._message_footer_length def _read_segment_header(self) -> None: header_data = self._read_from_inner(self._segment_header_length) self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1) + header_data, self._segment_number + 1 + ) self._message_offset += self._segment_header_length self._segment_content_offset = 0 @@ -583,8 +694,10 @@ def _read_segment_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) - if self._segment_crc64 != int.from_bytes(segment_crc, 'little'): - raise ValueError(f"CRC64 mismatch detected in segment {self._segment_number}. " - f"All data read should be considered invalid.") + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py index 0bd608d02379..ee7d92d14d77 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py @@ -8,11 +8,18 @@ from io import BytesIO, IOBase from typing import AsyncIterator -from .streams import StructuredMessageConstants, StructuredMessageProperties, parse_message_header, parse_segment_header +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) from .validation import calculate_crc64 -class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes +class AsyncStructuredMessageDecoder( + IOBase +): # pylint: disable=too-many-instance-attributes message_version: int """The version of the structured message.""" @@ -33,14 +40,22 @@ class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instanc _segment_content_offset: int _block_size: int - def __init__(self, inner_iterator: AsyncIterator[bytes], content_length: int, *, block_size: int = 4096) -> None: + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError("Content not long enough to contain a valid message header.") + raise ValueError( + "Content not long enough to contain a valid message header." + ) self._inner_iterator = inner_iterator - self._buffer = b'' + self._buffer = b"" self._message_offset = 0 self._message_crc64 = 0 @@ -61,11 +76,19 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _message_footer_length(self) -> int: - return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) @property def _end_of_segment_content(self) -> bool: @@ -91,7 +114,7 @@ async def read(self, size: int = -1) -> bytes: raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: - return b'' + return b"" if size < 0: size = sys.maxsize @@ -104,17 +127,23 @@ async def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: await self._read_segment_footer() if self.num_segments > 1: - raise ValueError("First message segment was empty but more segments were detected.") + raise ValueError( + "First message segment was empty but more segments were detected." + ) await self._read_message_footer() - return b'' + return b"" count = 0 content = BytesIO() - while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): if self._end_of_segment_content: await self._read_segment_header() - segment_remaining = self._segment_content_length - self._segment_content_offset + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) read_size = min(segment_remaining, size - count) segment_content = await self._read_from_inner(read_size) @@ -122,8 +151,12 @@ async def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) - self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) self._segment_content_offset += read_size self._message_offset += read_size @@ -137,7 +170,10 @@ async def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if self._message_offset == self.message_length and self._segment_number != self.num_segments: + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -151,16 +187,21 @@ async def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError("Invalid structured message data detected. Stream content incomplete.") + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) data = self._buffer[:size] self._buffer = self._buffer[size:] return data async def _read_message_header(self) -> None: - header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + header_data = await self._read_from_inner( + StructuredMessageConstants.V1_HEADER_LENGTH + ) self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length) + header_data, self.message_length + ) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH async def _read_message_footer(self) -> None: @@ -170,18 +211,23 @@ async def _read_message_footer(self) -> None: raise ValueError("Invalid structured message data detected.") if StructuredMessageProperties.CRC64 in self.flags: - message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + message_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) - if self._message_crc64 != int.from_bytes(message_crc, 'little'): - raise ValueError("CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid.") + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) self._message_offset += self._message_footer_length async def _read_segment_header(self) -> None: header_data = await self._read_from_inner(self._segment_header_length) self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1) + header_data, self._segment_number + 1 + ) self._message_offset += self._segment_header_length self._segment_content_offset = 0 @@ -189,10 +235,14 @@ async def _read_segment_header(self) -> None: async def _read_segment_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: - segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) - - if self._segment_crc64 != int.from_bytes(segment_crc, 'little'): - raise ValueError(f"CRC64 mismatch detected in segment {self._segment_number}. " - f"All data read should be considered invalid.") + segment_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py index 329ef7517d9b..5370d9dd669c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py @@ -8,7 +8,7 @@ import hashlib from enum import Enum from io import SEEK_SET -from typing import IO, cast, Literal, Union +from typing import IO, Literal, Optional, Union, cast from azure.core import CaseInsensitiveEnumMeta @@ -21,8 +21,49 @@ class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): MD5 = "md5" CRC64 = "crc64" + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) -def is_md5_validation(validate_content: Union[bool, Literal["md5", "crc64"]]) -> bool: + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], +) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + if validate_content not in (ChecksumAlgorithm.list()): + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if validate_content == ChecksumAlgorithm.AUTO: + validate_content = ChecksumAlgorithm.CRC64.value + + if validate_content == ChecksumAlgorithm.CRC64: + _verify_extensions("crc64") + + return validate_content + + +def is_md5_validation( + validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], +) -> bool: + if validate_content is None: + return False if isinstance(validate_content, bool): return validate_content return validate_content == ChecksumAlgorithm.MD5 @@ -61,4 +102,4 @@ def calculate_crc64_bytes(data: bytes) -> bytes: # Locally import to avoid error if not installed. from azure.storage.extensions import crc64 - return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, 'little')) + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py index 64b0432da803..6873b93bb4e6 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py @@ -5,7 +5,7 @@ # -------------------------------------------------------------------------- from io import SEEK_SET, UnsupportedOperation -from typing import Any, cast, Dict, IO, Optional, TypeVar, TYPE_CHECKING +from typing import Any, cast, Dict, IO, Literal, Optional, TypeVar, Union, TYPE_CHECKING from azure.core.exceptions import ResourceExistsError, ResourceModifiedError, HttpResponseError @@ -71,7 +71,7 @@ def upload_block_blob( # pylint: disable=too-many-locals, too-many-statements encryption_options: Dict[str, Any], blob_settings: "StorageConfiguration", headers: Dict[str, Any], - validate_content: bool, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], max_concurrency: Optional[int], length: Optional[int] = None, **kwargs: Any @@ -125,7 +125,7 @@ def upload_block_blob( # pylint: disable=too-many-locals, too-many-statements use_original_upload_path = ( blob_settings.use_byte_buffer - or validate_content is not None + or validate_content not in (None, False) or encryption_options.get('required') or blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold or hasattr(stream, 'seekable') and not stream.seekable() @@ -213,7 +213,7 @@ def upload_page_blob( headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: @@ -291,7 +291,7 @@ def upload_append_blob( # pylint: disable=unused-argument headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index b73b5691a6ca..3152bda21036 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -77,6 +77,7 @@ from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper, parse_connection_str from .._shared.policies_async import ExponentialRetry from .._shared.response_handlers import process_storage_error, return_response_headers +from .._shared.validation import ChecksumAlgorithm, parse_validation_option if TYPE_CHECKING: from azure.core import MatchConditions @@ -516,15 +517,11 @@ async def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: If specified, upload_blob only succeeds if the blob's lease is active and matches this ID. @@ -629,6 +626,9 @@ async def upload_blob( raise ValueError("Encryption required but no key was provided.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _upload_blob_options( data=data, blob_type=blob_type, @@ -640,6 +640,7 @@ async def upload_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -696,15 +697,11 @@ async def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a @@ -778,6 +775,9 @@ async def download_blob( raise ValueError("Offset value must not be None if length is set.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _download_blob_options( blob_name=self.blob_name, container_name=self.container_name, @@ -791,6 +791,7 @@ async def download_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -2053,15 +2054,11 @@ async def stage_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -2896,13 +2893,11 @@ async def upload_page( Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.aio.BlobLeaseClient or str - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int if_sequence_number_lte: If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. @@ -3204,13 +3199,11 @@ async def append_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int maxsize_condition: Optional conditional header. The max length in bytes permitted for the append blob. If the Append Block operation would cause the blob diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi index 9c4a37007f65..ba9a7460425c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi @@ -175,7 +175,7 @@ class BlobClient( # type: ignore[misc] tags: Optional[Dict[str, str]] = None, overwrite: bool = False, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[BlobLeaseClient] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -202,7 +202,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -224,7 +224,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -246,7 +246,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -470,7 +470,7 @@ class BlobClient( # type: ignore[misc] data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, encoding: Optional[str] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, @@ -655,7 +655,7 @@ class BlobClient( # type: ignore[misc] length: int, *, lease: Optional[Union[BlobLeaseClient, str]] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, if_sequence_number_lte: Optional[int] = None, if_sequence_number_lt: Optional[int] = None, if_sequence_number_eq: Optional[int] = None, @@ -725,7 +725,7 @@ class BlobClient( # type: ignore[misc] data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, maxsize_condition: Optional[int] = None, appendpos_condition: Optional[int] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index 32260b62b6b5..413ba2cd0f0c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -1015,15 +1015,11 @@ async def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the container has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -1261,15 +1257,11 @@ async def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi index 56063ac847ff..49362aac8058 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi @@ -16,6 +16,7 @@ from typing import ( Callable, Dict, List, + Literal, IO, Iterable, Optional, @@ -254,7 +255,7 @@ class ContainerClient( # type: ignore[misc] *, overwrite: Optional[bool] = None, content_settings: Optional[ContentSettings] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -296,7 +297,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -318,7 +319,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -340,7 +341,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index ed65a88f78e2..7e5bbb0918b3 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -15,7 +15,7 @@ from typing import ( Any, AsyncIterator, Awaitable, Generator, Callable, cast, Dict, - Generic, IO, Optional, overload, + Generic, IO, Literal, Optional, overload, Tuple, TypeVar, Union, TYPE_CHECKING ) @@ -239,7 +239,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, encryption_options: Dict[str, Any] = None, # type: ignore [assignment] max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py index dc7b35b04307..5b551fdec2fb 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py @@ -6,7 +6,7 @@ import inspect from io import SEEK_SET, UnsupportedOperation -from typing import Any, cast, Dict, IO, Optional, TypeVar, TYPE_CHECKING +from typing import Any, cast, Dict, IO, Literal, Optional, TypeVar, Union, TYPE_CHECKING from azure.core.exceptions import HttpResponseError, ResourceModifiedError @@ -47,7 +47,7 @@ async def upload_block_blob( # pylint: disable=too-many-locals, too-many-statem encryption_options: Dict[str, Any], blob_settings: "StorageConfiguration", headers: Dict[str, Any], - validate_content: bool, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], max_concurrency: Optional[int], length: Optional[int] = None, **kwargs: Any @@ -105,7 +105,7 @@ async def upload_block_blob( # pylint: disable=too-many-locals, too-many-statem use_original_upload_path = ( blob_settings.use_byte_buffer - or validate_content is not None + or validate_content not in (None, False) or encryption_options.get('required') or blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold or hasattr(stream, 'seekable') and not stream.seekable() @@ -193,7 +193,7 @@ async def upload_page_blob( headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: @@ -271,7 +271,7 @@ async def upload_append_blob( # pylint: disable=unused-argument headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-blob/dev_requirements.txt b/sdk/storage/azure-storage-blob/dev_requirements.txt index d6946c1ccfc1..9657377026f0 100644 --- a/sdk/storage/azure-storage-blob/dev_requirements.txt +++ b/sdk/storage/azure-storage-blob/dev_requirements.txt @@ -1,6 +1,7 @@ -e ../../../eng/tools/azure-sdk-tools ../../core/azure-core ../../identity/azure-identity +../azure-storage-extensions azure-mgmt-storage==20.1.0 aiohttp>=3.13.5 diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py index 4d24bd015f4f..5d670ba330fb 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -9,6 +9,7 @@ import pytest from azure.storage.blob import ( BlobBlock, + BlobClient, BlobServiceClient, BlobType, ContainerClient @@ -20,6 +21,7 @@ StorageRecordedTestCase ) +from encryption_test_helper import KeyWrapper from settings.testcase import BlobPreparer @@ -90,28 +92,29 @@ def teardown_method(self, _): def _get_blob_reference(self): return self.get_resource_name('blob') - # TODO: This test coming later - # @BlobPreparer() - # def test_encryption_blocked_crc64(self, **kwargs): - # storage_account_name = kwargs.pop("storage_account_name") - # storage_account_key = kwargs.pop("storage_account_key") - - # kek = KeyWrapper('key1') - # blob = BlobClient( - # self.account_url(storage_account_name, "blob"), - # "testing", - # "testing", - # credential=storage_account_key, - # require_encryption=True, - # encryption_version='2.0', - # key_encryption_key=kek) - - # with pytest.raises(ValueError): - # blob.upload_blob(b'123', validate_content='crc64') + @BlobPreparer() + def test_encryption_blocked_crc64(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + kek = KeyWrapper('key1') + blob = BlobClient( + self.account_url(storage_account_name, "blob"), + "testing", + "testing", + credential=self.get_credential(BlobServiceClient), + require_encryption=True, + encryption_version='2.0', + key_encryption_key=kek) + + with pytest.raises(ValueError): + blob.upload_blob(b'123', validate_content='crc64') + + # Needed for teardown + self.container = None @BlobPreparer() @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type - @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @pytest.mark.parametrize('b', [True, 'auto','md5', 'crc64']) # b: validate_content @GenericTestProxyParametrize2() @recorded_by_proxy def test_upload_blob(self, a, b, **kwargs): @@ -119,7 +122,7 @@ def test_upload_blob(self, a, b, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if b in ('auto', 'crc64') else assert_content_md5 # Test supported data types byte_data = b'abc' * 512 @@ -200,7 +203,7 @@ def test_upload_blob_substream(self, a, **kwargs): assert content.read() == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_stage_block(self, a, **kwargs): @@ -217,7 +220,7 @@ def generator(): for i in range(0, len(data1), 500): yield data1[i: i + 500] - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 blob.stage_block('1', data1, validate_content=a, raw_request_hook=assert_method) blob.stage_block('2', data2, encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) @@ -270,7 +273,7 @@ def test_stage_block_streaming_large(self, a, **kwargs): assert result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_append_block(self, a, **kwargs): @@ -287,7 +290,7 @@ def generator(): for i in range(0, len(data1), 500): yield data1[i: i + 500] - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 blob.create_append_blob() blob.append_block(data1, validate_content=a, raw_request_hook=assert_method) @@ -339,7 +342,7 @@ def test_append_block_streaming_large(self, a, **kwargs): assert result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_page(self, a, **kwargs): @@ -350,7 +353,7 @@ def test_upload_page(self, a, **kwargs): data1 = b'abc' * 512 data2 = "你好世界abcd" * 32 data2_encoded = data2.encode('utf-8') - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 # Act blob.create_page_blob(5 * 1024) @@ -362,7 +365,7 @@ def test_upload_page(self, a, **kwargs): assert content.read() == data1 + data2_encoded @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_blob(self, a, **kwargs): @@ -372,7 +375,7 @@ def test_download_blob(self, a, **kwargs): blob = self.container.get_blob_client(self._get_blob_reference()) data = b'abc' * 512 blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get # Act downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py index a0fdcbee01d8..b9de57c7015d 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -10,6 +10,7 @@ from azure.core.exceptions import ResourceExistsError from azure.storage.blob import BlobBlock, BlobType from azure.storage.blob.aio import ( + BlobClient, BlobServiceClient, ContainerClient ) @@ -19,6 +20,7 @@ GenericTestProxyParametrize1, GenericTestProxyParametrize2 ) +from encryption_test_helper import KeyWrapper from settings.testcase import BlobPreparer from test_content_validation import ( @@ -55,28 +57,26 @@ async def _teardown(self): def _get_blob_reference(self): return self.get_resource_name('blob') - # TODO: This test coming later - # @BlobPreparer() - # async def test_encryption_blocked_crc64(self, **kwargs): - # storage_account_name = kwargs.pop("storage_account_name") - # storage_account_key = kwargs.pop("storage_account_key") - - # kek = KeyWrapper('key1') - # blob = BlobClient( - # self.account_url(storage_account_name, "blob"), - # "testing", - # "testing", - # credential=storage_account_key, - # require_encryption=True, - # encryption_version='2.0', - # key_encryption_key=kek) - - # with pytest.raises(ValueError): - # await blob.upload_blob(b'123', validate_content='crc64') + @BlobPreparer() + async def test_encryption_blocked_crc64(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + kek = KeyWrapper('key1') + blob = BlobClient( + self.account_url(storage_account_name, "blob"), + "testing", + "testing", + credential=self.get_credential(BlobServiceClient, is_async=True), + require_encryption=True, + encryption_version='2.0', + key_encryption_key=kek) + + with pytest.raises(ValueError): + await blob.upload_blob(b'123', validate_content='crc64') @BlobPreparer() @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type - @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @pytest.mark.parametrize('b', [True, "auto", 'md5', 'crc64']) # b: validate_content @GenericTestProxyParametrize2() @recorded_by_proxy_async async def test_upload_blob(self, a, b, **kwargs): @@ -84,7 +84,7 @@ async def test_upload_blob(self, a, b, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if b in ('auto', 'crc64') else assert_content_md5 # Test supported data types byte_data = b'abc' * 512 @@ -172,7 +172,7 @@ async def test_upload_blob_substream(self, a, **kwargs): await self._teardown() @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_stage_block(self, a, **kwargs): @@ -189,7 +189,7 @@ def generator(): for i in range(0, len(data1), 500): yield data1[i: i + 500] - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 # Act await blob.stage_block('1', data1, validate_content=a, raw_request_hook=assert_method) @@ -246,7 +246,7 @@ async def test_stage_block_streaming_large(self, a, **kwargs): assert await result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_append_block(self, a, **kwargs): @@ -263,7 +263,7 @@ def generator(): for i in range(0, len(data1), 500): yield data1[i: i + 500] - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 # Act await blob.create_append_blob() @@ -320,7 +320,7 @@ async def test_append_block_streaming_large(self, a, **kwargs): await self._teardown() @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_page(self, a, **kwargs): @@ -331,7 +331,7 @@ async def test_upload_page(self, a, **kwargs): data1 = b'abc' * 512 data2 = "你好世界abcd" * 32 data2_encoded = data2.encode('utf-8') - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 # Act await blob.create_page_blob(5 * 1024) @@ -344,7 +344,7 @@ async def test_upload_page(self, a, **kwargs): await self._teardown() @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_blob(self, a, **kwargs): @@ -354,7 +354,7 @@ async def test_download_blob(self, a, **kwargs): blob = self.container.get_blob_client(self._get_blob_reference()) data = b'abc' * 512 await blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get # Act downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py index 61a4fdb15bdd..63a517a4e305 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py @@ -12,7 +12,7 @@ import uuid from io import SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +32,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_md5_validation, + ChecksumAlgorithm, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +56,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -55,7 +73,12 @@ def encode_base64(data): # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -64,7 +87,9 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs + ) # Is this method/status code retryable? (Based on allowlists and control @@ -84,7 +109,9 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get("x-ms-copy-source-error-code") + error_code = response.http_response.headers.get( + "x-ms-copy-source-error-code" + ) if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -101,12 +128,16 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -135,7 +166,9 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str( + uuid.uuid1() + ) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -175,7 +208,9 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError(f"Attempting to use undefined host location {use_location}") + raise ValueError( + f"Attempting to use undefined host location {use_location}" + ) if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -193,7 +228,9 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + super(StorageLoggingPolicy, self).__init__( + logging_enable=logging_enable, **kwargs + ) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -237,7 +274,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -268,7 +314,9 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get("content-type", "") + resp_content_type = response.http_response.headers.get( + "content-type", "" + ) if header and pattern.match(header): filename = header.partition("=")[2] @@ -297,7 +345,9 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop("raw_request_hook", self._request_callback) + request_callback = request.context.options.pop( + "raw_request_hook", self._request_callback + ) if request_callback: request_callback(request) @@ -315,36 +365,50 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -352,64 +416,133 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. +def _prepare_content_validation(request: "PipelineRequest") -> None: + """Shared request-side logic for content validation. - This will overwrite any headers already defined in the request. + Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 + validation, and stores the validation mode in the request context. """ + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" - - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() + # Download + if request.http_request.method == "GET": + if validate_content == ChecksumAlgorithm.CRC64: + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif validate_content == ChecksumAlgorithm.CRC64: + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64( + calculate_crc64_bytes(data) + ) + elif hasattr(data, "read"): + content_length = int( + request.http_request.headers.get(CONTENT_LENGTH_HEADER) + ) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream( + data, content_length, StructuredMessageProperties.CRC64 + ) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str( + len(sm_stream) + ) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) - def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content + request.context["validate_content"] = validate_content - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + """Shared response-side logic for content validation. + + Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches + ``stream_download`` to wrap the iterator in the given *decoder_cls*. + """ + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get( + "content-md5" + ): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif validate_content == ChecksumAlgorithm.CRC64: + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}", + ), + response=response.http_response, ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls( + iterator, content_length, block_size=DATA_BLOCK_SIZE ) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + def on_request(self, request: "PipelineRequest") -> None: + _prepare_content_validation(request) + + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -436,7 +569,9 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + def _set_next_host_location( + self, settings: Dict[str, Any], request: "PipelineRequest" + ) -> None: """ A function which sets the next host location on the request, if applicable. @@ -456,7 +591,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -475,7 +610,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "retry_secondary": options.pop( + "retry_to_secondary", self.retry_to_secondary + ), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -484,7 +621,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + def get_backoff_time( + self, settings: Dict[str, Any] + ) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -496,7 +635,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -547,7 +686,9 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append(RequestHistory(request, http_response=response)) + settings["history"].append( + RequestHistory(request, http_response=response) + ) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -570,7 +711,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -582,13 +723,20 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry( + response + ): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -596,9 +744,16 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -651,7 +806,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -664,8 +821,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -703,7 +866,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -718,7 +883,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -726,16 +895,22 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "TokenCredential", audience: str, **kwargs: Any + ) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py index 4cb32f23248b..14ce070e47ff 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py @@ -11,11 +11,25 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -31,27 +45,66 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + await settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) else: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -65,36 +118,50 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -123,13 +190,20 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): + if is_retry( + response, retry_settings["mode"] + ) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -137,9 +211,16 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err @@ -194,7 +275,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -207,8 +290,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -246,7 +335,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -261,7 +352,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -269,10 +364,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any + ) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + async def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py new file mode 100644 index 000000000000..712f4e90af69 --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py @@ -0,0 +1,703 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header( + version: int, size: int, flags: StructuredMessageProperties, num_segments: int +) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header( + data: bytes, expected_message_length: int +) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " + f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError( + f"Structured message segment number invalid or out of order {segment_number}" + ) + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream( + IOBase +): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _checksum_offset: int + """Tracks the offset the checksum has been calculated up to for seeking purposes""" + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._checksum_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + def close(self) -> None: + self._inner_stream.close() + super().close() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return ( + self._inner_stream.seekable() + and self._initial_content_position is not None + ) + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence == SEEK_SET: + position = offset + elif whence == SEEK_CUR: + position = self.tell() + offset + elif whence == SEEK_END: + position = self.message_length + offset + else: + raise ValueError(f"Invalid value for whence: {whence}") + + if position < 0: + raise ValueError(f"Cannot seek to negative position: {position}") + if position > self.tell(): + raise UnsupportedOperation("This stream only supports seeking backwards.") + + # MESSAGE_HEADER + if position < self._message_header_length: + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_offset = position + self._content_offset = 0 + self._current_segment_number = 0 + # MESSAGE_FOOTER + elif position >= self.message_length - self._message_footer_length: + self._current_region = SMRegion.MESSAGE_FOOTER + self._current_region_offset = position - ( + self.message_length - self._message_footer_length + ) + self._content_offset = self.content_length + self._current_segment_number = self._num_segments + else: + # The size of a "full" segment. Fine to use for calculating new segment number and pos + full_segment_size = ( + self._segment_header_length + + self._segment_size + + self._segment_footer_length + ) + new_segment_num = ( + 1 + (position - self._message_header_length) // full_segment_size + ) + segment_pos = (position - self._message_header_length) % full_segment_size + previous_segments_total_content_size = ( + new_segment_num - 1 + ) * self._segment_size + + # We need the size of the segment we are seeking to for some of the calculations below + new_segment_size = self._segment_size + if new_segment_num == self._num_segments: + # The last segment size is the remaining content length + new_segment_size = ( + self.content_length - previous_segments_total_content_size + ) + + # SEGMENT_HEADER + if segment_pos < self._segment_header_length: + self._current_region = SMRegion.SEGMENT_HEADER + self._current_region_offset = segment_pos + self._content_offset = previous_segments_total_content_size + # SEGMENT_CONTENT + elif segment_pos < self._segment_header_length + new_segment_size: + self._current_region = SMRegion.SEGMENT_CONTENT + self._current_region_offset = segment_pos - self._segment_header_length + self._content_offset = ( + previous_segments_total_content_size + self._current_region_offset + ) + # SEGMENT_FOOTER + else: + self._current_region = SMRegion.SEGMENT_FOOTER + self._current_region_offset = ( + segment_pos - self._segment_header_length - new_segment_size + ) + self._content_offset = ( + previous_segments_total_content_size + new_segment_size + ) + + self._current_segment_number = new_segment_num + + self._update_current_region_length() + self._inner_stream.seek( + (self._initial_content_position or 0) + self._content_offset + ) + return position + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region( + self._current_region, remaining, output + ) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += ( + self._segment_header_length + self._segment_footer_length + ) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min( + self._segment_size, self.content_length - self._content_offset + ) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region( + self, region: SMRegion, size: int, output: BytesIO + ) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[ + self._current_region_offset : self._current_region_offset + read_size + ] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + # Will be non-zero if there is data to read that does not need to have checksum calculated. + # Will always be positive as stream can only seek backwards. + checksum_offset = self._checksum_offset - self._content_offset + + read_size = min(size, self._current_region_length - self._current_region_offset) + if checksum_offset != 0: + # Only read up to checksum offset this iteration + read_size = min(read_size, checksum_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + if checksum_offset == 0: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + # Only update the checksum offset if we've read new data + if self._content_offset > self._checksum_offset: + self._checksum_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py new file mode 100644 index 000000000000..ee7d92d14d77 --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder( + IOBase +): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner( + StructuredMessageConstants.V1_HEADER_LENGTH + ) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py new file mode 100644 index 000000000000..5370d9dd669c --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py @@ -0,0 +1,105 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from enum import Enum +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +from azure.core import CaseInsensitiveEnumMeta + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + + +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + AUTO = "auto" + MD5 = "md5" + CRC64 = "crc64" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], +) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + if validate_content not in (ChecksumAlgorithm.list()): + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if validate_content == ChecksumAlgorithm.AUTO: + validate_content = ChecksumAlgorithm.CRC64.value + + if validate_content == ChecksumAlgorithm.CRC64: + _verify_extensions("crc64") + + return validate_content + + +def is_md5_validation( + validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == ChecksumAlgorithm.MD5 + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py index 61a4fdb15bdd..63a517a4e305 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py @@ -12,7 +12,7 @@ import uuid from io import SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +32,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_md5_validation, + ChecksumAlgorithm, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +56,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -55,7 +73,12 @@ def encode_base64(data): # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -64,7 +87,9 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs + ) # Is this method/status code retryable? (Based on allowlists and control @@ -84,7 +109,9 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get("x-ms-copy-source-error-code") + error_code = response.http_response.headers.get( + "x-ms-copy-source-error-code" + ) if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -101,12 +128,16 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -135,7 +166,9 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str( + uuid.uuid1() + ) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -175,7 +208,9 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError(f"Attempting to use undefined host location {use_location}") + raise ValueError( + f"Attempting to use undefined host location {use_location}" + ) if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -193,7 +228,9 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + super(StorageLoggingPolicy, self).__init__( + logging_enable=logging_enable, **kwargs + ) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -237,7 +274,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -268,7 +314,9 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get("content-type", "") + resp_content_type = response.http_response.headers.get( + "content-type", "" + ) if header and pattern.match(header): filename = header.partition("=")[2] @@ -297,7 +345,9 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop("raw_request_hook", self._request_callback) + request_callback = request.context.options.pop( + "raw_request_hook", self._request_callback + ) if request_callback: request_callback(request) @@ -315,36 +365,50 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -352,64 +416,133 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. +def _prepare_content_validation(request: "PipelineRequest") -> None: + """Shared request-side logic for content validation. - This will overwrite any headers already defined in the request. + Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 + validation, and stores the validation mode in the request context. """ + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" - - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() + # Download + if request.http_request.method == "GET": + if validate_content == ChecksumAlgorithm.CRC64: + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif validate_content == ChecksumAlgorithm.CRC64: + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64( + calculate_crc64_bytes(data) + ) + elif hasattr(data, "read"): + content_length = int( + request.http_request.headers.get(CONTENT_LENGTH_HEADER) + ) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream( + data, content_length, StructuredMessageProperties.CRC64 + ) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str( + len(sm_stream) + ) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) - def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content + request.context["validate_content"] = validate_content - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + """Shared response-side logic for content validation. + + Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches + ``stream_download`` to wrap the iterator in the given *decoder_cls*. + """ + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get( + "content-md5" + ): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif validate_content == ChecksumAlgorithm.CRC64: + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}", + ), + response=response.http_response, ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls( + iterator, content_length, block_size=DATA_BLOCK_SIZE ) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + def on_request(self, request: "PipelineRequest") -> None: + _prepare_content_validation(request) + + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -436,7 +569,9 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + def _set_next_host_location( + self, settings: Dict[str, Any], request: "PipelineRequest" + ) -> None: """ A function which sets the next host location on the request, if applicable. @@ -456,7 +591,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -475,7 +610,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "retry_secondary": options.pop( + "retry_to_secondary", self.retry_to_secondary + ), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -484,7 +621,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + def get_backoff_time( + self, settings: Dict[str, Any] + ) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -496,7 +635,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -547,7 +686,9 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append(RequestHistory(request, http_response=response)) + settings["history"].append( + RequestHistory(request, http_response=response) + ) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -570,7 +711,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -582,13 +723,20 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry( + response + ): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -596,9 +744,16 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -651,7 +806,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -664,8 +821,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -703,7 +866,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -718,7 +883,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -726,16 +895,22 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "TokenCredential", audience: str, **kwargs: Any + ) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py index 4cb32f23248b..14ce070e47ff 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py @@ -11,11 +11,25 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -31,27 +45,66 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + await settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) else: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -65,36 +118,50 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -123,13 +190,20 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): + if is_retry( + response, retry_settings["mode"] + ) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -137,9 +211,16 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err @@ -194,7 +275,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -207,8 +290,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -246,7 +335,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -261,7 +352,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -269,10 +364,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any + ) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + async def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py new file mode 100644 index 000000000000..712f4e90af69 --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py @@ -0,0 +1,703 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header( + version: int, size: int, flags: StructuredMessageProperties, num_segments: int +) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header( + data: bytes, expected_message_length: int +) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " + f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError( + f"Structured message segment number invalid or out of order {segment_number}" + ) + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream( + IOBase +): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _checksum_offset: int + """Tracks the offset the checksum has been calculated up to for seeking purposes""" + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._checksum_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + def close(self) -> None: + self._inner_stream.close() + super().close() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return ( + self._inner_stream.seekable() + and self._initial_content_position is not None + ) + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence == SEEK_SET: + position = offset + elif whence == SEEK_CUR: + position = self.tell() + offset + elif whence == SEEK_END: + position = self.message_length + offset + else: + raise ValueError(f"Invalid value for whence: {whence}") + + if position < 0: + raise ValueError(f"Cannot seek to negative position: {position}") + if position > self.tell(): + raise UnsupportedOperation("This stream only supports seeking backwards.") + + # MESSAGE_HEADER + if position < self._message_header_length: + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_offset = position + self._content_offset = 0 + self._current_segment_number = 0 + # MESSAGE_FOOTER + elif position >= self.message_length - self._message_footer_length: + self._current_region = SMRegion.MESSAGE_FOOTER + self._current_region_offset = position - ( + self.message_length - self._message_footer_length + ) + self._content_offset = self.content_length + self._current_segment_number = self._num_segments + else: + # The size of a "full" segment. Fine to use for calculating new segment number and pos + full_segment_size = ( + self._segment_header_length + + self._segment_size + + self._segment_footer_length + ) + new_segment_num = ( + 1 + (position - self._message_header_length) // full_segment_size + ) + segment_pos = (position - self._message_header_length) % full_segment_size + previous_segments_total_content_size = ( + new_segment_num - 1 + ) * self._segment_size + + # We need the size of the segment we are seeking to for some of the calculations below + new_segment_size = self._segment_size + if new_segment_num == self._num_segments: + # The last segment size is the remaining content length + new_segment_size = ( + self.content_length - previous_segments_total_content_size + ) + + # SEGMENT_HEADER + if segment_pos < self._segment_header_length: + self._current_region = SMRegion.SEGMENT_HEADER + self._current_region_offset = segment_pos + self._content_offset = previous_segments_total_content_size + # SEGMENT_CONTENT + elif segment_pos < self._segment_header_length + new_segment_size: + self._current_region = SMRegion.SEGMENT_CONTENT + self._current_region_offset = segment_pos - self._segment_header_length + self._content_offset = ( + previous_segments_total_content_size + self._current_region_offset + ) + # SEGMENT_FOOTER + else: + self._current_region = SMRegion.SEGMENT_FOOTER + self._current_region_offset = ( + segment_pos - self._segment_header_length - new_segment_size + ) + self._content_offset = ( + previous_segments_total_content_size + new_segment_size + ) + + self._current_segment_number = new_segment_num + + self._update_current_region_length() + self._inner_stream.seek( + (self._initial_content_position or 0) + self._content_offset + ) + return position + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region( + self._current_region, remaining, output + ) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += ( + self._segment_header_length + self._segment_footer_length + ) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min( + self._segment_size, self.content_length - self._content_offset + ) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region( + self, region: SMRegion, size: int, output: BytesIO + ) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[ + self._current_region_offset : self._current_region_offset + read_size + ] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + # Will be non-zero if there is data to read that does not need to have checksum calculated. + # Will always be positive as stream can only seek backwards. + checksum_offset = self._checksum_offset - self._content_offset + + read_size = min(size, self._current_region_length - self._current_region_offset) + if checksum_offset != 0: + # Only read up to checksum offset this iteration + read_size = min(read_size, checksum_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + if checksum_offset == 0: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + # Only update the checksum offset if we've read new data + if self._content_offset > self._checksum_offset: + self._checksum_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py new file mode 100644 index 000000000000..ee7d92d14d77 --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder( + IOBase +): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner( + StructuredMessageConstants.V1_HEADER_LENGTH + ) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py new file mode 100644 index 000000000000..5370d9dd669c --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py @@ -0,0 +1,105 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from enum import Enum +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +from azure.core import CaseInsensitiveEnumMeta + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + + +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + AUTO = "auto" + MD5 = "md5" + CRC64 = "crc64" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], +) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + if validate_content not in (ChecksumAlgorithm.list()): + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if validate_content == ChecksumAlgorithm.AUTO: + validate_content = ChecksumAlgorithm.CRC64.value + + if validate_content == ChecksumAlgorithm.CRC64: + _verify_extensions("crc64") + + return validate_content + + +def is_md5_validation( + validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == ChecksumAlgorithm.MD5 + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index cc275ae4af61..61d6a8b3dc9a 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -12,7 +12,7 @@ import uuid from io import SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +32,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_md5_validation, + ChecksumAlgorithm, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +56,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -69,7 +87,9 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs + ) # Is this method/status code retryable? (Based on allowlists and control @@ -89,7 +109,9 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get("x-ms-copy-source-error-code") + error_code = response.http_response.headers.get( + "x-ms-copy-source-error-code" + ) if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -106,12 +128,16 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -146,7 +172,9 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str( + uuid.uuid1() + ) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -186,7 +214,9 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError(f"Attempting to use undefined host location {use_location}") + raise ValueError( + f"Attempting to use undefined host location {use_location}" + ) if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -204,7 +234,9 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) + super(StorageLoggingPolicy, self).__init__( + logging_enable=logging_enable, **kwargs + ) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -288,7 +320,9 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get("content-type", "") + resp_content_type = response.http_response.headers.get( + "content-type", "" + ) if header and pattern.match(header): filename = header.partition("=")[2] @@ -317,7 +351,9 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop("raw_request_hook", self._request_callback) + request_callback = request.context.options.pop( + "raw_request_hook", self._request_callback + ) if request_callback: request_callback(request) @@ -335,36 +371,50 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -372,64 +422,133 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. +def _prepare_content_validation(request: "PipelineRequest") -> None: + """Shared request-side logic for content validation. - This will overwrite any headers already defined in the request. + Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 + validation, and stores the validation mode in the request context. """ + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" - - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() + # Download + if request.http_request.method == "GET": + if validate_content == ChecksumAlgorithm.CRC64: + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif validate_content == ChecksumAlgorithm.CRC64: + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64( + calculate_crc64_bytes(data) + ) + elif hasattr(data, "read"): + content_length = int( + request.http_request.headers.get(CONTENT_LENGTH_HEADER) + ) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream( + data, content_length, StructuredMessageProperties.CRC64 + ) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str( + len(sm_stream) + ) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) - def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content + request.context["validate_content"] = validate_content - def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + """Shared response-side logic for content validation. + + Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches + ``stream_download`` to wrap the iterator in the given *decoder_cls*. + """ + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get( + "content-md5" + ): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif validate_content == ChecksumAlgorithm.CRC64: + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}", + ), + response=response.http_response, ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls( + iterator, content_length, block_size=DATA_BLOCK_SIZE ) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + def on_request(self, request: "PipelineRequest") -> None: + _prepare_content_validation(request) + + def on_response( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> None: + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -456,7 +575,9 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: + def _set_next_host_location( + self, settings: Dict[str, Any], request: "PipelineRequest" + ) -> None: """ A function which sets the next host location on the request, if applicable. @@ -495,7 +616,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), + "retry_secondary": options.pop( + "retry_to_secondary", self.retry_to_secondary + ), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -504,7 +627,9 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument + def get_backoff_time( + self, settings: Dict[str, Any] + ) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -567,7 +692,9 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append(RequestHistory(request, http_response=response)) + settings["history"].append( + RequestHistory(request, http_response=response) + ) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -602,7 +729,9 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry( + response + ): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -621,7 +750,9 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: retry_hook( retry_settings, @@ -681,7 +812,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -694,8 +827,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -733,7 +872,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -748,7 +889,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -756,10 +901,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "TokenCredential", audience: str, **kwargs: Any + ) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: """Handle the challenge from the service and authorize the request. :param request: The request object. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index f4d235c082d8..14ce070e47ff 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -19,11 +19,17 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE from .policies import ( + _prepare_content_validation, + _validate_content_response, encode_base64, is_retry, - StorageContentValidation, StorageRetryPolicy, ) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -39,27 +45,66 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + await settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) else: - settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) + settings["hook"]( + retry_count=settings["count"] - 1, + location_mode=settings["mode"], + **kwargs + ) async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) + computed_md5 = response.http_request.headers.get( + "content-md5", None + ) or encode_base64(calculate_content_md5(response.http_response.body())) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -73,36 +118,50 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop("download_stream_current", None) + download_stream_current = request.context.options.pop( + "download_stream_current", None + ) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop("upload_stream_current", None) + upload_stream_current = request.context.options.pop( + "upload_stream_current", None + ) - response_callback = request.context.get("response_callback") or request.context.options.pop( - "raw_response_hook", self._response_callback - ) + response_callback = request.context.get( + "response_callback" + ) or request.context.options.pop("raw_response_hook", self._response_callback) response = await self.next.send(request) - will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) + will_retry = is_retry( + response, request.context.options.get("mode") + ) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) + download_stream_current += int( + response.http_response.headers.get("Content-Length", 0) + ) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) + data_stream_total = int( + content_range.split(" ", 1)[1].split("/", 1)[1] + ) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) + upload_stream_current += int( + response.http_request.headers.get("Content-Length", 0) + ) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = download_stream_current + pipeline_obj.context["download_stream_current"] = ( + download_stream_current + ) pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -131,7 +190,9 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): + if is_retry( + response, retry_settings["mode"] + ) or await is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -150,7 +211,9 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) + retries_remaining = self.increment( + retry_settings, request=request.http_request, error=err + ) if retries_remaining: await retry_hook( retry_settings, @@ -212,7 +275,9 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(ExponentialRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -225,8 +290,14 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) - random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 + backoff = self.initial_backoff + ( + 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) + ) + random_range_start = ( + backoff - self.random_jitter_range + if backoff > self.random_jitter_range + else 0 + ) random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -264,7 +335,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) + super(LinearRetry, self).__init__( + retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs + ) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -279,7 +352,11 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 + random_range_start = ( + self.backoff - self.random_jitter_range + if self.backoff > self.random_jitter_range + else 0 + ) random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -287,10 +364,16 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) + def __init__( + self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any + ) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( + credential, audience, **kwargs + ) - async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: + async def on_challenge( + self, request: "PipelineRequest", response: "PipelineResponse" + ) -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py new file mode 100644 index 000000000000..712f4e90af69 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py @@ -0,0 +1,703 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header( + version: int, size: int, flags: StructuredMessageProperties, num_segments: int +) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header( + data: bytes, expected_message_length: int +) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " + f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError( + f"Structured message segment number invalid or out of order {segment_number}" + ) + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream( + IOBase +): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _checksum_offset: int + """Tracks the offset the checksum has been calculated up to for seeking purposes""" + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._checksum_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + def close(self) -> None: + self._inner_stream.close() + super().close() + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return ( + self._inner_stream.seekable() + and self._initial_content_position is not None + ) + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) + * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number + * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence == SEEK_SET: + position = offset + elif whence == SEEK_CUR: + position = self.tell() + offset + elif whence == SEEK_END: + position = self.message_length + offset + else: + raise ValueError(f"Invalid value for whence: {whence}") + + if position < 0: + raise ValueError(f"Cannot seek to negative position: {position}") + if position > self.tell(): + raise UnsupportedOperation("This stream only supports seeking backwards.") + + # MESSAGE_HEADER + if position < self._message_header_length: + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_offset = position + self._content_offset = 0 + self._current_segment_number = 0 + # MESSAGE_FOOTER + elif position >= self.message_length - self._message_footer_length: + self._current_region = SMRegion.MESSAGE_FOOTER + self._current_region_offset = position - ( + self.message_length - self._message_footer_length + ) + self._content_offset = self.content_length + self._current_segment_number = self._num_segments + else: + # The size of a "full" segment. Fine to use for calculating new segment number and pos + full_segment_size = ( + self._segment_header_length + + self._segment_size + + self._segment_footer_length + ) + new_segment_num = ( + 1 + (position - self._message_header_length) // full_segment_size + ) + segment_pos = (position - self._message_header_length) % full_segment_size + previous_segments_total_content_size = ( + new_segment_num - 1 + ) * self._segment_size + + # We need the size of the segment we are seeking to for some of the calculations below + new_segment_size = self._segment_size + if new_segment_num == self._num_segments: + # The last segment size is the remaining content length + new_segment_size = ( + self.content_length - previous_segments_total_content_size + ) + + # SEGMENT_HEADER + if segment_pos < self._segment_header_length: + self._current_region = SMRegion.SEGMENT_HEADER + self._current_region_offset = segment_pos + self._content_offset = previous_segments_total_content_size + # SEGMENT_CONTENT + elif segment_pos < self._segment_header_length + new_segment_size: + self._current_region = SMRegion.SEGMENT_CONTENT + self._current_region_offset = segment_pos - self._segment_header_length + self._content_offset = ( + previous_segments_total_content_size + self._current_region_offset + ) + # SEGMENT_FOOTER + else: + self._current_region = SMRegion.SEGMENT_FOOTER + self._current_region_offset = ( + segment_pos - self._segment_header_length - new_segment_size + ) + self._content_offset = ( + previous_segments_total_content_size + new_segment_size + ) + + self._current_segment_number = new_segment_num + + self._update_current_region_length() + self._inner_stream.seek( + (self._initial_content_position or 0) + self._content_offset + ) + return position + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region( + self._current_region, remaining, output + ) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += ( + self._segment_header_length + self._segment_footer_length + ) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min( + self._segment_size, self.content_length - self._content_offset + ) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region( + self, region: SMRegion, size: int, output: BytesIO + ) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[ + self._current_region_offset : self._current_region_offset + read_size + ] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + # Will be non-zero if there is data to read that does not need to have checksum calculated. + # Will always be positive as stream can only seek backwards. + checksum_offset = self._checksum_offset - self._content_offset + + read_size = min(size, self._current_region_length - self._current_region_offset) + if checksum_offset != 0: + # Only read up to checksum offset this iteration + read_size = min(read_size, checksum_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + if checksum_offset == 0: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + # Only update the checksum offset if we've read new data + if self._content_offset > self._checksum_offset: + self._checksum_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py new file mode 100644 index 000000000000..ee7d92d14d77 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder( + IOBase +): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError( + "Content not long enough to contain a valid message header." + ) + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _message_footer_length(self) -> int: + return ( + StructuredMessageConstants.CRC64_LENGTH + if StructuredMessageProperties.CRC64 in self.flags + else 0 + ) + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError( + "First message segment was empty but more segments were detected." + ) + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not ( + self._end_of_segment_content and self._message_offset == self.message_length + ): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = ( + self._segment_content_length - self._segment_content_offset + ) + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64( + segment_content, self._segment_crc64 + ) + self._message_crc64 = calculate_crc64( + segment_content, self._message_crc64 + ) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if ( + self._message_offset == self.message_length + and self._segment_number != self.num_segments + ): + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError( + "Invalid structured message data detected. Stream content incomplete." + ) + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner( + StructuredMessageConstants.V1_HEADER_LENGTH + ) + self.message_version, self.flags, self.num_segments = parse_message_header( + header_data, self.message_length + ) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. " + "All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header( + header_data, self._segment_number + 1 + ) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner( + StructuredMessageConstants.CRC64_LENGTH + ) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py new file mode 100644 index 000000000000..5370d9dd669c --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py @@ -0,0 +1,105 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from enum import Enum +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +from azure.core import CaseInsensitiveEnumMeta + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + + +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + AUTO = "auto" + MD5 = "md5" + CRC64 = "crc64" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], +) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + if validate_content not in (ChecksumAlgorithm.list()): + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if validate_content == ChecksumAlgorithm.AUTO: + validate_content = ChecksumAlgorithm.CRC64.value + + if validate_content == ChecksumAlgorithm.CRC64: + _verify_extensions("crc64") + + return validate_content + + +def is_md5_validation( + validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == ChecksumAlgorithm.MD5 + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) From c5f2a8c9123ea160879a461875570443186560b8 Mon Sep 17 00:00:00 2001 From: Peter Wu <162184229+weirongw23-msft@users.noreply.github.com> Date: Wed, 25 Mar 2026 16:08:55 -0400 Subject: [PATCH 06/14] [Storage] [STG 102] Added Support for IPv6 Accounts + Data Lake Tag Tests (#45766) --- sdk/storage/azure-storage-blob/CHANGELOG.md | 12 ++++++++++++ .../azure/storage/blob/_shared/base_client.py | 1 + sdk/storage/azure-storage-file-datalake/CHANGELOG.md | 8 ++++++++ sdk/storage/azure-storage-file-share/CHANGELOG.md | 6 ++++++ sdk/storage/azure-storage-queue/CHANGELOG.md | 3 +++ 5 files changed, 30 insertions(+) diff --git a/sdk/storage/azure-storage-blob/CHANGELOG.md b/sdk/storage/azure-storage-blob/CHANGELOG.md index 52e79b9fe146..1073d9149a0c 100644 --- a/sdk/storage/azure-storage-blob/CHANGELOG.md +++ b/sdk/storage/azure-storage-blob/CHANGELOG.md @@ -3,6 +3,18 @@ ## 12.31.0b1 (Unreleased) ### Features Added +- Added support for service version 2026-06-06. +- Added support for connection strings and `account_url`s to accept URLs with `-ipv6` and `-dualstack` suffixes +for `BlobServiceClient`, `ContainerClient`, and `BlobClient`. +- Added support for `create` permission in `BlobSasPermissions` for `stage_block`, +`stage_block_from_url`, and `commit_block_list`. +- Added support for a new `Smart` access tier to `StandardBlobTier` used in `BlobClient.set_standard_blob_tier`, +which is optimized to automatically determine the most cost-effective access with no performance impact. +When set, `BlobProperties.smart_access_tier` will reveal the service's current access +tier choice between `Hot`, `Cool`, and `Archive`. + +### Other Changes +- Consolidated the behavior of `max_concurrency=None` by defaulting to the shared `DEFAULT_MAX_CONCURRENCY` constant. ## 12.29.0 (2026-05-14) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 8fd641acd2c2..0c509131cc65 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -33,6 +33,7 @@ RedirectPolicy, UserAgentPolicy, ) +from tests.settings.settings_real import SECONDARY_STORAGE_ACCOUNT_NAME from .authentication import SharedKeyCredentialPolicy from .constants import ( diff --git a/sdk/storage/azure-storage-file-datalake/CHANGELOG.md b/sdk/storage/azure-storage-file-datalake/CHANGELOG.md index bc59d313ba2a..5a1c45f8691b 100644 --- a/sdk/storage/azure-storage-file-datalake/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-datalake/CHANGELOG.md @@ -3,6 +3,14 @@ ## 12.26.0b1 (Unreleased) ### Features Added +- Added support for service version 2026-06-06. +- Added support for connection strings and `account_url`s to accept URLs with `-ipv6` and `-dualstack` suffixes +for `DataLakeServiceClient`, `FileSystemClient`, `DataLakeDirectoryClient`, and `DataLakeFileClient`. +- Added support for `DataLakeDirectoryClient` and `DataLakeFileClient`'s `set_tags` and `get_tags` APIs +to conditionally set and get tags associated with a directory or file client, respectively. + +### Other Changes +- Consolidated the behavior of `max_concurrency=None` by defaulting to the shared `DEFAULT_MAX_CONCURRENCY` constant. ## 12.24.0 (2026-05-14) diff --git a/sdk/storage/azure-storage-file-share/CHANGELOG.md b/sdk/storage/azure-storage-file-share/CHANGELOG.md index 374a1ba89946..470d6eb8bbe9 100644 --- a/sdk/storage/azure-storage-file-share/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-share/CHANGELOG.md @@ -3,10 +3,16 @@ ## 12.27.0b1 (Unreleased) ### Features Added +- Added support for service version 2026-06-06. - Added support for the keyword `file_property_semantics` in `ShareClient`'s `create_directory` and `DirectoryClient`'s `create_directory` APIs, which specifies permissions to be configured upon directory creation. - Added support for the keyword `data` to `FileClient`'s `create_file` API, which specifies the optional initial data to be uploaded (up to 4MB). +- Added support for connection strings and `account_url`s to accept URLs with `-ipv6` and `-dualstack` suffixes +for `ShareClient`, `ShareDirectoryClient`, and `ShareFileClient`. + +### Other Changes +- Consolidated the behavior of `max_concurrency=None` by defaulting to the shared `DEFAULT_MAX_CONCURRENCY` constant. ## 12.25.0 (2026-05-14) diff --git a/sdk/storage/azure-storage-queue/CHANGELOG.md b/sdk/storage/azure-storage-queue/CHANGELOG.md index b1b136d2741d..b886658fe3cb 100644 --- a/sdk/storage/azure-storage-queue/CHANGELOG.md +++ b/sdk/storage/azure-storage-queue/CHANGELOG.md @@ -3,6 +3,9 @@ ## 12.18.0b1 (Unreleased) ### Features Added +- Added support for service version 2026-06-06. +- Added support for connection strings and `account_url`s to accept URLs with `-ipv6` and `-dualstack` suffixes +for `QueueServiceClient` and `QueueClient`. ## 12.16.0 (2026-05-14) From 98afe7b84257f9562f13753e953fb433b82bab69 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 26 Mar 2026 01:10:26 -0400 Subject: [PATCH 07/14] Fixed test collection bug --- .../azure-storage-blob/azure/storage/blob/_shared/base_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 0c509131cc65..8fd641acd2c2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -33,7 +33,6 @@ RedirectPolicy, UserAgentPolicy, ) -from tests.settings.settings_real import SECONDARY_STORAGE_ACCOUNT_NAME from .authentication import SharedKeyCredentialPolicy from .constants import ( From a5c2e39bfd792f00df0e8802829c8829e6136e63 Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Wed, 1 Apr 2026 16:04:36 -0700 Subject: [PATCH 08/14] [Storage][102] CRC64 content validation - part 5 - datalake (#46034) --- .../tests/test_block_blob.py | 4 +- .../tests/test_block_blob_async.py | 2 +- .../tests/test_content_validation.py | 4 +- .../tests/test_content_validation_async.py | 36 +- .../filedatalake/_data_lake_file_client.py | 31 +- .../filedatalake/_data_lake_file_client.pyi | 5 +- .../_data_lake_file_client_helpers.py | 7 +- .../filedatalake/_shared/base_client_async.py | 5 +- .../aio/_data_lake_file_client_async.py | 31 +- .../aio/_data_lake_file_client_async.pyi | 5 +- .../dev_requirements.txt | 1 + .../tests/test_content_validation.py | 309 ++++++++++++++++++ .../tests/test_content_validation_async.py | 297 +++++++++++++++++ .../fileshare/_shared/base_client_async.py | 5 +- .../queue/_shared/base_client_async.py | 4 +- 15 files changed, 673 insertions(+), 73 deletions(-) create mode 100644 sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py create mode 100644 sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob.py b/sdk/storage/azure-storage-blob/tests/test_block_blob.py index 2f939e5e5665..d2278573ac80 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob.py @@ -33,7 +33,7 @@ ImmutabilityPolicy, StandardBlobTier, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -555,7 +555,7 @@ def test_upload_blob_from_url_with_source_content_md5(self, **kwargs): source_blob = self._create_blob(data=b"This is test data to be copied over.") source_blob_props = source_blob.get_blob_properties() source_md5 = source_blob_props.content_settings.content_md5 - bad_source_md5 = StorageContentValidation.get_content_md5(b"this is bad data") + bad_source_md5 = calculate_content_md5(b"this is bad data") sas = self.generate_sas( generate_blob_sas, account_name=storage_account_name, diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py index c52a301cba20..17278ca4dfb3 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py @@ -580,7 +580,7 @@ async def test_upload_blob_from_url_with_source_content_md5(self, **kwargs): source_blob = await self._create_blob(data=b"This is test data to be copied over.") source_blob_props = await source_blob.get_blob_properties() source_md5 = source_blob_props.content_settings.content_md5 - bad_source_md5 = StorageContentValidation.get_content_md5(b"this is bad data") + bad_source_md5 = calculate_content_md5(b"this is bad data") sas = self.generate_sas( generate_blob_sas, account_name=storage_account_name, diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py index 5d670ba330fb..e17777f65301 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -465,10 +465,10 @@ def test_download_blob_large_chunks(self, **kwargs): blob.upload_blob(data, overwrite=True, max_concurrency=5) # Act - downloader = blob.download_blob(validate_content='crc64', max_concurrency=3) + downloader = blob.download_blob(validate_content='crc64', max_concurrency=5) content = downloader.read() - downloader = blob.download_blob(offset=5 * 1024 * 1024, length=25 *1024 * 1024) + downloader = blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') partial = downloader.read() # Assert diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py index b9de57c7015d..80b6cfe3308f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -8,7 +8,7 @@ import pytest from azure.core.exceptions import ResourceExistsError -from azure.storage.blob import BlobBlock, BlobType +from azure.storage.blob import BlobBlock, BlobType, ContainerClient as SyncContainerClient from azure.storage.blob.aio import ( BlobClient, BlobServiceClient, @@ -46,11 +46,16 @@ async def _setup(self, account_name): except ResourceExistsError: pass - # TODO: Figure out how to get this to run automatically - async def _teardown(self): + def teardown_method(self, _): + # Use sync client as teardown_method must be sync if self.container: + sync_credential = self.get_credential(SyncContainerClient, is_async=False) + sync_container = SyncContainerClient.from_container_url( + self.container.url, + credential=sync_credential) + try: - await self.container.delete_container() + sync_container.delete_container() except: pass @@ -74,6 +79,9 @@ async def test_encryption_blocked_crc64(self, **kwargs): with pytest.raises(ValueError): await blob.upload_blob(b'123', validate_content='crc64') + # Needed for teardown + self.container = None + @BlobPreparer() @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type @pytest.mark.parametrize('b', [True, "auto", 'md5', 'crc64']) # b: validate_content @@ -106,8 +114,6 @@ async def test_upload_blob(self, a, b, **kwargs): await blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) assert await (await blob.download_blob()).read() == str_data_encoded - await self._teardown() - @BlobPreparer() @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content @@ -143,8 +149,6 @@ async def test_upload_blob_chunks(self, a, b, **kwargs): await blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) assert await (await blob.download_blob()).read() == str_data_encoded - await self._teardown() - @BlobPreparer() @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content @GenericTestProxyParametrize1() @@ -169,7 +173,6 @@ async def test_upload_blob_substream(self, a, **kwargs): # Assert content = await blob.download_blob() assert await content.read() == data - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @@ -200,7 +203,6 @@ def generator(): # Assert content = await blob.download_blob() assert await content.read() == data1 + data2.encode('utf-8-sig') + data1 - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content @@ -221,7 +223,6 @@ async def test_stage_block_streaming(self, a, **kwargs): # Assert result = await blob.download_blob() assert await result.read() == content - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content @@ -274,7 +275,6 @@ def generator(): # Assert content = await blob.download_blob() assert await content.readall() == data1 + data2.encode('utf-16') + data1 - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content @@ -294,7 +294,6 @@ async def test_append_block_streaming(self, a, **kwargs): result = await blob.download_blob() assert await result.read() == content - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content @@ -317,7 +316,6 @@ async def test_append_block_streaming_large(self, a, **kwargs): result = await blob.download_blob() assert await result.read() == data1 + data2 + data3 - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @@ -341,7 +339,6 @@ async def test_upload_page(self, a, **kwargs): # Assert content = await blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) assert await content.read() == data1 + data2_encoded - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content @@ -368,7 +365,6 @@ async def test_download_blob(self, a, **kwargs): # Assert assert content == data assert stream.read() == data - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content @@ -403,7 +399,6 @@ async def test_download_blob_chunks(self, a, **kwargs): assert content == data assert stream.read() == data assert read_content == data - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content @@ -432,7 +427,6 @@ async def test_download_blob_chunks_partial(self, a, **kwargs): # Assert assert content == data[10:1010] assert stream.read() == data[10:1010] - await self._teardown() @BlobPreparer() @pytest.mark.live_test_only @@ -447,16 +441,15 @@ async def test_download_blob_large_chunks(self, **kwargs): await blob.upload_blob(data, overwrite=True, max_concurrency=5) # Act - downloader = await blob.download_blob(validate_content='crc64', max_concurrency=3) + downloader = await blob.download_blob(validate_content='crc64', max_concurrency=5) content = await downloader.read() - downloader = await blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024) + downloader = await blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') partial = await downloader.read() # Assert assert content == data assert partial == data[5 * 1024 * 1024: 30 * 1024 * 1024] - await self._teardown() @BlobPreparer() @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content @@ -488,4 +481,3 @@ async def test_download_blob_chars(self, a, **kwargs): result += await stream.readall() assert result == data - await self._teardown() diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.py index 6641ea74f21f..7991643287f0 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.py @@ -427,15 +427,11 @@ def upload_data( If a date is passed in without timezone info, it is assumed to be UTC. Specify this header to perform the operation only if the resource has not been modified since the specified date/time. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword str etag: An ETag value, or the wildcard character (*). Used to check if the resource has changed, and act according to the condition specified by the `match_condition` parameter. @@ -497,13 +493,11 @@ def append_data( :type length: int or None :keyword bool flush: If true, will commit the data after it is appended. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https as https (the default) - will already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease_action: Used to perform lease operations along with appending data. @@ -714,6 +708,11 @@ def download_file( :paramtype progress_hook: ~typing.Callable[[int, int], None] :keyword bool decompress: If True, any compressed content, identified by the Content-Encoding header, will be decompressed automatically before being returned. Default value is True. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/rest/api/storageservices/setting-timeouts-for-blob-service-operations. diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.pyi b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.pyi index c7e7594894b3..020a89378f46 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.pyi +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.pyi @@ -149,7 +149,7 @@ class DataLakeFileClient(PathClient): if_unmodified_since: Optional[datetime] = None, etag: Optional[str] = None, match_condition: Optional[MatchConditions] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, max_concurrency: Optional[int] = None, chunk_size: Optional[int] = None, @@ -166,7 +166,7 @@ class DataLakeFileClient(PathClient): length: Optional[int] = None, *, flush: Optional[bool] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease_action: Optional[Literal["acquire", "auto-renew", "release", "acquire-release"]] = None, lease_duration: int = -1, lease: Optional[Union[DataLakeLeaseClient, str]] = None, @@ -207,6 +207,7 @@ class DataLakeFileClient(PathClient): cpk: Optional[CustomerProvidedEncryptionKey] = None, max_concurrency: Optional[int] = None, progress_hook: Optional[Callable[[int, int], None]] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, timeout: Optional[int] = None, **kwargs: Any ) -> StorageStreamDownloader: ... diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client_helpers.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client_helpers.py index 86a0a521d15d..5cf92d7bcfce 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client_helpers.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client_helpers.py @@ -24,6 +24,7 @@ from ._shared.response_handlers import return_response_headers from ._shared.uploads import IterStreamer from ._shared.uploads_async import AsyncIterStreamer +from ._shared.validation import parse_validation_option if TYPE_CHECKING: from ._generated.operations import PathOperations @@ -47,6 +48,8 @@ def _append_data_options( if isinstance(data, bytes): data = data[:length] + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + cpk_info = get_cpk_info(scheme, kwargs) kwargs.update(get_lease_action_properties(kwargs)) @@ -54,7 +57,7 @@ def _append_data_options( 'body': data, 'position': offset, 'content_length': length, - 'validate_content': kwargs.pop('validate_content', False), + 'validate_content': validate_content, 'cpk_info': cpk_info, 'timeout': kwargs.pop('timeout', None), 'cls': return_response_headers @@ -122,7 +125,7 @@ def _upload_options( else: raise TypeError(f"Unsupported data type: {type(data)}") - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) content_settings = kwargs.pop('content_settings', None) metadata = kwargs.pop('metadata', None) max_concurrency = kwargs.pop('max_concurrency', None) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py index 16aba3116029..2c917610eade 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py @@ -36,12 +36,11 @@ from .parser import DEVSTORE_ACCOUNT_KEY, _get_development_storage_endpoint from .policies import ( QueueMessagePolicy, - StorageContentValidation, StorageHeadersPolicy, StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook +from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncContentValidationPolicy, AsyncStorageResponseHook from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken @@ -130,7 +129,7 @@ def _create_pipeline( QueueMessagePolicy(), config.proxy_policy, config.user_agent_policy, - StorageContentValidation(), + AsyncContentValidationPolicy(), ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=hosts, **kwargs), diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py index 1edf3832a690..86695a19ec10 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py @@ -443,15 +443,11 @@ async def upload_data( If a date is passed in without timezone info, it is assumed to be UTC. Specify this header to perform the operation only if the resource has not been modified since the specified date/time. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword str etag: An ETag value, or the wildcard character (*). Used to check if the resource has changed, and act according to the condition specified by the `match_condition` parameter. @@ -513,13 +509,11 @@ async def append_data( :type length: int or None :keyword bool flush: If true, will commit the data after it is appended. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https as https (the default) - will already validate. Note that this MD5 hash is not stored with the - file. + ::keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease_action: Used to perform lease operations along with appending data. @@ -730,6 +724,11 @@ async def download_file( :paramtype progress_hook: ~typing.Callable[[int, Optional[int]], Awaitable[None]] :keyword bool decompress: If True, any compressed content, identified by the Content-Encoding header, will be decompressed automatically before being returned. Default value is True. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/rest/api/storageservices/setting-timeouts-for-blob-service-operations. diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.pyi b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.pyi index 8da588a32f51..ff21c42b73e6 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.pyi +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.pyi @@ -148,7 +148,7 @@ class DataLakeFileClient(PathClient): if_unmodified_since: Optional[datetime] = None, etag: Optional[str] = None, match_condition: Optional[MatchConditions] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, max_concurrency: Optional[int] = None, chunk_size: Optional[int] = None, @@ -165,7 +165,7 @@ class DataLakeFileClient(PathClient): length: Optional[int] = None, *, flush: Optional[bool] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease_action: Optional[Literal["acquire", "auto-renew", "release", "acquire-release"]] = None, lease_duration: int = -1, lease: Optional[Union[DataLakeLeaseClient, str]] = None, @@ -206,6 +206,7 @@ class DataLakeFileClient(PathClient): cpk: Optional[CustomerProvidedEncryptionKey] = None, max_concurrency: Optional[int] = None, progress_hook: Optional[Callable[[int, Optional[int]], Awaitable[None]]] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, timeout: Optional[int] = None, **kwargs: Any ) -> StorageStreamDownloader: ... diff --git a/sdk/storage/azure-storage-file-datalake/dev_requirements.txt b/sdk/storage/azure-storage-file-datalake/dev_requirements.txt index 60d588f5e1e1..b42a7dff1a22 100644 --- a/sdk/storage/azure-storage-file-datalake/dev_requirements.txt +++ b/sdk/storage/azure-storage-file-datalake/dev_requirements.txt @@ -2,4 +2,5 @@ ../../core/azure-core ../../identity/azure-identity ../azure-storage-blob +../azure-storage-extensions aiohttp>=3.13.5 diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py new file mode 100644 index 000000000000..87d8170e6035 --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py @@ -0,0 +1,309 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from azure.storage.filedatalake import ( + DataLakeServiceClient +) + +from devtools_testutils import recorded_by_proxy +from devtools_testutils.storage import GenericTestProxyParametrize1, StorageRecordedTestCase +from settings.testcase import DataLakePreparer + + +def assert_content_md5(request): + if request.http_request.query.get('action') == 'append': + assert request.http_request.headers.get('Content-MD5') is not None + + +def assert_content_md5_get(response): + assert response.http_request.headers.get('x-ms-range-get-content-md5') == 'true' + assert response.http_response.headers.get('Content-MD5') is not None + + +def assert_content_crc64(request): + if request.http_request.query.get('action') == 'append': + assert request.http_request.headers.get('x-ms-content-crc64') is not None + + +def assert_structured_message(request): + if request.http_request.query.get('action') == 'append': + assert request.http_request.headers.get('x-ms-structured-body') is not None + + +def assert_structured_message_get(response): + assert response.http_request.headers.get('x-ms-structured-body') is not None + assert response.http_response.headers.get('x-ms-structured-body') is not None + + +class TestStorageContentValidation(StorageRecordedTestCase): + def _setup(self, account_name): + token_credential = self.get_credential(DataLakeServiceClient) + self.dsc = DataLakeServiceClient(self.account_url(account_name, "dfs"), credential=token_credential, logging_enable=True) + self.file_system = self.dsc.get_file_system_client(self.get_resource_name('filesystem')) + self.file_system.create_file_system() + + def teardown_method(self, _): + if self.file_system: + try: + self.file_system.delete_file_system() + except: + pass + + def _get_file_reference(self): + return self.get_resource_name('file') + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_data(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + + # Act + file.upload_data(data, overwrite=True, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_data_chunks(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abcde' * 512 + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + # Act + file.upload_data(data, overwrite=True, validate_content=a, chunk_size=1024, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_data_substream(self, a, **kwargs): + # Substream is disabled when using content validation so this will behave like regular upload (buffer) + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + self.file_system._config.min_large_chunk_upload_threshold = 1 # Set less than chunk size to enable substream + file = self.file_system.get_file_client(self._get_file_reference()) + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + data = b'abc' * 512 + b'abcde' + io = BytesIO(data) + + # Act + file.upload_data(io, overwrite=True, validate_content=a, chunk_size=512, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_append_data(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data1 = b'abcde' * 512 + data2 = '你好世界' * 10 + encoded2 = data2.encode('utf-8-sig') + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i: i + 500] + + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + + # Act + file.create_file() + file.append_data(data1, 0, validate_content=a, raw_request_hook=assert_method) + file.append_data(data2, len(data1), encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) + file.append_data(generator(), len(data1) + len(encoded2), validate_content=a, raw_request_hook=assert_method) + file.flush_data(len(data1) + len(encoded2) + len(data1)) + + # Assert + content = file.download_file() + assert content.read() == data1 + encoded2 + data1 + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_append_data_streaming(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + + content = b'abcde' * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Act + file.create_file() + file.append_data(BytesIO(content), 0, flush=True, validate_content=a, raw_request_hook=assert_method) + + # Assert + result = file.download_file() + assert result.read() == content + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.live_test_only + def test_append_data_streaming_large(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + + data1 = b'abcde' * 1024 * 1024 # 5 MiB + data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 + data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Act + file.create_file() + file.append_data(BytesIO(data1), 0, flush=True, validate_content=a, raw_request_hook=assert_method) + file.append_data(BytesIO(data2), len(data1), flush=True, validate_content=a, raw_request_hook=assert_method) + file.append_data(BytesIO(data3), len(data1) + len(data2), flush=True, validate_content=a, raw_request_hook=assert_method) + file.flush_data(len(data1) + len(data2) + len(data3)) + + # Assert + result = file.download_file() + assert result.read() == data1 + data2 + data3 + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + + # Act + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file_chunks(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + self.file_system._config.max_single_get_size = 512 + self.file_system._config.max_chunk_get_size = 512 + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + b'abcde' + file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + read_content = bytearray() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + for _ in range(len(data) // 100 + 1): + read_content.extend(downloader.read(100)) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file_chunks_partial(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + self.file_system._config.max_single_get_size = 512 + self.file_system._config.max_chunk_get_size = 512 + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + b'abcde' + file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = file.download_file(offset=512, length=1024, validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[512:1536] + + @DataLakePreparer() + @pytest.mark.live_test_only + def test_download_file_large_chunks(self, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.file_system._config.max_chunk_get_size = 10 * 1024 * 1024 + data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + file.upload_data(data, overwrite=True, max_concurrency=5) + + # Act + downloader = file.download_file(validate_content='crc64', max_concurrency=5) + content = downloader.read() + + downloader = file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + partial = downloader.read() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024:30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py new file mode 100644 index 000000000000..4c9f1af53013 --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py @@ -0,0 +1,297 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from azure.storage.filedatalake import FileSystemClient as SyncFileSystemClient +from azure.storage.filedatalake.aio import DataLakeServiceClient + +from devtools_testutils.aio import recorded_by_proxy_async +from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase, GenericTestProxyParametrize1 +from settings.testcase import DataLakePreparer +from test_content_validation import ( + assert_content_crc64, + assert_content_md5, + assert_content_md5_get, + assert_structured_message, + assert_structured_message_get +) + + +class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): + async def _setup(self, account_name): + token_credential = self.get_credential(DataLakeServiceClient, is_async=True) + self.dsc = DataLakeServiceClient(self.account_url(account_name, "dfs"), credential=token_credential, logging_enable=True) + self.file_system = self.dsc.get_file_system_client(self.get_resource_name('filesystem')) + await self.file_system.create_file_system() + + def teardown_method(self, _): + # Use sync client as teardown_method must be sync + if self.file_system: + sync_credential = self.get_credential(SyncFileSystemClient, is_async=False) + sync_file_system = SyncFileSystemClient( + self.account_url(self.file_system.account_name, "dfs"), + self.file_system.file_system_name, + credential=sync_credential) + + try: + sync_file_system.delete_file_system() + except: + pass + + def _get_file_reference(self): + return self.get_resource_name('file') + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_data(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + + # Act + await file.upload_data(data, overwrite=True, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_data_chunks(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abcde' * 512 + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + # Act + await file.upload_data(data, overwrite=True, validate_content=a, chunk_size=1024, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_data_substream(self, a, **kwargs): + # Substream is disabled when using content validation so this will behave like regular upload (buffer) + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + self.file_system._config.min_large_chunk_upload_threshold = 1 # Set less than chunk size to enable substream + file = self.file_system.get_file_client(self._get_file_reference()) + assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + + data = b'abc' * 512 + b'abcde' + io = BytesIO(data) + + # Act + await file.upload_data(io, overwrite=True, validate_content=a, chunk_size=512, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_append_data(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data1 = b'abcde' * 512 + data2 = '你好世界' * 10 + encoded2 = data2.encode('utf-8-sig') + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i: i + 500] + + assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + + # Act + await file.create_file() + await file.append_data(data1, 0, validate_content=a, raw_request_hook=assert_method) + await file.append_data(data2, len(data1), encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) + await file.append_data(generator(), len(data1) + len(encoded2), validate_content=a, raw_request_hook=assert_method) + await file.flush_data(len(data1) + len(encoded2) + len(data1)) + + # Assert + content = await file.download_file() + assert await content.read() == data1 + encoded2 + data1 + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_append_data_streaming(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + + content = b'abcde' * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Act + await file.create_file() + await file.append_data(BytesIO(content), 0, flush=True, validate_content=a, raw_request_hook=assert_method) + + # Assert + result = await file.download_file() + assert await result.read() == content + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.live_test_only + async def test_append_data_streaming_large(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + + data1 = b'abcde' * 1024 * 1024 # 5 MiB + data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 + data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Act + await file.create_file() + await file.append_data(BytesIO(data1), 0, flush=True, validate_content=a, raw_request_hook=assert_method) + await file.append_data(BytesIO(data2), len(data1), flush=True, validate_content=a, raw_request_hook=assert_method) + await file.append_data(BytesIO(data3), len(data1) + len(data2), flush=True, validate_content=a, raw_request_hook=assert_method) + await file.flush_data(len(data1) + len(data2) + len(data3)) + + # Assert + result = await file.download_file() + assert await result.read() == data1 + data2 + data3 + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + await file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + + # Act + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file_chunks(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + self.file_system._config.max_single_get_size = 512 + self.file_system._config.max_chunk_get_size = 512 + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + b'abcde' + await file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + read_content = bytearray() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + for _ in range(len(data) // 100 + 1): + read_content.extend(await downloader.read(100)) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @DataLakePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file_chunks_partial(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + self.file_system._config.max_single_get_size = 512 + self.file_system._config.max_chunk_get_size = 512 + file = self.file_system.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + b'abcde' + await file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = await file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await file.download_file(offset=512, length=1024, validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[512:1536] + + @DataLakePreparer() + @pytest.mark.live_test_only + async def test_download_file_large_chunks(self, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.file_system._config.max_chunk_get_size = 10 * 1024 * 1024 + data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + await file.upload_data(data, overwrite=True, max_concurrency=5) + + # Act + downloader = await file.download_file(validate_content='crc64', max_concurrency=5) + content = await downloader.read() + + downloader = await file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + partial = await downloader.read() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024:30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py index 16aba3116029..2c917610eade 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py @@ -36,12 +36,11 @@ from .parser import DEVSTORE_ACCOUNT_KEY, _get_development_storage_endpoint from .policies import ( QueueMessagePolicy, - StorageContentValidation, StorageHeadersPolicy, StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook +from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncContentValidationPolicy, AsyncStorageResponseHook from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken @@ -130,7 +129,7 @@ def _create_pipeline( QueueMessagePolicy(), config.proxy_policy, config.user_agent_policy, - StorageContentValidation(), + AsyncContentValidationPolicy(), ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=hosts, **kwargs), diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 54446f7fc5b4..993c2cfb354f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -36,12 +36,12 @@ from .parser import DEVSTORE_ACCOUNT_KEY, _get_development_storage_endpoint from .policies import ( QueueMessagePolicy, - StorageContentValidation, StorageHeadersPolicy, StorageHosts, StorageRequestHook, ) from .policies_async import ( + AsyncContentValidationPolicy, AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook, ) @@ -161,7 +161,7 @@ def _create_pipeline( QueueMessagePolicy(), config.proxy_policy, config.user_agent_policy, - StorageContentValidation(), + AsyncContentValidationPolicy(), ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=hosts, **kwargs), From f4a5b09f4e9ed47c8db60b5cb1a4d5609911a3fd Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Thu, 16 Apr 2026 13:09:11 -0700 Subject: [PATCH 09/14] [Storage][102] CRC64 content validation - part 6 - file-share (#46262) --- .../azure/storage/blob/_blob_client.py | 6 +- .../storage/blob/_blob_client_helpers.py | 8 +- .../azure/storage/blob/_download.py | 8 +- .../azure/storage/blob/_shared/policies.py | 15 +- .../azure/storage/blob/_shared/validation.py | 47 +-- .../azure/storage/blob/_upload_helpers.py | 9 +- .../storage/blob/aio/_blob_client_async.py | 6 +- .../azure/storage/blob/aio/_download_async.py | 6 +- .../azure/storage/blob/aio/_upload_helpers.py | 9 +- .../tests/test_append_blob.py | 10 +- .../tests/test_append_blob_async.py | 7 +- .../tests/test_block_blob_sync_copy.py | 6 +- .../tests/test_block_blob_sync_copy_async.py | 6 +- .../tests/test_page_blob.py | 16 +- .../tests/test_page_blob_async.py | 6 +- .../dev_requirements.txt | 2 +- .../storage/filedatalake/_shared/policies.py | 15 +- .../filedatalake/_shared/validation.py | 47 +-- .../azure/storage/fileshare/_download.py | 16 +- .../azure/storage/fileshare/_file_client.py | 67 +++-- .../azure/storage/fileshare/_file_client.pyi | 7 +- .../storage/fileshare/_shared/policies.py | 15 +- .../storage/fileshare/_shared/validation.py | 47 +-- .../storage/fileshare/aio/_download_async.py | 15 +- .../fileshare/aio/_file_client_async.py | 67 +++-- .../fileshare/aio/_file_client_async.pyi | 7 +- .../dev_requirements.txt | 1 + .../tests/test_content_validation.py | 270 ++++++++++++++++++ .../tests/test_content_validation_async.py | 261 +++++++++++++++++ .../azure/storage/queue/_shared/policies.py | 15 +- .../azure/storage/queue/_shared/validation.py | 47 +-- 31 files changed, 834 insertions(+), 230 deletions(-) create mode 100644 sdk/storage/azure-storage-file-share/tests/test_content_validation.py create mode 100644 sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index 89b49dac7a94..04ae539788b3 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -65,7 +65,7 @@ from ._quick_query_helper import BlobQueryReader from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin, TransportWrapper from ._shared.response_handlers import process_storage_error, return_response_headers -from ._shared.validation import ChecksumAlgorithm, parse_validation_option +from ._shared.validation import is_crc64_validation, parse_validation_option from ._serialize import ( get_access_conditions, get_api_version, @@ -614,7 +614,7 @@ def upload_blob( if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") validate_content = parse_validation_option(kwargs.pop('validate_content', None)) - if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + if is_crc64_validation(validate_content) and self.key_encryption_key: raise ValueError("Using encryption and content validation together is not currently supported.") options = _upload_blob_options( data=data, @@ -763,7 +763,7 @@ def download_blob( if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") validate_content = parse_validation_option(kwargs.pop('validate_content', None)) - if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + if is_crc64_validation(validate_content) and self.key_encryption_key: raise ValueError("Using encryption and content validation together is not currently supported.") options = _download_blob_options( blob_name=self.blob_name, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py index 1bbfb33901d3..ef628ecbb316 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py @@ -8,7 +8,7 @@ from io import BytesIO from typing import ( Any, AnyStr, AsyncGenerator, AsyncIterable, cast, - Dict, IO, Iterable, List, Literal, Optional, Tuple, Union, + Dict, IO, Iterable, List, Optional, Tuple, Union, TYPE_CHECKING ) from urllib.parse import quote, unquote, urlparse @@ -58,7 +58,7 @@ from ._shared.response_handlers import return_headers_and_deserialized, return_response_headers from ._shared.uploads import IterStreamer from ._shared.uploads_async import AsyncIterStreamer -from ._shared.validation import parse_validation_option +from ._shared.validation import CV_TYPE_PARSED, parse_validation_option from ._upload_helpers import _any_conditions if TYPE_CHECKING: @@ -111,7 +111,7 @@ def _upload_blob_options( # pylint:disable=too-many-statements length: Optional[int], metadata: Optional[Dict[str, str]], encryption_options: Dict[str, Any], - validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]], + validate_content: CV_TYPE_PARSED, config: "StorageConfiguration", sdk_moniker: str, client: "AzureBlobStorage", @@ -259,7 +259,7 @@ def _download_blob_options( length: Optional[int], encoding: Optional[str], encryption_options: Dict[str, Any], - validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]], + validate_content: CV_TYPE_PARSED, config: "StorageConfiguration", sdk_moniker: str, client: "AzureBlobStorage", diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py index ba7018772413..7be6d68858d7 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py @@ -11,7 +11,7 @@ from io import BytesIO, StringIO from typing import ( Any, Callable, cast, Dict, Generator, - Generic, IO, Iterator, List, Literal, Optional, + Generic, IO, Iterator, List, Optional, overload, Tuple, TypeVar, Union, TYPE_CHECKING ) @@ -21,7 +21,7 @@ from ._shared.request_handlers import validate_and_format_range_headers from ._shared.response_handlers import parse_length_from_content_range, process_storage_error from ._shared.constants import DEFAULT_MAX_CONCURRENCY -from ._shared.validation import is_md5_validation +from ._shared.validation import is_md5_validation, CV_TYPE_PARSED from ._deserialize import deserialize_blob_properties, get_page_ranges_result from ._encryption import ( adjust_blob_size_for_encryption, @@ -92,7 +92,7 @@ def __init__( current_progress: int, start_range: int, end_range: int, - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], + validate_content: CV_TYPE_PARSED, encryption_options: Dict[str, Any], encryption_data: Optional["_EncryptionData"] = None, stream: Any = None, @@ -330,7 +330,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, + validate_content: CV_TYPE_PARSED = None, encryption_options: Dict[str, Any] = None, # type: ignore [assignment] max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 63a517a4e305..9b10ece3de79 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -5,12 +5,11 @@ # -------------------------------------------------------------------------- import base64 -import hashlib import logging import random import re import uuid -from io import SEEK_SET, UnsupportedOperation +from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( @@ -43,8 +42,8 @@ CV_TYPE_ERROR_MSG, calculate_content_md5, calculate_crc64_bytes, + is_crc64_validation, is_md5_validation, - ChecksumAlgorithm, ) if TYPE_CHECKING: @@ -428,7 +427,7 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: # Download if request.http_request.method == "GET": - if validate_content == ChecksumAlgorithm.CRC64: + if is_crc64_validation(validate_content): request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 # Upload @@ -441,7 +440,11 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: request.http_request.headers[MD5_HEADER] = computed_md5 request.context["validate_content_md5"] = computed_md5 - elif validate_content == ChecksumAlgorithm.CRC64: + elif is_crc64_validation(validate_content): + # For crc64-sm, force structured message even for bytes + if validate_content == "crc64-sm" and isinstance(data, bytes): + data = BytesIO(data) + if isinstance(data, bytes): request.http_request.headers[CRC64_HEADER] = encode_base64( calculate_crc64_bytes(data) @@ -495,7 +498,7 @@ def _validate_content_response( response=response.http_response, ) - elif validate_content == ChecksumAlgorithm.CRC64: + elif is_crc64_validation(validate_content): # For upload and download verify structured message header present in response if provided in request. sm_request = request.http_request.headers.get(SM_HEADER) sm_response = response.http_response.headers.get(SM_HEADER) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py index 5370d9dd669c..21d3b081d8cc 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py @@ -6,24 +6,16 @@ # pylint: disable=c-extension-no-member import hashlib -from enum import Enum from io import SEEK_SET from typing import IO, Literal, Optional, Union, cast -from azure.core import CaseInsensitiveEnumMeta - CRC64_LENGTH = 8 CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." +_VALID_CV_OPTIONS = ("auto", "crc64", "crc64-sm", "md5") -class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): - AUTO = "auto" - MD5 = "md5" - CRC64 = "crc64" - - @classmethod - def list(cls): - return list(map(lambda c: c.value, cls)) +CV_TYPE = Optional[Union[bool, Literal["auto", "crc64", "md5"]]] +CV_TYPE_PARSED = Optional[Union[bool, Literal["crc64", "crc64-sm", "md5"]]] def _verify_extensions(module: str) -> None: @@ -37,8 +29,10 @@ def _verify_extensions(module: str) -> None: def parse_validation_option( - validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], -) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + validate_content: CV_TYPE, + *, + force_structured_message: bool = False, +) -> CV_TYPE_PARSED: if validate_content is None: return None @@ -46,27 +40,40 @@ def parse_validation_option( if isinstance(validate_content, bool): return validate_content - if validate_content not in (ChecksumAlgorithm.list()): + parsed = validate_content.lower() + if parsed not in _VALID_CV_OPTIONS: raise ValueError("Invalid value for `validate_content` specified.") # Resolve auto - if validate_content == ChecksumAlgorithm.AUTO: - validate_content = ChecksumAlgorithm.CRC64.value + if parsed == "auto": + parsed = "crc64" - if validate_content == ChecksumAlgorithm.CRC64: + if parsed == "crc64": _verify_extensions("crc64") + if force_structured_message: + parsed = "crc64-sm" - return validate_content + return cast(CV_TYPE_PARSED, parsed) def is_md5_validation( - validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], + validate_content: CV_TYPE_PARSED, ) -> bool: if validate_content is None: return False if isinstance(validate_content, bool): return validate_content - return validate_content == ChecksumAlgorithm.MD5 + return validate_content == "md5" + + +def is_crc64_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return False + return validate_content in ("crc64", "crc64-sm") def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py index 6873b93bb4e6..1b782ca1a8d8 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py @@ -5,7 +5,7 @@ # -------------------------------------------------------------------------- from io import SEEK_SET, UnsupportedOperation -from typing import Any, cast, Dict, IO, Literal, Optional, TypeVar, Union, TYPE_CHECKING +from typing import Any, cast, Dict, IO, Optional, TypeVar, TYPE_CHECKING from azure.core.exceptions import ResourceExistsError, ResourceModifiedError, HttpResponseError @@ -32,6 +32,7 @@ upload_data_chunks, upload_substream_blocks ) +from ._shared.validation import CV_TYPE_PARSED if TYPE_CHECKING: from ._generated.operations import AppendBlobOperations, BlockBlobOperations, PageBlobOperations @@ -71,7 +72,7 @@ def upload_block_blob( # pylint: disable=too-many-locals, too-many-statements encryption_options: Dict[str, Any], blob_settings: "StorageConfiguration", headers: Dict[str, Any], - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], + validate_content: CV_TYPE_PARSED, max_concurrency: Optional[int], length: Optional[int] = None, **kwargs: Any @@ -213,7 +214,7 @@ def upload_page_blob( headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: @@ -291,7 +292,7 @@ def upload_append_blob( # pylint: disable=unused-argument headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index 3152bda21036..a050d8206df2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -77,7 +77,7 @@ from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper, parse_connection_str from .._shared.policies_async import ExponentialRetry from .._shared.response_handlers import process_storage_error, return_response_headers -from .._shared.validation import ChecksumAlgorithm, parse_validation_option +from .._shared.validation import is_crc64_validation, parse_validation_option if TYPE_CHECKING: from azure.core import MatchConditions @@ -627,7 +627,7 @@ async def upload_blob( if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") validate_content = parse_validation_option(kwargs.pop('validate_content', None)) - if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + if is_crc64_validation(validate_content) and self.key_encryption_key: raise ValueError("Using encryption and content validation together is not currently supported.") options = _upload_blob_options( data=data, @@ -776,7 +776,7 @@ async def download_blob( if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") validate_content = parse_validation_option(kwargs.pop('validate_content', None)) - if validate_content == ChecksumAlgorithm.CRC64 and self.key_encryption_key: + if is_crc64_validation(validate_content) and self.key_encryption_key: raise ValueError("Using encryption and content validation together is not currently supported.") options = _download_blob_options( blob_name=self.blob_name, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index 7e5bbb0918b3..83124c72a6fc 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -15,7 +15,7 @@ from typing import ( Any, AsyncIterator, Awaitable, Generator, Callable, cast, Dict, - Generic, IO, Literal, Optional, overload, + Generic, IO, Optional, overload, Tuple, TypeVar, Union, TYPE_CHECKING ) @@ -24,7 +24,7 @@ from .._shared.request_handlers import validate_and_format_range_headers from .._shared.response_handlers import parse_length_from_content_range, process_storage_error from .._shared.constants import DEFAULT_MAX_CONCURRENCY -from .._shared.validation import is_md5_validation +from .._shared.validation import is_md5_validation, CV_TYPE_PARSED from .._deserialize import deserialize_blob_properties, get_page_ranges_result from .._download import process_range_and_offset, _ChunkDownloader from .._encryption import ( @@ -239,7 +239,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, + validate_content: CV_TYPE_PARSED = None, encryption_options: Dict[str, Any] = None, # type: ignore [assignment] max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py index 5b551fdec2fb..befdc7f505a0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py @@ -6,7 +6,7 @@ import inspect from io import SEEK_SET, UnsupportedOperation -from typing import Any, cast, Dict, IO, Literal, Optional, TypeVar, Union, TYPE_CHECKING +from typing import Any, cast, Dict, IO, Optional, TypeVar, TYPE_CHECKING from azure.core.exceptions import HttpResponseError, ResourceModifiedError @@ -32,6 +32,7 @@ upload_data_chunks, upload_substream_blocks ) +from .._shared.validation import CV_TYPE_PARSED from .._upload_helpers import _any_conditions, _convert_mod_error if TYPE_CHECKING: @@ -47,7 +48,7 @@ async def upload_block_blob( # pylint: disable=too-many-locals, too-many-statem encryption_options: Dict[str, Any], blob_settings: "StorageConfiguration", headers: Dict[str, Any], - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]], + validate_content: CV_TYPE_PARSED, max_concurrency: Optional[int], length: Optional[int] = None, **kwargs: Any @@ -193,7 +194,7 @@ async def upload_page_blob( headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: @@ -271,7 +272,7 @@ async def upload_append_blob( # pylint: disable=unused-argument headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[Union[bool, Literal['crc64', 'md5']]] = None, + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-blob/tests/test_append_blob.py b/sdk/storage/azure-storage-blob/tests/test_append_blob.py index d642a6edf2ec..65da562e69e5 100644 --- a/sdk/storage/azure-storage-blob/tests/test_append_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_append_blob.py @@ -28,7 +28,7 @@ generate_blob_sas, ImmutabilityPolicy, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -375,7 +375,7 @@ def test_append_block_from_url_and_validate_content_md5(self, **kwargs): self._setup(bsc) source_blob_data = self.get_random_bytes(LARGE_BLOB_SIZE) source_blob_client = self._create_source_blob(source_blob_data, bsc) - src_md5 = StorageContentValidation.get_content_md5(source_blob_data) + src_md5 = calculate_content_md5(source_blob_data) sas = self.generate_sas( generate_blob_sas, source_blob_client.account_name, @@ -405,9 +405,9 @@ def test_append_block_from_url_and_validate_content_md5(self, **kwargs): # Act part 2: put block from url with wrong md5 with pytest.raises(HttpResponseError): - destination_blob_client.append_block_from_url(source_blob_client.url + '?' + sas, - source_content_md5=StorageContentValidation.get_content_md5( - b"POTATO")) + destination_blob_client.append_block_from_url( + source_blob_client.url + '?' + sas, + source_content_md5=calculate_content_md5(b"POTATO")) @BlobPreparer() @recorded_by_proxy diff --git a/sdk/storage/azure-storage-blob/tests/test_append_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_append_blob_async.py index 2e2bf71bc765..710abb01db9f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_append_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_append_blob_async.py @@ -27,7 +27,7 @@ from azure.mgmt.storage.aio import StorageManagementClient from azure.storage.blob import BlobImmutabilityPolicyMode, BlobSasPermissions, generate_blob_sas, ImmutabilityPolicy from azure.storage.blob import BlobType -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 from azure.storage.blob.aio import BlobClient, BlobServiceClient @@ -379,7 +379,7 @@ async def test_append_block_from_url_and_validate_content_md5(self, **kwargs): await self._setup(bsc) source_blob_data = self.get_random_bytes(LARGE_BLOB_SIZE) source_blob_client = await self._create_source_blob(source_blob_data, bsc) - src_md5 = StorageContentValidation.get_content_md5(source_blob_data) + src_md5 = calculate_content_md5(source_blob_data) sas = self.generate_sas( generate_blob_sas, source_blob_client.account_name, @@ -411,8 +411,7 @@ async def test_append_block_from_url_and_validate_content_md5(self, **kwargs): with pytest.raises(HttpResponseError): await destination_blob_client.append_block_from_url( source_blob_client.url + '?' + sas, - source_content_md5=StorageContentValidation.get_content_md5(b"POTATO") - ) + source_content_md5=calculate_content_md5(b"POTATO")) @BlobPreparer() @recorded_by_proxy_async diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy.py b/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy.py index 1aa1bf73a6a6..d13a16cb7fa5 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy.py @@ -22,7 +22,7 @@ StandardBlobTier, StorageErrorCode ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -206,7 +206,7 @@ def test_put_block_from_url_and_validate_content_md5(self, **kwargs): self._setup(storage_account_name, storage_account_key) dest_blob_name = self.get_resource_name('destblob') dest_blob = self.bsc.get_blob_client(self.container_name, dest_blob_name) - src_md5 = StorageContentValidation.get_content_md5(self.source_blob_data) + src_md5 = calculate_content_md5(self.source_blob_data) # Act part 1: put block from url with md5 validation dest_blob.stage_block_from_url( @@ -222,7 +222,7 @@ def test_put_block_from_url_and_validate_content_md5(self, **kwargs): assert len(committed) == 0 # Act part 2: put block from url with wrong md5 - fake_md5 = StorageContentValidation.get_content_md5(b"POTATO") + fake_md5 = calculate_content_md5(b"POTATO") with pytest.raises(HttpResponseError) as error: dest_blob.stage_block_from_url( block_id=2, diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy_async.py b/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy_async.py index 03928dbc08fb..46b004f12c29 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy_async.py @@ -17,7 +17,7 @@ from azure.core.exceptions import HttpResponseError, ResourceExistsError from azure.storage.blob import BlobSasPermissions, StandardBlobTier, StorageErrorCode, generate_blob_sas from azure.storage.blob.aio import BlobClient, BlobServiceClient -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -167,7 +167,7 @@ async def test_put_block_from_url_and_vldte_content_md5(self, **kwargs): await self._setup(storage_account_name, storage_account_key) dest_blob_name = self.get_resource_name('destblob') dest_blob = self.bsc.get_blob_client(self.container_name, dest_blob_name) - src_md5 = StorageContentValidation.get_content_md5(self.source_blob_data) + src_md5 = calculate_content_md5(self.source_blob_data) # Act part 1: put block from url with md5 validation await dest_blob.stage_block_from_url( @@ -183,7 +183,7 @@ async def test_put_block_from_url_and_vldte_content_md5(self, **kwargs): assert len(committed) == 0 # Act part 2: put block from url with wrong md5 - fake_md5 = StorageContentValidation.get_content_md5(b"POTATO") + fake_md5 = calculate_content_md5(b"POTATO") with pytest.raises(HttpResponseError) as error: await dest_blob.stage_block_from_url( block_id=2, diff --git a/sdk/storage/azure-storage-blob/tests/test_page_blob.py b/sdk/storage/azure-storage-blob/tests/test_page_blob.py index 8caaf49412f8..f246c400f24c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_page_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_page_blob.py @@ -22,7 +22,7 @@ BlobClient, BlobImmutabilityPolicyMode, BlobProperties, BlobSasPermissions, BlobServiceClient, BlobType, generate_blob_sas, ImmutabilityPolicy, PremiumPageBlobTier, SequenceNumberAction, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -639,7 +639,7 @@ def test_upload_pages_from_url_and_validate_content_md5(self, **kwargs): self._setup(bsc) source_blob_data = self.get_random_bytes(SOURCE_BLOB_SIZE) source_blob_client = self._create_source_blob(bsc, source_blob_data, 0, SOURCE_BLOB_SIZE) - src_md5 = StorageContentValidation.get_content_md5(source_blob_data) + src_md5 = calculate_content_md5(source_blob_data) sas = self.generate_sas( generate_blob_sas, source_blob_client.account_name, @@ -669,12 +669,12 @@ def test_upload_pages_from_url_and_validate_content_md5(self, **kwargs): # Act part 2: put block from url with wrong md5 with pytest.raises(HttpResponseError): - destination_blob_client.upload_pages_from_url(source_blob_client.url + "?" + sas, - offset=0, - length=SOURCE_BLOB_SIZE, - source_offset=0, - source_content_md5=StorageContentValidation.get_content_md5( - b"POTATO")) + destination_blob_client.upload_pages_from_url( + source_blob_client.url + "?" + sas, + offset=0, + length=SOURCE_BLOB_SIZE, + source_offset=0, + source_content_md5=calculate_content_md5(b"POTATO")) @BlobPreparer() @recorded_by_proxy diff --git a/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py index bc7b7f6f8fe4..af6ab9c7c46c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py @@ -22,7 +22,7 @@ BlobImmutabilityPolicyMode, BlobProperties, BlobSasPermissions, BlobType, generate_blob_sas, ImmutabilityPolicy, PremiumPageBlobTier, SequenceNumberAction, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 from azure.storage.blob.aio import BlobClient, BlobServiceClient @@ -619,7 +619,7 @@ async def test_upload_pages_from_url_and_validate_content_md5(self, **kwargs): await self._setup(bsc) source_blob_data = self.get_random_bytes(SOURCE_BLOB_SIZE) source_blob_client = await self._create_source_blob(bsc, source_blob_data, 0, SOURCE_BLOB_SIZE) - src_md5 = StorageContentValidation.get_content_md5(source_blob_data) + src_md5 = calculate_content_md5(source_blob_data) sas = self.generate_sas( generate_blob_sas, source_blob_client.account_name, @@ -654,7 +654,7 @@ async def test_upload_pages_from_url_and_validate_content_md5(self, **kwargs): 0, SOURCE_BLOB_SIZE, 0, - source_content_md5=StorageContentValidation.get_content_md5(b"POTATO") + source_content_md5=calculate_content_md5(b"POTATO") ) @BlobPreparer() diff --git a/sdk/storage/azure-storage-extensions/dev_requirements.txt b/sdk/storage/azure-storage-extensions/dev_requirements.txt index b18a83fbb955..7d496b4d1cc1 100644 --- a/sdk/storage/azure-storage-extensions/dev_requirements.txt +++ b/sdk/storage/azure-storage-extensions/dev_requirements.txt @@ -1 +1 @@ --e ../../../eng/tools/azure-sdk-tools \ No newline at end of file +-e ../../../eng/tools/azure-sdk-tools diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py index 63a517a4e305..9b10ece3de79 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py @@ -5,12 +5,11 @@ # -------------------------------------------------------------------------- import base64 -import hashlib import logging import random import re import uuid -from io import SEEK_SET, UnsupportedOperation +from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( @@ -43,8 +42,8 @@ CV_TYPE_ERROR_MSG, calculate_content_md5, calculate_crc64_bytes, + is_crc64_validation, is_md5_validation, - ChecksumAlgorithm, ) if TYPE_CHECKING: @@ -428,7 +427,7 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: # Download if request.http_request.method == "GET": - if validate_content == ChecksumAlgorithm.CRC64: + if is_crc64_validation(validate_content): request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 # Upload @@ -441,7 +440,11 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: request.http_request.headers[MD5_HEADER] = computed_md5 request.context["validate_content_md5"] = computed_md5 - elif validate_content == ChecksumAlgorithm.CRC64: + elif is_crc64_validation(validate_content): + # For crc64-sm, force structured message even for bytes + if validate_content == "crc64-sm" and isinstance(data, bytes): + data = BytesIO(data) + if isinstance(data, bytes): request.http_request.headers[CRC64_HEADER] = encode_base64( calculate_crc64_bytes(data) @@ -495,7 +498,7 @@ def _validate_content_response( response=response.http_response, ) - elif validate_content == ChecksumAlgorithm.CRC64: + elif is_crc64_validation(validate_content): # For upload and download verify structured message header present in response if provided in request. sm_request = request.http_request.headers.get(SM_HEADER) sm_response = response.http_response.headers.get(SM_HEADER) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py index 5370d9dd669c..21d3b081d8cc 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py @@ -6,24 +6,16 @@ # pylint: disable=c-extension-no-member import hashlib -from enum import Enum from io import SEEK_SET from typing import IO, Literal, Optional, Union, cast -from azure.core import CaseInsensitiveEnumMeta - CRC64_LENGTH = 8 CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." +_VALID_CV_OPTIONS = ("auto", "crc64", "crc64-sm", "md5") -class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): - AUTO = "auto" - MD5 = "md5" - CRC64 = "crc64" - - @classmethod - def list(cls): - return list(map(lambda c: c.value, cls)) +CV_TYPE = Optional[Union[bool, Literal["auto", "crc64", "md5"]]] +CV_TYPE_PARSED = Optional[Union[bool, Literal["crc64", "crc64-sm", "md5"]]] def _verify_extensions(module: str) -> None: @@ -37,8 +29,10 @@ def _verify_extensions(module: str) -> None: def parse_validation_option( - validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], -) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + validate_content: CV_TYPE, + *, + force_structured_message: bool = False, +) -> CV_TYPE_PARSED: if validate_content is None: return None @@ -46,27 +40,40 @@ def parse_validation_option( if isinstance(validate_content, bool): return validate_content - if validate_content not in (ChecksumAlgorithm.list()): + parsed = validate_content.lower() + if parsed not in _VALID_CV_OPTIONS: raise ValueError("Invalid value for `validate_content` specified.") # Resolve auto - if validate_content == ChecksumAlgorithm.AUTO: - validate_content = ChecksumAlgorithm.CRC64.value + if parsed == "auto": + parsed = "crc64" - if validate_content == ChecksumAlgorithm.CRC64: + if parsed == "crc64": _verify_extensions("crc64") + if force_structured_message: + parsed = "crc64-sm" - return validate_content + return cast(CV_TYPE_PARSED, parsed) def is_md5_validation( - validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], + validate_content: CV_TYPE_PARSED, ) -> bool: if validate_content is None: return False if isinstance(validate_content, bool): return validate_content - return validate_content == ChecksumAlgorithm.MD5 + return validate_content == "md5" + + +def is_crc64_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return False + return validate_content in ("crc64", "crc64-sm") def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py index 935b17ebfde9..d0e2d19d7f47 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py @@ -18,6 +18,7 @@ from ._shared.request_handlers import validate_and_format_range_headers from ._shared.response_handlers import parse_length_from_content_range, process_storage_error from ._shared.constants import DEFAULT_MAX_CONCURRENCY +from ._shared.validation import is_md5_validation, CV_TYPE, CV_TYPE_PARSED if TYPE_CHECKING: from ._generated.operations import FileOperations @@ -43,7 +44,7 @@ def __init__( current_progress: int, start_range: int, end_range: int, - validate_content: bool, + validate_content: CV_TYPE_PARSED, etag: str, stream: Any = None, parallel: Optional[int] = None, @@ -120,7 +121,9 @@ def _write_to_stream(self, chunk_data: bytes, chunk_start: int) -> None: def _download_chunk(self, chunk_start: int, chunk_end: int) -> bytes: range_header, range_validation = validate_and_format_range_headers( - chunk_start, chunk_end, check_content_md5=self.validate_content + chunk_start, + chunk_end, + check_content_md5=is_md5_validation(self.validate_content) ) try: @@ -217,7 +220,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] path: str = None, # type: ignore [assignment] @@ -247,11 +250,12 @@ def __init__( self._etag = "" # The service only provides transactional MD5s for chunks under 4MB. - # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. self._first_get_size = ( - self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size + self._config.max_single_get_size if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size ) + initial_request_start = self._start_range or 0 if self._end_range is not None and self._end_range - initial_request_start < self._first_get_size: initial_request_end = self._end_range @@ -292,7 +296,7 @@ def _initial_request(self): self._initial_range[1], start_range_required=False, end_range_required=False, - check_content_md5=self._validate_content + check_content_md5=is_md5_validation(self._validate_content) ) try: diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.py index 78bfa03a6945..2dc032a9eb43 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.py @@ -45,6 +45,7 @@ from ._shared.request_handlers import add_metadata_headers, get_length from ._shared.response_handlers import return_response_headers, process_storage_error from ._shared.uploads import IterStreamer, FileChunkUploader, upload_data_chunks +from ._shared.validation import CV_TYPE_PARSED, parse_validation_option if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential @@ -58,7 +59,7 @@ def _upload_file_helper( size: Optional[int], metadata: Optional[Dict[str, str]], content_settings: Optional["ContentSettings"], - validate_content: bool, + validate_content: CV_TYPE_PARSED, timeout: Optional[int], max_concurrency: int, file_settings: "StorageConfiguration", @@ -447,8 +448,13 @@ def create_file( Restore - apply changes without further modification. :paramtype file_property_semantics: Optional[Literal["New", "Restore"]] - :keyword data: Optional initial data to upload, up to 4MB. - :paramtype data: bytes + :keyword bytes data: Optional initial data to upload, up to 4MB. + :keyword validate_content: + Only applicable when `data` is provided. + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[Literal['auto', 'crc64', 'md5']] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/rest/api/storageservices/setting-timeouts-for-file-service-operations. @@ -471,6 +477,10 @@ def create_file( content_settings = kwargs.pop('content_settings', None) metadata = kwargs.pop('metadata', None) timeout = kwargs.pop('timeout', None) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) headers = kwargs.pop('headers', {}) headers.update(add_metadata_headers(metadata)) data = kwargs.pop('data', None) @@ -498,6 +508,7 @@ def create_file( file_permission_key=permission_key, file_http_headers=file_http_headers, optionalbody=data, + validate_content=validate_content, content_length=len(data) if data is not None else None, lease_access_conditions=access_conditions, headers=headers, @@ -562,13 +573,11 @@ def upload_file( :keyword ~azure.storage.fileshare.ContentSettings content_settings: ContentSettings object used to set file properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each range of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https as https (the default) will - already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int max_concurrency: Maximum number of parallel connections to use when transferring the file in chunks. This option does not affect the underlying connection pool, and may @@ -610,7 +619,10 @@ def upload_file( max_concurrency = kwargs.pop('max_concurrency', None) if max_concurrency is None: max_concurrency = DEFAULT_MAX_CONCURRENCY - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) progress_hook = kwargs.pop('progress_hook', None) timeout = kwargs.pop('timeout', None) encoding = kwargs.pop('encoding', 'UTF-8') @@ -877,15 +889,11 @@ def download_file( Maximum number of parallel connections to use when transferring the file in chunks. This option does not affect the underlying connection pool, and may require a separate configuration of the connection pool. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https as https (the default) will - already validate. Note that this MD5 hash is not stored with the - file. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the file has an active lease. Value can be a ShareLeaseClient object or the lease ID as a string. @@ -925,12 +933,14 @@ def download_file( range_end = offset + length - 1 # Service actually uses an end-range inclusive index access_conditions = get_access_conditions(kwargs.pop('lease', None)) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) return StorageStreamDownloader( client=self._client.file, config=self._config, start_range=offset, end_range=range_end, + validate_content=validate_content, name=self.file_name, path='/'.join(self.file_path), share=self.share_name, @@ -1296,13 +1306,11 @@ def upload_range( :param int length: Number of bytes to use for uploading a section of the file. The range can be up to 4 MB in size. - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https as https (the default) - will already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword file_last_write_mode: If the file last write time should be preserved or overwritten. Possible values are "preserve" or "now". If not specified, file last write time will be changed to @@ -1331,7 +1339,10 @@ def upload_range( :returns: File-updated property dict (Etag and last modified). :rtype: Dict[str, Any] """ - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) timeout = kwargs.pop('timeout', None) encoding = kwargs.pop('encoding', 'UTF-8') file_last_write_mode = kwargs.pop('file_last_write_mode', None) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.pyi b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.pyi index 22d645d975cc..a517773f19ca 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.pyi +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.pyi @@ -134,6 +134,7 @@ class ShareFileClient(StorageAccountHostsMixin): file_mode: Optional[str] = None, file_property_semantics: Optional[Literal["New", "Restore"]] = None, data: Optional[bytes] = None, + validate_content: Optional[Literal['auto', 'crc64', 'md5']] = None, timeout: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: ... @@ -151,7 +152,7 @@ class ShareFileClient(StorageAccountHostsMixin): file_change_time: Optional[Union[str, datetime]] = None, metadata: Optional[Dict[str, str]] = None, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, progress_hook: Optional[Callable[[int, Optional[int]], None]] = None, @@ -199,7 +200,7 @@ class ShareFileClient(StorageAccountHostsMixin): length: Optional[int] = None, *, max_concurrency: Optional[int] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, progress_hook: Optional[Callable[[int, Optional[int]], None]] = None, decompress: Optional[bool] = None, @@ -270,7 +271,7 @@ class ShareFileClient(StorageAccountHostsMixin): offset: int, length: int, *, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, encoding: str = "UTF-8", file_last_write_mode: Optional[Literal["preserve", "now"]] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py index 63a517a4e305..9b10ece3de79 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py @@ -5,12 +5,11 @@ # -------------------------------------------------------------------------- import base64 -import hashlib import logging import random import re import uuid -from io import SEEK_SET, UnsupportedOperation +from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( @@ -43,8 +42,8 @@ CV_TYPE_ERROR_MSG, calculate_content_md5, calculate_crc64_bytes, + is_crc64_validation, is_md5_validation, - ChecksumAlgorithm, ) if TYPE_CHECKING: @@ -428,7 +427,7 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: # Download if request.http_request.method == "GET": - if validate_content == ChecksumAlgorithm.CRC64: + if is_crc64_validation(validate_content): request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 # Upload @@ -441,7 +440,11 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: request.http_request.headers[MD5_HEADER] = computed_md5 request.context["validate_content_md5"] = computed_md5 - elif validate_content == ChecksumAlgorithm.CRC64: + elif is_crc64_validation(validate_content): + # For crc64-sm, force structured message even for bytes + if validate_content == "crc64-sm" and isinstance(data, bytes): + data = BytesIO(data) + if isinstance(data, bytes): request.http_request.headers[CRC64_HEADER] = encode_base64( calculate_crc64_bytes(data) @@ -495,7 +498,7 @@ def _validate_content_response( response=response.http_response, ) - elif validate_content == ChecksumAlgorithm.CRC64: + elif is_crc64_validation(validate_content): # For upload and download verify structured message header present in response if provided in request. sm_request = request.http_request.headers.get(SM_HEADER) sm_response = response.http_response.headers.get(SM_HEADER) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py index 5370d9dd669c..21d3b081d8cc 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py @@ -6,24 +6,16 @@ # pylint: disable=c-extension-no-member import hashlib -from enum import Enum from io import SEEK_SET from typing import IO, Literal, Optional, Union, cast -from azure.core import CaseInsensitiveEnumMeta - CRC64_LENGTH = 8 CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." +_VALID_CV_OPTIONS = ("auto", "crc64", "crc64-sm", "md5") -class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): - AUTO = "auto" - MD5 = "md5" - CRC64 = "crc64" - - @classmethod - def list(cls): - return list(map(lambda c: c.value, cls)) +CV_TYPE = Optional[Union[bool, Literal["auto", "crc64", "md5"]]] +CV_TYPE_PARSED = Optional[Union[bool, Literal["crc64", "crc64-sm", "md5"]]] def _verify_extensions(module: str) -> None: @@ -37,8 +29,10 @@ def _verify_extensions(module: str) -> None: def parse_validation_option( - validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], -) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + validate_content: CV_TYPE, + *, + force_structured_message: bool = False, +) -> CV_TYPE_PARSED: if validate_content is None: return None @@ -46,27 +40,40 @@ def parse_validation_option( if isinstance(validate_content, bool): return validate_content - if validate_content not in (ChecksumAlgorithm.list()): + parsed = validate_content.lower() + if parsed not in _VALID_CV_OPTIONS: raise ValueError("Invalid value for `validate_content` specified.") # Resolve auto - if validate_content == ChecksumAlgorithm.AUTO: - validate_content = ChecksumAlgorithm.CRC64.value + if parsed == "auto": + parsed = "crc64" - if validate_content == ChecksumAlgorithm.CRC64: + if parsed == "crc64": _verify_extensions("crc64") + if force_structured_message: + parsed = "crc64-sm" - return validate_content + return cast(CV_TYPE_PARSED, parsed) def is_md5_validation( - validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], + validate_content: CV_TYPE_PARSED, ) -> bool: if validate_content is None: return False if isinstance(validate_content, bool): return validate_content - return validate_content == ChecksumAlgorithm.MD5 + return validate_content == "md5" + + +def is_crc64_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return False + return validate_content in ("crc64", "crc64-sm") def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py index 731e8c86bd92..ae8fe9962fcd 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py @@ -22,6 +22,7 @@ from .._shared.request_handlers import validate_and_format_range_headers from .._shared.response_handlers import parse_length_from_content_range, process_storage_error from .._shared.constants import DEFAULT_MAX_CONCURRENCY +from .._shared.validation import is_md5_validation, CV_TYPE_PARSED if TYPE_CHECKING: from .._generated.aio.operations import FileOperations @@ -82,7 +83,7 @@ async def _download_chunk(self, chunk_start: int, chunk_end: int) -> bytes: range_header, range_validation = validate_and_format_range_headers( chunk_start, chunk_end, - check_content_md5=self.validate_content + check_content_md5=is_md5_validation(self.validate_content) ) try: _, response = await cast(Awaitable[Any], self.client.download( @@ -178,7 +179,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] path: str = None, # type: ignore [assignment] @@ -208,10 +209,12 @@ def __init__( self._etag = "" # The service only provides transactional MD5s for chunks under 4MB. - # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. - self._first_get_size = self._config.max_single_get_size if not self._validate_content \ - else self._config.max_chunk_get_size + self._first_get_size = ( + self._config.max_single_get_size if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size + ) + initial_request_start = self._start_range or 0 if self._end_range is not None and self._end_range - initial_request_start < self._first_get_size: initial_request_end = self._end_range @@ -253,7 +256,7 @@ async def _initial_request(self): self._initial_range[1], start_range_required=False, end_range_required=False, - check_content_md5=self._validate_content) + check_content_md5=is_md5_validation(self._validate_content)) try: location_mode, response = cast(Tuple[Optional[str], Any], await self._client.download( diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.py index 45b1a96fefb9..636228460008 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.py @@ -48,6 +48,7 @@ from .._shared.request_handlers import add_metadata_headers, get_length from .._shared.response_handlers import process_storage_error, return_response_headers from .._shared.uploads_async import AsyncIterStreamer, FileChunkUploader, IterStreamer, upload_data_chunks +from .._shared.validation import CV_TYPE_PARSED, parse_validation_option from ._download_async import StorageStreamDownloader from ._lease_async import ShareLeaseClient from ._models import FileProperties, Handle, HandlesPaged @@ -65,7 +66,7 @@ async def _upload_file_helper( size: Optional[int], metadata: Optional[Dict[str, str]], content_settings: Optional["ContentSettings"], - validate_content: bool, + validate_content: CV_TYPE_PARSED, timeout: Optional[int], max_concurrency: int, file_settings: "StorageConfiguration", @@ -442,8 +443,13 @@ async def create_file( Restore - apply changes without further modification. :paramtype file_property_semantics: Optional[Literal["New", "Restore"]] - :keyword data: Optional initial data to upload, up to 4MB. - :paramtype data: bytes + :keyword bytes data: Optional initial data to upload, up to 4MB. + :keyword validate_content: + Only applicable when `data` is provided. + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[Literal['auto', 'crc64', 'md5']] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/rest/api/storageservices/setting-timeouts-for-file-service-operations. @@ -466,6 +472,10 @@ async def create_file( content_settings = kwargs.pop('content_settings', None) metadata = kwargs.pop('metadata', None) timeout = kwargs.pop('timeout', None) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) headers = kwargs.pop("headers", {}) headers.update(add_metadata_headers(metadata)) data = kwargs.pop('data', None) @@ -493,6 +503,7 @@ async def create_file( file_permission_key=permission_key, file_http_headers=file_http_headers, optionalbody=data, + validate_content=validate_content, content_length=len(data) if data is not None else None, lease_access_conditions=access_conditions, headers=headers, @@ -556,13 +567,11 @@ async def upload_file( :keyword ~azure.storage.fileshare.ContentSettings content_settings: ContentSettings object used to set file properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each range of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https as https (the default) will - already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int max_concurrency: Maximum number of parallel connections to use when transferring the file in chunks. This option does not affect the underlying connection pool, and may @@ -604,7 +613,10 @@ async def upload_file( max_concurrency = kwargs.pop('max_concurrency', None) if max_concurrency is None: max_concurrency = DEFAULT_MAX_CONCURRENCY - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) progress_hook = kwargs.pop('progress_hook', None) timeout = kwargs.pop('timeout', None) encoding = kwargs.pop('encoding', 'UTF-8') @@ -873,15 +885,11 @@ async def download_file( Maximum number of parallel connections to use when transferring the file in chunks. This option does not affect the underlying connection pool, and may require a separate configuration of the connection pool. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https as https (the default) will - already validate. Note that this MD5 hash is not stored with the - file. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the file has an active lease. Value can be a ShareLeaseClient object or the lease ID as a string. @@ -924,12 +932,14 @@ async def download_file( range_end = offset + length - 1 # Service actually uses an end-range inclusive index access_conditions = get_access_conditions(kwargs.pop('lease', None)) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) downloader = StorageStreamDownloader( client=self._client.file, config=self._config, start_range=offset, end_range=range_end, + validate_content=validate_content, name=self.file_name, path='/'.join(self.file_path), share=self.share_name, @@ -1295,13 +1305,11 @@ async def upload_range( :param int length: Number of bytes to use for uploading a section of the file. The range can be up to 4 MB in size. - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https as https (the default) - will already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword file_last_write_mode: If the file last write time should be preserved or overwritten. Possible values are "preserve" or "now". If not specified, file last write time will be changed to @@ -1330,7 +1338,10 @@ async def upload_range( :returns: File-updated property dict (Etag and last modified). :rtype: Dict[str, Any] """ - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) timeout = kwargs.pop('timeout', None) encoding = kwargs.pop('encoding', 'UTF-8') file_last_write_mode = kwargs.pop('file_last_write_mode', None) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.pyi b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.pyi index 3f70def69364..905f5a2bb002 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.pyi +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.pyi @@ -135,6 +135,7 @@ class ShareFileClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin): file_mode: Optional[str] = None, file_property_semantics: Optional[Literal["New", "Restore"]] = None, data: Optional[bytes] = None, + validate_content: Optional[Literal['auto', 'crc64', 'md5']] = None, timeout: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: ... @@ -151,7 +152,7 @@ class ShareFileClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin): file_change_time: Optional[Union[str, datetime]] = None, metadata: Optional[Dict[str, str]] = None, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, progress_hook: Optional[Callable[[int, Optional[int]], Awaitable[None]]] = None, @@ -199,7 +200,7 @@ class ShareFileClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin): length: Optional[int] = None, *, max_concurrency: Optional[int] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, progress_hook: Optional[Callable[[int, Optional[int]], Awaitable[None]]] = None, decompress: Optional[bool] = None, @@ -270,7 +271,7 @@ class ShareFileClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin): offset: int, length: int, *, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, file_last_write_mode: Optional[Literal["preserve", "now"]] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, encoding: str = "UTF-8", diff --git a/sdk/storage/azure-storage-file-share/dev_requirements.txt b/sdk/storage/azure-storage-file-share/dev_requirements.txt index 60d588f5e1e1..b42a7dff1a22 100644 --- a/sdk/storage/azure-storage-file-share/dev_requirements.txt +++ b/sdk/storage/azure-storage-file-share/dev_requirements.txt @@ -2,4 +2,5 @@ ../../core/azure-core ../../identity/azure-identity ../azure-storage-blob +../azure-storage-extensions aiohttp>=3.13.5 diff --git a/sdk/storage/azure-storage-file-share/tests/test_content_validation.py b/sdk/storage/azure-storage-file-share/tests/test_content_validation.py new file mode 100644 index 000000000000..d6fc92168363 --- /dev/null +++ b/sdk/storage/azure-storage-file-share/tests/test_content_validation.py @@ -0,0 +1,270 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from azure.storage.fileshare import ShareClient, ShareServiceClient + +from devtools_testutils import recorded_by_proxy +from devtools_testutils.storage import GenericTestProxyParametrize1, StorageRecordedTestCase +from settings.testcase import FileSharePreparer + + +def assert_content_md5(request): + if request.http_request.query.get('comp') == 'range': + assert request.http_request.headers.get('Content-MD5') is not None + + +def assert_content_md5_get(response): + assert response.http_request.headers.get('x-ms-range-get-content-md5') == 'true' + assert response.http_response.headers.get('Content-MD5') is not None + + +def assert_structured_message(request): + if request.http_request.query.get('comp') == 'range': + assert request.http_request.headers.get('x-ms-structured-body') is not None + + +def assert_structured_message_get(response): + assert response.http_request.headers.get('x-ms-structured-body') is not None + assert response.http_response.headers.get('x-ms-structured-body') is not None + + +class TestStorageContentValidation(StorageRecordedTestCase): + share_client: ShareClient + + def _setup(self, account_name): + token_credential = self.get_credential(ShareServiceClient) + self.ssc = ShareServiceClient(self.account_url(account_name, "file"), credential=token_credential, token_intent="backup", logging_enable=True) + self.share_client = self.ssc.get_share_client(self.get_resource_name('utshare')) + self.share_client.create_share() + + def teardown_method(self, _): + if self.share_client: + try: + self.share_client.delete_share() + except: + pass + + def _get_file_reference(self): + return self.get_resource_name('file') + + @FileSharePreparer() + @pytest.mark.parametrize('a', ['auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_create_file_with_data(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + + # Act + file.create_file(len(data), data=data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_file(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + + # Act + file.upload_file(data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_file_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.share_client._config.max_range_size = 1024 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abcde' * 512 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Act + file.upload_file(data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'auto','md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_range(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data1 = b'abcde' * 512 + data2 = '你好世界' * 10 + encoded2 = data2.encode('utf-16') + + assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + + # Act + file.create_file(len(data1) + len(encoded2)) + file.upload_range(data1, 0, len(data1), validate_content=a, raw_request_hook=assert_method) + file.upload_range(data2, len(data1), len(encoded2), encoding='utf-16', validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data1 + encoded2 + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_range_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + + data = b'abcd' * 1030 # 4 KiB + 24 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Act + file.create_file(len(data)) + # This is not an officially supported data type for upload_range, but we should still be able to handle it + # as there are probably users out there using it. + file.upload_range(BytesIO(data), 0, len(data), validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + file.upload_file(data) + assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + + # Act + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + content = downloader.readall() + + stream = BytesIO() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.share_client._config.max_single_get_size = 512 + self.share_client._config.max_chunk_get_size = 512 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + b'abcde' + file.upload_file(data) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + content = downloader.readall() + + stream = BytesIO() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + read_content = bytearray() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + for chunk in downloader.chunks(): + read_content.extend(chunk) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file_chunks_partial(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.share_client._config.max_single_get_size = 512 + self.share_client._config.max_chunk_get_size = 512 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + b'abcde' + file.upload_file(data) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = downloader.readall() + + stream = BytesIO() + downloader = file.download_file(offset=512, length=1024, validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[512:1536] + + @FileSharePreparer() + @pytest.mark.live_test_only + def test_download_file_large_chunks(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.share_client._config.max_chunk_get_size = 5 * 1024 * 1024 # 5 MiB + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + file.upload_file(data, max_concurrency=5) + + # Act + downloader = file.download_file(validate_content='crc64', max_concurrency=5) + content = downloader.readall() + + downloader = file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + partial = downloader.readall() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024:30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py b/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py new file mode 100644 index 000000000000..d5cce63c3900 --- /dev/null +++ b/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py @@ -0,0 +1,261 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from azure.storage.fileshare import ShareClient as SyncShareClient +from azure.storage.fileshare.aio import ShareServiceClient + +from devtools_testutils.aio import recorded_by_proxy_async +from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase, GenericTestProxyParametrize1 +from settings.testcase import FileSharePreparer +from test_content_validation import ( + assert_content_md5, + assert_content_md5_get, + assert_structured_message, + assert_structured_message_get +) + + +class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): + async def _setup(self, account_name): + token_credential = self.get_credential(ShareServiceClient, is_async=True) + self.ssc = ShareServiceClient(self.account_url(account_name, "file"), credential=token_credential, token_intent="backup", logging_enable=True) + self.share_client = self.ssc.get_share_client(self.get_resource_name('utshare')) + await self.share_client.create_share() + + def teardown_method(self, _): + if self.share_client: + sync_credential = self.get_credential(SyncShareClient, is_async=False) + sync_share_client = SyncShareClient( + self.account_url(self.share_client.account_name, "file"), + self.share_client.share_name, + credential=sync_credential, + token_intent="backup") + try: + sync_share_client.delete_share() + except: + pass + + def _get_file_reference(self): + return self.get_resource_name('file') + + @FileSharePreparer() + @pytest.mark.parametrize('a', ['auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_create_file_with_data(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + + # Act + await file.create_file(len(data), data=data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_file(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + + # Act + await file.upload_file(data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_file_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.share_client._config.max_range_size = 1024 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abcde' * 512 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Act + await file.upload_file(data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_range(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data1 = b'abcde' * 512 + data2 = '你好世界' * 10 + encoded2 = data2.encode('utf-16') + + assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + + # Act + await file.create_file(len(data1) + len(encoded2)) + await file.upload_range(data1, 0, len(data1), validate_content=a, raw_request_hook=assert_method) + await file.upload_range(data2, len(data1), len(encoded2), encoding='utf-16', validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data1 + encoded2 + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_range_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + + data = b'abcd' * 1030 # 4 KiB + 24 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Act + await file.create_file(len(data)) + # This is not an officially supported data type for upload_range, but we should still be able to handle it + # as there are probably users out there using it. + await file.upload_range(BytesIO(data), 0, len(data), validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + await file.upload_file(data) + assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + + # Act + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + content = await downloader.readall() + + stream = BytesIO() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.share_client._config.max_single_get_size = 512 + self.share_client._config.max_chunk_get_size = 512 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + b'abcde' + await file.upload_file(data) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + content = await downloader.readall() + + stream = BytesIO() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + read_content = bytearray() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + async for chunk in downloader.chunks(): + read_content.extend(chunk) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @FileSharePreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file_chunks_partial(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.share_client._config.max_single_get_size = 512 + self.share_client._config.max_chunk_get_size = 512 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abc' * 512 + b'abcde' + await file.upload_file(data) + assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Act + downloader = await file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = await downloader.readall() + + stream = BytesIO() + downloader = await file.download_file(offset=512, length=1024, validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[512:1536] + + @FileSharePreparer() + @pytest.mark.live_test_only + async def test_download_file_large_chunks(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.share_client._config.max_chunk_get_size = 5 * 1024 * 1024 # 5 MiB + file = self.share_client.get_file_client(self._get_file_reference()) + data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + await file.upload_file(data, max_concurrency=5) + + # Act + downloader = await file.download_file(validate_content='crc64', max_concurrency=5) + content = await downloader.readall() + + downloader = await file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + partial = await downloader.readall() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024:30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 61d6a8b3dc9a..69c95e6bb6b7 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -5,12 +5,11 @@ # -------------------------------------------------------------------------- import base64 -import hashlib import logging import random import re import uuid -from io import SEEK_SET, UnsupportedOperation +from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( @@ -43,8 +42,8 @@ CV_TYPE_ERROR_MSG, calculate_content_md5, calculate_crc64_bytes, + is_crc64_validation, is_md5_validation, - ChecksumAlgorithm, ) if TYPE_CHECKING: @@ -434,7 +433,7 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: # Download if request.http_request.method == "GET": - if validate_content == ChecksumAlgorithm.CRC64: + if is_crc64_validation(validate_content): request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 # Upload @@ -447,7 +446,11 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: request.http_request.headers[MD5_HEADER] = computed_md5 request.context["validate_content_md5"] = computed_md5 - elif validate_content == ChecksumAlgorithm.CRC64: + elif is_crc64_validation(validate_content): + # For crc64-sm, force structured message even for bytes + if validate_content == "crc64-sm" and isinstance(data, bytes): + data = BytesIO(data) + if isinstance(data, bytes): request.http_request.headers[CRC64_HEADER] = encode_base64( calculate_crc64_bytes(data) @@ -501,7 +504,7 @@ def _validate_content_response( response=response.http_response, ) - elif validate_content == ChecksumAlgorithm.CRC64: + elif is_crc64_validation(validate_content): # For upload and download verify structured message header present in response if provided in request. sm_request = request.http_request.headers.get(SM_HEADER) sm_response = response.http_response.headers.get(SM_HEADER) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py index 5370d9dd669c..21d3b081d8cc 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py @@ -6,24 +6,16 @@ # pylint: disable=c-extension-no-member import hashlib -from enum import Enum from io import SEEK_SET from typing import IO, Literal, Optional, Union, cast -from azure.core import CaseInsensitiveEnumMeta - CRC64_LENGTH = 8 CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." +_VALID_CV_OPTIONS = ("auto", "crc64", "crc64-sm", "md5") -class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): - AUTO = "auto" - MD5 = "md5" - CRC64 = "crc64" - - @classmethod - def list(cls): - return list(map(lambda c: c.value, cls)) +CV_TYPE = Optional[Union[bool, Literal["auto", "crc64", "md5"]]] +CV_TYPE_PARSED = Optional[Union[bool, Literal["crc64", "crc64-sm", "md5"]]] def _verify_extensions(module: str) -> None: @@ -37,8 +29,10 @@ def _verify_extensions(module: str) -> None: def parse_validation_option( - validate_content: Optional[Union[bool, Literal["auto", "crc64", "md5"]]], -) -> Optional[Union[bool, Literal["auto", "crc64", "md5"]]]: + validate_content: CV_TYPE, + *, + force_structured_message: bool = False, +) -> CV_TYPE_PARSED: if validate_content is None: return None @@ -46,27 +40,40 @@ def parse_validation_option( if isinstance(validate_content, bool): return validate_content - if validate_content not in (ChecksumAlgorithm.list()): + parsed = validate_content.lower() + if parsed not in _VALID_CV_OPTIONS: raise ValueError("Invalid value for `validate_content` specified.") # Resolve auto - if validate_content == ChecksumAlgorithm.AUTO: - validate_content = ChecksumAlgorithm.CRC64.value + if parsed == "auto": + parsed = "crc64" - if validate_content == ChecksumAlgorithm.CRC64: + if parsed == "crc64": _verify_extensions("crc64") + if force_structured_message: + parsed = "crc64-sm" - return validate_content + return cast(CV_TYPE_PARSED, parsed) def is_md5_validation( - validate_content: Optional[Union[bool, Literal["md5", "crc64"]]], + validate_content: CV_TYPE_PARSED, ) -> bool: if validate_content is None: return False if isinstance(validate_content, bool): return validate_content - return validate_content == ChecksumAlgorithm.MD5 + return validate_content == "md5" + + +def is_crc64_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return False + return validate_content in ("crc64", "crc64-sm") def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: From f3d2300cb2b1ed3e5093764027321f3b95333e2e Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Fri, 24 Apr 2026 11:33:49 -0700 Subject: [PATCH 10/14] [Storage][102] CRC64 content validation - part 7 - record tests (#46461) --- .../devtools_testutils/storage/decorators.py | 4 +- .../storage/blob/_blob_client_helpers.py | 27 --- .../azure/storage/blob/_download.py | 6 +- .../storage/blob/_shared/base_client_async.py | 6 +- .../azure/storage/blob/_shared/policies.py | 168 +++++------------- .../storage/blob/_shared/policies_async.py | 96 +++------- .../azure/storage/blob/_shared/streams.py | 168 +++++------------- .../storage/blob/_shared/streams_async.py | 74 ++------ .../azure/storage/blob/aio/_download_async.py | 5 +- .../tests/test_content_validation.py | 68 ++++++- .../tests/test_content_validation_async.py | 61 +++++++ .../dev_requirements.txt | 2 +- .../filedatalake/_shared/base_client_async.py | 6 +- .../storage/filedatalake/_shared/policies.py | 166 +++++------------ .../filedatalake/_shared/policies_async.py | 96 +++------- .../storage/filedatalake/_shared/streams.py | 168 +++++------------- .../filedatalake/_shared/streams_async.py | 74 ++------ .../aio/_data_lake_file_client_async.py | 2 +- .../tests/test_content_validation.py | 4 +- .../tests/test_content_validation_async.py | 3 + .../azure-storage-file-share/assets.json | 2 +- .../azure/storage/fileshare/_download.py | 6 +- .../fileshare/_shared/base_client_async.py | 6 +- .../storage/fileshare/_shared/policies.py | 166 +++++------------ .../fileshare/_shared/policies_async.py | 96 +++------- .../storage/fileshare/_shared/streams.py | 168 +++++------------- .../fileshare/_shared/streams_async.py | 74 ++------ .../storage/fileshare/aio/_download_async.py | 4 +- .../tests/test_content_validation.py | 4 +- .../tests/test_content_validation_async.py | 3 + .../azure/storage/queue/_shared/policies.py | 166 +++++------------ .../storage/queue/_shared/policies_async.py | 96 +++------- .../azure/storage/queue/_shared/streams.py | 168 +++++------------- .../storage/queue/_shared/streams_async.py | 74 ++------ .../azure-storage-queue/dev_requirements.txt | 1 + 35 files changed, 674 insertions(+), 1564 deletions(-) diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py index 20d5c2dfba10..45f1db5c588c 100644 --- a/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py @@ -7,12 +7,12 @@ class GenericTestProxyParametrize1: def __call__(self, fn): def _wrapper(test_class, a, **kwargs): - fn(test_class, a, **kwargs) + return fn(test_class, a, **kwargs) return _wrapper class GenericTestProxyParametrize2: def __call__(self, fn): def _wrapper(test_class, a, b, **kwargs): - fn(test_class, a, b, **kwargs) + return fn(test_class, a, b, **kwargs) return _wrapper diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py index ef628ecbb316..430a437bf886 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py @@ -265,33 +265,6 @@ def _download_blob_options( client: "AzureBlobStorage", **kwargs ) -> Dict[str, Any]: - """Creates a dictionary containing the options for a download blob operation. - - :param str blob_name: - The name of the blob. - :param str container_name: - The name of the container. - :param Optional[str] version_id: - The version id parameter is a value that, when present, specifies the version of the blob to download. - :param Optional[int] offset: - Start of byte range to use for downloading a section of the blob. Must be set if length is provided. - :param Optional[int] length: - Number of bytes to read from the stream. This is optional, but should be supplied for optimal performance. - :param Optional[str] encoding: - Encoding to decode the downloaded bytes. Default is None, i.e. no decoding. - :param Dict[str, Any] encryption_options: - The options for encryption, if enabled. - :param validate_content: - Enables checksum validation for the transfer. Already parsed via parse_validation_option. - :param StorageConfiguration config: - The Storage configuration options. - :param str sdk_moniker: - The string representing the SDK package version. - :param AzureBlobStorage client: - The generated Blob Storage client. - :return: A dictionary containing the download blob options. - :rtype: Dict[str, Any] - """ if length is not None: if offset is None: raise ValueError("Offset must be provided if length is provided.") diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py index 7be6d68858d7..17304f6bed8f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py @@ -386,7 +386,11 @@ def __init__( # The service only provides transactional MD5s for chunks under 4MB. # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. - first_get_size = self._config.max_single_get_size if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size + first_get_size = ( + self._config.max_single_get_size + if not is_md5_validation(self._validate_content) + else self._config.max_chunk_get_size + ) initial_request_start = self._download_start if self._end_range is not None and self._end_range - initial_request_start < first_get_size: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 2c917610eade..2e023b1cc8d9 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -40,7 +40,11 @@ StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncContentValidationPolicy, AsyncStorageResponseHook +from .policies_async import ( + AsyncStorageBearerTokenCredentialPolicy, + AsyncContentValidationPolicy, + AsyncStorageResponseHook, +) from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 9b10ece3de79..832717f7457a 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -86,9 +86,7 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"]( - retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs - ) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) # Is this method/status code retryable? (Based on allowlists and control @@ -108,9 +106,7 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get( - "x-ms-copy-source-error-code" - ) + error_code = response.http_response.headers.get("x-ms-copy-source-error-code") if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -134,9 +130,9 @@ def is_checksum_retry(response) -> bool: # Legacy code - evaluate retry only on validate_content=True if validate_content is True and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get( - "content-md5", None - ) or encode_base64(calculate_content_md5(response.http_response.body())) + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + calculate_content_md5(response.http_response.body()) + ) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -165,9 +161,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str( - uuid.uuid1() - ) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -207,9 +201,7 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError( - f"Attempting to use undefined host location {use_location}" - ) + raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -227,9 +219,7 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__( - logging_enable=logging_enable, **kwargs - ) + super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -313,9 +303,7 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get( - "content-type", "" - ) + resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): filename = header.partition("=")[2] @@ -344,9 +332,7 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop( - "raw_request_hook", self._request_callback - ) + request_callback = request.context.options.pop("raw_request_hook", self._request_callback) if request_callback: request_callback(request) @@ -364,50 +350,36 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop( - "download_stream_current", None - ) + download_stream_current = request.context.options.pop("download_stream_current", None) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop( - "upload_stream_current", None - ) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get( - "response_callback" - ) or request.context.options.pop("raw_response_hook", self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = self.next.send(request) - will_retry = is_retry( - response, request.context.options.get("mode") - ) or is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int( - response.http_response.headers.get("Content-Length", 0) - ) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int( - content_range.split(" ", 1)[1].split("/", 1)[1] - ) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int( - response.http_request.headers.get("Content-Length", 0) - ) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = ( - download_stream_current - ) + pipeline_obj.context["download_stream_current"] = download_stream_current pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -416,11 +388,6 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": def _prepare_content_validation(request: "PipelineRequest") -> None: - """Shared request-side logic for content validation. - - Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 - validation, and stores the validation mode in the request context. - """ validate_content = request.context.options.pop("validate_content", False) if not validate_content: return @@ -446,21 +413,13 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: data = BytesIO(data) if isinstance(data, bytes): - request.http_request.headers[CRC64_HEADER] = encode_base64( - calculate_crc64_bytes(data) - ) + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) elif hasattr(data, "read"): - content_length = int( - request.http_request.headers.get(CONTENT_LENGTH_HEADER) - ) + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) # Wrap data in structured message stream and adjust HTTP request - sm_stream = StructuredMessageEncodeStream( - data, content_length, StructuredMessageProperties.CRC64 - ) + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) request.http_request.data = sm_stream - request.http_request.headers[CONTENT_LENGTH_HEADER] = str( - len(sm_stream) - ) + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 else: @@ -474,18 +433,11 @@ def _validate_content_response( response: "PipelineResponse", decoder_cls: type, ) -> None: - """Shared response-side logic for content validation. - - Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches - ``stream_download`` to wrap the iterator in the given *decoder_cls*. - """ validate_content = response.context.get("validate_content", False) if not validate_content: return - if is_md5_validation(validate_content) and response.http_response.headers.get( - "content-md5" - ): + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): computed_md5 = request.context.get("validate_content_md5") or encode_base64( calculate_content_md5(response.http_response.body()) ) @@ -520,9 +472,7 @@ def _validate_content_response( def wrapped_stream_download(*args, **kwargs): iterator = original_stream_download(*args, **kwargs) - decoder = decoder_cls( - iterator, content_length, block_size=DATA_BLOCK_SIZE - ) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) decoder.request = iterator.request # type: ignore decoder.response = iterator.response # type: ignore return decoder @@ -542,9 +492,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument def on_request(self, request: "PipelineRequest") -> None: _prepare_content_validation(request) - def on_response( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> None: + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: _validate_content_response(request, response, StructuredMessageDecoder) @@ -572,9 +520,7 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location( - self, settings: Dict[str, Any], request: "PipelineRequest" - ) -> None: + def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. @@ -613,9 +559,7 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop( - "retry_to_secondary", self.retry_to_secondary - ), + "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -624,9 +568,7 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time( - self, settings: Dict[str, Any] - ) -> float: # pylint: disable=unused-argument + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -689,9 +631,7 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append( - RequestHistory(request, http_response=response) - ) + settings["history"].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -726,9 +666,7 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry( - response - ): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -747,9 +685,7 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err - ) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: retry_hook( retry_settings, @@ -809,9 +745,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -824,14 +758,8 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + ( - 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) - ) - random_range_start = ( - backoff - self.random_jitter_range - if backoff > self.random_jitter_range - else 0 - ) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -869,11 +797,9 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings: Dict[str, Any]) -> float: + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """ Calculates how long to sleep before retrying. @@ -886,11 +812,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = ( - self.backoff - self.random_jitter_range - if self.backoff > self.random_jitter_range - else 0 - ) + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -898,16 +820,10 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__( - self, credential: "TokenCredential", audience: str, **kwargs: Any - ) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__( - credential, audience, **kwargs - ) + def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Handle the challenge from the service and authorize the request. :param request: The request object. diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 14ce070e47ff..e1d13b1a83fa 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -45,17 +45,9 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"]( - retry_count=settings["count"] - 1, - location_mode=settings["mode"], - **kwargs - ) + await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) else: - settings["hook"]( - retry_count=settings["count"] - 1, - location_mode=settings["mode"], - **kwargs - ) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) async def is_checksum_retry(response): @@ -70,9 +62,9 @@ async def is_checksum_retry(response): await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get( - "content-md5", None - ) or encode_base64(calculate_content_md5(response.http_response.body())) + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + calculate_content_md5(response.http_response.body()) + ) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -118,50 +110,36 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop( - "download_stream_current", None - ) + download_stream_current = request.context.options.pop("download_stream_current", None) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop( - "upload_stream_current", None - ) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get( - "response_callback" - ) or request.context.options.pop("raw_response_hook", self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = await self.next.send(request) - will_retry = is_retry( - response, request.context.options.get("mode") - ) or await is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int( - response.http_response.headers.get("Content-Length", 0) - ) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int( - content_range.split(" ", 1)[1].split("/", 1)[1] - ) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int( - response.http_request.headers.get("Content-Length", 0) - ) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = ( - download_stream_current - ) + pipeline_obj.context["download_stream_current"] = download_stream_current pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -190,9 +168,7 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry( - response, retry_settings["mode"] - ) or await is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -211,9 +187,7 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err - ) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: await retry_hook( retry_settings, @@ -275,9 +249,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -290,14 +262,8 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + ( - 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) - ) - random_range_start = ( - backoff - self.random_jitter_range - if backoff > self.random_jitter_range - else 0 - ) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -335,9 +301,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -352,11 +316,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = ( - self.backoff - self.random_jitter_range - if self.backoff > self.random_jitter_range - else 0 - ) + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -364,16 +324,10 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__( - self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any - ) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( - credential, audience, **kwargs - ) + def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - async def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py index 712f4e90af69..27272fdac592 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -35,9 +35,7 @@ class SMRegion(Enum): MESSAGE_FOOTER = 5 -def generate_message_header( - version: int, size: int, flags: StructuredMessageProperties, num_segments: int -) -> bytes: +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: return ( version.to_bytes(1, "little") + size.to_bytes(8, "little") @@ -50,17 +48,14 @@ def generate_segment_header(number: int, size: int) -> bytes: return number.to_bytes(2, "little") + size.to_bytes(8, "little") -def parse_message_header( - data: bytes, expected_message_length: int -) -> tuple[int, StructuredMessageProperties, int]: +def parse_message_header(data: bytes, expected_message_length: int) -> tuple[int, StructuredMessageProperties, int]: version = data[0] if version != 1: raise ValueError(f"The structured message version is not supported: {version}") message_length = int.from_bytes(data[1:9], "little") if message_length != expected_message_length: raise ValueError( - f"Structured message length {message_length} " - f"did not match content length {expected_message_length}" + f"Structured message length {message_length} " f"did not match content length {expected_message_length}" ) flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) num_segments = int.from_bytes(data[11:13], "little") @@ -70,16 +65,12 @@ def parse_message_header( def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: segment_number = int.from_bytes(data[0:2], "little") if segment_number != expected_segment_number: - raise ValueError( - f"Structured message segment number invalid or out of order {segment_number}" - ) + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") segment_content_length = int.from_bytes(data[2:10], "little") return segment_number, segment_content_length -class StructuredMessageEncodeStream( - IOBase -): # pylint: disable=too-many-instance-attributes +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes message_version: int content_length: int message_length: int @@ -151,19 +142,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 def _update_current_region_length(self) -> None: if self._current_region == SMRegion.MESSAGE_HEADER: @@ -198,10 +181,7 @@ def readable(self) -> bool: def seekable(self) -> bool: try: # Only seekable if the inner stream is and we could get its initial position - return ( - self._inner_stream.seekable() - and self._initial_content_position is not None - ) + return self._inner_stream.seekable() and self._initial_content_position is not None except (AttributeError, UnsupportedOperation, OSError): return False @@ -212,24 +192,21 @@ def tell(self) -> int: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._current_region_offset ) if self._current_region == SMRegion.SEGMENT_CONTENT: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._segment_header_length ) if self._current_region == SMRegion.SEGMENT_FOOTER: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._segment_header_length + self._current_region_offset ) @@ -237,8 +214,7 @@ def tell(self) -> int: return ( self._message_header_length + self._content_offset - + self._current_segment_number - * (self._segment_header_length + self._segment_footer_length) + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + self._current_region_offset ) @@ -271,33 +247,21 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: # MESSAGE_FOOTER elif position >= self.message_length - self._message_footer_length: self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - ( - self.message_length - self._message_footer_length - ) + self._current_region_offset = position - (self.message_length - self._message_footer_length) self._content_offset = self.content_length self._current_segment_number = self._num_segments else: # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = ( - self._segment_header_length - + self._segment_size - + self._segment_footer_length - ) - new_segment_num = ( - 1 + (position - self._message_header_length) // full_segment_size - ) + full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length + new_segment_num = 1 + (position - self._message_header_length) // full_segment_size segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = ( - new_segment_num - 1 - ) * self._segment_size + previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size # We need the size of the segment we are seeking to for some of the calculations below new_segment_size = self._segment_size if new_segment_num == self._num_segments: # The last segment size is the remaining content length - new_segment_size = ( - self.content_length - previous_segments_total_content_size - ) + new_segment_size = self.content_length - previous_segments_total_content_size # SEGMENT_HEADER if segment_pos < self._segment_header_length: @@ -308,25 +272,17 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: elif segment_pos < self._segment_header_length + new_segment_size: self._current_region = SMRegion.SEGMENT_CONTENT self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = ( - previous_segments_total_content_size + self._current_region_offset - ) + self._content_offset = previous_segments_total_content_size + self._current_region_offset # SEGMENT_FOOTER else: self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = ( - segment_pos - self._segment_header_length - new_segment_size - ) - self._content_offset = ( - previous_segments_total_content_size + new_segment_size - ) + self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size + self._content_offset = previous_segments_total_content_size + new_segment_size self._current_segment_number = new_segment_num self._update_current_region_length() - self._inner_stream.seek( - (self._initial_content_position or 0) + self._content_offset - ) + self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) return position def read(self, size: int = -1) -> bytes: @@ -349,9 +305,7 @@ def read(self, size: int = -1) -> bytes: SMRegion.SEGMENT_FOOTER, SMRegion.MESSAGE_FOOTER, ): - count += self._read_metadata_region( - self._current_region, remaining, output - ) + count += self._read_metadata_region(self._current_region, remaining, output) elif self._current_region == SMRegion.SEGMENT_CONTENT: count += self._read_content(remaining, output) else: @@ -361,9 +315,7 @@ def read(self, size: int = -1) -> bytes: def _calculate_message_length(self) -> int: length = self._message_header_length - length += ( - self._segment_header_length + self._segment_footer_length - ) * self._num_segments + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments length += self.content_length length += self._message_footer_length return length @@ -378,9 +330,7 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: ) if region == SMRegion.SEGMENT_HEADER: - segment_size = min( - self._segment_size, self.content_length - self._content_offset - ) + segment_size = min(self._segment_size, self.content_length - self._content_offset) return generate_segment_header(self._current_segment_number, segment_size) if region == SMRegion.SEGMENT_FOOTER: @@ -392,9 +342,7 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: if region == SMRegion.MESSAGE_FOOTER: if StructuredMessageProperties.CRC64 in self.flags: - return self._message_crc64.to_bytes( - StructuredMessageConstants.CRC64_LENGTH, "little" - ) + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, "little") return b"" raise ValueError(f"Invalid metadata SMRegion {self._current_region}") @@ -421,15 +369,11 @@ def _advance_region(self, current: SMRegion): self._update_current_region_length() - def _read_metadata_region( - self, region: SMRegion, size: int, output: BytesIO - ) -> int: + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: metadata = self._get_metadata_region(region) read_size = min(size, self._current_region_length - self._current_region_offset) - content = metadata[ - self._current_region_offset : self._current_region_offset + read_size - ] + content = metadata[self._current_region_offset : self._current_region_offset + read_size] output.write(content) self._current_region_offset += read_size @@ -511,9 +455,7 @@ def __init__( self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError( - "Content not long enough to contain a valid message header." - ) + raise ValueError("Content not long enough to contain a valid message header.") self._inner_iterator = inner_iterator self._buffer = b"" @@ -537,19 +479,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _end_of_segment_content(self) -> bool: @@ -571,7 +505,7 @@ def __next__(self) -> bytes: return data def read(self, size: int = -1) -> bytes: - if self.closed: + if self.closed: # pylint: disable=using-constant-test raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: @@ -588,23 +522,17 @@ def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: self._read_segment_footer() if self.num_segments > 1: - raise ValueError( - "First message segment was empty but more segments were detected." - ) + raise ValueError("First message segment was empty but more segments were detected.") self._read_message_footer() return b"" count = 0 content = BytesIO() - while count < size and not ( - self._end_of_segment_content and self._message_offset == self.message_length - ): + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): if self._end_of_segment_content: self._read_segment_header() - segment_remaining = ( - self._segment_content_length - self._segment_content_offset - ) + segment_remaining = self._segment_content_length - self._segment_content_offset read_size = min(segment_remaining, size - count) segment_content = self._read_from_inner(read_size) @@ -612,12 +540,8 @@ def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64( - segment_content, self._segment_crc64 - ) - self._message_crc64 = calculate_crc64( - segment_content, self._message_crc64 - ) + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) self._segment_content_offset += read_size self._message_offset += read_size @@ -631,10 +555,7 @@ def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if ( - self._message_offset == self.message_length - and self._segment_number != self.num_segments - ): + if self._message_offset == self.message_length and self._segment_number != self.num_segments: raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -648,9 +569,7 @@ def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError( - "Invalid structured message data detected. Stream content incomplete." - ) + raise ValueError("Invalid structured message data detected. Stream content incomplete.") data = self._buffer[:size] self._buffer = self._buffer[size:] @@ -658,9 +577,7 @@ def _read_from_inner(self, size: int) -> bytes: def _read_message_header(self) -> None: header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) - self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length - ) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH def _read_message_footer(self) -> None: @@ -674,17 +591,14 @@ def _read_message_footer(self) -> None: if self._message_crc64 != int.from_bytes(message_crc, "little"): raise ValueError( - "CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid." + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." ) self._message_offset += self._message_footer_length def _read_segment_header(self) -> None: header_data = self._read_from_inner(self._segment_header_length) - self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1 - ) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py index ee7d92d14d77..9fedc055a623 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py @@ -17,9 +17,7 @@ from .validation import calculate_crc64 -class AsyncStructuredMessageDecoder( - IOBase -): # pylint: disable=too-many-instance-attributes +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes message_version: int """The version of the structured message.""" @@ -50,9 +48,7 @@ def __init__( self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError( - "Content not long enough to contain a valid message header." - ) + raise ValueError("Content not long enough to contain a valid message header.") self._inner_iterator = inner_iterator self._buffer = b"" @@ -76,19 +72,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _end_of_segment_content(self) -> bool: @@ -110,7 +98,7 @@ async def __anext__(self) -> bytes: return data async def read(self, size: int = -1) -> bytes: - if self.closed: + if self.closed: # pylint: disable=using-constant-test raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: @@ -127,23 +115,17 @@ async def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: await self._read_segment_footer() if self.num_segments > 1: - raise ValueError( - "First message segment was empty but more segments were detected." - ) + raise ValueError("First message segment was empty but more segments were detected.") await self._read_message_footer() return b"" count = 0 content = BytesIO() - while count < size and not ( - self._end_of_segment_content and self._message_offset == self.message_length - ): + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): if self._end_of_segment_content: await self._read_segment_header() - segment_remaining = ( - self._segment_content_length - self._segment_content_offset - ) + segment_remaining = self._segment_content_length - self._segment_content_offset read_size = min(segment_remaining, size - count) segment_content = await self._read_from_inner(read_size) @@ -151,12 +133,8 @@ async def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64( - segment_content, self._segment_crc64 - ) - self._message_crc64 = calculate_crc64( - segment_content, self._message_crc64 - ) + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) self._segment_content_offset += read_size self._message_offset += read_size @@ -170,10 +148,7 @@ async def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if ( - self._message_offset == self.message_length - and self._segment_number != self.num_segments - ): + if self._message_offset == self.message_length and self._segment_number != self.num_segments: raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -187,21 +162,15 @@ async def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError( - "Invalid structured message data detected. Stream content incomplete." - ) + raise ValueError("Invalid structured message data detected. Stream content incomplete.") data = self._buffer[:size] self._buffer = self._buffer[size:] return data async def _read_message_header(self) -> None: - header_data = await self._read_from_inner( - StructuredMessageConstants.V1_HEADER_LENGTH - ) - self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length - ) + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH async def _read_message_footer(self) -> None: @@ -211,23 +180,18 @@ async def _read_message_footer(self) -> None: raise ValueError("Invalid structured message data detected.") if StructuredMessageProperties.CRC64 in self.flags: - message_crc = await self._read_from_inner( - StructuredMessageConstants.CRC64_LENGTH - ) + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) if self._message_crc64 != int.from_bytes(message_crc, "little"): raise ValueError( - "CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid." + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." ) self._message_offset += self._message_footer_length async def _read_segment_header(self) -> None: header_data = await self._read_from_inner(self._segment_header_length) - self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1 - ) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 @@ -235,9 +199,7 @@ async def _read_segment_header(self) -> None: async def _read_segment_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: - segment_crc = await self._read_from_inner( - StructuredMessageConstants.CRC64_LENGTH - ) + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) if self._segment_crc64 != int.from_bytes(segment_crc, "little"): raise ValueError( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index 83124c72a6fc..2f7d2a7f1e95 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -323,8 +323,11 @@ async def _setup(self) -> None: # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. first_get_size = ( - self._config.max_single_get_size if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size + self._config.max_single_get_size + if not is_md5_validation(self._validate_content) + else self._config.max_chunk_get_size ) + initial_request_start = self._start_range if self._start_range is not None else 0 if self._end_range is not None and self._end_range - initial_request_start < first_get_size: initial_request_end = self._end_range diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py index e17777f65301..e2f65b03771e 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -7,6 +7,7 @@ from io import BytesIO import pytest +from azure.core.exceptions import ResourceExistsError from azure.storage.blob import ( BlobBlock, BlobClient, @@ -14,7 +15,7 @@ BlobType, ContainerClient ) -from devtools_testutils import recorded_by_proxy +from devtools_testutils import is_live, recorded_by_proxy from devtools_testutils.storage import ( GenericTestProxyParametrize1, GenericTestProxyParametrize2, @@ -80,10 +81,13 @@ def _setup(self, account_name): token_credential = self.get_credential(BlobServiceClient) self.bsc = BlobServiceClient(self.account_url(account_name, "blob"), token_credential, logging_enable=True) self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) - self.container.create_container() + try: + self.container.create_container() + except ResourceExistsError: + pass def teardown_method(self, _): - if self.container: + if self.container and is_live(): try: self.container.delete_container() except: @@ -505,3 +509,61 @@ def test_download_blob_chars(self, a, **kwargs): result += stream.readall() assert result == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_content_validation_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True + ) + self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + try: + self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b'abc' * 512 + + # Determine the appropriate assert methods based on validation mode + upload_assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + download_assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Test upload with retry + upload_call_count = 0 + def upload_hook_fail_once(response): + nonlocal upload_call_count + upload_call_count += 1 + # Assert content validation headers are present on both attempts + upload_assert_method(response) + if upload_call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + blob.upload_blob(data, validate_content=a, overwrite=True, raw_response_hook=upload_hook_fail_once) + assert upload_call_count == 2 # Original + retry + assert blob.download_blob().read() == data + + # Test download with retry + download_call_count = 0 + def download_hook_fail_once(response): + nonlocal download_call_count + download_call_count += 1 + # Assert content validation headers are present on both attempts + download_assert_method(response) + if download_call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + downloader = blob.download_blob(validate_content=a, raw_response_hook=download_hook_fail_once) + content = downloader.read() + assert download_call_count == 2 # Original + retry + assert content == data diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py index 80b6cfe3308f..6b1114b81a44 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -14,6 +14,7 @@ BlobServiceClient, ContainerClient ) +from devtools_testutils import is_live from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import ( AsyncStorageRecordedTestCase, @@ -47,6 +48,8 @@ async def _setup(self, account_name): pass def teardown_method(self, _): + if not is_live(): + return # Use sync client as teardown_method must be sync if self.container: sync_credential = self.get_credential(SyncContainerClient, is_async=False) @@ -481,3 +484,61 @@ async def test_download_blob_chars(self, a, **kwargs): result += await stream.readall() assert result == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_content_validation_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient, is_async=True) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True + ) + self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + try: + await self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b'abc' * 512 + + # Determine the appropriate assert methods based on validation mode + upload_assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + download_assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + + # Test upload with retry + upload_call_count = 0 + def upload_hook_fail_once(response): + nonlocal upload_call_count + upload_call_count += 1 + # Assert content validation headers are present on both attempts + upload_assert_method(response) + if upload_call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + await blob.upload_blob(data, validate_content=a, overwrite=True, raw_response_hook=upload_hook_fail_once) + assert upload_call_count == 2 # Original + retry + assert await (await blob.download_blob()).read() == data + + # Test download with retry + download_call_count = 0 + def download_hook_fail_once(response): + nonlocal download_call_count + download_call_count += 1 + # Assert content validation headers are present on both attempts + download_assert_method(response) + if download_call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + downloader = await blob.download_blob(validate_content=a, raw_response_hook=download_hook_fail_once) + content = await downloader.read() + assert download_call_count == 2 # Original + retry + assert content == data diff --git a/sdk/storage/azure-storage-extensions/dev_requirements.txt b/sdk/storage/azure-storage-extensions/dev_requirements.txt index 7d496b4d1cc1..b18a83fbb955 100644 --- a/sdk/storage/azure-storage-extensions/dev_requirements.txt +++ b/sdk/storage/azure-storage-extensions/dev_requirements.txt @@ -1 +1 @@ --e ../../../eng/tools/azure-sdk-tools +-e ../../../eng/tools/azure-sdk-tools \ No newline at end of file diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py index 2c917610eade..2e023b1cc8d9 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py @@ -40,7 +40,11 @@ StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncContentValidationPolicy, AsyncStorageResponseHook +from .policies_async import ( + AsyncStorageBearerTokenCredentialPolicy, + AsyncContentValidationPolicy, + AsyncStorageResponseHook, +) from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py index 9b10ece3de79..7a02e4479149 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py @@ -86,9 +86,7 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"]( - retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs - ) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) # Is this method/status code retryable? (Based on allowlists and control @@ -108,9 +106,7 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get( - "x-ms-copy-source-error-code" - ) + error_code = response.http_response.headers.get("x-ms-copy-source-error-code") if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -134,9 +130,9 @@ def is_checksum_retry(response) -> bool: # Legacy code - evaluate retry only on validate_content=True if validate_content is True and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get( - "content-md5", None - ) or encode_base64(calculate_content_md5(response.http_response.body())) + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + calculate_content_md5(response.http_response.body()) + ) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -165,9 +161,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str( - uuid.uuid1() - ) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -207,9 +201,7 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError( - f"Attempting to use undefined host location {use_location}" - ) + raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -227,9 +219,7 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__( - logging_enable=logging_enable, **kwargs - ) + super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -313,9 +303,7 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get( - "content-type", "" - ) + resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): filename = header.partition("=")[2] @@ -344,9 +332,7 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop( - "raw_request_hook", self._request_callback - ) + request_callback = request.context.options.pop("raw_request_hook", self._request_callback) if request_callback: request_callback(request) @@ -364,50 +350,36 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop( - "download_stream_current", None - ) + download_stream_current = request.context.options.pop("download_stream_current", None) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop( - "upload_stream_current", None - ) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get( - "response_callback" - ) or request.context.options.pop("raw_response_hook", self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = self.next.send(request) - will_retry = is_retry( - response, request.context.options.get("mode") - ) or is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int( - response.http_response.headers.get("Content-Length", 0) - ) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int( - content_range.split(" ", 1)[1].split("/", 1)[1] - ) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int( - response.http_request.headers.get("Content-Length", 0) - ) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = ( - download_stream_current - ) + pipeline_obj.context["download_stream_current"] = download_stream_current pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -416,11 +388,6 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": def _prepare_content_validation(request: "PipelineRequest") -> None: - """Shared request-side logic for content validation. - - Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 - validation, and stores the validation mode in the request context. - """ validate_content = request.context.options.pop("validate_content", False) if not validate_content: return @@ -446,21 +413,13 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: data = BytesIO(data) if isinstance(data, bytes): - request.http_request.headers[CRC64_HEADER] = encode_base64( - calculate_crc64_bytes(data) - ) + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) elif hasattr(data, "read"): - content_length = int( - request.http_request.headers.get(CONTENT_LENGTH_HEADER) - ) + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) # Wrap data in structured message stream and adjust HTTP request - sm_stream = StructuredMessageEncodeStream( - data, content_length, StructuredMessageProperties.CRC64 - ) + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) request.http_request.data = sm_stream - request.http_request.headers[CONTENT_LENGTH_HEADER] = str( - len(sm_stream) - ) + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 else: @@ -474,18 +433,11 @@ def _validate_content_response( response: "PipelineResponse", decoder_cls: type, ) -> None: - """Shared response-side logic for content validation. - - Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches - ``stream_download`` to wrap the iterator in the given *decoder_cls*. - """ validate_content = response.context.get("validate_content", False) if not validate_content: return - if is_md5_validation(validate_content) and response.http_response.headers.get( - "content-md5" - ): + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): computed_md5 = request.context.get("validate_content_md5") or encode_base64( calculate_content_md5(response.http_response.body()) ) @@ -520,9 +472,7 @@ def _validate_content_response( def wrapped_stream_download(*args, **kwargs): iterator = original_stream_download(*args, **kwargs) - decoder = decoder_cls( - iterator, content_length, block_size=DATA_BLOCK_SIZE - ) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) decoder.request = iterator.request # type: ignore decoder.response = iterator.response # type: ignore return decoder @@ -542,9 +492,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument def on_request(self, request: "PipelineRequest") -> None: _prepare_content_validation(request) - def on_response( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> None: + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: _validate_content_response(request, response, StructuredMessageDecoder) @@ -572,9 +520,7 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location( - self, settings: Dict[str, Any], request: "PipelineRequest" - ) -> None: + def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. @@ -613,9 +559,7 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop( - "retry_to_secondary", self.retry_to_secondary - ), + "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -624,9 +568,7 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time( - self, settings: Dict[str, Any] - ) -> float: # pylint: disable=unused-argument + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -689,9 +631,7 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append( - RequestHistory(request, http_response=response) - ) + settings["history"].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -726,9 +666,7 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry( - response - ): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -747,9 +685,7 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err - ) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: retry_hook( retry_settings, @@ -809,9 +745,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -824,14 +758,8 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + ( - 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) - ) - random_range_start = ( - backoff - self.random_jitter_range - if backoff > self.random_jitter_range - else 0 - ) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -869,9 +797,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -886,11 +812,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = ( - self.backoff - self.random_jitter_range - if self.backoff > self.random_jitter_range - else 0 - ) + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -898,16 +820,10 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__( - self, credential: "TokenCredential", audience: str, **kwargs: Any - ) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__( - credential, audience, **kwargs - ) + def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Handle the challenge from the service and authorize the request. :param request: The request object. diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py index 14ce070e47ff..e1d13b1a83fa 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py @@ -45,17 +45,9 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"]( - retry_count=settings["count"] - 1, - location_mode=settings["mode"], - **kwargs - ) + await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) else: - settings["hook"]( - retry_count=settings["count"] - 1, - location_mode=settings["mode"], - **kwargs - ) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) async def is_checksum_retry(response): @@ -70,9 +62,9 @@ async def is_checksum_retry(response): await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get( - "content-md5", None - ) or encode_base64(calculate_content_md5(response.http_response.body())) + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + calculate_content_md5(response.http_response.body()) + ) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -118,50 +110,36 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop( - "download_stream_current", None - ) + download_stream_current = request.context.options.pop("download_stream_current", None) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop( - "upload_stream_current", None - ) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get( - "response_callback" - ) or request.context.options.pop("raw_response_hook", self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = await self.next.send(request) - will_retry = is_retry( - response, request.context.options.get("mode") - ) or await is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int( - response.http_response.headers.get("Content-Length", 0) - ) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int( - content_range.split(" ", 1)[1].split("/", 1)[1] - ) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int( - response.http_request.headers.get("Content-Length", 0) - ) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = ( - download_stream_current - ) + pipeline_obj.context["download_stream_current"] = download_stream_current pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -190,9 +168,7 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry( - response, retry_settings["mode"] - ) or await is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -211,9 +187,7 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err - ) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: await retry_hook( retry_settings, @@ -275,9 +249,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -290,14 +262,8 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + ( - 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) - ) - random_range_start = ( - backoff - self.random_jitter_range - if backoff > self.random_jitter_range - else 0 - ) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -335,9 +301,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -352,11 +316,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = ( - self.backoff - self.random_jitter_range - if self.backoff > self.random_jitter_range - else 0 - ) + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -364,16 +324,10 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__( - self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any - ) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( - credential, audience, **kwargs - ) + def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - async def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py index 712f4e90af69..27272fdac592 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py @@ -35,9 +35,7 @@ class SMRegion(Enum): MESSAGE_FOOTER = 5 -def generate_message_header( - version: int, size: int, flags: StructuredMessageProperties, num_segments: int -) -> bytes: +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: return ( version.to_bytes(1, "little") + size.to_bytes(8, "little") @@ -50,17 +48,14 @@ def generate_segment_header(number: int, size: int) -> bytes: return number.to_bytes(2, "little") + size.to_bytes(8, "little") -def parse_message_header( - data: bytes, expected_message_length: int -) -> tuple[int, StructuredMessageProperties, int]: +def parse_message_header(data: bytes, expected_message_length: int) -> tuple[int, StructuredMessageProperties, int]: version = data[0] if version != 1: raise ValueError(f"The structured message version is not supported: {version}") message_length = int.from_bytes(data[1:9], "little") if message_length != expected_message_length: raise ValueError( - f"Structured message length {message_length} " - f"did not match content length {expected_message_length}" + f"Structured message length {message_length} " f"did not match content length {expected_message_length}" ) flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) num_segments = int.from_bytes(data[11:13], "little") @@ -70,16 +65,12 @@ def parse_message_header( def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: segment_number = int.from_bytes(data[0:2], "little") if segment_number != expected_segment_number: - raise ValueError( - f"Structured message segment number invalid or out of order {segment_number}" - ) + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") segment_content_length = int.from_bytes(data[2:10], "little") return segment_number, segment_content_length -class StructuredMessageEncodeStream( - IOBase -): # pylint: disable=too-many-instance-attributes +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes message_version: int content_length: int message_length: int @@ -151,19 +142,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 def _update_current_region_length(self) -> None: if self._current_region == SMRegion.MESSAGE_HEADER: @@ -198,10 +181,7 @@ def readable(self) -> bool: def seekable(self) -> bool: try: # Only seekable if the inner stream is and we could get its initial position - return ( - self._inner_stream.seekable() - and self._initial_content_position is not None - ) + return self._inner_stream.seekable() and self._initial_content_position is not None except (AttributeError, UnsupportedOperation, OSError): return False @@ -212,24 +192,21 @@ def tell(self) -> int: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._current_region_offset ) if self._current_region == SMRegion.SEGMENT_CONTENT: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._segment_header_length ) if self._current_region == SMRegion.SEGMENT_FOOTER: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._segment_header_length + self._current_region_offset ) @@ -237,8 +214,7 @@ def tell(self) -> int: return ( self._message_header_length + self._content_offset - + self._current_segment_number - * (self._segment_header_length + self._segment_footer_length) + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + self._current_region_offset ) @@ -271,33 +247,21 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: # MESSAGE_FOOTER elif position >= self.message_length - self._message_footer_length: self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - ( - self.message_length - self._message_footer_length - ) + self._current_region_offset = position - (self.message_length - self._message_footer_length) self._content_offset = self.content_length self._current_segment_number = self._num_segments else: # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = ( - self._segment_header_length - + self._segment_size - + self._segment_footer_length - ) - new_segment_num = ( - 1 + (position - self._message_header_length) // full_segment_size - ) + full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length + new_segment_num = 1 + (position - self._message_header_length) // full_segment_size segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = ( - new_segment_num - 1 - ) * self._segment_size + previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size # We need the size of the segment we are seeking to for some of the calculations below new_segment_size = self._segment_size if new_segment_num == self._num_segments: # The last segment size is the remaining content length - new_segment_size = ( - self.content_length - previous_segments_total_content_size - ) + new_segment_size = self.content_length - previous_segments_total_content_size # SEGMENT_HEADER if segment_pos < self._segment_header_length: @@ -308,25 +272,17 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: elif segment_pos < self._segment_header_length + new_segment_size: self._current_region = SMRegion.SEGMENT_CONTENT self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = ( - previous_segments_total_content_size + self._current_region_offset - ) + self._content_offset = previous_segments_total_content_size + self._current_region_offset # SEGMENT_FOOTER else: self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = ( - segment_pos - self._segment_header_length - new_segment_size - ) - self._content_offset = ( - previous_segments_total_content_size + new_segment_size - ) + self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size + self._content_offset = previous_segments_total_content_size + new_segment_size self._current_segment_number = new_segment_num self._update_current_region_length() - self._inner_stream.seek( - (self._initial_content_position or 0) + self._content_offset - ) + self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) return position def read(self, size: int = -1) -> bytes: @@ -349,9 +305,7 @@ def read(self, size: int = -1) -> bytes: SMRegion.SEGMENT_FOOTER, SMRegion.MESSAGE_FOOTER, ): - count += self._read_metadata_region( - self._current_region, remaining, output - ) + count += self._read_metadata_region(self._current_region, remaining, output) elif self._current_region == SMRegion.SEGMENT_CONTENT: count += self._read_content(remaining, output) else: @@ -361,9 +315,7 @@ def read(self, size: int = -1) -> bytes: def _calculate_message_length(self) -> int: length = self._message_header_length - length += ( - self._segment_header_length + self._segment_footer_length - ) * self._num_segments + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments length += self.content_length length += self._message_footer_length return length @@ -378,9 +330,7 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: ) if region == SMRegion.SEGMENT_HEADER: - segment_size = min( - self._segment_size, self.content_length - self._content_offset - ) + segment_size = min(self._segment_size, self.content_length - self._content_offset) return generate_segment_header(self._current_segment_number, segment_size) if region == SMRegion.SEGMENT_FOOTER: @@ -392,9 +342,7 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: if region == SMRegion.MESSAGE_FOOTER: if StructuredMessageProperties.CRC64 in self.flags: - return self._message_crc64.to_bytes( - StructuredMessageConstants.CRC64_LENGTH, "little" - ) + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, "little") return b"" raise ValueError(f"Invalid metadata SMRegion {self._current_region}") @@ -421,15 +369,11 @@ def _advance_region(self, current: SMRegion): self._update_current_region_length() - def _read_metadata_region( - self, region: SMRegion, size: int, output: BytesIO - ) -> int: + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: metadata = self._get_metadata_region(region) read_size = min(size, self._current_region_length - self._current_region_offset) - content = metadata[ - self._current_region_offset : self._current_region_offset + read_size - ] + content = metadata[self._current_region_offset : self._current_region_offset + read_size] output.write(content) self._current_region_offset += read_size @@ -511,9 +455,7 @@ def __init__( self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError( - "Content not long enough to contain a valid message header." - ) + raise ValueError("Content not long enough to contain a valid message header.") self._inner_iterator = inner_iterator self._buffer = b"" @@ -537,19 +479,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _end_of_segment_content(self) -> bool: @@ -571,7 +505,7 @@ def __next__(self) -> bytes: return data def read(self, size: int = -1) -> bytes: - if self.closed: + if self.closed: # pylint: disable=using-constant-test raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: @@ -588,23 +522,17 @@ def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: self._read_segment_footer() if self.num_segments > 1: - raise ValueError( - "First message segment was empty but more segments were detected." - ) + raise ValueError("First message segment was empty but more segments were detected.") self._read_message_footer() return b"" count = 0 content = BytesIO() - while count < size and not ( - self._end_of_segment_content and self._message_offset == self.message_length - ): + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): if self._end_of_segment_content: self._read_segment_header() - segment_remaining = ( - self._segment_content_length - self._segment_content_offset - ) + segment_remaining = self._segment_content_length - self._segment_content_offset read_size = min(segment_remaining, size - count) segment_content = self._read_from_inner(read_size) @@ -612,12 +540,8 @@ def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64( - segment_content, self._segment_crc64 - ) - self._message_crc64 = calculate_crc64( - segment_content, self._message_crc64 - ) + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) self._segment_content_offset += read_size self._message_offset += read_size @@ -631,10 +555,7 @@ def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if ( - self._message_offset == self.message_length - and self._segment_number != self.num_segments - ): + if self._message_offset == self.message_length and self._segment_number != self.num_segments: raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -648,9 +569,7 @@ def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError( - "Invalid structured message data detected. Stream content incomplete." - ) + raise ValueError("Invalid structured message data detected. Stream content incomplete.") data = self._buffer[:size] self._buffer = self._buffer[size:] @@ -658,9 +577,7 @@ def _read_from_inner(self, size: int) -> bytes: def _read_message_header(self) -> None: header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) - self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length - ) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH def _read_message_footer(self) -> None: @@ -674,17 +591,14 @@ def _read_message_footer(self) -> None: if self._message_crc64 != int.from_bytes(message_crc, "little"): raise ValueError( - "CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid." + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." ) self._message_offset += self._message_footer_length def _read_segment_header(self) -> None: header_data = self._read_from_inner(self._segment_header_length) - self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1 - ) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py index ee7d92d14d77..9fedc055a623 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py @@ -17,9 +17,7 @@ from .validation import calculate_crc64 -class AsyncStructuredMessageDecoder( - IOBase -): # pylint: disable=too-many-instance-attributes +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes message_version: int """The version of the structured message.""" @@ -50,9 +48,7 @@ def __init__( self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError( - "Content not long enough to contain a valid message header." - ) + raise ValueError("Content not long enough to contain a valid message header.") self._inner_iterator = inner_iterator self._buffer = b"" @@ -76,19 +72,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _end_of_segment_content(self) -> bool: @@ -110,7 +98,7 @@ async def __anext__(self) -> bytes: return data async def read(self, size: int = -1) -> bytes: - if self.closed: + if self.closed: # pylint: disable=using-constant-test raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: @@ -127,23 +115,17 @@ async def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: await self._read_segment_footer() if self.num_segments > 1: - raise ValueError( - "First message segment was empty but more segments were detected." - ) + raise ValueError("First message segment was empty but more segments were detected.") await self._read_message_footer() return b"" count = 0 content = BytesIO() - while count < size and not ( - self._end_of_segment_content and self._message_offset == self.message_length - ): + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): if self._end_of_segment_content: await self._read_segment_header() - segment_remaining = ( - self._segment_content_length - self._segment_content_offset - ) + segment_remaining = self._segment_content_length - self._segment_content_offset read_size = min(segment_remaining, size - count) segment_content = await self._read_from_inner(read_size) @@ -151,12 +133,8 @@ async def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64( - segment_content, self._segment_crc64 - ) - self._message_crc64 = calculate_crc64( - segment_content, self._message_crc64 - ) + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) self._segment_content_offset += read_size self._message_offset += read_size @@ -170,10 +148,7 @@ async def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if ( - self._message_offset == self.message_length - and self._segment_number != self.num_segments - ): + if self._message_offset == self.message_length and self._segment_number != self.num_segments: raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -187,21 +162,15 @@ async def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError( - "Invalid structured message data detected. Stream content incomplete." - ) + raise ValueError("Invalid structured message data detected. Stream content incomplete.") data = self._buffer[:size] self._buffer = self._buffer[size:] return data async def _read_message_header(self) -> None: - header_data = await self._read_from_inner( - StructuredMessageConstants.V1_HEADER_LENGTH - ) - self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length - ) + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH async def _read_message_footer(self) -> None: @@ -211,23 +180,18 @@ async def _read_message_footer(self) -> None: raise ValueError("Invalid structured message data detected.") if StructuredMessageProperties.CRC64 in self.flags: - message_crc = await self._read_from_inner( - StructuredMessageConstants.CRC64_LENGTH - ) + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) if self._message_crc64 != int.from_bytes(message_crc, "little"): raise ValueError( - "CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid." + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." ) self._message_offset += self._message_footer_length async def _read_segment_header(self) -> None: header_data = await self._read_from_inner(self._segment_header_length) - self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1 - ) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 @@ -235,9 +199,7 @@ async def _read_segment_header(self) -> None: async def _read_segment_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: - segment_crc = await self._read_from_inner( - StructuredMessageConstants.CRC64_LENGTH - ) + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) if self._segment_crc64 != int.from_bytes(segment_crc, "little"): raise ValueError( diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py index 86695a19ec10..f13042518f98 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py @@ -509,7 +509,7 @@ async def append_data( :type length: int or None :keyword bool flush: If true, will commit the data after it is appended. - ::keyword validate_content: + :keyword validate_content: Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py index 87d8170e6035..77f81647e956 100644 --- a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py @@ -11,7 +11,7 @@ DataLakeServiceClient ) -from devtools_testutils import recorded_by_proxy +from devtools_testutils import is_live, recorded_by_proxy from devtools_testutils.storage import GenericTestProxyParametrize1, StorageRecordedTestCase from settings.testcase import DataLakePreparer @@ -49,7 +49,7 @@ def _setup(self, account_name): self.file_system.create_file_system() def teardown_method(self, _): - if self.file_system: + if self.file_system and is_live(): try: self.file_system.delete_file_system() except: diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py index 4c9f1af53013..b137e11cfba1 100644 --- a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py @@ -10,6 +10,7 @@ from azure.storage.filedatalake import FileSystemClient as SyncFileSystemClient from azure.storage.filedatalake.aio import DataLakeServiceClient +from devtools_testutils import is_live from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase, GenericTestProxyParametrize1 from settings.testcase import DataLakePreparer @@ -30,6 +31,8 @@ async def _setup(self, account_name): await self.file_system.create_file_system() def teardown_method(self, _): + if not is_live(): + return # Use sync client as teardown_method must be sync if self.file_system: sync_credential = self.get_credential(SyncFileSystemClient, is_async=False) diff --git a/sdk/storage/azure-storage-file-share/assets.json b/sdk/storage/azure-storage-file-share/assets.json index f8436a861d3c..79ed1f733678 100644 --- a/sdk/storage/azure-storage-file-share/assets.json +++ b/sdk/storage/azure-storage-file-share/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/storage/azure-storage-file-share", - "Tag": "python/storage/azure-storage-file-share_4afd6de033" + "Tag": "python/storage/azure-storage-file-share_d5f376c42f" } diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py index d0e2d19d7f47..c087390eb09b 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py @@ -18,7 +18,7 @@ from ._shared.request_handlers import validate_and_format_range_headers from ._shared.response_handlers import parse_length_from_content_range, process_storage_error from ._shared.constants import DEFAULT_MAX_CONCURRENCY -from ._shared.validation import is_md5_validation, CV_TYPE, CV_TYPE_PARSED +from ._shared.validation import is_md5_validation, CV_TYPE_PARSED if TYPE_CHECKING: from ._generated.operations import FileOperations @@ -253,7 +253,9 @@ def __init__( # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. self._first_get_size = ( - self._config.max_single_get_size if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size + self._config.max_single_get_size + if not is_md5_validation(self._validate_content) + else self._config.max_chunk_get_size ) initial_request_start = self._start_range or 0 diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py index 2c917610eade..2e023b1cc8d9 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py @@ -40,7 +40,11 @@ StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncContentValidationPolicy, AsyncStorageResponseHook +from .policies_async import ( + AsyncStorageBearerTokenCredentialPolicy, + AsyncContentValidationPolicy, + AsyncStorageResponseHook, +) from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py index 9b10ece3de79..7a02e4479149 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py @@ -86,9 +86,7 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"]( - retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs - ) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) # Is this method/status code retryable? (Based on allowlists and control @@ -108,9 +106,7 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get( - "x-ms-copy-source-error-code" - ) + error_code = response.http_response.headers.get("x-ms-copy-source-error-code") if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -134,9 +130,9 @@ def is_checksum_retry(response) -> bool: # Legacy code - evaluate retry only on validate_content=True if validate_content is True and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get( - "content-md5", None - ) or encode_base64(calculate_content_md5(response.http_response.body())) + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + calculate_content_md5(response.http_response.body()) + ) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -165,9 +161,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str( - uuid.uuid1() - ) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -207,9 +201,7 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError( - f"Attempting to use undefined host location {use_location}" - ) + raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -227,9 +219,7 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__( - logging_enable=logging_enable, **kwargs - ) + super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -313,9 +303,7 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get( - "content-type", "" - ) + resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): filename = header.partition("=")[2] @@ -344,9 +332,7 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop( - "raw_request_hook", self._request_callback - ) + request_callback = request.context.options.pop("raw_request_hook", self._request_callback) if request_callback: request_callback(request) @@ -364,50 +350,36 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop( - "download_stream_current", None - ) + download_stream_current = request.context.options.pop("download_stream_current", None) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop( - "upload_stream_current", None - ) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get( - "response_callback" - ) or request.context.options.pop("raw_response_hook", self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = self.next.send(request) - will_retry = is_retry( - response, request.context.options.get("mode") - ) or is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int( - response.http_response.headers.get("Content-Length", 0) - ) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int( - content_range.split(" ", 1)[1].split("/", 1)[1] - ) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int( - response.http_request.headers.get("Content-Length", 0) - ) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = ( - download_stream_current - ) + pipeline_obj.context["download_stream_current"] = download_stream_current pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -416,11 +388,6 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": def _prepare_content_validation(request: "PipelineRequest") -> None: - """Shared request-side logic for content validation. - - Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 - validation, and stores the validation mode in the request context. - """ validate_content = request.context.options.pop("validate_content", False) if not validate_content: return @@ -446,21 +413,13 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: data = BytesIO(data) if isinstance(data, bytes): - request.http_request.headers[CRC64_HEADER] = encode_base64( - calculate_crc64_bytes(data) - ) + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) elif hasattr(data, "read"): - content_length = int( - request.http_request.headers.get(CONTENT_LENGTH_HEADER) - ) + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) # Wrap data in structured message stream and adjust HTTP request - sm_stream = StructuredMessageEncodeStream( - data, content_length, StructuredMessageProperties.CRC64 - ) + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) request.http_request.data = sm_stream - request.http_request.headers[CONTENT_LENGTH_HEADER] = str( - len(sm_stream) - ) + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 else: @@ -474,18 +433,11 @@ def _validate_content_response( response: "PipelineResponse", decoder_cls: type, ) -> None: - """Shared response-side logic for content validation. - - Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches - ``stream_download`` to wrap the iterator in the given *decoder_cls*. - """ validate_content = response.context.get("validate_content", False) if not validate_content: return - if is_md5_validation(validate_content) and response.http_response.headers.get( - "content-md5" - ): + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): computed_md5 = request.context.get("validate_content_md5") or encode_base64( calculate_content_md5(response.http_response.body()) ) @@ -520,9 +472,7 @@ def _validate_content_response( def wrapped_stream_download(*args, **kwargs): iterator = original_stream_download(*args, **kwargs) - decoder = decoder_cls( - iterator, content_length, block_size=DATA_BLOCK_SIZE - ) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) decoder.request = iterator.request # type: ignore decoder.response = iterator.response # type: ignore return decoder @@ -542,9 +492,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument def on_request(self, request: "PipelineRequest") -> None: _prepare_content_validation(request) - def on_response( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> None: + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: _validate_content_response(request, response, StructuredMessageDecoder) @@ -572,9 +520,7 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location( - self, settings: Dict[str, Any], request: "PipelineRequest" - ) -> None: + def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. @@ -613,9 +559,7 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop( - "retry_to_secondary", self.retry_to_secondary - ), + "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -624,9 +568,7 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time( - self, settings: Dict[str, Any] - ) -> float: # pylint: disable=unused-argument + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -689,9 +631,7 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append( - RequestHistory(request, http_response=response) - ) + settings["history"].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -726,9 +666,7 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry( - response - ): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -747,9 +685,7 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err - ) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: retry_hook( retry_settings, @@ -809,9 +745,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -824,14 +758,8 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + ( - 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) - ) - random_range_start = ( - backoff - self.random_jitter_range - if backoff > self.random_jitter_range - else 0 - ) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -869,9 +797,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -886,11 +812,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = ( - self.backoff - self.random_jitter_range - if self.backoff > self.random_jitter_range - else 0 - ) + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -898,16 +820,10 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__( - self, credential: "TokenCredential", audience: str, **kwargs: Any - ) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__( - credential, audience, **kwargs - ) + def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Handle the challenge from the service and authorize the request. :param request: The request object. diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py index 14ce070e47ff..e1d13b1a83fa 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py @@ -45,17 +45,9 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"]( - retry_count=settings["count"] - 1, - location_mode=settings["mode"], - **kwargs - ) + await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) else: - settings["hook"]( - retry_count=settings["count"] - 1, - location_mode=settings["mode"], - **kwargs - ) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) async def is_checksum_retry(response): @@ -70,9 +62,9 @@ async def is_checksum_retry(response): await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get( - "content-md5", None - ) or encode_base64(calculate_content_md5(response.http_response.body())) + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + calculate_content_md5(response.http_response.body()) + ) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -118,50 +110,36 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop( - "download_stream_current", None - ) + download_stream_current = request.context.options.pop("download_stream_current", None) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop( - "upload_stream_current", None - ) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get( - "response_callback" - ) or request.context.options.pop("raw_response_hook", self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = await self.next.send(request) - will_retry = is_retry( - response, request.context.options.get("mode") - ) or await is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int( - response.http_response.headers.get("Content-Length", 0) - ) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int( - content_range.split(" ", 1)[1].split("/", 1)[1] - ) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int( - response.http_request.headers.get("Content-Length", 0) - ) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = ( - download_stream_current - ) + pipeline_obj.context["download_stream_current"] = download_stream_current pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -190,9 +168,7 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry( - response, retry_settings["mode"] - ) or await is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -211,9 +187,7 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err - ) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: await retry_hook( retry_settings, @@ -275,9 +249,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -290,14 +262,8 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + ( - 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) - ) - random_range_start = ( - backoff - self.random_jitter_range - if backoff > self.random_jitter_range - else 0 - ) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -335,9 +301,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -352,11 +316,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = ( - self.backoff - self.random_jitter_range - if self.backoff > self.random_jitter_range - else 0 - ) + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -364,16 +324,10 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__( - self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any - ) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( - credential, audience, **kwargs - ) + def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - async def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py index 712f4e90af69..27272fdac592 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py @@ -35,9 +35,7 @@ class SMRegion(Enum): MESSAGE_FOOTER = 5 -def generate_message_header( - version: int, size: int, flags: StructuredMessageProperties, num_segments: int -) -> bytes: +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: return ( version.to_bytes(1, "little") + size.to_bytes(8, "little") @@ -50,17 +48,14 @@ def generate_segment_header(number: int, size: int) -> bytes: return number.to_bytes(2, "little") + size.to_bytes(8, "little") -def parse_message_header( - data: bytes, expected_message_length: int -) -> tuple[int, StructuredMessageProperties, int]: +def parse_message_header(data: bytes, expected_message_length: int) -> tuple[int, StructuredMessageProperties, int]: version = data[0] if version != 1: raise ValueError(f"The structured message version is not supported: {version}") message_length = int.from_bytes(data[1:9], "little") if message_length != expected_message_length: raise ValueError( - f"Structured message length {message_length} " - f"did not match content length {expected_message_length}" + f"Structured message length {message_length} " f"did not match content length {expected_message_length}" ) flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) num_segments = int.from_bytes(data[11:13], "little") @@ -70,16 +65,12 @@ def parse_message_header( def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: segment_number = int.from_bytes(data[0:2], "little") if segment_number != expected_segment_number: - raise ValueError( - f"Structured message segment number invalid or out of order {segment_number}" - ) + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") segment_content_length = int.from_bytes(data[2:10], "little") return segment_number, segment_content_length -class StructuredMessageEncodeStream( - IOBase -): # pylint: disable=too-many-instance-attributes +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes message_version: int content_length: int message_length: int @@ -151,19 +142,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 def _update_current_region_length(self) -> None: if self._current_region == SMRegion.MESSAGE_HEADER: @@ -198,10 +181,7 @@ def readable(self) -> bool: def seekable(self) -> bool: try: # Only seekable if the inner stream is and we could get its initial position - return ( - self._inner_stream.seekable() - and self._initial_content_position is not None - ) + return self._inner_stream.seekable() and self._initial_content_position is not None except (AttributeError, UnsupportedOperation, OSError): return False @@ -212,24 +192,21 @@ def tell(self) -> int: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._current_region_offset ) if self._current_region == SMRegion.SEGMENT_CONTENT: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._segment_header_length ) if self._current_region == SMRegion.SEGMENT_FOOTER: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._segment_header_length + self._current_region_offset ) @@ -237,8 +214,7 @@ def tell(self) -> int: return ( self._message_header_length + self._content_offset - + self._current_segment_number - * (self._segment_header_length + self._segment_footer_length) + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + self._current_region_offset ) @@ -271,33 +247,21 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: # MESSAGE_FOOTER elif position >= self.message_length - self._message_footer_length: self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - ( - self.message_length - self._message_footer_length - ) + self._current_region_offset = position - (self.message_length - self._message_footer_length) self._content_offset = self.content_length self._current_segment_number = self._num_segments else: # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = ( - self._segment_header_length - + self._segment_size - + self._segment_footer_length - ) - new_segment_num = ( - 1 + (position - self._message_header_length) // full_segment_size - ) + full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length + new_segment_num = 1 + (position - self._message_header_length) // full_segment_size segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = ( - new_segment_num - 1 - ) * self._segment_size + previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size # We need the size of the segment we are seeking to for some of the calculations below new_segment_size = self._segment_size if new_segment_num == self._num_segments: # The last segment size is the remaining content length - new_segment_size = ( - self.content_length - previous_segments_total_content_size - ) + new_segment_size = self.content_length - previous_segments_total_content_size # SEGMENT_HEADER if segment_pos < self._segment_header_length: @@ -308,25 +272,17 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: elif segment_pos < self._segment_header_length + new_segment_size: self._current_region = SMRegion.SEGMENT_CONTENT self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = ( - previous_segments_total_content_size + self._current_region_offset - ) + self._content_offset = previous_segments_total_content_size + self._current_region_offset # SEGMENT_FOOTER else: self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = ( - segment_pos - self._segment_header_length - new_segment_size - ) - self._content_offset = ( - previous_segments_total_content_size + new_segment_size - ) + self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size + self._content_offset = previous_segments_total_content_size + new_segment_size self._current_segment_number = new_segment_num self._update_current_region_length() - self._inner_stream.seek( - (self._initial_content_position or 0) + self._content_offset - ) + self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) return position def read(self, size: int = -1) -> bytes: @@ -349,9 +305,7 @@ def read(self, size: int = -1) -> bytes: SMRegion.SEGMENT_FOOTER, SMRegion.MESSAGE_FOOTER, ): - count += self._read_metadata_region( - self._current_region, remaining, output - ) + count += self._read_metadata_region(self._current_region, remaining, output) elif self._current_region == SMRegion.SEGMENT_CONTENT: count += self._read_content(remaining, output) else: @@ -361,9 +315,7 @@ def read(self, size: int = -1) -> bytes: def _calculate_message_length(self) -> int: length = self._message_header_length - length += ( - self._segment_header_length + self._segment_footer_length - ) * self._num_segments + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments length += self.content_length length += self._message_footer_length return length @@ -378,9 +330,7 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: ) if region == SMRegion.SEGMENT_HEADER: - segment_size = min( - self._segment_size, self.content_length - self._content_offset - ) + segment_size = min(self._segment_size, self.content_length - self._content_offset) return generate_segment_header(self._current_segment_number, segment_size) if region == SMRegion.SEGMENT_FOOTER: @@ -392,9 +342,7 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: if region == SMRegion.MESSAGE_FOOTER: if StructuredMessageProperties.CRC64 in self.flags: - return self._message_crc64.to_bytes( - StructuredMessageConstants.CRC64_LENGTH, "little" - ) + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, "little") return b"" raise ValueError(f"Invalid metadata SMRegion {self._current_region}") @@ -421,15 +369,11 @@ def _advance_region(self, current: SMRegion): self._update_current_region_length() - def _read_metadata_region( - self, region: SMRegion, size: int, output: BytesIO - ) -> int: + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: metadata = self._get_metadata_region(region) read_size = min(size, self._current_region_length - self._current_region_offset) - content = metadata[ - self._current_region_offset : self._current_region_offset + read_size - ] + content = metadata[self._current_region_offset : self._current_region_offset + read_size] output.write(content) self._current_region_offset += read_size @@ -511,9 +455,7 @@ def __init__( self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError( - "Content not long enough to contain a valid message header." - ) + raise ValueError("Content not long enough to contain a valid message header.") self._inner_iterator = inner_iterator self._buffer = b"" @@ -537,19 +479,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _end_of_segment_content(self) -> bool: @@ -571,7 +505,7 @@ def __next__(self) -> bytes: return data def read(self, size: int = -1) -> bytes: - if self.closed: + if self.closed: # pylint: disable=using-constant-test raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: @@ -588,23 +522,17 @@ def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: self._read_segment_footer() if self.num_segments > 1: - raise ValueError( - "First message segment was empty but more segments were detected." - ) + raise ValueError("First message segment was empty but more segments were detected.") self._read_message_footer() return b"" count = 0 content = BytesIO() - while count < size and not ( - self._end_of_segment_content and self._message_offset == self.message_length - ): + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): if self._end_of_segment_content: self._read_segment_header() - segment_remaining = ( - self._segment_content_length - self._segment_content_offset - ) + segment_remaining = self._segment_content_length - self._segment_content_offset read_size = min(segment_remaining, size - count) segment_content = self._read_from_inner(read_size) @@ -612,12 +540,8 @@ def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64( - segment_content, self._segment_crc64 - ) - self._message_crc64 = calculate_crc64( - segment_content, self._message_crc64 - ) + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) self._segment_content_offset += read_size self._message_offset += read_size @@ -631,10 +555,7 @@ def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if ( - self._message_offset == self.message_length - and self._segment_number != self.num_segments - ): + if self._message_offset == self.message_length and self._segment_number != self.num_segments: raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -648,9 +569,7 @@ def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError( - "Invalid structured message data detected. Stream content incomplete." - ) + raise ValueError("Invalid structured message data detected. Stream content incomplete.") data = self._buffer[:size] self._buffer = self._buffer[size:] @@ -658,9 +577,7 @@ def _read_from_inner(self, size: int) -> bytes: def _read_message_header(self) -> None: header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) - self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length - ) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH def _read_message_footer(self) -> None: @@ -674,17 +591,14 @@ def _read_message_footer(self) -> None: if self._message_crc64 != int.from_bytes(message_crc, "little"): raise ValueError( - "CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid." + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." ) self._message_offset += self._message_footer_length def _read_segment_header(self) -> None: header_data = self._read_from_inner(self._segment_header_length) - self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1 - ) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py index ee7d92d14d77..9fedc055a623 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py @@ -17,9 +17,7 @@ from .validation import calculate_crc64 -class AsyncStructuredMessageDecoder( - IOBase -): # pylint: disable=too-many-instance-attributes +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes message_version: int """The version of the structured message.""" @@ -50,9 +48,7 @@ def __init__( self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError( - "Content not long enough to contain a valid message header." - ) + raise ValueError("Content not long enough to contain a valid message header.") self._inner_iterator = inner_iterator self._buffer = b"" @@ -76,19 +72,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _end_of_segment_content(self) -> bool: @@ -110,7 +98,7 @@ async def __anext__(self) -> bytes: return data async def read(self, size: int = -1) -> bytes: - if self.closed: + if self.closed: # pylint: disable=using-constant-test raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: @@ -127,23 +115,17 @@ async def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: await self._read_segment_footer() if self.num_segments > 1: - raise ValueError( - "First message segment was empty but more segments were detected." - ) + raise ValueError("First message segment was empty but more segments were detected.") await self._read_message_footer() return b"" count = 0 content = BytesIO() - while count < size and not ( - self._end_of_segment_content and self._message_offset == self.message_length - ): + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): if self._end_of_segment_content: await self._read_segment_header() - segment_remaining = ( - self._segment_content_length - self._segment_content_offset - ) + segment_remaining = self._segment_content_length - self._segment_content_offset read_size = min(segment_remaining, size - count) segment_content = await self._read_from_inner(read_size) @@ -151,12 +133,8 @@ async def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64( - segment_content, self._segment_crc64 - ) - self._message_crc64 = calculate_crc64( - segment_content, self._message_crc64 - ) + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) self._segment_content_offset += read_size self._message_offset += read_size @@ -170,10 +148,7 @@ async def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if ( - self._message_offset == self.message_length - and self._segment_number != self.num_segments - ): + if self._message_offset == self.message_length and self._segment_number != self.num_segments: raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -187,21 +162,15 @@ async def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError( - "Invalid structured message data detected. Stream content incomplete." - ) + raise ValueError("Invalid structured message data detected. Stream content incomplete.") data = self._buffer[:size] self._buffer = self._buffer[size:] return data async def _read_message_header(self) -> None: - header_data = await self._read_from_inner( - StructuredMessageConstants.V1_HEADER_LENGTH - ) - self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length - ) + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH async def _read_message_footer(self) -> None: @@ -211,23 +180,18 @@ async def _read_message_footer(self) -> None: raise ValueError("Invalid structured message data detected.") if StructuredMessageProperties.CRC64 in self.flags: - message_crc = await self._read_from_inner( - StructuredMessageConstants.CRC64_LENGTH - ) + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) if self._message_crc64 != int.from_bytes(message_crc, "little"): raise ValueError( - "CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid." + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." ) self._message_offset += self._message_footer_length async def _read_segment_header(self) -> None: header_data = await self._read_from_inner(self._segment_header_length) - self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1 - ) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 @@ -235,9 +199,7 @@ async def _read_segment_header(self) -> None: async def _read_segment_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: - segment_crc = await self._read_from_inner( - StructuredMessageConstants.CRC64_LENGTH - ) + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) if self._segment_crc64 != int.from_bytes(segment_crc, "little"): raise ValueError( diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py index ae8fe9962fcd..b3a2a1750f04 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py @@ -212,7 +212,9 @@ def __init__( # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. self._first_get_size = ( - self._config.max_single_get_size if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size + self._config.max_single_get_size + if not is_md5_validation(self._validate_content) + else self._config.max_chunk_get_size ) initial_request_start = self._start_range or 0 diff --git a/sdk/storage/azure-storage-file-share/tests/test_content_validation.py b/sdk/storage/azure-storage-file-share/tests/test_content_validation.py index d6fc92168363..c318c22acd2c 100644 --- a/sdk/storage/azure-storage-file-share/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-file-share/tests/test_content_validation.py @@ -9,7 +9,7 @@ import pytest from azure.storage.fileshare import ShareClient, ShareServiceClient -from devtools_testutils import recorded_by_proxy +from devtools_testutils import is_live, recorded_by_proxy from devtools_testutils.storage import GenericTestProxyParametrize1, StorageRecordedTestCase from settings.testcase import FileSharePreparer @@ -44,7 +44,7 @@ def _setup(self, account_name): self.share_client.create_share() def teardown_method(self, _): - if self.share_client: + if self.share_client and is_live(): try: self.share_client.delete_share() except: diff --git a/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py b/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py index d5cce63c3900..a8a8d82477eb 100644 --- a/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py @@ -10,6 +10,7 @@ from azure.storage.fileshare import ShareClient as SyncShareClient from azure.storage.fileshare.aio import ShareServiceClient +from devtools_testutils import is_live from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase, GenericTestProxyParametrize1 from settings.testcase import FileSharePreparer @@ -29,6 +30,8 @@ async def _setup(self, account_name): await self.share_client.create_share() def teardown_method(self, _): + if not is_live(): + return if self.share_client: sync_credential = self.get_credential(SyncShareClient, is_async=False) sync_share_client = SyncShareClient( diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 69c95e6bb6b7..068bef5601f3 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -86,9 +86,7 @@ def is_exhausted(settings): def retry_hook(settings, **kwargs): if settings["hook"]: - settings["hook"]( - retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs - ) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) # Is this method/status code retryable? (Based on allowlists and control @@ -108,9 +106,7 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements # Response code 408 is a timeout and should be retried. return True if status >= 400: - error_code = response.http_response.headers.get( - "x-ms-copy-source-error-code" - ) + error_code = response.http_response.headers.get("x-ms-copy-source-error-code") if error_code in [ StorageErrorCode.OPERATION_TIMED_OUT, StorageErrorCode.INTERNAL_ERROR, @@ -134,9 +130,9 @@ def is_checksum_retry(response) -> bool: # Legacy code - evaluate retry only on validate_content=True if validate_content is True and response.http_response.headers.get("content-md5"): - computed_md5 = response.http_request.headers.get( - "content-md5", None - ) or encode_base64(calculate_content_md5(response.http_response.body())) + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + calculate_content_md5(response.http_response.body()) + ) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -171,9 +167,7 @@ def on_request(self, request: "PipelineRequest") -> None: request.http_request.headers["x-ms-date"] = current_time custom_id = request.context.options.pop("client_request_id", None) - request.http_request.headers["x-ms-client-request-id"] = custom_id or str( - uuid.uuid1() - ) + request.http_request.headers["x-ms-client-request-id"] = custom_id or str(uuid.uuid1()) # def on_response(self, request, response): # # raise exception if the echoed client request id from the service is not identical to the one we sent @@ -213,9 +207,7 @@ def on_request(self, request: "PipelineRequest") -> None: # Lock retries to the specific location request.context.options["retry_to_secondary"] = False if use_location not in self.hosts: - raise ValueError( - f"Attempting to use undefined host location {use_location}" - ) + raise ValueError(f"Attempting to use undefined host location {use_location}") if use_location != location_mode: # Update request URL to use the specified location updated = parsed_url._replace(netloc=self.hosts[use_location]) @@ -233,9 +225,7 @@ class StorageLoggingPolicy(NetworkTraceLoggingPolicy): def __init__(self, logging_enable: bool = False, **kwargs) -> None: self.logging_body = kwargs.pop("logging_body", False) - super(StorageLoggingPolicy, self).__init__( - logging_enable=logging_enable, **kwargs - ) + super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs) def on_request(self, request: "PipelineRequest") -> None: http_request = request.http_request @@ -319,9 +309,7 @@ def on_response(self, request: "PipelineRequest", response: "PipelineResponse") _LOGGER.debug("Response content:") pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE) header = response.http_response.headers.get("content-disposition") - resp_content_type = response.http_response.headers.get( - "content-type", "" - ) + resp_content_type = response.http_response.headers.get("content-type", "") if header and pattern.match(header): filename = header.partition("=")[2] @@ -350,9 +338,7 @@ def __init__(self, **kwargs): super(StorageRequestHook, self).__init__() def on_request(self, request: "PipelineRequest") -> None: - request_callback = request.context.options.pop( - "raw_request_hook", self._request_callback - ) + request_callback = request.context.options.pop("raw_request_hook", self._request_callback) if request_callback: request_callback(request) @@ -370,50 +356,36 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop( - "download_stream_current", None - ) + download_stream_current = request.context.options.pop("download_stream_current", None) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop( - "upload_stream_current", None - ) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get( - "response_callback" - ) or request.context.options.pop("raw_response_hook", self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = self.next.send(request) - will_retry = is_retry( - response, request.context.options.get("mode") - ) or is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int( - response.http_response.headers.get("Content-Length", 0) - ) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int( - content_range.split(" ", 1)[1].split("/", 1)[1] - ) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int( - response.http_request.headers.get("Content-Length", 0) - ) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = ( - download_stream_current - ) + pipeline_obj.context["download_stream_current"] = download_stream_current pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: response_callback(response) @@ -422,11 +394,6 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": def _prepare_content_validation(request: "PipelineRequest") -> None: - """Shared request-side logic for content validation. - - Pops 'validate_content' from options, sets up headers/streams for MD5 or CRC64 - validation, and stores the validation mode in the request context. - """ validate_content = request.context.options.pop("validate_content", False) if not validate_content: return @@ -452,21 +419,13 @@ def _prepare_content_validation(request: "PipelineRequest") -> None: data = BytesIO(data) if isinstance(data, bytes): - request.http_request.headers[CRC64_HEADER] = encode_base64( - calculate_crc64_bytes(data) - ) + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) elif hasattr(data, "read"): - content_length = int( - request.http_request.headers.get(CONTENT_LENGTH_HEADER) - ) + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) # Wrap data in structured message stream and adjust HTTP request - sm_stream = StructuredMessageEncodeStream( - data, content_length, StructuredMessageProperties.CRC64 - ) + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) request.http_request.data = sm_stream - request.http_request.headers[CONTENT_LENGTH_HEADER] = str( - len(sm_stream) - ) + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 else: @@ -480,18 +439,11 @@ def _validate_content_response( response: "PipelineResponse", decoder_cls: type, ) -> None: - """Shared response-side logic for content validation. - - Checks MD5 or CRC64 validation on the response. For CRC64 GET responses, patches - ``stream_download`` to wrap the iterator in the given *decoder_cls*. - """ validate_content = response.context.get("validate_content", False) if not validate_content: return - if is_md5_validation(validate_content) and response.http_response.headers.get( - "content-md5" - ): + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): computed_md5 = request.context.get("validate_content_md5") or encode_base64( calculate_content_md5(response.http_response.body()) ) @@ -526,9 +478,7 @@ def _validate_content_response( def wrapped_stream_download(*args, **kwargs): iterator = original_stream_download(*args, **kwargs) - decoder = decoder_cls( - iterator, content_length, block_size=DATA_BLOCK_SIZE - ) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) decoder.request = iterator.request # type: ignore decoder.response = iterator.response # type: ignore return decoder @@ -548,9 +498,7 @@ def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument def on_request(self, request: "PipelineRequest") -> None: _prepare_content_validation(request) - def on_response( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> None: + def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: _validate_content_response(request, response, StructuredMessageDecoder) @@ -578,9 +526,7 @@ def __init__(self, **kwargs: Any) -> None: self.retry_to_secondary = kwargs.pop("retry_to_secondary", False) super(StorageRetryPolicy, self).__init__() - def _set_next_host_location( - self, settings: Dict[str, Any], request: "PipelineRequest" - ) -> None: + def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None: """ A function which sets the next host location on the request, if applicable. @@ -619,9 +565,7 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "connect": options.pop("retry_connect", self.connect_retries), "read": options.pop("retry_read", self.read_retries), "status": options.pop("retry_status", self.status_retries), - "retry_secondary": options.pop( - "retry_to_secondary", self.retry_to_secondary - ), + "retry_secondary": options.pop("retry_to_secondary", self.retry_to_secondary), "mode": options.pop("location_mode", LocationMode.PRIMARY), "hosts": options.pop("hosts", None), "hook": options.pop("retry_hook", None), @@ -630,9 +574,7 @@ def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: "history": [], } - def get_backoff_time( - self, settings: Dict[str, Any] - ) -> float: # pylint: disable=unused-argument + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """Formula for computing the current backoff. Should be calculated by child class. @@ -695,9 +637,7 @@ def increment( # status_forcelist and a the given method is in the allowlist if response: settings["status"] -= 1 - settings["history"].append( - RequestHistory(request, http_response=response) - ) + settings["history"].append(RequestHistory(request, http_response=response)) if not is_exhausted(settings): if request.method not in ["PUT"] and settings["retry_secondary"]: @@ -732,9 +672,7 @@ def send(self, request): while retries_remaining: try: response = self.next.send(request) - if is_retry(response, retry_settings["mode"]) or is_checksum_retry( - response - ): + if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -753,9 +691,7 @@ def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err - ) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: retry_hook( retry_settings, @@ -815,9 +751,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -830,14 +764,8 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: float """ random_generator = random.Random() - backoff = self.initial_backoff + ( - 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) - ) - random_range_start = ( - backoff - self.random_jitter_range - if backoff > self.random_jitter_range - else 0 - ) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -875,9 +803,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -892,11 +818,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = ( - self.backoff - self.random_jitter_range - if self.backoff > self.random_jitter_range - else 0 - ) + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -904,16 +826,10 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__( - self, credential: "TokenCredential", audience: str, **kwargs: Any - ) -> None: - super(StorageBearerTokenCredentialPolicy, self).__init__( - credential, audience, **kwargs - ) + def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None: + super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Handle the challenge from the service and authorize the request. :param request: The request object. diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 14ce070e47ff..e1d13b1a83fa 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -45,17 +45,9 @@ async def retry_hook(settings, **kwargs): if settings["hook"]: if asyncio.iscoroutine(settings["hook"]): - await settings["hook"]( - retry_count=settings["count"] - 1, - location_mode=settings["mode"], - **kwargs - ) + await settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) else: - settings["hook"]( - retry_count=settings["count"] - 1, - location_mode=settings["mode"], - **kwargs - ) + settings["hook"](retry_count=settings["count"] - 1, location_mode=settings["mode"], **kwargs) async def is_checksum_retry(response): @@ -70,9 +62,9 @@ async def is_checksum_retry(response): await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass - computed_md5 = response.http_request.headers.get( - "content-md5", None - ) or encode_base64(calculate_content_md5(response.http_response.body())) + computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( + calculate_content_md5(response.http_response.body()) + ) if response.http_response.headers["content-md5"] != computed_md5: return True return False @@ -118,50 +110,36 @@ async def send(self, request: "PipelineRequest") -> "PipelineResponse": data_stream_total = request.context.options.pop("data_stream_total", None) download_stream_current = request.context.get("download_stream_current") if download_stream_current is None: - download_stream_current = request.context.options.pop( - "download_stream_current", None - ) + download_stream_current = request.context.options.pop("download_stream_current", None) upload_stream_current = request.context.get("upload_stream_current") if upload_stream_current is None: - upload_stream_current = request.context.options.pop( - "upload_stream_current", None - ) + upload_stream_current = request.context.options.pop("upload_stream_current", None) - response_callback = request.context.get( - "response_callback" - ) or request.context.options.pop("raw_response_hook", self._response_callback) + response_callback = request.context.get("response_callback") or request.context.options.pop( + "raw_response_hook", self._response_callback + ) response = await self.next.send(request) - will_retry = is_retry( - response, request.context.options.get("mode") - ) or await is_checksum_retry(response) + will_retry = is_retry(response, request.context.options.get("mode")) or await is_checksum_retry(response) # Auth error could come from Bearer challenge, in which case this request will be made again is_auth_error = response.http_response.status_code == 401 should_update_counts = not (will_retry or is_auth_error) if should_update_counts and download_stream_current is not None: - download_stream_current += int( - response.http_response.headers.get("Content-Length", 0) - ) + download_stream_current += int(response.http_response.headers.get("Content-Length", 0)) if data_stream_total is None: content_range = response.http_response.headers.get("Content-Range") if content_range: - data_stream_total = int( - content_range.split(" ", 1)[1].split("/", 1)[1] - ) + data_stream_total = int(content_range.split(" ", 1)[1].split("/", 1)[1]) else: data_stream_total = download_stream_current elif should_update_counts and upload_stream_current is not None: - upload_stream_current += int( - response.http_request.headers.get("Content-Length", 0) - ) + upload_stream_current += int(response.http_request.headers.get("Content-Length", 0)) for pipeline_obj in [request, response]: if hasattr(pipeline_obj, "context"): pipeline_obj.context["data_stream_total"] = data_stream_total - pipeline_obj.context["download_stream_current"] = ( - download_stream_current - ) + pipeline_obj.context["download_stream_current"] = download_stream_current pipeline_obj.context["upload_stream_current"] = upload_stream_current if response_callback: if asyncio.iscoroutine(response_callback): @@ -190,9 +168,7 @@ async def send(self, request): while retries_remaining: try: response = await self.next.send(request) - if is_retry( - response, retry_settings["mode"] - ) or await is_checksum_retry(response): + if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( retry_settings, request=request.http_request, @@ -211,9 +187,7 @@ async def send(self, request): except AzureError as err: if isinstance(err, AzureSigningError): raise - retries_remaining = self.increment( - retry_settings, request=request.http_request, error=err - ) + retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: await retry_hook( retry_settings, @@ -275,9 +249,7 @@ def __init__( self.initial_backoff = initial_backoff self.increment_base = increment_base self.random_jitter_range = random_jitter_range - super(ExponentialRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(ExponentialRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -290,14 +262,8 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: :rtype: int or None """ random_generator = random.Random() - backoff = self.initial_backoff + ( - 0 if settings["count"] == 0 else pow(self.increment_base, settings["count"]) - ) - random_range_start = ( - backoff - self.random_jitter_range - if backoff > self.random_jitter_range - else 0 - ) + backoff = self.initial_backoff + (0 if settings["count"] == 0 else pow(self.increment_base, settings["count"])) + random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0 random_range_end = backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -335,9 +301,7 @@ def __init__( """ self.backoff = backoff self.random_jitter_range = random_jitter_range - super(LinearRetry, self).__init__( - retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs - ) + super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) def get_backoff_time(self, settings: Dict[str, Any]) -> float: """ @@ -352,11 +316,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: random_generator = random.Random() # the backoff interval normally does not change, however there is the possibility # that it was modified by accessing the property directly after initializing the object - random_range_start = ( - self.backoff - self.random_jitter_range - if self.backoff > self.random_jitter_range - else 0 - ) + random_range_start = self.backoff - self.random_jitter_range if self.backoff > self.random_jitter_range else 0 random_range_end = self.backoff + self.random_jitter_range return random_generator.uniform(random_range_start, random_range_end) @@ -364,16 +324,10 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom Bearer token credential policy for following Storage Bearer challenges""" - def __init__( - self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any - ) -> None: - super(AsyncStorageBearerTokenCredentialPolicy, self).__init__( - credential, audience, **kwargs - ) + def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None: + super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs) - async def on_challenge( - self, request: "PipelineRequest", response: "PipelineResponse" - ) -> bool: + async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: try: auth_header = response.http_response.headers.get("WWW-Authenticate") challenge = StorageHttpChallenge(auth_header) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py index 712f4e90af69..27272fdac592 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py @@ -35,9 +35,7 @@ class SMRegion(Enum): MESSAGE_FOOTER = 5 -def generate_message_header( - version: int, size: int, flags: StructuredMessageProperties, num_segments: int -) -> bytes: +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: return ( version.to_bytes(1, "little") + size.to_bytes(8, "little") @@ -50,17 +48,14 @@ def generate_segment_header(number: int, size: int) -> bytes: return number.to_bytes(2, "little") + size.to_bytes(8, "little") -def parse_message_header( - data: bytes, expected_message_length: int -) -> tuple[int, StructuredMessageProperties, int]: +def parse_message_header(data: bytes, expected_message_length: int) -> tuple[int, StructuredMessageProperties, int]: version = data[0] if version != 1: raise ValueError(f"The structured message version is not supported: {version}") message_length = int.from_bytes(data[1:9], "little") if message_length != expected_message_length: raise ValueError( - f"Structured message length {message_length} " - f"did not match content length {expected_message_length}" + f"Structured message length {message_length} " f"did not match content length {expected_message_length}" ) flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) num_segments = int.from_bytes(data[11:13], "little") @@ -70,16 +65,12 @@ def parse_message_header( def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: segment_number = int.from_bytes(data[0:2], "little") if segment_number != expected_segment_number: - raise ValueError( - f"Structured message segment number invalid or out of order {segment_number}" - ) + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") segment_content_length = int.from_bytes(data[2:10], "little") return segment_number, segment_content_length -class StructuredMessageEncodeStream( - IOBase -): # pylint: disable=too-many-instance-attributes +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes message_version: int content_length: int message_length: int @@ -151,19 +142,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 def _update_current_region_length(self) -> None: if self._current_region == SMRegion.MESSAGE_HEADER: @@ -198,10 +181,7 @@ def readable(self) -> bool: def seekable(self) -> bool: try: # Only seekable if the inner stream is and we could get its initial position - return ( - self._inner_stream.seekable() - and self._initial_content_position is not None - ) + return self._inner_stream.seekable() and self._initial_content_position is not None except (AttributeError, UnsupportedOperation, OSError): return False @@ -212,24 +192,21 @@ def tell(self) -> int: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._current_region_offset ) if self._current_region == SMRegion.SEGMENT_CONTENT: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._segment_header_length ) if self._current_region == SMRegion.SEGMENT_FOOTER: return ( self._message_header_length + self._content_offset - + (self._current_segment_number - 1) - * (self._segment_header_length + self._segment_footer_length) + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + self._segment_header_length + self._current_region_offset ) @@ -237,8 +214,7 @@ def tell(self) -> int: return ( self._message_header_length + self._content_offset - + self._current_segment_number - * (self._segment_header_length + self._segment_footer_length) + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + self._current_region_offset ) @@ -271,33 +247,21 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: # MESSAGE_FOOTER elif position >= self.message_length - self._message_footer_length: self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - ( - self.message_length - self._message_footer_length - ) + self._current_region_offset = position - (self.message_length - self._message_footer_length) self._content_offset = self.content_length self._current_segment_number = self._num_segments else: # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = ( - self._segment_header_length - + self._segment_size - + self._segment_footer_length - ) - new_segment_num = ( - 1 + (position - self._message_header_length) // full_segment_size - ) + full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length + new_segment_num = 1 + (position - self._message_header_length) // full_segment_size segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = ( - new_segment_num - 1 - ) * self._segment_size + previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size # We need the size of the segment we are seeking to for some of the calculations below new_segment_size = self._segment_size if new_segment_num == self._num_segments: # The last segment size is the remaining content length - new_segment_size = ( - self.content_length - previous_segments_total_content_size - ) + new_segment_size = self.content_length - previous_segments_total_content_size # SEGMENT_HEADER if segment_pos < self._segment_header_length: @@ -308,25 +272,17 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: elif segment_pos < self._segment_header_length + new_segment_size: self._current_region = SMRegion.SEGMENT_CONTENT self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = ( - previous_segments_total_content_size + self._current_region_offset - ) + self._content_offset = previous_segments_total_content_size + self._current_region_offset # SEGMENT_FOOTER else: self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = ( - segment_pos - self._segment_header_length - new_segment_size - ) - self._content_offset = ( - previous_segments_total_content_size + new_segment_size - ) + self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size + self._content_offset = previous_segments_total_content_size + new_segment_size self._current_segment_number = new_segment_num self._update_current_region_length() - self._inner_stream.seek( - (self._initial_content_position or 0) + self._content_offset - ) + self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) return position def read(self, size: int = -1) -> bytes: @@ -349,9 +305,7 @@ def read(self, size: int = -1) -> bytes: SMRegion.SEGMENT_FOOTER, SMRegion.MESSAGE_FOOTER, ): - count += self._read_metadata_region( - self._current_region, remaining, output - ) + count += self._read_metadata_region(self._current_region, remaining, output) elif self._current_region == SMRegion.SEGMENT_CONTENT: count += self._read_content(remaining, output) else: @@ -361,9 +315,7 @@ def read(self, size: int = -1) -> bytes: def _calculate_message_length(self) -> int: length = self._message_header_length - length += ( - self._segment_header_length + self._segment_footer_length - ) * self._num_segments + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments length += self.content_length length += self._message_footer_length return length @@ -378,9 +330,7 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: ) if region == SMRegion.SEGMENT_HEADER: - segment_size = min( - self._segment_size, self.content_length - self._content_offset - ) + segment_size = min(self._segment_size, self.content_length - self._content_offset) return generate_segment_header(self._current_segment_number, segment_size) if region == SMRegion.SEGMENT_FOOTER: @@ -392,9 +342,7 @@ def _get_metadata_region(self, region: SMRegion) -> bytes: if region == SMRegion.MESSAGE_FOOTER: if StructuredMessageProperties.CRC64 in self.flags: - return self._message_crc64.to_bytes( - StructuredMessageConstants.CRC64_LENGTH, "little" - ) + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, "little") return b"" raise ValueError(f"Invalid metadata SMRegion {self._current_region}") @@ -421,15 +369,11 @@ def _advance_region(self, current: SMRegion): self._update_current_region_length() - def _read_metadata_region( - self, region: SMRegion, size: int, output: BytesIO - ) -> int: + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: metadata = self._get_metadata_region(region) read_size = min(size, self._current_region_length - self._current_region_offset) - content = metadata[ - self._current_region_offset : self._current_region_offset + read_size - ] + content = metadata[self._current_region_offset : self._current_region_offset + read_size] output.write(content) self._current_region_offset += read_size @@ -511,9 +455,7 @@ def __init__( self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError( - "Content not long enough to contain a valid message header." - ) + raise ValueError("Content not long enough to contain a valid message header.") self._inner_iterator = inner_iterator self._buffer = b"" @@ -537,19 +479,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _end_of_segment_content(self) -> bool: @@ -571,7 +505,7 @@ def __next__(self) -> bytes: return data def read(self, size: int = -1) -> bytes: - if self.closed: + if self.closed: # pylint: disable=using-constant-test raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: @@ -588,23 +522,17 @@ def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: self._read_segment_footer() if self.num_segments > 1: - raise ValueError( - "First message segment was empty but more segments were detected." - ) + raise ValueError("First message segment was empty but more segments were detected.") self._read_message_footer() return b"" count = 0 content = BytesIO() - while count < size and not ( - self._end_of_segment_content and self._message_offset == self.message_length - ): + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): if self._end_of_segment_content: self._read_segment_header() - segment_remaining = ( - self._segment_content_length - self._segment_content_offset - ) + segment_remaining = self._segment_content_length - self._segment_content_offset read_size = min(segment_remaining, size - count) segment_content = self._read_from_inner(read_size) @@ -612,12 +540,8 @@ def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64( - segment_content, self._segment_crc64 - ) - self._message_crc64 = calculate_crc64( - segment_content, self._message_crc64 - ) + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) self._segment_content_offset += read_size self._message_offset += read_size @@ -631,10 +555,7 @@ def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if ( - self._message_offset == self.message_length - and self._segment_number != self.num_segments - ): + if self._message_offset == self.message_length and self._segment_number != self.num_segments: raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -648,9 +569,7 @@ def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError( - "Invalid structured message data detected. Stream content incomplete." - ) + raise ValueError("Invalid structured message data detected. Stream content incomplete.") data = self._buffer[:size] self._buffer = self._buffer[size:] @@ -658,9 +577,7 @@ def _read_from_inner(self, size: int) -> bytes: def _read_message_header(self) -> None: header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) - self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length - ) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH def _read_message_footer(self) -> None: @@ -674,17 +591,14 @@ def _read_message_footer(self) -> None: if self._message_crc64 != int.from_bytes(message_crc, "little"): raise ValueError( - "CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid." + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." ) self._message_offset += self._message_footer_length def _read_segment_header(self) -> None: header_data = self._read_from_inner(self._segment_header_length) - self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1 - ) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py index ee7d92d14d77..9fedc055a623 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py @@ -17,9 +17,7 @@ from .validation import calculate_crc64 -class AsyncStructuredMessageDecoder( - IOBase -): # pylint: disable=too-many-instance-attributes +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes message_version: int """The version of the structured message.""" @@ -50,9 +48,7 @@ def __init__( self.message_length = content_length # The stream should be at least long enough to hold minimum header length if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: - raise ValueError( - "Content not long enough to contain a valid message header." - ) + raise ValueError("Content not long enough to contain a valid message header.") self._inner_iterator = inner_iterator self._buffer = b"" @@ -76,19 +72,11 @@ def _segment_header_length(self) -> int: @property def _segment_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _message_footer_length(self) -> int: - return ( - StructuredMessageConstants.CRC64_LENGTH - if StructuredMessageProperties.CRC64 in self.flags - else 0 - ) + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 @property def _end_of_segment_content(self) -> bool: @@ -110,7 +98,7 @@ async def __anext__(self) -> bytes: return data async def read(self, size: int = -1) -> bytes: - if self.closed: + if self.closed: # pylint: disable=using-constant-test raise ValueError("Stream is closed") if size == 0 or self._message_offset >= self.message_length: @@ -127,23 +115,17 @@ async def read(self, size: int = -1) -> bytes: if self._end_of_segment_content: await self._read_segment_footer() if self.num_segments > 1: - raise ValueError( - "First message segment was empty but more segments were detected." - ) + raise ValueError("First message segment was empty but more segments were detected.") await self._read_message_footer() return b"" count = 0 content = BytesIO() - while count < size and not ( - self._end_of_segment_content and self._message_offset == self.message_length - ): + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): if self._end_of_segment_content: await self._read_segment_header() - segment_remaining = ( - self._segment_content_length - self._segment_content_offset - ) + segment_remaining = self._segment_content_length - self._segment_content_offset read_size = min(segment_remaining, size - count) segment_content = await self._read_from_inner(read_size) @@ -151,12 +133,8 @@ async def read(self, size: int = -1) -> bytes: # Update the running CRC64 for the segment and message if StructuredMessageProperties.CRC64 in self.flags: - self._segment_crc64 = calculate_crc64( - segment_content, self._segment_crc64 - ) - self._message_crc64 = calculate_crc64( - segment_content, self._message_crc64 - ) + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) self._segment_content_offset += read_size self._message_offset += read_size @@ -170,10 +148,7 @@ async def read(self, size: int = -1) -> bytes: # One final check to ensure if we think we've reached the end of the stream # that the current segment number matches the total. - if ( - self._message_offset == self.message_length - and self._segment_number != self.num_segments - ): + if self._message_offset == self.message_length and self._segment_number != self.num_segments: raise ValueError("Invalid structured message data detected.") return content.getvalue() @@ -187,21 +162,15 @@ async def _read_from_inner(self, size: int) -> bytes: self._buffer += chunk if len(self._buffer) < size: - raise ValueError( - "Invalid structured message data detected. Stream content incomplete." - ) + raise ValueError("Invalid structured message data detected. Stream content incomplete.") data = self._buffer[:size] self._buffer = self._buffer[size:] return data async def _read_message_header(self) -> None: - header_data = await self._read_from_inner( - StructuredMessageConstants.V1_HEADER_LENGTH - ) - self.message_version, self.flags, self.num_segments = parse_message_header( - header_data, self.message_length - ) + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH async def _read_message_footer(self) -> None: @@ -211,23 +180,18 @@ async def _read_message_footer(self) -> None: raise ValueError("Invalid structured message data detected.") if StructuredMessageProperties.CRC64 in self.flags: - message_crc = await self._read_from_inner( - StructuredMessageConstants.CRC64_LENGTH - ) + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) if self._message_crc64 != int.from_bytes(message_crc, "little"): raise ValueError( - "CRC64 mismatch detected in message trailer. " - "All data read should be considered invalid." + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." ) self._message_offset += self._message_footer_length async def _read_segment_header(self) -> None: header_data = await self._read_from_inner(self._segment_header_length) - self._segment_number, self._segment_content_length = parse_segment_header( - header_data, self._segment_number + 1 - ) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) self._message_offset += self._segment_header_length self._segment_content_offset = 0 @@ -235,9 +199,7 @@ async def _read_segment_header(self) -> None: async def _read_segment_footer(self) -> None: if StructuredMessageProperties.CRC64 in self.flags: - segment_crc = await self._read_from_inner( - StructuredMessageConstants.CRC64_LENGTH - ) + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) if self._segment_crc64 != int.from_bytes(segment_crc, "little"): raise ValueError( diff --git a/sdk/storage/azure-storage-queue/dev_requirements.txt b/sdk/storage/azure-storage-queue/dev_requirements.txt index c9cf9f09f32d..b8770ca0db3a 100644 --- a/sdk/storage/azure-storage-queue/dev_requirements.txt +++ b/sdk/storage/azure-storage-queue/dev_requirements.txt @@ -1,5 +1,6 @@ -e ../../../eng/tools/azure-sdk-tools ../../core/azure-core ../../identity/azure-identity +../azure-storage-extensions azure-mgmt-storage==20.1.0 aiohttp>=3.13.5 From 32548665b431a731aa298831be64c31deaf4ec40 Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Tue, 28 Apr 2026 16:44:02 -0700 Subject: [PATCH 11/14] [Storage] Simplify Encoder seek, fix SM streaming retry (#46564) --- .../azure/storage/blob/_shared/streams.py | 108 +++++----------- .../tests/test_content_validation.py | 44 +++++++ .../tests/test_content_validation_async.py | 45 +++++++ .../azure-storage-blob/tests/test_streams.py | 119 +++--------------- .../storage/filedatalake/_shared/streams.py | 108 +++++----------- .../storage/fileshare/_shared/streams.py | 108 +++++----------- .../azure/storage/queue/_shared/streams.py | 108 +++++----------- 7 files changed, 216 insertions(+), 424 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py index 27272fdac592..cb745693921c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -7,7 +7,7 @@ import math import sys from enum import auto, Enum, IntFlag -from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _current_region_length: int _current_region_offset: int - _checksum_offset: int - """Tracks the offset the checksum has been calculated up to for seeking purposes""" - _message_crc64: int _segment_crc64s: dict[int, int] @@ -121,7 +118,6 @@ def __init__( self._current_region_length = self._message_header_length self._current_region_offset = 0 - self._checksum_offset = 0 self._message_crc64 = 0 self._segment_crc64s = {} @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None: def __len__(self): return self.message_length + @property + def closed(self) -> bool: + return self._inner_stream.closed + def close(self) -> None: - self._inner_stream.close() - super().close() + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass def readable(self) -> bool: return True @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: if not self.seekable(): raise UnsupportedOperation("Inner stream is not seekable.") - if whence == SEEK_SET: - position = offset - elif whence == SEEK_CUR: - position = self.tell() + offset - elif whence == SEEK_END: - position = self.message_length + offset - else: - raise ValueError(f"Invalid value for whence: {whence}") - - if position < 0: - raise ValueError(f"Cannot seek to negative position: {position}") - if position > self.tell(): - raise UnsupportedOperation("This stream only supports seeking backwards.") - - # MESSAGE_HEADER - if position < self._message_header_length: - self._current_region = SMRegion.MESSAGE_HEADER - self._current_region_offset = position - self._content_offset = 0 - self._current_segment_number = 0 - # MESSAGE_FOOTER - elif position >= self.message_length - self._message_footer_length: - self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) - self._content_offset = self.content_length - self._current_segment_number = self._num_segments - else: - # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size - segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size - - # We need the size of the segment we are seeking to for some of the calculations below - new_segment_size = self._segment_size - if new_segment_num == self._num_segments: - # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size - - # SEGMENT_HEADER - if segment_pos < self._segment_header_length: - self._current_region = SMRegion.SEGMENT_HEADER - self._current_region_offset = segment_pos - self._content_offset = previous_segments_total_content_size - # SEGMENT_CONTENT - elif segment_pos < self._segment_header_length + new_segment_size: - self._current_region = SMRegion.SEGMENT_CONTENT - self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset - # SEGMENT_FOOTER - else: - self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") - self._current_segment_number = new_segment_num + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") - self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) - return position + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 def read(self, size: int = -1) -> bytes: if self.closed: # pylint: disable=using-constant-test @@ -386,14 +345,7 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> return read_size def _read_content(self, size: int, output: BytesIO) -> int: - # Will be non-zero if there is data to read that does not need to have checksum calculated. - # Will always be positive as stream can only seek backwards. - checksum_offset = self._checksum_offset - self._content_offset - read_size = min(size, self._current_region_length - self._current_region_offset) - if checksum_offset != 0: - # Only read up to checksum offset this iteration - read_size = min(read_size, checksum_offset) content = self._inner_stream.read(read_size) if len(content) != read_size: @@ -401,16 +353,12 @@ def _read_content(self, size: int, output: BytesIO) -> int: output.write(content) if StructuredMessageProperties.CRC64 in self.flags: - if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = calculate_crc64( - content, self._segment_crc64s[self._current_segment_number] - ) - self._message_crc64 = calculate_crc64(content, self._message_crc64) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size - # Only update the checksum offset if we've read new data - if self._content_offset > self._checksum_offset: - self._checksum_offset += read_size self._current_region_offset += read_size if self._current_region_offset == self._current_region_length: self._advance_region(SMRegion.SEGMENT_CONTENT) diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py index e2f65b03771e..f4e8fd759221 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -567,3 +567,47 @@ def download_hook_fail_once(response): content = downloader.read() assert download_call_count == 2 # Original + retry assert content == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_streaming_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True + ) + self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + try: + self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b'abc' * 512 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + call_count = 0 + def hook_fail_once(response): + nonlocal call_count + call_count += 1 + # Assert content validation headers are present on both attempts + assert_method(response) + if call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + # Use stage_block to test structured message streaming + blob.stage_block('1', BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) + assert call_count == 2 # Original + retry + + blob.commit_block_list([BlobBlock('1')]) + result = blob.download_blob() + assert result.read() == content diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py index 6b1114b81a44..727fa121a9fc 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -542,3 +542,48 @@ def download_hook_fail_once(response): content = await downloader.read() assert download_call_count == 2 # Original + retry assert content == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_streaming_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient, is_async=True) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True + ) + self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + try: + await self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b'abc' * 512 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Test stage_block streaming with retry + call_count = 0 + def hook_fail_once(response): + nonlocal call_count + call_count += 1 + # Assert content validation headers are present on both attempts + assert_method(response) + if call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + # Use stage_block to test structured message streaming + await blob.stage_block('1', BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) + assert call_count == 2 # Original + retry + + await blob.commit_block_list([BlobBlock('1')]) + result = await blob.download_blob() + assert await result.read() == content diff --git a/sdk/storage/azure-storage-blob/tests/test_streams.py b/sdk/storage/azure-storage-blob/tests/test_streams.py index 874c8e4a912f..e6bbb0d6414c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_streams.py +++ b/sdk/storage/azure-storage-blob/tests/test_streams.py @@ -107,10 +107,12 @@ def test_close(self): assert not stream.closed assert not inner.closed - stream.close() - assert stream.closed - assert inner.closed + stream.close() # no-op + assert not stream.closed + assert not inner.closed + assert stream.read(1) is not None + inner.close() # closing inner will block reads with pytest.raises(ValueError): stream.read(0) @@ -226,28 +228,24 @@ def test_not_seekable(self): with pytest.raises(UnsupportedOperation): sm_stream.seek(0) - def test_seek_whence(self): + def test_seek(self): data = os.urandom(10) inner_stream = BytesIO(data) sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) - # Read so we can seek backwards sm_stream.read(25) - pos = sm_stream.seek(10, SEEK_SET) - assert pos == 10 - pos = sm_stream.seek(-len(sm_stream) + 9, SEEK_END) - assert pos == 9 - pos = sm_stream.seek(-5, SEEK_CUR) - assert pos == 4 + with pytest.raises(UnsupportedOperation): + sm_stream.seek(5) - def test_seek_forward(self): - data = os.urandom(10) - inner_stream = BytesIO(data) - sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + with pytest.raises(UnsupportedOperation): + sm_stream.seek(0, SEEK_CUR) - sm_stream.read(5) with pytest.raises(UnsupportedOperation): - sm_stream.seek(10) + sm_stream.seek(0, SEEK_END) + + # Only SEEK_SET to 0 is supported + pos = sm_stream.seek(0, SEEK_SET) + assert pos == 0 @pytest.mark.parametrize("initial_read, segment_size, flags", [ # Single segment @@ -290,74 +288,6 @@ def test_seek_reverse_beginning(self, initial_read, segment_size, flags): result = sm_stream.read() assert result == expected - @pytest.mark.parametrize("initial_read, seek_offset, segment_size, flags", [ - # Single segment - (10, 5, 2048, StructuredMessageProperties.NONE), # Message header -> Message header - (10, 5, 2048, StructuredMessageProperties.CRC64), - (20, 15, 2048, StructuredMessageProperties.NONE), # Segment header -> Segment header - (20, 15, 2048, StructuredMessageProperties.CRC64), - (100, 50, 2048, StructuredMessageProperties.NONE), # First segment content -> First segment content - (100, 50, 2048, StructuredMessageProperties.CRC64), - (1000, 900, 2048, StructuredMessageProperties.NONE), # Second segment content -> Second segment content - (1000, 900, 2048, StructuredMessageProperties.CRC64), - (530, 525, 2048, StructuredMessageProperties.CRC64), # Segment footer -> Segment footer - (1060, 1050, 2048, StructuredMessageProperties.CRC64), # Message footer -> Segment footer - (1000, 100, 2048, StructuredMessageProperties.NONE), # Second segment content -> First segment content - (1000, 100, 2048, StructuredMessageProperties.CRC64), - (1000, 20, 2048, StructuredMessageProperties.NONE), # Second segment content -> First segment header - (1000, 20, 2048, StructuredMessageProperties.CRC64), - (1000, 530, 2048, StructuredMessageProperties.CRC64), # Second segment content -> First segment footer - (1097, 100, 2048, StructuredMessageProperties.CRC64), # Message footer -> First segment content - # Multiple segments - (10, 5, 500, StructuredMessageProperties.NONE), # Message header -> Message header - (10, 5, 500, StructuredMessageProperties.CRC64), - (20, 15, 500, StructuredMessageProperties.NONE), # Segment header -> Segment header - (20, 15, 500, StructuredMessageProperties.CRC64), - (100, 50, 500, StructuredMessageProperties.NONE), # First segment content -> First segment content - (100, 50, 500, StructuredMessageProperties.CRC64), - (1000, 900, 500, StructuredMessageProperties.NONE), # Second segment content -> Second segment content - (1000, 900, 500, StructuredMessageProperties.CRC64), - (530, 525, 500, StructuredMessageProperties.CRC64), # Segment footer -> Segment footer - (1097, 1090, 500, StructuredMessageProperties.CRC64), # Message footer -> Segment footer - (1000, 100, 500, StructuredMessageProperties.NONE), # Second segment content -> First segment content - (1000, 100, 500, StructuredMessageProperties.CRC64), - (1000, 20, 500, StructuredMessageProperties.NONE), # Second segment content -> First segment header - (1000, 20, 500, StructuredMessageProperties.CRC64), - (1000, 530, 500, StructuredMessageProperties.CRC64), # Second segment content -> First segment footer - (1097, 100, 500, StructuredMessageProperties.CRC64), # Message footer -> First segment content - ]) - def test_seek_reverse_middle(self, initial_read, seek_offset, segment_size, flags): - data = os.urandom(1024) - inner_stream = BytesIO(data) - sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) - expected = _build_structured_message(data, segment_size, flags)[0].getvalue() - - initial = sm_stream.read(initial_read) - assert initial == expected[:initial_read] - - sm_stream.seek(seek_offset) - result = sm_stream.read() - assert result == expected[seek_offset:] - - @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) - def test_seek_reverse_random(self, flags): - data = os.urandom(1024) - expected = _build_structured_message(data, 500, flags)[0].getvalue() - - for _ in range(10): - inner_stream = BytesIO(data) - sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=500) - - initial_read = random.randint(5, len(data)) - seek_offset = random.randint(0, initial_read) - - initial = sm_stream.read(initial_read) - assert initial == expected[:initial_read] - - sm_stream.seek(seek_offset) - result = sm_stream.read() - assert result == expected[seek_offset:] - @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) def test_partial_stream_read(self, flags): data = os.urandom(1024) @@ -390,26 +320,7 @@ def test_partial_stream_seek_beginning(self, flags): result = sm_stream.read() assert result == expected - @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) - def test_partial_stream_seek_middle(self, flags): - data = os.urandom(1024) - partial_read = 100 - - inner_stream = BytesIO(data) - inner_stream.seek(partial_read) - expected = _build_structured_message(data[partial_read:], 500, flags)[0].getvalue() - sm_stream = StructuredMessageEncodeStream(inner_stream, len(data) - partial_read, flags, segment_size=500) - initial = sm_stream.read(501) - assert initial == expected[:501] - - sm_stream.seek(100) - assert inner_stream.tell() == partial_read + (100 - - StructuredMessageConstants.V1_HEADER_LENGTH - - StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH) - - result = sm_stream.read() - assert result == expected[100:] class TestStructuredMessageDecoder: diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py index 27272fdac592..cb745693921c 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py @@ -7,7 +7,7 @@ import math import sys from enum import auto, Enum, IntFlag -from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _current_region_length: int _current_region_offset: int - _checksum_offset: int - """Tracks the offset the checksum has been calculated up to for seeking purposes""" - _message_crc64: int _segment_crc64s: dict[int, int] @@ -121,7 +118,6 @@ def __init__( self._current_region_length = self._message_header_length self._current_region_offset = 0 - self._checksum_offset = 0 self._message_crc64 = 0 self._segment_crc64s = {} @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None: def __len__(self): return self.message_length + @property + def closed(self) -> bool: + return self._inner_stream.closed + def close(self) -> None: - self._inner_stream.close() - super().close() + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass def readable(self) -> bool: return True @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: if not self.seekable(): raise UnsupportedOperation("Inner stream is not seekable.") - if whence == SEEK_SET: - position = offset - elif whence == SEEK_CUR: - position = self.tell() + offset - elif whence == SEEK_END: - position = self.message_length + offset - else: - raise ValueError(f"Invalid value for whence: {whence}") - - if position < 0: - raise ValueError(f"Cannot seek to negative position: {position}") - if position > self.tell(): - raise UnsupportedOperation("This stream only supports seeking backwards.") - - # MESSAGE_HEADER - if position < self._message_header_length: - self._current_region = SMRegion.MESSAGE_HEADER - self._current_region_offset = position - self._content_offset = 0 - self._current_segment_number = 0 - # MESSAGE_FOOTER - elif position >= self.message_length - self._message_footer_length: - self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) - self._content_offset = self.content_length - self._current_segment_number = self._num_segments - else: - # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size - segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size - - # We need the size of the segment we are seeking to for some of the calculations below - new_segment_size = self._segment_size - if new_segment_num == self._num_segments: - # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size - - # SEGMENT_HEADER - if segment_pos < self._segment_header_length: - self._current_region = SMRegion.SEGMENT_HEADER - self._current_region_offset = segment_pos - self._content_offset = previous_segments_total_content_size - # SEGMENT_CONTENT - elif segment_pos < self._segment_header_length + new_segment_size: - self._current_region = SMRegion.SEGMENT_CONTENT - self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset - # SEGMENT_FOOTER - else: - self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") - self._current_segment_number = new_segment_num + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") - self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) - return position + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 def read(self, size: int = -1) -> bytes: if self.closed: # pylint: disable=using-constant-test @@ -386,14 +345,7 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> return read_size def _read_content(self, size: int, output: BytesIO) -> int: - # Will be non-zero if there is data to read that does not need to have checksum calculated. - # Will always be positive as stream can only seek backwards. - checksum_offset = self._checksum_offset - self._content_offset - read_size = min(size, self._current_region_length - self._current_region_offset) - if checksum_offset != 0: - # Only read up to checksum offset this iteration - read_size = min(read_size, checksum_offset) content = self._inner_stream.read(read_size) if len(content) != read_size: @@ -401,16 +353,12 @@ def _read_content(self, size: int, output: BytesIO) -> int: output.write(content) if StructuredMessageProperties.CRC64 in self.flags: - if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = calculate_crc64( - content, self._segment_crc64s[self._current_segment_number] - ) - self._message_crc64 = calculate_crc64(content, self._message_crc64) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size - # Only update the checksum offset if we've read new data - if self._content_offset > self._checksum_offset: - self._checksum_offset += read_size self._current_region_offset += read_size if self._current_region_offset == self._current_region_length: self._advance_region(SMRegion.SEGMENT_CONTENT) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py index 27272fdac592..cb745693921c 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py @@ -7,7 +7,7 @@ import math import sys from enum import auto, Enum, IntFlag -from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _current_region_length: int _current_region_offset: int - _checksum_offset: int - """Tracks the offset the checksum has been calculated up to for seeking purposes""" - _message_crc64: int _segment_crc64s: dict[int, int] @@ -121,7 +118,6 @@ def __init__( self._current_region_length = self._message_header_length self._current_region_offset = 0 - self._checksum_offset = 0 self._message_crc64 = 0 self._segment_crc64s = {} @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None: def __len__(self): return self.message_length + @property + def closed(self) -> bool: + return self._inner_stream.closed + def close(self) -> None: - self._inner_stream.close() - super().close() + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass def readable(self) -> bool: return True @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: if not self.seekable(): raise UnsupportedOperation("Inner stream is not seekable.") - if whence == SEEK_SET: - position = offset - elif whence == SEEK_CUR: - position = self.tell() + offset - elif whence == SEEK_END: - position = self.message_length + offset - else: - raise ValueError(f"Invalid value for whence: {whence}") - - if position < 0: - raise ValueError(f"Cannot seek to negative position: {position}") - if position > self.tell(): - raise UnsupportedOperation("This stream only supports seeking backwards.") - - # MESSAGE_HEADER - if position < self._message_header_length: - self._current_region = SMRegion.MESSAGE_HEADER - self._current_region_offset = position - self._content_offset = 0 - self._current_segment_number = 0 - # MESSAGE_FOOTER - elif position >= self.message_length - self._message_footer_length: - self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) - self._content_offset = self.content_length - self._current_segment_number = self._num_segments - else: - # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size - segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size - - # We need the size of the segment we are seeking to for some of the calculations below - new_segment_size = self._segment_size - if new_segment_num == self._num_segments: - # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size - - # SEGMENT_HEADER - if segment_pos < self._segment_header_length: - self._current_region = SMRegion.SEGMENT_HEADER - self._current_region_offset = segment_pos - self._content_offset = previous_segments_total_content_size - # SEGMENT_CONTENT - elif segment_pos < self._segment_header_length + new_segment_size: - self._current_region = SMRegion.SEGMENT_CONTENT - self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset - # SEGMENT_FOOTER - else: - self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") - self._current_segment_number = new_segment_num + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") - self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) - return position + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 def read(self, size: int = -1) -> bytes: if self.closed: # pylint: disable=using-constant-test @@ -386,14 +345,7 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> return read_size def _read_content(self, size: int, output: BytesIO) -> int: - # Will be non-zero if there is data to read that does not need to have checksum calculated. - # Will always be positive as stream can only seek backwards. - checksum_offset = self._checksum_offset - self._content_offset - read_size = min(size, self._current_region_length - self._current_region_offset) - if checksum_offset != 0: - # Only read up to checksum offset this iteration - read_size = min(read_size, checksum_offset) content = self._inner_stream.read(read_size) if len(content) != read_size: @@ -401,16 +353,12 @@ def _read_content(self, size: int, output: BytesIO) -> int: output.write(content) if StructuredMessageProperties.CRC64 in self.flags: - if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = calculate_crc64( - content, self._segment_crc64s[self._current_segment_number] - ) - self._message_crc64 = calculate_crc64(content, self._message_crc64) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size - # Only update the checksum offset if we've read new data - if self._content_offset > self._checksum_offset: - self._checksum_offset += read_size self._current_region_offset += read_size if self._current_region_offset == self._current_region_length: self._advance_region(SMRegion.SEGMENT_CONTENT) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py index 27272fdac592..cb745693921c 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py @@ -7,7 +7,7 @@ import math import sys from enum import auto, Enum, IntFlag -from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _current_region_length: int _current_region_offset: int - _checksum_offset: int - """Tracks the offset the checksum has been calculated up to for seeking purposes""" - _message_crc64: int _segment_crc64s: dict[int, int] @@ -121,7 +118,6 @@ def __init__( self._current_region_length = self._message_header_length self._current_region_offset = 0 - self._checksum_offset = 0 self._message_crc64 = 0 self._segment_crc64s = {} @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None: def __len__(self): return self.message_length + @property + def closed(self) -> bool: + return self._inner_stream.closed + def close(self) -> None: - self._inner_stream.close() - super().close() + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass def readable(self) -> bool: return True @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: if not self.seekable(): raise UnsupportedOperation("Inner stream is not seekable.") - if whence == SEEK_SET: - position = offset - elif whence == SEEK_CUR: - position = self.tell() + offset - elif whence == SEEK_END: - position = self.message_length + offset - else: - raise ValueError(f"Invalid value for whence: {whence}") - - if position < 0: - raise ValueError(f"Cannot seek to negative position: {position}") - if position > self.tell(): - raise UnsupportedOperation("This stream only supports seeking backwards.") - - # MESSAGE_HEADER - if position < self._message_header_length: - self._current_region = SMRegion.MESSAGE_HEADER - self._current_region_offset = position - self._content_offset = 0 - self._current_segment_number = 0 - # MESSAGE_FOOTER - elif position >= self.message_length - self._message_footer_length: - self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) - self._content_offset = self.content_length - self._current_segment_number = self._num_segments - else: - # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size - segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size - - # We need the size of the segment we are seeking to for some of the calculations below - new_segment_size = self._segment_size - if new_segment_num == self._num_segments: - # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size - - # SEGMENT_HEADER - if segment_pos < self._segment_header_length: - self._current_region = SMRegion.SEGMENT_HEADER - self._current_region_offset = segment_pos - self._content_offset = previous_segments_total_content_size - # SEGMENT_CONTENT - elif segment_pos < self._segment_header_length + new_segment_size: - self._current_region = SMRegion.SEGMENT_CONTENT - self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset - # SEGMENT_FOOTER - else: - self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") - self._current_segment_number = new_segment_num + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") - self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) - return position + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 def read(self, size: int = -1) -> bytes: if self.closed: # pylint: disable=using-constant-test @@ -386,14 +345,7 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> return read_size def _read_content(self, size: int, output: BytesIO) -> int: - # Will be non-zero if there is data to read that does not need to have checksum calculated. - # Will always be positive as stream can only seek backwards. - checksum_offset = self._checksum_offset - self._content_offset - read_size = min(size, self._current_region_length - self._current_region_offset) - if checksum_offset != 0: - # Only read up to checksum offset this iteration - read_size = min(read_size, checksum_offset) content = self._inner_stream.read(read_size) if len(content) != read_size: @@ -401,16 +353,12 @@ def _read_content(self, size: int, output: BytesIO) -> int: output.write(content) if StructuredMessageProperties.CRC64 in self.flags: - if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = calculate_crc64( - content, self._segment_crc64s[self._current_segment_number] - ) - self._message_crc64 = calculate_crc64(content, self._message_crc64) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size - # Only update the checksum offset if we've read new data - if self._content_offset > self._checksum_offset: - self._checksum_offset += read_size self._current_region_offset += read_size if self._current_region_offset == self._current_region_length: self._advance_region(SMRegion.SEGMENT_CONTENT) From 82f70d9803660a5c386992aed83a804bb0951bf3 Mon Sep 17 00:00:00 2001 From: Jacob Lauzon <96087589+jalauzon-msft@users.noreply.github.com> Date: Tue, 19 May 2026 20:33:33 +0000 Subject: [PATCH 12/14] [Storage] Cleanup and prepare content validation for merge to main (#46971) --- sdk/storage/azure-storage-blob/assets.json | 2 +- .../tests/test_block_blob_async.py | 2 +- .../tests/test_content_validation.py | 288 ++++++++++-------- .../tests/test_content_validation_async.py | 287 +++++++++-------- .../azure-storage-blob/tests/test_streams.py | 8 +- .../tests/test_streams_async.py | 5 +- .../azure-storage-file-datalake/assets.json | 2 +- .../tests/test_content_validation.py | 112 +++---- .../tests/test_content_validation_async.py | 121 ++++---- .../tests/test_content_validation.py | 97 +++--- .../tests/test_content_validation_async.py | 94 +++--- 11 files changed, 565 insertions(+), 453 deletions(-) diff --git a/sdk/storage/azure-storage-blob/assets.json b/sdk/storage/azure-storage-blob/assets.json index 007e0506d1c8..93dd0f1e298c 100644 --- a/sdk/storage/azure-storage-blob/assets.json +++ b/sdk/storage/azure-storage-blob/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/storage/azure-storage-blob", - "Tag": "python/storage/azure-storage-blob_28cfcca089" + "Tag": "python/storage/azure-storage-blob_b09e37b521" } diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py index 17278ca4dfb3..4f672d5e1b9f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py @@ -36,7 +36,7 @@ ImmutabilityPolicy, StandardBlobTier, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 from azure.storage.blob.aio import BlobClient, BlobServiceClient diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py index f4e8fd759221..c4327291e3e0 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -7,48 +7,51 @@ from io import BytesIO import pytest -from azure.core.exceptions import ResourceExistsError -from azure.storage.blob import ( - BlobBlock, - BlobClient, - BlobServiceClient, - BlobType, - ContainerClient -) from devtools_testutils import is_live, recorded_by_proxy from devtools_testutils.storage import ( GenericTestProxyParametrize1, GenericTestProxyParametrize2, - StorageRecordedTestCase + StorageRecordedTestCase, ) - from encryption_test_helper import KeyWrapper from settings.testcase import BlobPreparer +from azure.core.exceptions import ResourceExistsError +from azure.storage.blob import BlobBlock, BlobClient, BlobServiceClient, BlobType, ContainerClient + def assert_content_md5(request): - if request.http_request.query.get('comp') in ('block', 'page') or request.http_request.headers.get('x-ms-blob-type') == 'BlockBlob': - assert request.http_request.headers.get('Content-MD5') is not None + if ( + request.http_request.query.get("comp") in ("block", "page") + or request.http_request.headers.get("x-ms-blob-type") == "BlockBlob" + ): + assert request.http_request.headers.get("Content-MD5") is not None def assert_content_md5_get(response): - assert response.http_request.headers.get('x-ms-range-get-content-md5') == 'true' - assert response.http_response.headers.get('Content-MD5') is not None + assert response.http_request.headers.get("x-ms-range-get-content-md5") == "true" + assert response.http_response.headers.get("Content-MD5") is not None def assert_content_crc64(request): - if request.http_request.query.get('comp') in ('block', 'page') or request.http_request.headers.get('x-ms-blob-type') == 'BlockBlob': - assert request.http_request.headers.get('x-ms-content-crc64') is not None + if ( + request.http_request.query.get("comp") in ("block", "page") + or request.http_request.headers.get("x-ms-blob-type") == "BlockBlob" + ): + assert request.http_request.headers.get("x-ms-content-crc64") is not None def assert_structured_message(request): - if request.http_request.query.get('comp') in ('block', 'page') or request.http_request.headers.get('x-ms-blob-type') == 'BlockBlob': - assert request.http_request.headers.get('x-ms-structured-body') is not None + if ( + request.http_request.query.get("comp") in ("block", "page") + or request.http_request.headers.get("x-ms-blob-type") == "BlockBlob" + ): + assert request.http_request.headers.get("x-ms-structured-body") is not None def assert_structured_message_get(response): - assert response.http_request.headers.get('x-ms-structured-body') is not None - assert response.http_response.headers.get('x-ms-structured-body') is not None + assert response.http_request.headers.get("x-ms-structured-body") is not None + assert response.http_response.headers.get("x-ms-structured-body") is not None class TestIter: @@ -68,7 +71,7 @@ def __next__(self): if self.offset >= self.length: raise StopIteration - result = self.data[self.offset: self.offset + self.chunk_size] + result = self.data[self.offset : self.offset + self.chunk_size] self.offset += len(result) return result @@ -80,7 +83,7 @@ class TestStorageContentValidation(StorageRecordedTestCase): def _setup(self, account_name): token_credential = self.get_credential(BlobServiceClient) self.bsc = BlobServiceClient(self.account_url(account_name, "blob"), token_credential, logging_enable=True) - self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) try: self.container.create_container() except ResourceExistsError: @@ -94,31 +97,32 @@ def teardown_method(self, _): pass def _get_blob_reference(self): - return self.get_resource_name('blob') + return self.get_resource_name("blob") @BlobPreparer() def test_encryption_blocked_crc64(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") blob = BlobClient( self.account_url(storage_account_name, "blob"), "testing", "testing", credential=self.get_credential(BlobServiceClient), require_encryption=True, - encryption_version='2.0', - key_encryption_key=kek) + encryption_version="2.0", + key_encryption_key=kek, + ) with pytest.raises(ValueError): - blob.upload_blob(b'123', validate_content='crc64') + blob.upload_blob(b"123", validate_content="crc64") # Needed for teardown self.container = None @BlobPreparer() - @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type - @pytest.mark.parametrize('b', [True, 'auto','md5', 'crc64']) # b: validate_content + @pytest.mark.parametrize("a", [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize("b", [True, "auto", "md5", "crc64"]) # b: validate_content @GenericTestProxyParametrize2() @recorded_by_proxy def test_upload_blob(self, a, b, **kwargs): @@ -126,30 +130,40 @@ def test_upload_blob(self, a, b, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if b in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_content_crc64 if b in ("auto", "crc64") else assert_content_md5 # Test supported data types - byte_data = b'abc' * 512 + byte_data = b"abc" * 512 str_data = "你好世界abcd" * 32 - str_data_encoded = str_data.encode('utf-8') + str_data_encoded = str_data.encode("utf-8") byte_stream = BytesIO(byte_data) byte_iter = TestIter(byte_data) str_iter = TestIter(str_data) blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) assert blob.download_blob().read() == byte_data - blob.upload_blob(str_data, blob_type=a, encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + blob.upload_blob( + str_data, blob_type=a, encoding="utf-8", validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert blob.download_blob().read() == str_data_encoded blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) assert blob.download_blob().read() == byte_data blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) assert blob.download_blob().read() == byte_data - blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + blob.upload_blob( + str_iter, + blob_type=a, + length=len(str_data_encoded), + encoding="utf-8", + validate_content=b, + overwrite=True, + raw_request_hook=assert_method, + ) assert blob.download_blob().read() == str_data_encoded @BlobPreparer() - @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type - @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @pytest.mark.parametrize("a", [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize("b", [True, "md5", "crc64"]) # b: validate_content @GenericTestProxyParametrize2() @recorded_by_proxy def test_upload_blob_chunks(self, a, b, **kwargs): @@ -160,54 +174,64 @@ def test_upload_blob_chunks(self, a, b, **kwargs): self.container._config.max_block_size = 512 self.container._config.max_page_size = 512 blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if b == "crc64" else assert_content_md5 # Test supported data types - byte_data = b'abc' * 512 + byte_data = b"abc" * 512 str_data = "你好世界abcd" * 32 - str_data_encoded = str_data.encode('utf-8') + str_data_encoded = str_data.encode("utf-8") byte_stream = BytesIO(byte_data) byte_iter = TestIter(byte_data) str_iter = TestIter(str_data) blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) assert blob.download_blob().read() == byte_data - blob.upload_blob(str_data, blob_type=a, encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + blob.upload_blob( + str_data, blob_type=a, encoding="utf-8", validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert blob.download_blob().read() == str_data_encoded blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) assert blob.download_blob().read() == byte_data blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) assert blob.download_blob().read() == byte_data - blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + blob.upload_blob( + str_iter, + blob_type=a, + length=len(str_data_encoded), + encoding="utf-8", + validate_content=b, + overwrite=True, + raw_request_hook=assert_method, + ) assert blob.download_blob().read() == str_data_encoded @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_blob_substream(self, a, **kwargs): # Substream is disabled when using content validation so this will behave like regular upload (buffer) storage_account_name = kwargs.pop("storage_account_name") - + self._setup(storage_account_name) self.container._config.max_single_put_size = 512 self.container._config.max_block_size = 512 self.container._config.min_large_block_upload_threshold = 1 # Set less than block size to enable substream blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 - - data = b'abc' * 512 + b'abcde' + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + + data = b"abc" * 512 + b"abcde" io = BytesIO(data) - + # Act blob.upload_blob(io, validate_content=a, raw_request_hook=assert_method) - + # Assert content = blob.download_blob() assert content.read() == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_stage_block(self, a, **kwargs): @@ -215,28 +239,28 @@ def test_stage_block(self, a, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abc' * 512 - data2 = '你好世界' * 10 + data1 = b"abc" * 512 + data2 = "你好世界" * 10 # An iterable with no length will be read into bytes and therefore will behave like # bytes when it comes to testing content validation. def generator(): for i in range(0, len(data1), 500): - yield data1[i: i + 500] + yield data1[i : i + 500] - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 - blob.stage_block('1', data1, validate_content=a, raw_request_hook=assert_method) - blob.stage_block('2', data2, encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) - blob.stage_block('3', generator(), validate_content=a, raw_request_hook=assert_method) - blob.commit_block_list([BlobBlock('1'), BlobBlock('2'), BlobBlock('3')]) + blob.stage_block("1", data1, validate_content=a, raw_request_hook=assert_method) + blob.stage_block("2", data2, encoding="utf-8-sig", validate_content=a, raw_request_hook=assert_method) + blob.stage_block("3", generator(), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock("1"), BlobBlock("2"), BlobBlock("3")]) # Assert content = blob.download_blob() - assert content.read() == data1 + data2.encode('utf-8-sig') + data1 + assert content.read() == data1 + data2.encode("utf-8-sig") + data1 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_stage_block_streaming(self, a, **kwargs): @@ -245,17 +269,17 @@ def test_stage_block_streaming(self, a, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - content = b'abcde' * 1030 # 5 KiB + 30 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 - blob.stage_block('1', BytesIO(content), validate_content=a, raw_request_hook=assert_method) - blob.commit_block_list([BlobBlock('1')]) + blob.stage_block("1", BytesIO(content), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock("1")]) result = blob.download_blob() assert result.read() == content @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @pytest.mark.live_test_only def test_stage_block_streaming_large(self, a, **kwargs): storage_account_name = kwargs.pop("storage_account_name") @@ -263,21 +287,21 @@ def test_stage_block_streaming_large(self, a, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abcde' * 1024 * 1024 # 5 MiB - data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 - data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 - blob.stage_block('1', BytesIO(data1), validate_content=a, raw_request_hook=assert_method) - blob.stage_block('2', BytesIO(data2), validate_content=a, raw_request_hook=assert_method) - blob.stage_block('3', BytesIO(data3), validate_content=a, raw_request_hook=assert_method) - blob.commit_block_list([BlobBlock('1'), BlobBlock('2'), BlobBlock('3')]) + blob.stage_block("1", BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + blob.stage_block("2", BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + blob.stage_block("3", BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock("1"), BlobBlock("2"), BlobBlock("3")]) result = blob.download_blob() assert result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_append_block(self, a, **kwargs): @@ -285,27 +309,27 @@ def test_append_block(self, a, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abc' * 512 - data2 = '你好世界' * 10 + data1 = b"abc" * 512 + data2 = "你好世界" * 10 # An iterable with no length will be read into bytes and therefore will behave like # bytes when it comes to testing content validation. def generator(): for i in range(0, len(data1), 500): - yield data1[i: i + 500] + yield data1[i : i + 500] - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 blob.create_append_blob() blob.append_block(data1, validate_content=a, raw_request_hook=assert_method) - blob.append_block(data2, encoding='utf-16', validate_content=a, raw_request_hook=assert_method) + blob.append_block(data2, encoding="utf-16", validate_content=a, raw_request_hook=assert_method) blob.append_block(generator(), validate_content=a, raw_request_hook=assert_method) content = blob.download_blob() - assert content.read() == data1 + data2.encode('utf-16') + data1 + assert content.read() == data1 + data2.encode("utf-16") + data1 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_append_block_streaming(self, a, **kwargs): @@ -314,8 +338,8 @@ def test_append_block_streaming(self, a, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - content = b'abcde' * 1030 # 5 KiB + 30 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 blob.create_append_blob() blob.append_block(BytesIO(content), validate_content=a, raw_request_hook=assert_method) @@ -324,7 +348,7 @@ def test_append_block_streaming(self, a, **kwargs): assert result.read() == content @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @pytest.mark.live_test_only def test_append_block_streaming_large(self, a, **kwargs): storage_account_name = kwargs.pop("storage_account_name") @@ -332,10 +356,10 @@ def test_append_block_streaming_large(self, a, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abcde' * 1024 * 1024 # 5 MiB - data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 - data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 blob.create_append_blob() blob.append_block(BytesIO(data1), validate_content=a, raw_request_hook=assert_method) @@ -346,7 +370,7 @@ def test_append_block_streaming_large(self, a, **kwargs): assert result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_page(self, a, **kwargs): @@ -354,22 +378,29 @@ def test_upload_page(self, a, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abc' * 512 + data1 = b"abc" * 512 data2 = "你好世界abcd" * 32 - data2_encoded = data2.encode('utf-8') - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + data2_encoded = data2.encode("utf-8") + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 # Act blob.create_page_blob(5 * 1024) blob.upload_page(data1, offset=0, length=len(data1), validate_content=a, raw_request_hook=assert_method) - blob.upload_page(data2, offset=len(data1), length=len(data2_encoded), encoding='utf-8', validate_content=a, raw_request_hook=assert_method) + blob.upload_page( + data2, + offset=len(data1), + length=len(data2_encoded), + encoding="utf-8", + validate_content=a, + raw_request_hook=assert_method, + ) # Assert content = blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) assert content.read() == data1 + data2_encoded @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_blob(self, a, **kwargs): @@ -377,9 +408,9 @@ def test_download_blob(self, a, **kwargs): self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data = b'abc' * 512 + data = b"abc" * 512 blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get # Act downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) @@ -395,7 +426,7 @@ def test_download_blob(self, a, **kwargs): assert stream.read() == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_blob_chunks(self, a, **kwargs): @@ -405,9 +436,9 @@ def test_download_blob_chunks(self, a, **kwargs): self.container._config.max_single_get_size = 512 self.container._config.max_chunk_get_size = 512 blob = self.container.get_blob_client(self._get_blob_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) @@ -418,7 +449,7 @@ def test_download_blob_chunks(self, a, **kwargs): downloader.readinto(stream) stream.seek(0) - read_content = b'' + read_content = b"" downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) for _ in range(len(data) // 100 + 1): read_content += downloader.read(100) @@ -429,7 +460,7 @@ def test_download_blob_chunks(self, a, **kwargs): assert read_content == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_blob_chunks_partial(self, a, **kwargs): @@ -439,9 +470,9 @@ def test_download_blob_chunks_partial(self, a, **kwargs): self.container._config.max_single_get_size = 512 self.container._config.max_chunk_get_size = 512 blob = self.container.get_blob_client(self._get_blob_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) @@ -465,22 +496,22 @@ def test_download_blob_large_chunks(self, **kwargs): blob = self.container.get_blob_client(self._get_blob_reference()) # The service will use 4 MiB for structured message chunk size, so make chunk size larger self.container._config.max_chunk_get_size = 10 * 1024 * 1024 - data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 blob.upload_blob(data, overwrite=True, max_concurrency=5) # Act - downloader = blob.download_blob(validate_content='crc64', max_concurrency=5) + downloader = blob.download_blob(validate_content="crc64", max_concurrency=5) content = downloader.read() - downloader = blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + downloader = blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") partial = downloader.read() # Assert assert content == data - assert partial == data[5 * 1024 * 1024: 30 * 1024 * 1024] + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_blob_chars(self, a, **kwargs): @@ -490,18 +521,18 @@ def test_download_blob_chars(self, a, **kwargs): self.container._config.max_single_get_size = 512 self.container._config.max_chunk_get_size = 512 - data = '你好世界' * 256 # 3 KiB + data = "你好世界" * 256 # 3 KiB blob = self.container.get_blob_client(self._get_blob_reference()) - blob.upload_blob(data, encoding='utf-8', overwrite=True) + blob.upload_blob(data, encoding="utf-8", overwrite=True) - stream = blob.download_blob(encoding='utf-8', validate_content=a) + stream = blob.download_blob(encoding="utf-8", validate_content=a) assert stream.read() == data - stream = blob.download_blob(encoding='utf-8', validate_content=a) + stream = blob.download_blob(encoding="utf-8", validate_content=a) assert stream.read(chars=100000) == data - result = '' - stream = blob.download_blob(encoding='utf-8', validate_content=a) + result = "" + stream = blob.download_blob(encoding="utf-8", validate_content=a) for _ in range(4): chunk = stream.read(chars=100) result += chunk @@ -511,7 +542,7 @@ def test_download_blob_chars(self, a, **kwargs): assert result == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_content_validation_with_retry(self, a, **kwargs): @@ -525,22 +556,23 @@ def test_content_validation_with_retry(self, a, **kwargs): retry_total=1, initial_backoff=0.1, increment_base=0.1, - logging_enable=True + logging_enable=True, ) - self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) try: self.container.create_container() except ResourceExistsError: pass blob = self.container.get_blob_client(self._get_blob_reference()) - data = b'abc' * 512 + data = b"abc" * 512 # Determine the appropriate assert methods based on validation mode - upload_assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 - download_assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + upload_assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + download_assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Test upload with retry upload_call_count = 0 + def upload_hook_fail_once(response): nonlocal upload_call_count upload_call_count += 1 @@ -555,6 +587,7 @@ def upload_hook_fail_once(response): # Test download with retry download_call_count = 0 + def download_hook_fail_once(response): nonlocal download_call_count download_call_count += 1 @@ -569,7 +602,7 @@ def download_hook_fail_once(response): assert content == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_streaming_with_retry(self, a, **kwargs): @@ -583,19 +616,20 @@ def test_streaming_with_retry(self, a, **kwargs): retry_total=1, initial_backoff=0.1, increment_base=0.1, - logging_enable=True + logging_enable=True, ) - self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) try: self.container.create_container() except ResourceExistsError: pass blob = self.container.get_blob_client(self._get_blob_reference()) - content = b'abc' * 512 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + content = b"abc" * 512 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 call_count = 0 + def hook_fail_once(response): nonlocal call_count call_count += 1 @@ -605,9 +639,9 @@ def hook_fail_once(response): response.http_response.status_code = 408 # Request Timeout - triggers retry # Use stage_block to test structured message streaming - blob.stage_block('1', BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) + blob.stage_block("1", BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) assert call_count == 2 # Original + retry - - blob.commit_block_list([BlobBlock('1')]) + + blob.commit_block_list([BlobBlock("1")]) result = blob.download_blob() assert result.read() == content diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py index 727fa121a9fc..e22a03b1fe75 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -7,32 +7,28 @@ from io import BytesIO import pytest -from azure.core.exceptions import ResourceExistsError -from azure.storage.blob import BlobBlock, BlobType, ContainerClient as SyncContainerClient -from azure.storage.blob.aio import ( - BlobClient, - BlobServiceClient, - ContainerClient -) from devtools_testutils import is_live from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import ( AsyncStorageRecordedTestCase, GenericTestProxyParametrize1, - GenericTestProxyParametrize2 + GenericTestProxyParametrize2, ) from encryption_test_helper import KeyWrapper from settings.testcase import BlobPreparer - from test_content_validation import ( assert_content_crc64, assert_content_md5, assert_content_md5_get, assert_structured_message, assert_structured_message_get, - TestIter + TestIter, ) +from azure.core.exceptions import ResourceExistsError +from azure.storage.blob import BlobBlock, BlobType, ContainerClient as SyncContainerClient +from azure.storage.blob.aio import BlobClient, BlobServiceClient, ContainerClient + class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): bsc: BlobServiceClient @@ -41,7 +37,7 @@ class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): async def _setup(self, account_name): token_credential = self.get_credential(BlobServiceClient, is_async=True) self.bsc = BlobServiceClient(self.account_url(account_name, "blob"), token_credential, logging_enable=True) - self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) try: await self.container.create_container() except ResourceExistsError: @@ -53,9 +49,7 @@ def teardown_method(self, _): # Use sync client as teardown_method must be sync if self.container: sync_credential = self.get_credential(SyncContainerClient, is_async=False) - sync_container = SyncContainerClient.from_container_url( - self.container.url, - credential=sync_credential) + sync_container = SyncContainerClient.from_container_url(self.container.url, credential=sync_credential) try: sync_container.delete_container() @@ -63,31 +57,32 @@ def teardown_method(self, _): pass def _get_blob_reference(self): - return self.get_resource_name('blob') + return self.get_resource_name("blob") @BlobPreparer() async def test_encryption_blocked_crc64(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") - kek = KeyWrapper('key1') + kek = KeyWrapper("key1") blob = BlobClient( self.account_url(storage_account_name, "blob"), "testing", "testing", credential=self.get_credential(BlobServiceClient, is_async=True), require_encryption=True, - encryption_version='2.0', - key_encryption_key=kek) + encryption_version="2.0", + key_encryption_key=kek, + ) with pytest.raises(ValueError): - await blob.upload_blob(b'123', validate_content='crc64') + await blob.upload_blob(b"123", validate_content="crc64") # Needed for teardown self.container = None @BlobPreparer() - @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type - @pytest.mark.parametrize('b', [True, "auto", 'md5', 'crc64']) # b: validate_content + @pytest.mark.parametrize("a", [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize("b", [True, "auto", "md5", "crc64"]) # b: validate_content @GenericTestProxyParametrize2() @recorded_by_proxy_async async def test_upload_blob(self, a, b, **kwargs): @@ -95,31 +90,47 @@ async def test_upload_blob(self, a, b, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if b in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_content_crc64 if b in ("auto", "crc64") else assert_content_md5 # Test supported data types - byte_data = b'abc' * 512 + byte_data = b"abc" * 512 str_data = "你好世界abcd" * 32 - str_data_encoded = str_data.encode('utf-8') + str_data_encoded = str_data.encode("utf-8") byte_stream = BytesIO(byte_data) byte_iter = TestIter(byte_data) str_iter = TestIter(str_data) # Act / Assert - await blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert await (await blob.download_blob()).read() == byte_data - await blob.upload_blob(str_data, blob_type=a, encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + str_data, blob_type=a, encoding="utf-8", validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert await (await blob.download_blob()).read() == str_data_encoded - await blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert await (await blob.download_blob()).read() == byte_data - await blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert await (await blob.download_blob()).read() == byte_data - await blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + str_iter, + blob_type=a, + length=len(str_data_encoded), + encoding="utf-8", + validate_content=b, + overwrite=True, + raw_request_hook=assert_method, + ) assert await (await blob.download_blob()).read() == str_data_encoded @BlobPreparer() - @pytest.mark.parametrize('a', [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type - @pytest.mark.parametrize('b', [True, 'md5', 'crc64']) # b: validate_content + @pytest.mark.parametrize("a", [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize("b", [True, "md5", "crc64"]) # b: validate_content @GenericTestProxyParametrize2() @recorded_by_proxy_async async def test_upload_blob_chunks(self, a, b, **kwargs): @@ -130,30 +141,46 @@ async def test_upload_blob_chunks(self, a, b, **kwargs): self.container._config.max_block_size = 512 self.container._config.max_page_size = 512 blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if b == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if b == "crc64" else assert_content_md5 # Test supported data types - byte_data = b'abc' * 512 + byte_data = b"abc" * 512 str_data = "你好世界abcd" * 32 - str_data_encoded = str_data.encode('utf-8') + str_data_encoded = str_data.encode("utf-8") byte_stream = BytesIO(byte_data) byte_iter = TestIter(byte_data) str_iter = TestIter(str_data) # Act / Assert - await blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert await (await blob.download_blob()).read() == byte_data - await blob.upload_blob(str_data, blob_type=a, encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + str_data, blob_type=a, encoding="utf-8", validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert await (await blob.download_blob()).read() == str_data_encoded - await blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert await (await blob.download_blob()).read() == byte_data - await blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) assert await (await blob.download_blob()).read() == byte_data - await blob.upload_blob(str_iter, blob_type=a, length=len(str_data_encoded), encoding='utf-8', validate_content=b, overwrite=True, raw_request_hook=assert_method) + await blob.upload_blob( + str_iter, + blob_type=a, + length=len(str_data_encoded), + encoding="utf-8", + validate_content=b, + overwrite=True, + raw_request_hook=assert_method, + ) assert await (await blob.download_blob()).read() == str_data_encoded @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_blob_substream(self, a, **kwargs): @@ -165,9 +192,9 @@ async def test_upload_blob_substream(self, a, **kwargs): self.container._config.max_block_size = 512 self.container._config.min_large_block_upload_threshold = 1 # Set less than block size to enable substream blob = self.container.get_blob_client(self._get_blob_reference()) - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" io = BytesIO(data) # Act @@ -178,7 +205,7 @@ async def test_upload_blob_substream(self, a, **kwargs): assert await content.read() == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_stage_block(self, a, **kwargs): @@ -186,29 +213,29 @@ async def test_stage_block(self, a, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abc' * 512 - data2 = '你好世界' * 10 + data1 = b"abc" * 512 + data2 = "你好世界" * 10 # An iterable with no length will be read into bytes and therefore will behave like # bytes when it comes to testing content validation. def generator(): for i in range(0, len(data1), 500): - yield data1[i: i + 500] + yield data1[i : i + 500] - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 # Act - await blob.stage_block('1', data1, validate_content=a, raw_request_hook=assert_method) - await blob.stage_block('2', data2, encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) - await blob.stage_block('3', generator(), validate_content=a, raw_request_hook=assert_method) - await blob.commit_block_list([BlobBlock('1'), BlobBlock('2'), BlobBlock('3')]) + await blob.stage_block("1", data1, validate_content=a, raw_request_hook=assert_method) + await blob.stage_block("2", data2, encoding="utf-8-sig", validate_content=a, raw_request_hook=assert_method) + await blob.stage_block("3", generator(), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock("1"), BlobBlock("2"), BlobBlock("3")]) # Assert content = await blob.download_blob() - assert await content.read() == data1 + data2.encode('utf-8-sig') + data1 + assert await content.read() == data1 + data2.encode("utf-8-sig") + data1 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_stage_block_streaming(self, a, **kwargs): @@ -217,18 +244,18 @@ async def test_stage_block_streaming(self, a, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - content = b'abcde' * 1030 # 5 KiB + 30 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 - await blob.stage_block('1', BytesIO(content), validate_content=a, raw_request_hook=assert_method) - await blob.commit_block_list([BlobBlock('1')]) + await blob.stage_block("1", BytesIO(content), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock("1")]) # Assert result = await blob.download_blob() assert await result.read() == content @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @pytest.mark.live_test_only async def test_stage_block_streaming_large(self, a, **kwargs): storage_account_name = kwargs.pop("storage_account_name") @@ -236,21 +263,21 @@ async def test_stage_block_streaming_large(self, a, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abcde' * 1024 * 1024 # 5 MiB - data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 - data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 - await blob.stage_block('1', BytesIO(data1), validate_content=a, raw_request_hook=assert_method) - await blob.stage_block('2', BytesIO(data2), validate_content=a, raw_request_hook=assert_method) - await blob.stage_block('3', BytesIO(data3), validate_content=a, raw_request_hook=assert_method) - await blob.commit_block_list([BlobBlock('1'), BlobBlock('2'), BlobBlock('3')]) + await blob.stage_block("1", BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + await blob.stage_block("2", BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + await blob.stage_block("3", BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock("1"), BlobBlock("2"), BlobBlock("3")]) result = await blob.download_blob() assert await result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_append_block(self, a, **kwargs): @@ -258,29 +285,29 @@ async def test_append_block(self, a, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abc' * 512 - data2 = '你好世界' * 10 + data1 = b"abc" * 512 + data2 = "你好世界" * 10 # An iterable with no length will be read into bytes and therefore will behave like # bytes when it comes to testing content validation. def generator(): for i in range(0, len(data1), 500): - yield data1[i: i + 500] + yield data1[i : i + 500] - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 # Act await blob.create_append_blob() await blob.append_block(data1, validate_content=a, raw_request_hook=assert_method) - await blob.append_block(data2, encoding='utf-16', validate_content=a, raw_request_hook=assert_method) + await blob.append_block(data2, encoding="utf-16", validate_content=a, raw_request_hook=assert_method) await blob.append_block(generator(), validate_content=a, raw_request_hook=assert_method) # Assert content = await blob.download_blob() - assert await content.readall() == data1 + data2.encode('utf-16') + data1 + assert await content.readall() == data1 + data2.encode("utf-16") + data1 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_append_block_streaming(self, a, **kwargs): @@ -289,8 +316,8 @@ async def test_append_block_streaming(self, a, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - content = b'abcde' * 1030 # 5 KiB + 30 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 await blob.create_append_blob() await blob.append_block(BytesIO(content), validate_content=a, raw_request_hook=assert_method) @@ -299,7 +326,7 @@ async def test_append_block_streaming(self, a, **kwargs): assert await result.read() == content @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @pytest.mark.live_test_only async def test_append_block_streaming_large(self, a, **kwargs): storage_account_name = kwargs.pop("storage_account_name") @@ -307,10 +334,10 @@ async def test_append_block_streaming_large(self, a, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abcde' * 1024 * 1024 # 5 MiB - data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 - data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 await blob.create_append_blob() await blob.append_block(BytesIO(data1), validate_content=a, raw_request_hook=assert_method) @@ -321,7 +348,7 @@ async def test_append_block_streaming_large(self, a, **kwargs): assert await result.read() == data1 + data2 + data3 @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_page(self, a, **kwargs): @@ -329,22 +356,29 @@ async def test_upload_page(self, a, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data1 = b'abc' * 512 + data1 = b"abc" * 512 data2 = "你好世界abcd" * 32 - data2_encoded = data2.encode('utf-8') - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + data2_encoded = data2.encode("utf-8") + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 # Act await blob.create_page_blob(5 * 1024) await blob.upload_page(data1, offset=0, length=len(data1), validate_content=a, raw_request_hook=assert_method) - await blob.upload_page(data2, offset=len(data1), length=len(data2_encoded), encoding='utf-8', validate_content=a, raw_request_hook=assert_method) + await blob.upload_page( + data2, + offset=len(data1), + length=len(data2_encoded), + encoding="utf-8", + validate_content=a, + raw_request_hook=assert_method, + ) # Assert content = await blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) assert await content.read() == data1 + data2_encoded @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_blob(self, a, **kwargs): @@ -352,9 +386,9 @@ async def test_download_blob(self, a, **kwargs): await self._setup(storage_account_name) blob = self.container.get_blob_client(self._get_blob_reference()) - data = b'abc' * 512 + data = b"abc" * 512 await blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get # Act downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) @@ -370,7 +404,7 @@ async def test_download_blob(self, a, **kwargs): assert stream.read() == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_blob_chunks(self, a, **kwargs): @@ -380,9 +414,9 @@ async def test_download_blob_chunks(self, a, **kwargs): self.container._config.max_single_get_size = 512 self.container._config.max_chunk_get_size = 512 blob = self.container.get_blob_client(self._get_blob_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" await blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) @@ -393,7 +427,7 @@ async def test_download_blob_chunks(self, a, **kwargs): await downloader.readinto(stream) stream.seek(0) - read_content = b'' + read_content = b"" downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) for _ in range(len(data) // 100 + 1): read_content += await downloader.read(100) @@ -404,7 +438,7 @@ async def test_download_blob_chunks(self, a, **kwargs): assert read_content == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_blob_chunks_partial(self, a, **kwargs): @@ -414,16 +448,20 @@ async def test_download_blob_chunks_partial(self, a, **kwargs): self.container._config.max_single_get_size = 512 self.container._config.max_chunk_get_size = 512 blob = self.container.get_blob_client(self._get_blob_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" await blob.upload_blob(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act - downloader = await blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + downloader = await blob.download_blob( + offset=10, length=1000, validate_content=a, raw_response_hook=assert_method + ) content = await downloader.read() stream = BytesIO() - downloader = await blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + downloader = await blob.download_blob( + offset=10, length=1000, validate_content=a, raw_response_hook=assert_method + ) await downloader.readinto(stream) stream.seek(0) @@ -440,22 +478,22 @@ async def test_download_blob_large_chunks(self, **kwargs): blob = self.container.get_blob_client(self._get_blob_reference()) # The service will use 4 MiB for structured message chunk size, so make chunk size larger self.container._config.max_chunk_get_size = 10 * 1024 * 1024 - data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 await blob.upload_blob(data, overwrite=True, max_concurrency=5) # Act - downloader = await blob.download_blob(validate_content='crc64', max_concurrency=5) + downloader = await blob.download_blob(validate_content="crc64", max_concurrency=5) content = await downloader.read() - downloader = await blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + downloader = await blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") partial = await downloader.read() # Assert assert content == data - assert partial == data[5 * 1024 * 1024: 30 * 1024 * 1024] + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_blob_chars(self, a, **kwargs): @@ -465,18 +503,18 @@ async def test_download_blob_chars(self, a, **kwargs): self.container._config.max_single_get_size = 512 self.container._config.max_chunk_get_size = 512 - data = '你好世界' * 256 # 3 KiB + data = "你好世界" * 256 # 3 KiB blob = self.container.get_blob_client(self._get_blob_reference()) - await blob.upload_blob(data, encoding='utf-8', overwrite=True) + await blob.upload_blob(data, encoding="utf-8", overwrite=True) - stream = await blob.download_blob(encoding='utf-8', validate_content=a) + stream = await blob.download_blob(encoding="utf-8", validate_content=a) assert await stream.read() == data - stream = await blob.download_blob(encoding='utf-8', validate_content=a) + stream = await blob.download_blob(encoding="utf-8", validate_content=a) assert await stream.read(chars=100000) == data - result = '' - stream = await blob.download_blob(encoding='utf-8', validate_content=a) + result = "" + stream = await blob.download_blob(encoding="utf-8", validate_content=a) for _ in range(4): chunk = await stream.read(chars=100) result += chunk @@ -486,7 +524,7 @@ async def test_download_blob_chars(self, a, **kwargs): assert result == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_content_validation_with_retry(self, a, **kwargs): @@ -500,22 +538,23 @@ async def test_content_validation_with_retry(self, a, **kwargs): retry_total=1, initial_backoff=0.1, increment_base=0.1, - logging_enable=True + logging_enable=True, ) - self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) try: await self.container.create_container() except ResourceExistsError: pass blob = self.container.get_blob_client(self._get_blob_reference()) - data = b'abc' * 512 + data = b"abc" * 512 # Determine the appropriate assert methods based on validation mode - upload_assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 - download_assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + upload_assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + download_assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Test upload with retry upload_call_count = 0 + def upload_hook_fail_once(response): nonlocal upload_call_count upload_call_count += 1 @@ -530,6 +569,7 @@ def upload_hook_fail_once(response): # Test download with retry download_call_count = 0 + def download_hook_fail_once(response): nonlocal download_call_count download_call_count += 1 @@ -544,7 +584,7 @@ def download_hook_fail_once(response): assert content == data @BlobPreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_streaming_with_retry(self, a, **kwargs): @@ -558,20 +598,21 @@ async def test_streaming_with_retry(self, a, **kwargs): retry_total=1, initial_backoff=0.1, increment_base=0.1, - logging_enable=True + logging_enable=True, ) - self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) try: await self.container.create_container() except ResourceExistsError: pass blob = self.container.get_blob_client(self._get_blob_reference()) - content = b'abc' * 512 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + content = b"abc" * 512 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Test stage_block streaming with retry call_count = 0 + def hook_fail_once(response): nonlocal call_count call_count += 1 @@ -581,9 +622,9 @@ def hook_fail_once(response): response.http_response.status_code = 408 # Request Timeout - triggers retry # Use stage_block to test structured message streaming - await blob.stage_block('1', BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) + await blob.stage_block("1", BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) assert call_count == 2 # Original + retry - - await blob.commit_block_list([BlobBlock('1')]) + + await blob.commit_block_list([BlobBlock("1")]) result = await blob.download_blob() assert await result.read() == content diff --git a/sdk/storage/azure-storage-blob/tests/test_streams.py b/sdk/storage/azure-storage-blob/tests/test_streams.py index e6bbb0d6414c..11f63b3be962 100644 --- a/sdk/storage/azure-storage-blob/tests/test_streams.py +++ b/sdk/storage/azure-storage-blob/tests/test_streams.py @@ -11,6 +11,8 @@ from typing import Iterator, List, Optional, Tuple, Union import pytest +from test_helpers import NonSeekableStream + from azure.storage.blob._shared.streams import ( StructuredMessageConstants, StructuredMessageDecoder, @@ -19,8 +21,6 @@ ) from azure.storage.extensions import crc64 -from test_helpers import NonSeekableStream - def _iter_bytes(data: bytes, chunk_size: int = 1024) -> Iterator[bytes]: """Convert bytes to an Iterator[bytes] with the given chunk size.""" @@ -83,12 +83,12 @@ def _build_structured_message( segment_crc = None if StructuredMessageProperties.CRC64 in flags: - segment_crc = crc64.compute(segment_data, 0) + segment_crc = crc64.compute(segment_data, 0) # pylint: disable=I1101 if i == invalidate_crc_segment: segment_crc += 5 _write_segment(i, segment_data, segment_crc, message) - message_crc = crc64.compute(segment_data, message_crc) + message_crc = crc64.compute(segment_data, message_crc) # pylint: disable=I1101 # Message footer if StructuredMessageProperties.CRC64 in flags: diff --git a/sdk/storage/azure-storage-blob/tests/test_streams_async.py b/sdk/storage/azure-storage-blob/tests/test_streams_async.py index e49cc3becd72..1d1bb3aff654 100644 --- a/sdk/storage/azure-storage-blob/tests/test_streams_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_streams_async.py @@ -11,7 +11,6 @@ from typing import AsyncIterator, List, Optional, Tuple, Union import pytest -import pytest_asyncio from azure.storage.blob._shared.streams import ( StructuredMessageConstants, StructuredMessageProperties, @@ -81,12 +80,12 @@ def _build_structured_message( segment_crc = None if StructuredMessageProperties.CRC64 in flags: - segment_crc = crc64.compute(segment_data, 0) + segment_crc = crc64.compute(segment_data, 0) # pylint: disable=I1101 if i == invalidate_crc_segment: segment_crc += 5 _write_segment(i, segment_data, segment_crc, message) - message_crc = crc64.compute(segment_data, message_crc) + message_crc = crc64.compute(segment_data, message_crc) # pylint: disable=I1101 # Message footer if StructuredMessageProperties.CRC64 in flags: diff --git a/sdk/storage/azure-storage-file-datalake/assets.json b/sdk/storage/azure-storage-file-datalake/assets.json index 73f72bbd8f6d..a0c79aa64e2b 100644 --- a/sdk/storage/azure-storage-file-datalake/assets.json +++ b/sdk/storage/azure-storage-file-datalake/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/storage/azure-storage-file-datalake", - "Tag": "python/storage/azure-storage-file-datalake_c9ea6b56de" + "Tag": "python/storage/azure-storage-file-datalake_34be4c4176" } diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py index 77f81647e956..d0668fd75a8d 100644 --- a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py @@ -7,45 +7,45 @@ from io import BytesIO import pytest -from azure.storage.filedatalake import ( - DataLakeServiceClient -) - from devtools_testutils import is_live, recorded_by_proxy from devtools_testutils.storage import GenericTestProxyParametrize1, StorageRecordedTestCase from settings.testcase import DataLakePreparer +from azure.storage.filedatalake import DataLakeServiceClient + def assert_content_md5(request): - if request.http_request.query.get('action') == 'append': - assert request.http_request.headers.get('Content-MD5') is not None + if request.http_request.query.get("action") == "append": + assert request.http_request.headers.get("Content-MD5") is not None def assert_content_md5_get(response): - assert response.http_request.headers.get('x-ms-range-get-content-md5') == 'true' - assert response.http_response.headers.get('Content-MD5') is not None + assert response.http_request.headers.get("x-ms-range-get-content-md5") == "true" + assert response.http_response.headers.get("Content-MD5") is not None def assert_content_crc64(request): - if request.http_request.query.get('action') == 'append': - assert request.http_request.headers.get('x-ms-content-crc64') is not None + if request.http_request.query.get("action") == "append": + assert request.http_request.headers.get("x-ms-content-crc64") is not None def assert_structured_message(request): - if request.http_request.query.get('action') == 'append': - assert request.http_request.headers.get('x-ms-structured-body') is not None + if request.http_request.query.get("action") == "append": + assert request.http_request.headers.get("x-ms-structured-body") is not None def assert_structured_message_get(response): - assert response.http_request.headers.get('x-ms-structured-body') is not None - assert response.http_response.headers.get('x-ms-structured-body') is not None + assert response.http_request.headers.get("x-ms-structured-body") is not None + assert response.http_response.headers.get("x-ms-structured-body") is not None class TestStorageContentValidation(StorageRecordedTestCase): def _setup(self, account_name): token_credential = self.get_credential(DataLakeServiceClient) - self.dsc = DataLakeServiceClient(self.account_url(account_name, "dfs"), credential=token_credential, logging_enable=True) - self.file_system = self.dsc.get_file_system_client(self.get_resource_name('filesystem')) + self.dsc = DataLakeServiceClient( + self.account_url(account_name, "dfs"), credential=token_credential, logging_enable=True + ) + self.file_system = self.dsc.get_file_system_client(self.get_resource_name("filesystem")) self.file_system.create_file_system() def teardown_method(self, _): @@ -56,10 +56,10 @@ def teardown_method(self, _): pass def _get_file_reference(self): - return self.get_resource_name('file') + return self.get_resource_name("file") @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_data(self, a, **kwargs): @@ -67,8 +67,8 @@ def test_upload_data(self, a, **kwargs): self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abc' * 512 - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + data = b"abc" * 512 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 # Act file.upload_data(data, overwrite=True, validate_content=a, raw_request_hook=assert_method) @@ -78,7 +78,7 @@ def test_upload_data(self, a, **kwargs): assert content.read() == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_data_chunks(self, a, **kwargs): @@ -86,8 +86,8 @@ def test_upload_data_chunks(self, a, **kwargs): self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abcde' * 512 - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + data = b"abcde" * 512 + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 # Act file.upload_data(data, overwrite=True, validate_content=a, chunk_size=1024, raw_request_hook=assert_method) @@ -97,7 +97,7 @@ def test_upload_data_chunks(self, a, **kwargs): assert content.read() == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_data_substream(self, a, **kwargs): @@ -107,9 +107,9 @@ def test_upload_data_substream(self, a, **kwargs): self._setup(datalake_storage_account_name) self.file_system._config.min_large_chunk_upload_threshold = 1 # Set less than chunk size to enable substream file = self.file_system.get_file_client(self._get_file_reference()) - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" io = BytesIO(data) # Act @@ -120,7 +120,7 @@ def test_upload_data_substream(self, a, **kwargs): assert content.read() == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_append_data(self, a, **kwargs): @@ -128,22 +128,22 @@ def test_append_data(self, a, **kwargs): self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data1 = b'abcde' * 512 - data2 = '你好世界' * 10 - encoded2 = data2.encode('utf-8-sig') + data1 = b"abcde" * 512 + data2 = "你好世界" * 10 + encoded2 = data2.encode("utf-8-sig") # An iterable with no length will be read into bytes and therefore will behave like # bytes when it comes to testing content validation. def generator(): for i in range(0, len(data1), 500): - yield data1[i: i + 500] + yield data1[i : i + 500] - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 # Act file.create_file() file.append_data(data1, 0, validate_content=a, raw_request_hook=assert_method) - file.append_data(data2, len(data1), encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) + file.append_data(data2, len(data1), encoding="utf-8-sig", validate_content=a, raw_request_hook=assert_method) file.append_data(generator(), len(data1) + len(encoded2), validate_content=a, raw_request_hook=assert_method) file.flush_data(len(data1) + len(encoded2) + len(data1)) @@ -152,7 +152,7 @@ def generator(): assert content.read() == data1 + encoded2 + data1 @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_append_data_streaming(self, a, **kwargs): @@ -161,8 +161,8 @@ def test_append_data_streaming(self, a, **kwargs): self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - content = b'abcde' * 1030 # 5 KiB + 30 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Act file.create_file() @@ -173,7 +173,7 @@ def test_append_data_streaming(self, a, **kwargs): assert result.read() == content @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @pytest.mark.live_test_only def test_append_data_streaming_large(self, a, **kwargs): datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") @@ -181,16 +181,18 @@ def test_append_data_streaming_large(self, a, **kwargs): self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data1 = b'abcde' * 1024 * 1024 # 5 MiB - data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 - data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Act file.create_file() file.append_data(BytesIO(data1), 0, flush=True, validate_content=a, raw_request_hook=assert_method) file.append_data(BytesIO(data2), len(data1), flush=True, validate_content=a, raw_request_hook=assert_method) - file.append_data(BytesIO(data3), len(data1) + len(data2), flush=True, validate_content=a, raw_request_hook=assert_method) + file.append_data( + BytesIO(data3), len(data1) + len(data2), flush=True, validate_content=a, raw_request_hook=assert_method + ) file.flush_data(len(data1) + len(data2) + len(data3)) # Assert @@ -198,7 +200,7 @@ def test_append_data_streaming_large(self, a, **kwargs): assert result.read() == data1 + data2 + data3 @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_file(self, a, **kwargs): @@ -206,9 +208,9 @@ def test_download_file(self, a, **kwargs): self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + data = b"abc" * 512 file.upload_data(data, overwrite=True) - assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get # Act downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) @@ -224,7 +226,7 @@ def test_download_file(self, a, **kwargs): assert stream.read() == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_file_chunks(self, a, **kwargs): @@ -234,9 +236,9 @@ def test_download_file_chunks(self, a, **kwargs): self.file_system._config.max_single_get_size = 512 self.file_system._config.max_chunk_get_size = 512 file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" file.upload_data(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) @@ -258,7 +260,7 @@ def test_download_file_chunks(self, a, **kwargs): assert read_content == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_file_chunks_partial(self, a, **kwargs): @@ -268,9 +270,9 @@ def test_download_file_chunks_partial(self, a, **kwargs): self.file_system._config.max_single_get_size = 512 self.file_system._config.max_chunk_get_size = 512 file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" file.upload_data(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) @@ -294,16 +296,16 @@ def test_download_file_large_chunks(self, **kwargs): file = self.file_system.get_file_client(self._get_file_reference()) # The service will use 4 MiB for structured message chunk size, so make chunk size larger self.file_system._config.max_chunk_get_size = 10 * 1024 * 1024 - data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 file.upload_data(data, overwrite=True, max_concurrency=5) # Act - downloader = file.download_file(validate_content='crc64', max_concurrency=5) + downloader = file.download_file(validate_content="crc64", max_concurrency=5) content = downloader.read() - downloader = file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + downloader = file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") partial = downloader.read() # Assert assert content == data - assert partial == data[5 * 1024 * 1024:30 * 1024 * 1024] + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py index b137e11cfba1..a9bcf2c2ec50 100644 --- a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py @@ -7,9 +7,6 @@ from io import BytesIO import pytest -from azure.storage.filedatalake import FileSystemClient as SyncFileSystemClient -from azure.storage.filedatalake.aio import DataLakeServiceClient - from devtools_testutils import is_live from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase, GenericTestProxyParametrize1 @@ -19,15 +16,20 @@ assert_content_md5, assert_content_md5_get, assert_structured_message, - assert_structured_message_get + assert_structured_message_get, ) +from azure.storage.filedatalake import FileSystemClient as SyncFileSystemClient +from azure.storage.filedatalake.aio import DataLakeServiceClient + class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): async def _setup(self, account_name): token_credential = self.get_credential(DataLakeServiceClient, is_async=True) - self.dsc = DataLakeServiceClient(self.account_url(account_name, "dfs"), credential=token_credential, logging_enable=True) - self.file_system = self.dsc.get_file_system_client(self.get_resource_name('filesystem')) + self.dsc = DataLakeServiceClient( + self.account_url(account_name, "dfs"), credential=token_credential, logging_enable=True + ) + self.file_system = self.dsc.get_file_system_client(self.get_resource_name("filesystem")) await self.file_system.create_file_system() def teardown_method(self, _): @@ -39,7 +41,8 @@ def teardown_method(self, _): sync_file_system = SyncFileSystemClient( self.account_url(self.file_system.account_name, "dfs"), self.file_system.file_system_name, - credential=sync_credential) + credential=sync_credential, + ) try: sync_file_system.delete_file_system() @@ -47,10 +50,10 @@ def teardown_method(self, _): pass def _get_file_reference(self): - return self.get_resource_name('file') - + return self.get_resource_name("file") + @DataLakePreparer() - @pytest.mark.parametrize('a', [True]) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_data(self, a, **kwargs): @@ -58,8 +61,8 @@ async def test_upload_data(self, a, **kwargs): await self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abc' * 512 - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + data = b"abc" * 512 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 # Act await file.upload_data(data, overwrite=True, validate_content=a, raw_request_hook=assert_method) @@ -69,7 +72,7 @@ async def test_upload_data(self, a, **kwargs): assert await content.read() == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_data_chunks(self, a, **kwargs): @@ -77,18 +80,20 @@ async def test_upload_data_chunks(self, a, **kwargs): await self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abcde' * 512 - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + data = b"abcde" * 512 + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 # Act - await file.upload_data(data, overwrite=True, validate_content=a, chunk_size=1024, raw_request_hook=assert_method) + await file.upload_data( + data, overwrite=True, validate_content=a, chunk_size=1024, raw_request_hook=assert_method + ) # Assert content = await file.download_file() assert await content.read() == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_data_substream(self, a, **kwargs): @@ -98,9 +103,9 @@ async def test_upload_data_substream(self, a, **kwargs): await self._setup(datalake_storage_account_name) self.file_system._config.min_large_chunk_upload_threshold = 1 # Set less than chunk size to enable substream file = self.file_system.get_file_client(self._get_file_reference()) - assert_method = assert_content_crc64 if a == 'crc64' else assert_content_md5 + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" io = BytesIO(data) # Act @@ -111,7 +116,7 @@ async def test_upload_data_substream(self, a, **kwargs): assert await content.read() == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_append_data(self, a, **kwargs): @@ -119,23 +124,27 @@ async def test_append_data(self, a, **kwargs): await self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data1 = b'abcde' * 512 - data2 = '你好世界' * 10 - encoded2 = data2.encode('utf-8-sig') + data1 = b"abcde" * 512 + data2 = "你好世界" * 10 + encoded2 = data2.encode("utf-8-sig") # An iterable with no length will be read into bytes and therefore will behave like # bytes when it comes to testing content validation. def generator(): for i in range(0, len(data1), 500): - yield data1[i: i + 500] + yield data1[i : i + 500] - assert_method = assert_content_crc64 if a in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 # Act await file.create_file() await file.append_data(data1, 0, validate_content=a, raw_request_hook=assert_method) - await file.append_data(data2, len(data1), encoding='utf-8-sig', validate_content=a, raw_request_hook=assert_method) - await file.append_data(generator(), len(data1) + len(encoded2), validate_content=a, raw_request_hook=assert_method) + await file.append_data( + data2, len(data1), encoding="utf-8-sig", validate_content=a, raw_request_hook=assert_method + ) + await file.append_data( + generator(), len(data1) + len(encoded2), validate_content=a, raw_request_hook=assert_method + ) await file.flush_data(len(data1) + len(encoded2) + len(data1)) # Assert @@ -143,7 +152,7 @@ def generator(): assert await content.read() == data1 + encoded2 + data1 @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_append_data_streaming(self, a, **kwargs): @@ -152,8 +161,8 @@ async def test_append_data_streaming(self, a, **kwargs): await self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - content = b'abcde' * 1030 # 5 KiB + 30 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Act await file.create_file() @@ -164,7 +173,7 @@ async def test_append_data_streaming(self, a, **kwargs): assert await result.read() == content @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @pytest.mark.live_test_only async def test_append_data_streaming_large(self, a, **kwargs): datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") @@ -172,16 +181,20 @@ async def test_append_data_streaming_large(self, a, **kwargs): await self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data1 = b'abcde' * 1024 * 1024 # 5 MiB - data2 = b'12345' * 2 * 1024 * 1024 + b'abcdefg' # 10 MiB + 7 - data3 = b'12345678' * 8 * 1024 * 1024 # 64 MiB - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Act await file.create_file() await file.append_data(BytesIO(data1), 0, flush=True, validate_content=a, raw_request_hook=assert_method) - await file.append_data(BytesIO(data2), len(data1), flush=True, validate_content=a, raw_request_hook=assert_method) - await file.append_data(BytesIO(data3), len(data1) + len(data2), flush=True, validate_content=a, raw_request_hook=assert_method) + await file.append_data( + BytesIO(data2), len(data1), flush=True, validate_content=a, raw_request_hook=assert_method + ) + await file.append_data( + BytesIO(data3), len(data1) + len(data2), flush=True, validate_content=a, raw_request_hook=assert_method + ) await file.flush_data(len(data1) + len(data2) + len(data3)) # Assert @@ -189,7 +202,7 @@ async def test_append_data_streaming_large(self, a, **kwargs): assert await result.read() == data1 + data2 + data3 @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_file(self, a, **kwargs): @@ -197,9 +210,9 @@ async def test_download_file(self, a, **kwargs): await self._setup(datalake_storage_account_name) file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + data = b"abc" * 512 await file.upload_data(data, overwrite=True) - assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get # Act downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) @@ -215,7 +228,7 @@ async def test_download_file(self, a, **kwargs): assert stream.read() == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_file_chunks(self, a, **kwargs): @@ -225,9 +238,9 @@ async def test_download_file_chunks(self, a, **kwargs): self.file_system._config.max_single_get_size = 512 self.file_system._config.max_chunk_get_size = 512 file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" await file.upload_data(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) @@ -249,7 +262,7 @@ async def test_download_file_chunks(self, a, **kwargs): assert read_content == data @DataLakePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_file_chunks_partial(self, a, **kwargs): @@ -259,16 +272,20 @@ async def test_download_file_chunks_partial(self, a, **kwargs): self.file_system._config.max_single_get_size = 512 self.file_system._config.max_chunk_get_size = 512 file = self.file_system.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" await file.upload_data(data, overwrite=True) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act - downloader = await file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + downloader = await file.download_file( + offset=10, length=1000, validate_content=a, raw_response_hook=assert_method + ) content = await downloader.read() stream = BytesIO() - downloader = await file.download_file(offset=512, length=1024, validate_content=a, raw_response_hook=assert_method) + downloader = await file.download_file( + offset=512, length=1024, validate_content=a, raw_response_hook=assert_method + ) await downloader.readinto(stream) stream.seek(0) @@ -285,16 +302,16 @@ async def test_download_file_large_chunks(self, **kwargs): file = self.file_system.get_file_client(self._get_file_reference()) # The service will use 4 MiB for structured message chunk size, so make chunk size larger self.file_system._config.max_chunk_get_size = 10 * 1024 * 1024 - data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 await file.upload_data(data, overwrite=True, max_concurrency=5) # Act - downloader = await file.download_file(validate_content='crc64', max_concurrency=5) + downloader = await file.download_file(validate_content="crc64", max_concurrency=5) content = await downloader.read() - downloader = await file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + downloader = await file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") partial = await downloader.read() # Assert assert content == data - assert partial == data[5 * 1024 * 1024:30 * 1024 * 1024] + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-share/tests/test_content_validation.py b/sdk/storage/azure-storage-file-share/tests/test_content_validation.py index c318c22acd2c..6afdfa1194ac 100644 --- a/sdk/storage/azure-storage-file-share/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-file-share/tests/test_content_validation.py @@ -7,31 +7,31 @@ from io import BytesIO import pytest -from azure.storage.fileshare import ShareClient, ShareServiceClient - from devtools_testutils import is_live, recorded_by_proxy from devtools_testutils.storage import GenericTestProxyParametrize1, StorageRecordedTestCase from settings.testcase import FileSharePreparer +from azure.storage.fileshare import ShareClient, ShareServiceClient + def assert_content_md5(request): - if request.http_request.query.get('comp') == 'range': - assert request.http_request.headers.get('Content-MD5') is not None + if request.http_request.query.get("comp") == "range": + assert request.http_request.headers.get("Content-MD5") is not None def assert_content_md5_get(response): - assert response.http_request.headers.get('x-ms-range-get-content-md5') == 'true' - assert response.http_response.headers.get('Content-MD5') is not None + assert response.http_request.headers.get("x-ms-range-get-content-md5") == "true" + assert response.http_response.headers.get("Content-MD5") is not None def assert_structured_message(request): - if request.http_request.query.get('comp') == 'range': - assert request.http_request.headers.get('x-ms-structured-body') is not None + if request.http_request.query.get("comp") == "range": + assert request.http_request.headers.get("x-ms-structured-body") is not None def assert_structured_message_get(response): - assert response.http_request.headers.get('x-ms-structured-body') is not None - assert response.http_response.headers.get('x-ms-structured-body') is not None + assert response.http_request.headers.get("x-ms-structured-body") is not None + assert response.http_response.headers.get("x-ms-structured-body") is not None class TestStorageContentValidation(StorageRecordedTestCase): @@ -39,8 +39,13 @@ class TestStorageContentValidation(StorageRecordedTestCase): def _setup(self, account_name): token_credential = self.get_credential(ShareServiceClient) - self.ssc = ShareServiceClient(self.account_url(account_name, "file"), credential=token_credential, token_intent="backup", logging_enable=True) - self.share_client = self.ssc.get_share_client(self.get_resource_name('utshare')) + self.ssc = ShareServiceClient( + self.account_url(account_name, "file"), + credential=token_credential, + token_intent="backup", + logging_enable=True, + ) + self.share_client = self.ssc.get_share_client(self.get_resource_name("utshare")) self.share_client.create_share() def teardown_method(self, _): @@ -51,10 +56,10 @@ def teardown_method(self, _): pass def _get_file_reference(self): - return self.get_resource_name('file') + return self.get_resource_name("file") @FileSharePreparer() - @pytest.mark.parametrize('a', ['auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", ["auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_create_file_with_data(self, a, **kwargs): @@ -62,8 +67,8 @@ def test_create_file_with_data(self, a, **kwargs): self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 - assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + data = b"abc" * 512 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 # Act file.create_file(len(data), data=data, validate_content=a, raw_request_hook=assert_method) @@ -73,7 +78,7 @@ def test_create_file_with_data(self, a, **kwargs): assert content.readall() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_file(self, a, **kwargs): @@ -81,8 +86,8 @@ def test_upload_file(self, a, **kwargs): self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 - assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + data = b"abc" * 512 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 # Act file.upload_file(data, validate_content=a, raw_request_hook=assert_method) @@ -92,7 +97,7 @@ def test_upload_file(self, a, **kwargs): assert content.readall() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_file_chunks(self, a, **kwargs): @@ -101,8 +106,8 @@ def test_upload_file_chunks(self, a, **kwargs): self._setup(storage_account_name) self.share_client._config.max_range_size = 1024 file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abcde' * 512 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data = b"abcde" * 512 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Act file.upload_file(data, validate_content=a, raw_request_hook=assert_method) @@ -112,7 +117,7 @@ def test_upload_file_chunks(self, a, **kwargs): assert content.readall() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'auto','md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_range(self, a, **kwargs): @@ -120,23 +125,25 @@ def test_upload_range(self, a, **kwargs): self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data1 = b'abcde' * 512 - data2 = '你好世界' * 10 - encoded2 = data2.encode('utf-16') + data1 = b"abcde" * 512 + data2 = "你好世界" * 10 + encoded2 = data2.encode("utf-16") + + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 - assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 - # Act file.create_file(len(data1) + len(encoded2)) file.upload_range(data1, 0, len(data1), validate_content=a, raw_request_hook=assert_method) - file.upload_range(data2, len(data1), len(encoded2), encoding='utf-16', validate_content=a, raw_request_hook=assert_method) + file.upload_range( + data2, len(data1), len(encoded2), encoding="utf-16", validate_content=a, raw_request_hook=assert_method + ) # Assert content = file.download_file() assert content.readall() == data1 + encoded2 @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_upload_range_streaming(self, a, **kwargs): @@ -145,8 +152,8 @@ def test_upload_range_streaming(self, a, **kwargs): self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abcd' * 1030 # 4 KiB + 24 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data = b"abcd" * 1030 # 4 KiB + 24 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Act file.create_file(len(data)) @@ -159,7 +166,7 @@ def test_upload_range_streaming(self, a, **kwargs): assert content.readall() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_file(self, a, **kwargs): @@ -167,9 +174,9 @@ def test_download_file(self, a, **kwargs): self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + data = b"abc" * 512 file.upload_file(data) - assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get # Act downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) @@ -185,7 +192,7 @@ def test_download_file(self, a, **kwargs): assert stream.read() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_file_chunks(self, a, **kwargs): @@ -195,9 +202,9 @@ def test_download_file_chunks(self, a, **kwargs): self.share_client._config.max_single_get_size = 512 self.share_client._config.max_chunk_get_size = 512 file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" file.upload_file(data) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) @@ -219,7 +226,7 @@ def test_download_file_chunks(self, a, **kwargs): assert read_content == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy def test_download_file_chunks_partial(self, a, **kwargs): @@ -229,9 +236,9 @@ def test_download_file_chunks_partial(self, a, **kwargs): self.share_client._config.max_single_get_size = 512 self.share_client._config.max_chunk_get_size = 512 file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" file.upload_file(data) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) @@ -255,16 +262,16 @@ def test_download_file_large_chunks(self, **kwargs): # The service will use 4 MiB for structured message chunk size, so make chunk size larger self.share_client._config.max_chunk_get_size = 5 * 1024 * 1024 # 5 MiB file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 file.upload_file(data, max_concurrency=5) # Act - downloader = file.download_file(validate_content='crc64', max_concurrency=5) + downloader = file.download_file(validate_content="crc64", max_concurrency=5) content = downloader.readall() - downloader = file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + downloader = file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") partial = downloader.readall() # Assert assert content == data - assert partial == data[5 * 1024 * 1024:30 * 1024 * 1024] + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py b/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py index a8a8d82477eb..628356cbca07 100644 --- a/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py @@ -7,9 +7,6 @@ from io import BytesIO import pytest -from azure.storage.fileshare import ShareClient as SyncShareClient -from azure.storage.fileshare.aio import ShareServiceClient - from devtools_testutils import is_live from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase, GenericTestProxyParametrize1 @@ -18,15 +15,23 @@ assert_content_md5, assert_content_md5_get, assert_structured_message, - assert_structured_message_get + assert_structured_message_get, ) +from azure.storage.fileshare import ShareClient as SyncShareClient +from azure.storage.fileshare.aio import ShareServiceClient + class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): async def _setup(self, account_name): token_credential = self.get_credential(ShareServiceClient, is_async=True) - self.ssc = ShareServiceClient(self.account_url(account_name, "file"), credential=token_credential, token_intent="backup", logging_enable=True) - self.share_client = self.ssc.get_share_client(self.get_resource_name('utshare')) + self.ssc = ShareServiceClient( + self.account_url(account_name, "file"), + credential=token_credential, + token_intent="backup", + logging_enable=True, + ) + self.share_client = self.ssc.get_share_client(self.get_resource_name("utshare")) await self.share_client.create_share() def teardown_method(self, _): @@ -38,17 +43,18 @@ def teardown_method(self, _): self.account_url(self.share_client.account_name, "file"), self.share_client.share_name, credential=sync_credential, - token_intent="backup") + token_intent="backup", + ) try: sync_share_client.delete_share() except: pass def _get_file_reference(self): - return self.get_resource_name('file') + return self.get_resource_name("file") @FileSharePreparer() - @pytest.mark.parametrize('a', ['auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", ["auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_create_file_with_data(self, a, **kwargs): @@ -56,8 +62,8 @@ async def test_create_file_with_data(self, a, **kwargs): await self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 - assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + data = b"abc" * 512 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 # Act await file.create_file(len(data), data=data, validate_content=a, raw_request_hook=assert_method) @@ -67,7 +73,7 @@ async def test_create_file_with_data(self, a, **kwargs): assert await content.readall() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_file(self, a, **kwargs): @@ -75,8 +81,8 @@ async def test_upload_file(self, a, **kwargs): await self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 - assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + data = b"abc" * 512 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 # Act await file.upload_file(data, validate_content=a, raw_request_hook=assert_method) @@ -86,7 +92,7 @@ async def test_upload_file(self, a, **kwargs): assert await content.readall() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_file_chunks(self, a, **kwargs): @@ -95,8 +101,8 @@ async def test_upload_file_chunks(self, a, **kwargs): await self._setup(storage_account_name) self.share_client._config.max_range_size = 1024 file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abcde' * 512 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data = b"abcde" * 512 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Act await file.upload_file(data, validate_content=a, raw_request_hook=assert_method) @@ -106,7 +112,7 @@ async def test_upload_file_chunks(self, a, **kwargs): assert await content.readall() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_range(self, a, **kwargs): @@ -114,23 +120,25 @@ async def test_upload_range(self, a, **kwargs): await self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data1 = b'abcde' * 512 - data2 = '你好世界' * 10 - encoded2 = data2.encode('utf-16') + data1 = b"abcde" * 512 + data2 = "你好世界" * 10 + encoded2 = data2.encode("utf-16") - assert_method = assert_structured_message if a in ('auto', 'crc64') else assert_content_md5 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 # Act await file.create_file(len(data1) + len(encoded2)) await file.upload_range(data1, 0, len(data1), validate_content=a, raw_request_hook=assert_method) - await file.upload_range(data2, len(data1), len(encoded2), encoding='utf-16', validate_content=a, raw_request_hook=assert_method) + await file.upload_range( + data2, len(data1), len(encoded2), encoding="utf-16", validate_content=a, raw_request_hook=assert_method + ) # Assert content = await file.download_file() assert await content.readall() == data1 + encoded2 @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_upload_range_streaming(self, a, **kwargs): @@ -139,8 +147,8 @@ async def test_upload_range_streaming(self, a, **kwargs): await self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abcd' * 1030 # 4 KiB + 24 - assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + data = b"abcd" * 1030 # 4 KiB + 24 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 # Act await file.create_file(len(data)) @@ -153,7 +161,7 @@ async def test_upload_range_streaming(self, a, **kwargs): assert await content.readall() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'auto', 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_file(self, a, **kwargs): @@ -161,9 +169,9 @@ async def test_download_file(self, a, **kwargs): await self._setup(storage_account_name) file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + data = b"abc" * 512 await file.upload_file(data) - assert_method = assert_structured_message_get if a in ('auto', 'crc64') else assert_content_md5_get + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get # Act downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) @@ -179,7 +187,7 @@ async def test_download_file(self, a, **kwargs): assert stream.read() == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_file_chunks(self, a, **kwargs): @@ -189,9 +197,9 @@ async def test_download_file_chunks(self, a, **kwargs): self.share_client._config.max_single_get_size = 512 self.share_client._config.max_chunk_get_size = 512 file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" await file.upload_file(data) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) @@ -213,7 +221,7 @@ async def test_download_file_chunks(self, a, **kwargs): assert read_content == data @FileSharePreparer() - @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content @GenericTestProxyParametrize1() @recorded_by_proxy_async async def test_download_file_chunks_partial(self, a, **kwargs): @@ -223,16 +231,20 @@ async def test_download_file_chunks_partial(self, a, **kwargs): self.share_client._config.max_single_get_size = 512 self.share_client._config.max_chunk_get_size = 512 file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abc' * 512 + b'abcde' + data = b"abc" * 512 + b"abcde" await file.upload_file(data) - assert_method = assert_structured_message_get if a == 'crc64' else assert_content_md5_get + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get # Act - downloader = await file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + downloader = await file.download_file( + offset=10, length=1000, validate_content=a, raw_response_hook=assert_method + ) content = await downloader.readall() stream = BytesIO() - downloader = await file.download_file(offset=512, length=1024, validate_content=a, raw_response_hook=assert_method) + downloader = await file.download_file( + offset=512, length=1024, validate_content=a, raw_response_hook=assert_method + ) await downloader.readinto(stream) stream.seek(0) @@ -249,16 +261,16 @@ async def test_download_file_large_chunks(self, **kwargs): # The service will use 4 MiB for structured message chunk size, so make chunk size larger self.share_client._config.max_chunk_get_size = 5 * 1024 * 1024 # 5 MiB file = self.share_client.get_file_client(self._get_file_reference()) - data = b'abcde' * 30 * 1024 * 1024 + b'abcde' # 150 MiB + 5 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 await file.upload_file(data, max_concurrency=5) # Act - downloader = await file.download_file(validate_content='crc64', max_concurrency=5) + downloader = await file.download_file(validate_content="crc64", max_concurrency=5) content = await downloader.readall() - downloader = await file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content='crc64') + downloader = await file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") partial = await downloader.readall() # Assert assert content == data - assert partial == data[5 * 1024 * 1024:30 * 1024 * 1024] + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] From 15ba37ae19880fa961b28c70c1137ada37f19723 Mon Sep 17 00:00:00 2001 From: Jacob Lauzon Date: Tue, 19 May 2026 14:02:41 -0700 Subject: [PATCH 13/14] Fix changelogs after merge --- sdk/storage/azure-storage-blob/CHANGELOG.md | 12 ------------ sdk/storage/azure-storage-file-datalake/CHANGELOG.md | 8 -------- sdk/storage/azure-storage-file-share/CHANGELOG.md | 10 ---------- sdk/storage/azure-storage-queue/CHANGELOG.md | 3 --- 4 files changed, 33 deletions(-) diff --git a/sdk/storage/azure-storage-blob/CHANGELOG.md b/sdk/storage/azure-storage-blob/CHANGELOG.md index 1073d9149a0c..52e79b9fe146 100644 --- a/sdk/storage/azure-storage-blob/CHANGELOG.md +++ b/sdk/storage/azure-storage-blob/CHANGELOG.md @@ -3,18 +3,6 @@ ## 12.31.0b1 (Unreleased) ### Features Added -- Added support for service version 2026-06-06. -- Added support for connection strings and `account_url`s to accept URLs with `-ipv6` and `-dualstack` suffixes -for `BlobServiceClient`, `ContainerClient`, and `BlobClient`. -- Added support for `create` permission in `BlobSasPermissions` for `stage_block`, -`stage_block_from_url`, and `commit_block_list`. -- Added support for a new `Smart` access tier to `StandardBlobTier` used in `BlobClient.set_standard_blob_tier`, -which is optimized to automatically determine the most cost-effective access with no performance impact. -When set, `BlobProperties.smart_access_tier` will reveal the service's current access -tier choice between `Hot`, `Cool`, and `Archive`. - -### Other Changes -- Consolidated the behavior of `max_concurrency=None` by defaulting to the shared `DEFAULT_MAX_CONCURRENCY` constant. ## 12.29.0 (2026-05-14) diff --git a/sdk/storage/azure-storage-file-datalake/CHANGELOG.md b/sdk/storage/azure-storage-file-datalake/CHANGELOG.md index 5a1c45f8691b..bc59d313ba2a 100644 --- a/sdk/storage/azure-storage-file-datalake/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-datalake/CHANGELOG.md @@ -3,14 +3,6 @@ ## 12.26.0b1 (Unreleased) ### Features Added -- Added support for service version 2026-06-06. -- Added support for connection strings and `account_url`s to accept URLs with `-ipv6` and `-dualstack` suffixes -for `DataLakeServiceClient`, `FileSystemClient`, `DataLakeDirectoryClient`, and `DataLakeFileClient`. -- Added support for `DataLakeDirectoryClient` and `DataLakeFileClient`'s `set_tags` and `get_tags` APIs -to conditionally set and get tags associated with a directory or file client, respectively. - -### Other Changes -- Consolidated the behavior of `max_concurrency=None` by defaulting to the shared `DEFAULT_MAX_CONCURRENCY` constant. ## 12.24.0 (2026-05-14) diff --git a/sdk/storage/azure-storage-file-share/CHANGELOG.md b/sdk/storage/azure-storage-file-share/CHANGELOG.md index 470d6eb8bbe9..162222511349 100644 --- a/sdk/storage/azure-storage-file-share/CHANGELOG.md +++ b/sdk/storage/azure-storage-file-share/CHANGELOG.md @@ -3,16 +3,6 @@ ## 12.27.0b1 (Unreleased) ### Features Added -- Added support for service version 2026-06-06. -- Added support for the keyword `file_property_semantics` in `ShareClient`'s `create_directory` and `DirectoryClient`'s -`create_directory` APIs, which specifies permissions to be configured upon directory creation. -- Added support for the keyword `data` to `FileClient`'s `create_file` API, which specifies the -optional initial data to be uploaded (up to 4MB). -- Added support for connection strings and `account_url`s to accept URLs with `-ipv6` and `-dualstack` suffixes -for `ShareClient`, `ShareDirectoryClient`, and `ShareFileClient`. - -### Other Changes -- Consolidated the behavior of `max_concurrency=None` by defaulting to the shared `DEFAULT_MAX_CONCURRENCY` constant. ## 12.25.0 (2026-05-14) diff --git a/sdk/storage/azure-storage-queue/CHANGELOG.md b/sdk/storage/azure-storage-queue/CHANGELOG.md index b886658fe3cb..b1b136d2741d 100644 --- a/sdk/storage/azure-storage-queue/CHANGELOG.md +++ b/sdk/storage/azure-storage-queue/CHANGELOG.md @@ -3,9 +3,6 @@ ## 12.18.0b1 (Unreleased) ### Features Added -- Added support for service version 2026-06-06. -- Added support for connection strings and `account_url`s to accept URLs with `-ipv6` and `-dualstack` suffixes -for `QueueServiceClient` and `QueueClient`. ## 12.16.0 (2026-05-14) From 07684716840d7ad24c13b0122d6172bdd6a37398 Mon Sep 17 00:00:00 2001 From: Jacob Lauzon Date: Tue, 19 May 2026 15:16:45 -0700 Subject: [PATCH 14/14] Tools black, Copilot feedback --- .../devtools_testutils/storage/aio/__init__.py | 6 +----- .../devtools_testutils/storage/aio/asyncdecorators.py | 3 +++ .../devtools_testutils/storage/decorators.py | 3 +++ .../azure/storage/blob/_shared/policies.py | 2 +- .../azure/storage/filedatalake/_shared/policies.py | 2 +- .../azure/storage/fileshare/_shared/policies.py | 2 +- .../azure/storage/queue/_shared/policies.py | 2 +- 7 files changed, 11 insertions(+), 9 deletions(-) diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py index f1c9a35dcb20..e377e97ee9c0 100644 --- a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py @@ -1,8 +1,4 @@ from .asynctestcase import AsyncStorageRecordedTestCase from .asyncdecorators import GenericTestProxyParametrize1, GenericTestProxyParametrize2 -__all__ = [ - "AsyncStorageRecordedTestCase", - "GenericTestProxyParametrize1", - "GenericTestProxyParametrize2" -] \ No newline at end of file +__all__ = ["AsyncStorageRecordedTestCase", "GenericTestProxyParametrize1", "GenericTestProxyParametrize2"] diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py index cc2455a69bbe..1cac397fa80a 100644 --- a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- + class GenericTestProxyParametrize1: def __call__(self, fn): async def _wrapper(test_class, a, **kwargs): await fn(test_class, a, **kwargs) + return _wrapper @@ -15,4 +17,5 @@ class GenericTestProxyParametrize2: def __call__(self, fn): async def _wrapper(test_class, a, b, **kwargs): await fn(test_class, a, b, **kwargs) + return _wrapper diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py index 45f1db5c588c..ba347a6ad1ec 100644 --- a/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py @@ -4,10 +4,12 @@ # license information. # -------------------------------------------------------------------------- + class GenericTestProxyParametrize1: def __call__(self, fn): def _wrapper(test_class, a, **kwargs): return fn(test_class, a, **kwargs) + return _wrapper @@ -15,4 +17,5 @@ class GenericTestProxyParametrize2: def __call__(self, fn): def _wrapper(test_class, a, b, **kwargs): return fn(test_class, a, b, **kwargs) + return _wrapper diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 832717f7457a..3a5f0b9d662f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -458,7 +458,7 @@ def _validate_content_response( raise AzureError( ( f"Expected structured message header in response does not match request. " - f"Request: {sm_request}, Response: {sm_response}", + f"Request: {sm_request}, Response: {sm_response}" ), response=response.http_response, ) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py index 7a02e4479149..b5d0b7d79766 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py @@ -458,7 +458,7 @@ def _validate_content_response( raise AzureError( ( f"Expected structured message header in response does not match request. " - f"Request: {sm_request}, Response: {sm_response}", + f"Request: {sm_request}, Response: {sm_response}" ), response=response.http_response, ) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py index 7a02e4479149..b5d0b7d79766 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py @@ -458,7 +458,7 @@ def _validate_content_response( raise AzureError( ( f"Expected structured message header in response does not match request. " - f"Request: {sm_request}, Response: {sm_response}", + f"Request: {sm_request}, Response: {sm_response}" ), response=response.http_response, ) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index 068bef5601f3..f4f602d1c669 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -464,7 +464,7 @@ def _validate_content_response( raise AzureError( ( f"Expected structured message header in response does not match request. " - f"Request: {sm_request}, Response: {sm_response}", + f"Request: {sm_request}, Response: {sm_response}" ), response=response.http_response, )