diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py index 1482a6448440..4098c53f93dd 100644 --- a/eng/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/__init__.py @@ -1,9 +1,12 @@ from .api_version_policy import ApiVersionAssertPolicy +from .decorators import GenericTestProxyParametrize1, GenericTestProxyParametrize2 from .service_versions import service_version_map, ServiceVersion, is_version_before from .testcase import StorageRecordedTestCase, LogCaptured __all__ = [ "ApiVersionAssertPolicy", + "GenericTestProxyParametrize1", + "GenericTestProxyParametrize2", "service_version_map", "StorageRecordedTestCase", "ServiceVersion", diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py index bc57c1b559e8..e377e97ee9c0 100644 --- a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/__init__.py @@ -1,3 +1,4 @@ from .asynctestcase import AsyncStorageRecordedTestCase +from .asyncdecorators import GenericTestProxyParametrize1, GenericTestProxyParametrize2 -__all__ = ["AsyncStorageRecordedTestCase"] +__all__ = ["AsyncStorageRecordedTestCase", "GenericTestProxyParametrize1", "GenericTestProxyParametrize2"] diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py new file mode 100644 index 000000000000..1cac397fa80a --- /dev/null +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/aio/asyncdecorators.py @@ -0,0 +1,21 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + + +class GenericTestProxyParametrize1: + def __call__(self, fn): + async def _wrapper(test_class, a, **kwargs): + await fn(test_class, a, **kwargs) + + return _wrapper + + +class GenericTestProxyParametrize2: + def __call__(self, fn): + async def _wrapper(test_class, a, b, **kwargs): + await fn(test_class, a, b, **kwargs) + + return _wrapper diff --git a/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py b/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py new file mode 100644 index 000000000000..ba347a6ad1ec --- /dev/null +++ b/eng/tools/azure-sdk-tools/devtools_testutils/storage/decorators.py @@ -0,0 +1,21 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + + +class GenericTestProxyParametrize1: + def __call__(self, fn): + def _wrapper(test_class, a, **kwargs): + return fn(test_class, a, **kwargs) + + return _wrapper + + +class GenericTestProxyParametrize2: + def __call__(self, fn): + def _wrapper(test_class, a, b, **kwargs): + return fn(test_class, a, b, **kwargs) + + return _wrapper diff --git a/sdk/storage/azure-storage-blob/assets.json b/sdk/storage/azure-storage-blob/assets.json index 007e0506d1c8..93dd0f1e298c 100644 --- a/sdk/storage/azure-storage-blob/assets.json +++ b/sdk/storage/azure-storage-blob/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/storage/azure-storage-blob", - "Tag": "python/storage/azure-storage-blob_28cfcca089" + "Tag": "python/storage/azure-storage-blob_b09e37b521" } diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py index b3dafa603afd..04ae539788b3 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py @@ -65,6 +65,7 @@ from ._quick_query_helper import BlobQueryReader from ._shared.base_client import parse_connection_str, StorageAccountHostsMixin, TransportWrapper from ._shared.response_handlers import process_storage_error, return_response_headers +from ._shared.validation import is_crc64_validation, parse_validation_option from ._serialize import ( get_access_conditions, get_api_version, @@ -505,15 +506,11 @@ def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, upload_blob only succeeds if the blob's lease is active and matches this ID. Value can be a BlobLeaseClient object @@ -616,6 +613,9 @@ def upload_blob( raise ValueError("Encryption required but no key was provided.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if is_crc64_validation(validate_content) and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _upload_blob_options( data=data, blob_type=blob_type, @@ -627,6 +627,7 @@ def upload_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -683,15 +684,11 @@ def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a @@ -765,6 +762,9 @@ def download_blob( raise ValueError("Offset value must not be None if length is set.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if is_crc64_validation(validate_content) and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _download_blob_options( blob_name=self.blob_name, container_name=self.container_name, @@ -778,6 +778,7 @@ def download_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -2009,15 +2010,11 @@ def stage_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -2850,13 +2847,11 @@ def upload_page( Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.BlobLeaseClient or str - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int if_sequence_number_lte: If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. @@ -3157,13 +3152,11 @@ def append_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int maxsize_condition: Optional conditional header. The max length in bytes permitted for the append blob. If the Append Block operation would cause the blob diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi index fee0f4757f79..f5679a595ad8 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.pyi @@ -173,7 +173,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): tags: Optional[Dict[str, str]] = None, overwrite: bool = False, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[BlobLeaseClient] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -200,7 +200,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -222,7 +222,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -244,7 +244,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -486,7 +486,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, encoding: Optional[str] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, @@ -671,7 +671,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: int, *, lease: Optional[Union[BlobLeaseClient, str]] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, if_sequence_number_lte: Optional[int] = None, if_sequence_number_lt: Optional[int] = None, if_sequence_number_eq: Optional[int] = None, @@ -741,7 +741,7 @@ class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, maxsize_condition: Optional[int] = None, appendpos_condition: Optional[int] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py index 33d5dfc7f0b2..430a437bf886 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client_helpers.py @@ -58,6 +58,7 @@ from ._shared.response_handlers import return_headers_and_deserialized, return_response_headers from ._shared.uploads import IterStreamer from ._shared.uploads_async import AsyncIterStreamer +from ._shared.validation import CV_TYPE_PARSED, parse_validation_option from ._upload_helpers import _any_conditions if TYPE_CHECKING: @@ -110,6 +111,7 @@ def _upload_blob_options( # pylint:disable=too-many-statements length: Optional[int], metadata: Optional[Dict[str, str]], encryption_options: Dict[str, Any], + validate_content: CV_TYPE_PARSED, config: "StorageConfiguration", sdk_moniker: str, client: "AzureBlobStorage", @@ -135,7 +137,6 @@ def _upload_blob_options( # pylint:disable=too-many-statements else: raise TypeError(f"Unsupported data type: {type(data)}") - validate_content = kwargs.pop('validate_content', False) content_settings = kwargs.pop('content_settings', None) overwrite = kwargs.pop('overwrite', False) max_concurrency = kwargs.pop('max_concurrency', None) @@ -258,42 +259,16 @@ def _download_blob_options( length: Optional[int], encoding: Optional[str], encryption_options: Dict[str, Any], + validate_content: CV_TYPE_PARSED, config: "StorageConfiguration", sdk_moniker: str, client: "AzureBlobStorage", **kwargs ) -> Dict[str, Any]: - """Creates a dictionary containing the options for a download blob operation. - - :param str blob_name: - The name of the blob. - :param str container_name: - The name of the container. - :param Optional[str] version_id: - The version id parameter is a value that, when present, specifies the version of the blob to download. - :param Optional[int] offset: - Start of byte range to use for downloading a section of the blob. Must be set if length is provided. - :param Optional[int] length: - Number of bytes to read from the stream. This is optional, but should be supplied for optimal performance. - :param Optional[str] encoding: - Encoding to decode the downloaded bytes. Default is None, i.e. no decoding. - :param Dict[str, Any] encryption_options: - The options for encryption, if enabled. - :param StorageConfiguration config: - The Storage configuration options. - :param str sdk_moniker: - The string representing the SDK package version. - :param AzureBlobStorage client: - The generated Blob Storage client. - :return: A dictionary containing the download blob options. - :rtype: Dict[str, Any] - """ if length is not None: if offset is None: raise ValueError("Offset must be provided if length is provided.") length = offset + length - 1 # Service actually uses an end-range inclusive index - - validate_content = kwargs.pop('validate_content', False) access_conditions = get_access_conditions(kwargs.pop('lease', None)) mod_conditions = get_modify_conditions(kwargs) @@ -721,7 +696,7 @@ def _stage_block_options( if isinstance(data, bytes): data = data[:length] - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) cpk_scope_info = get_cpk_scope_info(kwargs) cpk = kwargs.pop('cpk', None) cpk_info = None @@ -1004,7 +979,7 @@ def _upload_page_options( ) mod_conditions = get_modify_conditions(kwargs) cpk_scope_info = get_cpk_scope_info(kwargs) - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) cpk = kwargs.pop('cpk', None) cpk_info = None if cpk: @@ -1149,7 +1124,7 @@ def _append_block_options( appendpos_condition = kwargs.pop('appendpos_condition', None) maxsize_condition = kwargs.pop('maxsize_condition', None) - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) append_conditions = None if maxsize_condition or appendpos_condition is not None: append_conditions = AppendPositionAccessConditions( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index 6d1a96d4bb88..dfba8f54eb63 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -1010,15 +1010,11 @@ def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the container has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -1255,15 +1251,11 @@ def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi index 9ce6d9b7acdb..2ad0514290d3 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi @@ -13,6 +13,7 @@ from typing import ( Callable, Dict, List, + Literal, IO, Iterable, Iterator, @@ -246,7 +247,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): *, overwrite: Optional[bool] = None, content_settings: Optional[ContentSettings] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -288,7 +289,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -309,7 +310,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -331,7 +332,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py index a2f50ebc91ec..17304f6bed8f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_download.py @@ -21,6 +21,7 @@ from ._shared.request_handlers import validate_and_format_range_headers from ._shared.response_handlers import parse_length_from_content_range, process_storage_error from ._shared.constants import DEFAULT_MAX_CONCURRENCY +from ._shared.validation import is_md5_validation, CV_TYPE_PARSED from ._deserialize import deserialize_blob_properties, get_page_ranges_result from ._encryption import ( adjust_blob_size_for_encryption, @@ -91,7 +92,7 @@ def __init__( current_progress: int, start_range: int, end_range: int, - validate_content: bool, + validate_content: CV_TYPE_PARSED, encryption_options: Dict[str, Any], encryption_data: Optional["_EncryptionData"] = None, stream: Any = None, @@ -214,7 +215,7 @@ def _download_chunk(self, chunk_start: int, chunk_end: int) -> Tuple[bytes, int] range_header, range_validation = validate_and_format_range_headers( download_range[0], download_range[1], - check_content_md5=self.validate_content + check_content_md5=is_md5_validation(self.validate_content) ) retry_active = True @@ -329,7 +330,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: CV_TYPE_PARSED = None, encryption_options: Dict[str, Any] = None, # type: ignore [assignment] max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] @@ -358,6 +359,7 @@ def __init__( self._file_size = 0 self._non_empty_ranges = None self._encryption_data: Optional["_EncryptionData"] = None + self._is_structured_message = False # The content download offset, after any processing (decryption), in bytes self._download_offset = 0 @@ -382,11 +384,14 @@ def __init__( self._get_encryption_data_request() # The service only provides transactional MD5s for chunks under 4MB. - # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. first_get_size = ( - self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size + self._config.max_single_get_size + if not is_md5_validation(self._validate_content) + else self._config.max_chunk_get_size ) + initial_request_start = self._download_start if self._end_range is not None and self._end_range - initial_request_start < first_get_size: initial_request_end = self._end_range @@ -445,7 +450,7 @@ def _get_encryption_data_request(self) -> None: @property def _download_complete(self): - if is_encryption_v2(self._encryption_data): + if is_encryption_v2(self._encryption_data) or self._is_structured_message: return self._download_offset >= self.size return self._raw_download_offset >= self.size @@ -455,7 +460,7 @@ def _initial_request(self): self._initial_range[1], start_range_required=False, end_range_required=False, - check_content_md5=self._validate_content + check_content_md5=is_md5_validation(self._validate_content) ) retry_active = True @@ -543,6 +548,7 @@ def _initial_request(self): except HttpResponseError: pass + self._is_structured_message = response.response.headers.get("x-ms-structured-body") is not None if not self._download_complete and self._request_options.get("modified_access_conditions"): self._request_options["modified_access_conditions"].if_match = response.properties.etag diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 16aba3116029..2e023b1cc8d9 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -36,12 +36,15 @@ from .parser import DEVSTORE_ACCOUNT_KEY, _get_development_storage_endpoint from .policies import ( QueueMessagePolicy, - StorageContentValidation, StorageHeadersPolicy, StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook +from .policies_async import ( + AsyncStorageBearerTokenCredentialPolicy, + AsyncContentValidationPolicy, + AsyncStorageResponseHook, +) from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken @@ -130,7 +133,7 @@ def _create_pipeline( QueueMessagePolicy(), config.proxy_policy, config.user_agent_policy, - StorageContentValidation(), + AsyncContentValidationPolicy(), ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=hosts, **kwargs), diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 61a4fdb15bdd..3a5f0b9d662f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -5,14 +5,13 @@ # -------------------------------------------------------------------------- import base64 -import hashlib import logging import random import re import uuid -from io import SEEK_SET, UnsupportedOperation +from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +31,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_crc64_validation, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +55,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -55,7 +72,12 @@ def encode_base64(data): # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -101,11 +123,15 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True @@ -237,7 +263,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -352,64 +387,113 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. - - This will overwrite any headers already defined in the request. - """ +def _prepare_content_validation(request: "PipelineRequest") -> None: + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" + # Download + if request.http_request.method == "GET": + if is_crc64_validation(validate_content): + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() - - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 + + elif is_crc64_validation(validate_content): + # For crc64-sm, force structured message even for bytes + if validate_content == "crc64-sm" and isinstance(data, bytes): + data = BytesIO(data) + + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) + elif hasattr(data, "read"): + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) - return md5.digest() + request.context["validate_content"] = validate_content + + +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif is_crc64_validation(validate_content): + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}" + ), + response=response.http_response, + ) + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content + _prepare_content_validation(request) def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, - ) + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -456,7 +540,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -496,7 +580,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -570,7 +654,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -584,11 +668,16 @@ def send(self, request): response = self.next.send(request) if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -598,7 +687,12 @@ def send(self, request): raise retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -705,7 +799,7 @@ def __init__( self.random_jitter_range = random_jitter_range super(LinearRetry, self).__init__(retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs) - def get_backoff_time(self, settings: Dict[str, Any]) -> float: + def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument """ Calculates how long to sleep before retrying. @@ -731,11 +825,11 @@ def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 4cb32f23248b..e1d13b1a83fa 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -11,11 +11,25 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -37,21 +51,52 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -125,11 +170,16 @@ async def send(self, request): response = await self.next.send(request) if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -139,7 +189,12 @@ async def send(self, request): raise retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py index b23f65859690..9f106333e1fb 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/request_handlers.py @@ -139,7 +139,7 @@ def validate_and_format_range_headers( raise ValueError("Both start and end range required for MD5 content validation.") if end_range - start_range > 4 * 1024 * 1024: raise ValueError("Getting content MD5 for a range greater than 4MB is not supported.") - range_validation = "true" + range_validation = True return range_header, range_validation diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py new file mode 100644 index 000000000000..cb745693921c --- /dev/null +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -0,0 +1,565 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header(data: bytes, expected_message_length: int) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + @property + def closed(self) -> bool: + return self._inner_stream.closed + + def close(self) -> None: + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return self._inner_stream.seekable() and self._initial_content_position is not None + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") + + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") + + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region(self._current_region, remaining, output) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min(self._segment_size, self.content_length - self._content_offset) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, "little") + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[self._current_region_offset : self._current_region_offset + read_size] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + read_size = min(size, self._current_region_length - self._current_region_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py new file mode 100644 index 000000000000..9fedc055a623 --- /dev/null +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams_async.py @@ -0,0 +1,210 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py new file mode 100644 index 000000000000..21d3b081d8cc --- /dev/null +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/validation.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + +_VALID_CV_OPTIONS = ("auto", "crc64", "crc64-sm", "md5") + +CV_TYPE = Optional[Union[bool, Literal["auto", "crc64", "md5"]]] +CV_TYPE_PARSED = Optional[Union[bool, Literal["crc64", "crc64-sm", "md5"]]] + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: CV_TYPE, + *, + force_structured_message: bool = False, +) -> CV_TYPE_PARSED: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + parsed = validate_content.lower() + if parsed not in _VALID_CV_OPTIONS: + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if parsed == "auto": + parsed = "crc64" + + if parsed == "crc64": + _verify_extensions("crc64") + if force_structured_message: + parsed = "crc64-sm" + + return cast(CV_TYPE_PARSED, parsed) + + +def is_md5_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == "md5" + + +def is_crc64_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return False + return validate_content in ("crc64", "crc64-sm") + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py index 2ce55f7ab237..1b782ca1a8d8 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_upload_helpers.py @@ -32,6 +32,7 @@ upload_data_chunks, upload_substream_blocks ) +from ._shared.validation import CV_TYPE_PARSED if TYPE_CHECKING: from ._generated.operations import AppendBlobOperations, BlockBlobOperations, PageBlobOperations @@ -71,7 +72,7 @@ def upload_block_blob( # pylint: disable=too-many-locals, too-many-statements encryption_options: Dict[str, Any], blob_settings: "StorageConfiguration", headers: Dict[str, Any], - validate_content: bool, + validate_content: CV_TYPE_PARSED, max_concurrency: Optional[int], length: Optional[int] = None, **kwargs: Any @@ -123,11 +124,15 @@ def upload_block_blob( # pylint: disable=too-many-locals, too-many-statements return cast(Dict[str, Any], response) - use_original_upload_path = blob_settings.use_byte_buffer or \ - validate_content or encryption_options.get('required') or \ - blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold or \ - hasattr(stream, 'seekable') and not stream.seekable() or \ - not hasattr(stream, 'seek') or not hasattr(stream, 'tell') + use_original_upload_path = ( + blob_settings.use_byte_buffer + or validate_content not in (None, False) + or encryption_options.get('required') + or blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold + or hasattr(stream, 'seekable') and not stream.seekable() + or not hasattr(stream, 'seek') + or not hasattr(stream, 'tell') + ) if use_original_upload_path: total_size = length @@ -209,7 +214,7 @@ def upload_page_blob( headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: @@ -287,7 +292,7 @@ def upload_append_blob( # pylint: disable=unused-argument headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index b73b5691a6ca..a050d8206df2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -77,6 +77,7 @@ from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper, parse_connection_str from .._shared.policies_async import ExponentialRetry from .._shared.response_handlers import process_storage_error, return_response_headers +from .._shared.validation import is_crc64_validation, parse_validation_option if TYPE_CHECKING: from azure.core import MatchConditions @@ -516,15 +517,11 @@ async def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: If specified, upload_blob only succeeds if the blob's lease is active and matches this ID. @@ -629,6 +626,9 @@ async def upload_blob( raise ValueError("Encryption required but no key was provided.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if is_crc64_validation(validate_content) and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _upload_blob_options( data=data, blob_type=blob_type, @@ -640,6 +640,7 @@ async def upload_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -696,15 +697,11 @@ async def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a @@ -778,6 +775,9 @@ async def download_blob( raise ValueError("Offset value must not be None if length is set.") if kwargs.get('cpk') and self.scheme.lower() != 'https': raise ValueError("Customer provided encryption key must be used over HTTPS.") + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + if is_crc64_validation(validate_content) and self.key_encryption_key: + raise ValueError("Using encryption and content validation together is not currently supported.") options = _download_blob_options( blob_name=self.blob_name, container_name=self.container_name, @@ -791,6 +791,7 @@ async def download_blob( 'key': self.key_encryption_key, 'resolver': self.key_resolver_function }, + validate_content=validate_content, config=self._config, sdk_moniker=self._sdk_moniker, client=self._client, @@ -2053,15 +2054,11 @@ async def stage_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -2896,13 +2893,11 @@ async def upload_page( Required if the blob has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. :paramtype lease: ~azure.storage.blob.aio.BlobLeaseClient or str - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int if_sequence_number_lte: If the blob's sequence number is less than or equal to the specified value, the request proceeds; otherwise it fails. @@ -3204,13 +3199,11 @@ async def append_block( :param int length: Size of the block. Optional if the length of data can be determined. For Iterable and IO, if the length is not provided and cannot be determined, all data will be read into memory. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https, as https (the default), - will already validate. Note that this MD5 hash is not stored with the - blob. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int maxsize_condition: Optional conditional header. The max length in bytes permitted for the append blob. If the Append Block operation would cause the blob diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi index 9c4a37007f65..ba9a7460425c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi @@ -175,7 +175,7 @@ class BlobClient( # type: ignore[misc] tags: Optional[Dict[str, str]] = None, overwrite: bool = False, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[BlobLeaseClient] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -202,7 +202,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -224,7 +224,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -246,7 +246,7 @@ class BlobClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -470,7 +470,7 @@ class BlobClient( # type: ignore[misc] data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, encoding: Optional[str] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, @@ -655,7 +655,7 @@ class BlobClient( # type: ignore[misc] length: int, *, lease: Optional[Union[BlobLeaseClient, str]] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, if_sequence_number_lte: Optional[int] = None, if_sequence_number_lt: Optional[int] = None, if_sequence_number_eq: Optional[int] = None, @@ -725,7 +725,7 @@ class BlobClient( # type: ignore[misc] data: Union[bytes, Iterable[bytes], IO[bytes]], length: Optional[int] = None, *, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, maxsize_condition: Optional[int] = None, appendpos_condition: Optional[int] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index 32260b62b6b5..413ba2cd0f0c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -1015,15 +1015,11 @@ async def upload_blob( :keyword ~azure.storage.blob.ContentSettings content_settings: ContentSettings object used to set blob properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the container has an active lease. Value can be a BlobLeaseClient object or the lease ID as a string. @@ -1261,15 +1257,11 @@ async def download_blob( This keyword argument was introduced in API version '2019-12-12'. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the blob. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the blob has an active lease. If specified, download_blob only succeeds if the blob's lease is active and matches this ID. Value can be a diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi index 56063ac847ff..49362aac8058 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi @@ -16,6 +16,7 @@ from typing import ( Callable, Dict, List, + Literal, IO, Iterable, Optional, @@ -254,7 +255,7 @@ class ContainerClient( # type: ignore[misc] *, overwrite: Optional[bool] = None, content_settings: Optional[ContentSettings] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -296,7 +297,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -318,7 +319,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, @@ -340,7 +341,7 @@ class ContainerClient( # type: ignore[misc] length: Optional[int] = None, *, version_id: Optional[str] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[BlobLeaseClient, str]] = None, if_modified_since: Optional[datetime] = None, if_unmodified_since: Optional[datetime] = None, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index 30cbb0c68fbf..2f7d2a7f1e95 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -24,6 +24,7 @@ from .._shared.request_handlers import validate_and_format_range_headers from .._shared.response_handlers import parse_length_from_content_range, process_storage_error from .._shared.constants import DEFAULT_MAX_CONCURRENCY +from .._shared.validation import is_md5_validation, CV_TYPE_PARSED from .._deserialize import deserialize_blob_properties, get_page_ranges_result from .._download import process_range_and_offset, _ChunkDownloader from .._encryption import ( @@ -124,7 +125,7 @@ async def _download_chunk(self, chunk_start: int, chunk_end: int) -> Tuple[bytes range_header, range_validation = validate_and_format_range_headers( download_range[0], download_range[1], - check_content_md5=self.validate_content + check_content_md5=is_md5_validation(self.validate_content) ) retry_active = True @@ -238,7 +239,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: CV_TYPE_PARSED = None, encryption_options: Dict[str, Any] = None, # type: ignore [assignment] max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] @@ -267,6 +268,7 @@ def __init__( self._file_size = 0 self._non_empty_ranges = None self._encryption_data: Optional["_EncryptionData"] = None + self._is_structured_message = False # The content download offset, after any processing (decryption), in bytes self._download_offset = 0 @@ -318,11 +320,14 @@ async def _setup(self) -> None: await self._get_encryption_data_request() # The service only provides transactional MD5s for chunks under 4MB. - # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. first_get_size = ( - self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size + self._config.max_single_get_size + if not is_md5_validation(self._validate_content) + else self._config.max_chunk_get_size ) + initial_request_start = self._start_range if self._start_range is not None else 0 if self._end_range is not None and self._end_range - initial_request_start < first_get_size: initial_request_end = self._end_range @@ -356,7 +361,7 @@ async def _setup(self) -> None: @property def _download_complete(self): - if is_encryption_v2(self._encryption_data): + if is_encryption_v2(self._encryption_data) or self._is_structured_message: return self._download_offset >= self.size return self._raw_download_offset >= self.size @@ -366,7 +371,7 @@ async def _initial_request(self): self._initial_range[1], start_range_required=False, end_range_required=False, - check_content_md5=self._validate_content + check_content_md5=is_md5_validation(self._validate_content) ) retry_active = True @@ -449,6 +454,7 @@ async def _initial_request(self): except HttpResponseError: pass + self._is_structured_message = response.response.headers.get("x-ms-structured-body") is not None if not self._download_complete and self._request_options.get("modified_access_conditions"): self._request_options["modified_access_conditions"].if_match = response.properties.etag diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py index 794beee36e3b..befdc7f505a0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_upload_helpers.py @@ -32,6 +32,7 @@ upload_data_chunks, upload_substream_blocks ) +from .._shared.validation import CV_TYPE_PARSED from .._upload_helpers import _any_conditions, _convert_mod_error if TYPE_CHECKING: @@ -47,7 +48,7 @@ async def upload_block_blob( # pylint: disable=too-many-locals, too-many-statem encryption_options: Dict[str, Any], blob_settings: "StorageConfiguration", headers: Dict[str, Any], - validate_content: bool, + validate_content: CV_TYPE_PARSED, max_concurrency: Optional[int], length: Optional[int] = None, **kwargs: Any @@ -103,11 +104,15 @@ async def upload_block_blob( # pylint: disable=too-many-locals, too-many-statem return response - use_original_upload_path = blob_settings.use_byte_buffer or \ - validate_content or encryption_options.get('required') or \ - blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold or \ - hasattr(stream, 'seekable') and not stream.seekable() or \ - not hasattr(stream, 'seek') or not hasattr(stream, 'tell') + use_original_upload_path = ( + blob_settings.use_byte_buffer + or validate_content not in (None, False) + or encryption_options.get('required') + or blob_settings.max_block_size < blob_settings.min_large_block_upload_threshold + or hasattr(stream, 'seekable') and not stream.seekable() + or not hasattr(stream, 'seek') + or not hasattr(stream, 'tell') + ) if use_original_upload_path: total_size = length @@ -189,7 +194,7 @@ async def upload_page_blob( headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: @@ -267,7 +272,7 @@ async def upload_append_blob( # pylint: disable=unused-argument headers: Dict[str, Any], stream: IO, length: Optional[int] = None, - validate_content: Optional[bool] = None, + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: diff --git a/sdk/storage/azure-storage-blob/dev_requirements.txt b/sdk/storage/azure-storage-blob/dev_requirements.txt index d6946c1ccfc1..9657377026f0 100644 --- a/sdk/storage/azure-storage-blob/dev_requirements.txt +++ b/sdk/storage/azure-storage-blob/dev_requirements.txt @@ -1,6 +1,7 @@ -e ../../../eng/tools/azure-sdk-tools ../../core/azure-core ../../identity/azure-identity +../azure-storage-extensions azure-mgmt-storage==20.1.0 aiohttp>=3.13.5 diff --git a/sdk/storage/azure-storage-blob/tests/test_append_blob.py b/sdk/storage/azure-storage-blob/tests/test_append_blob.py index d642a6edf2ec..65da562e69e5 100644 --- a/sdk/storage/azure-storage-blob/tests/test_append_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_append_blob.py @@ -28,7 +28,7 @@ generate_blob_sas, ImmutabilityPolicy, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -375,7 +375,7 @@ def test_append_block_from_url_and_validate_content_md5(self, **kwargs): self._setup(bsc) source_blob_data = self.get_random_bytes(LARGE_BLOB_SIZE) source_blob_client = self._create_source_blob(source_blob_data, bsc) - src_md5 = StorageContentValidation.get_content_md5(source_blob_data) + src_md5 = calculate_content_md5(source_blob_data) sas = self.generate_sas( generate_blob_sas, source_blob_client.account_name, @@ -405,9 +405,9 @@ def test_append_block_from_url_and_validate_content_md5(self, **kwargs): # Act part 2: put block from url with wrong md5 with pytest.raises(HttpResponseError): - destination_blob_client.append_block_from_url(source_blob_client.url + '?' + sas, - source_content_md5=StorageContentValidation.get_content_md5( - b"POTATO")) + destination_blob_client.append_block_from_url( + source_blob_client.url + '?' + sas, + source_content_md5=calculate_content_md5(b"POTATO")) @BlobPreparer() @recorded_by_proxy diff --git a/sdk/storage/azure-storage-blob/tests/test_append_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_append_blob_async.py index 2e2bf71bc765..710abb01db9f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_append_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_append_blob_async.py @@ -27,7 +27,7 @@ from azure.mgmt.storage.aio import StorageManagementClient from azure.storage.blob import BlobImmutabilityPolicyMode, BlobSasPermissions, generate_blob_sas, ImmutabilityPolicy from azure.storage.blob import BlobType -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 from azure.storage.blob.aio import BlobClient, BlobServiceClient @@ -379,7 +379,7 @@ async def test_append_block_from_url_and_validate_content_md5(self, **kwargs): await self._setup(bsc) source_blob_data = self.get_random_bytes(LARGE_BLOB_SIZE) source_blob_client = await self._create_source_blob(source_blob_data, bsc) - src_md5 = StorageContentValidation.get_content_md5(source_blob_data) + src_md5 = calculate_content_md5(source_blob_data) sas = self.generate_sas( generate_blob_sas, source_blob_client.account_name, @@ -411,8 +411,7 @@ async def test_append_block_from_url_and_validate_content_md5(self, **kwargs): with pytest.raises(HttpResponseError): await destination_blob_client.append_block_from_url( source_blob_client.url + '?' + sas, - source_content_md5=StorageContentValidation.get_content_md5(b"POTATO") - ) + source_content_md5=calculate_content_md5(b"POTATO")) @BlobPreparer() @recorded_by_proxy_async diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob.py b/sdk/storage/azure-storage-blob/tests/test_block_blob.py index 2f939e5e5665..d2278573ac80 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob.py @@ -33,7 +33,7 @@ ImmutabilityPolicy, StandardBlobTier, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -555,7 +555,7 @@ def test_upload_blob_from_url_with_source_content_md5(self, **kwargs): source_blob = self._create_blob(data=b"This is test data to be copied over.") source_blob_props = source_blob.get_blob_properties() source_md5 = source_blob_props.content_settings.content_md5 - bad_source_md5 = StorageContentValidation.get_content_md5(b"this is bad data") + bad_source_md5 = calculate_content_md5(b"this is bad data") sas = self.generate_sas( generate_blob_sas, account_name=storage_account_name, diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py index c52a301cba20..4f672d5e1b9f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob_async.py @@ -36,7 +36,7 @@ ImmutabilityPolicy, StandardBlobTier, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 from azure.storage.blob.aio import BlobClient, BlobServiceClient @@ -580,7 +580,7 @@ async def test_upload_blob_from_url_with_source_content_md5(self, **kwargs): source_blob = await self._create_blob(data=b"This is test data to be copied over.") source_blob_props = await source_blob.get_blob_properties() source_md5 = source_blob_props.content_settings.content_md5 - bad_source_md5 = StorageContentValidation.get_content_md5(b"this is bad data") + bad_source_md5 = calculate_content_md5(b"this is bad data") sas = self.generate_sas( generate_blob_sas, account_name=storage_account_name, diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy.py b/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy.py index 1aa1bf73a6a6..d13a16cb7fa5 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy.py @@ -22,7 +22,7 @@ StandardBlobTier, StorageErrorCode ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -206,7 +206,7 @@ def test_put_block_from_url_and_validate_content_md5(self, **kwargs): self._setup(storage_account_name, storage_account_key) dest_blob_name = self.get_resource_name('destblob') dest_blob = self.bsc.get_blob_client(self.container_name, dest_blob_name) - src_md5 = StorageContentValidation.get_content_md5(self.source_blob_data) + src_md5 = calculate_content_md5(self.source_blob_data) # Act part 1: put block from url with md5 validation dest_blob.stage_block_from_url( @@ -222,7 +222,7 @@ def test_put_block_from_url_and_validate_content_md5(self, **kwargs): assert len(committed) == 0 # Act part 2: put block from url with wrong md5 - fake_md5 = StorageContentValidation.get_content_md5(b"POTATO") + fake_md5 = calculate_content_md5(b"POTATO") with pytest.raises(HttpResponseError) as error: dest_blob.stage_block_from_url( block_id=2, diff --git a/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy_async.py b/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy_async.py index 03928dbc08fb..46b004f12c29 100644 --- a/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_block_blob_sync_copy_async.py @@ -17,7 +17,7 @@ from azure.core.exceptions import HttpResponseError, ResourceExistsError from azure.storage.blob import BlobSasPermissions, StandardBlobTier, StorageErrorCode, generate_blob_sas from azure.storage.blob.aio import BlobClient, BlobServiceClient -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -167,7 +167,7 @@ async def test_put_block_from_url_and_vldte_content_md5(self, **kwargs): await self._setup(storage_account_name, storage_account_key) dest_blob_name = self.get_resource_name('destblob') dest_blob = self.bsc.get_blob_client(self.container_name, dest_blob_name) - src_md5 = StorageContentValidation.get_content_md5(self.source_blob_data) + src_md5 = calculate_content_md5(self.source_blob_data) # Act part 1: put block from url with md5 validation await dest_blob.stage_block_from_url( @@ -183,7 +183,7 @@ async def test_put_block_from_url_and_vldte_content_md5(self, **kwargs): assert len(committed) == 0 # Act part 2: put block from url with wrong md5 - fake_md5 = StorageContentValidation.get_content_md5(b"POTATO") + fake_md5 = calculate_content_md5(b"POTATO") with pytest.raises(HttpResponseError) as error: await dest_blob.stage_block_from_url( block_id=2, diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py new file mode 100644 index 000000000000..c4327291e3e0 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -0,0 +1,647 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from devtools_testutils import is_live, recorded_by_proxy +from devtools_testutils.storage import ( + GenericTestProxyParametrize1, + GenericTestProxyParametrize2, + StorageRecordedTestCase, +) +from encryption_test_helper import KeyWrapper +from settings.testcase import BlobPreparer + +from azure.core.exceptions import ResourceExistsError +from azure.storage.blob import BlobBlock, BlobClient, BlobServiceClient, BlobType, ContainerClient + + +def assert_content_md5(request): + if ( + request.http_request.query.get("comp") in ("block", "page") + or request.http_request.headers.get("x-ms-blob-type") == "BlockBlob" + ): + assert request.http_request.headers.get("Content-MD5") is not None + + +def assert_content_md5_get(response): + assert response.http_request.headers.get("x-ms-range-get-content-md5") == "true" + assert response.http_response.headers.get("Content-MD5") is not None + + +def assert_content_crc64(request): + if ( + request.http_request.query.get("comp") in ("block", "page") + or request.http_request.headers.get("x-ms-blob-type") == "BlockBlob" + ): + assert request.http_request.headers.get("x-ms-content-crc64") is not None + + +def assert_structured_message(request): + if ( + request.http_request.query.get("comp") in ("block", "page") + or request.http_request.headers.get("x-ms-blob-type") == "BlockBlob" + ): + assert request.http_request.headers.get("x-ms-structured-body") is not None + + +def assert_structured_message_get(response): + assert response.http_request.headers.get("x-ms-structured-body") is not None + assert response.http_response.headers.get("x-ms-structured-body") is not None + + +class TestIter: + def __init__(self, data, *, chunk_size=100): + self.data = data + self.chunk_size = chunk_size + self.length = len(data) + self.offset = 0 + + def __len__(self): + return self.length + + def __iter__(self): + return self + + def __next__(self): + if self.offset >= self.length: + raise StopIteration + + result = self.data[self.offset : self.offset + self.chunk_size] + self.offset += len(result) + return result + + +class TestStorageContentValidation(StorageRecordedTestCase): + bsc: BlobServiceClient + container: ContainerClient + + def _setup(self, account_name): + token_credential = self.get_credential(BlobServiceClient) + self.bsc = BlobServiceClient(self.account_url(account_name, "blob"), token_credential, logging_enable=True) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) + try: + self.container.create_container() + except ResourceExistsError: + pass + + def teardown_method(self, _): + if self.container and is_live(): + try: + self.container.delete_container() + except: + pass + + def _get_blob_reference(self): + return self.get_resource_name("blob") + + @BlobPreparer() + def test_encryption_blocked_crc64(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + kek = KeyWrapper("key1") + blob = BlobClient( + self.account_url(storage_account_name, "blob"), + "testing", + "testing", + credential=self.get_credential(BlobServiceClient), + require_encryption=True, + encryption_version="2.0", + key_encryption_key=kek, + ) + + with pytest.raises(ValueError): + blob.upload_blob(b"123", validate_content="crc64") + + # Needed for teardown + self.container = None + + @BlobPreparer() + @pytest.mark.parametrize("a", [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize("b", [True, "auto", "md5", "crc64"]) # b: validate_content + @GenericTestProxyParametrize2() + @recorded_by_proxy + def test_upload_blob(self, a, b, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if b in ("auto", "crc64") else assert_content_md5 + + # Test supported data types + byte_data = b"abc" * 512 + str_data = "你好世界abcd" * 32 + str_data_encoded = str_data.encode("utf-8") + byte_stream = BytesIO(byte_data) + byte_iter = TestIter(byte_data) + str_iter = TestIter(str_data) + + blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob( + str_data, blob_type=a, encoding="utf-8", validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert blob.download_blob().read() == str_data_encoded + blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob( + str_iter, + blob_type=a, + length=len(str_data_encoded), + encoding="utf-8", + validate_content=b, + overwrite=True, + raw_request_hook=assert_method, + ) + assert blob.download_blob().read() == str_data_encoded + + @BlobPreparer() + @pytest.mark.parametrize("a", [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize("b", [True, "md5", "crc64"]) # b: validate_content + @GenericTestProxyParametrize2() + @recorded_by_proxy + def test_upload_blob_chunks(self, a, b, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_put_size = 512 + self.container._config.max_block_size = 512 + self.container._config.max_page_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if b == "crc64" else assert_content_md5 + + # Test supported data types + byte_data = b"abc" * 512 + str_data = "你好世界abcd" * 32 + str_data_encoded = str_data.encode("utf-8") + byte_stream = BytesIO(byte_data) + byte_iter = TestIter(byte_data) + str_iter = TestIter(str_data) + + blob.upload_blob(byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob( + str_data, blob_type=a, encoding="utf-8", validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert blob.download_blob().read() == str_data_encoded + blob.upload_blob(byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob(byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method) + assert blob.download_blob().read() == byte_data + blob.upload_blob( + str_iter, + blob_type=a, + length=len(str_data_encoded), + encoding="utf-8", + validate_content=b, + overwrite=True, + raw_request_hook=assert_method, + ) + assert blob.download_blob().read() == str_data_encoded + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_blob_substream(self, a, **kwargs): + # Substream is disabled when using content validation so this will behave like regular upload (buffer) + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_put_size = 512 + self.container._config.max_block_size = 512 + self.container._config.min_large_block_upload_threshold = 1 # Set less than block size to enable substream + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + + data = b"abc" * 512 + b"abcde" + io = BytesIO(data) + + # Act + blob.upload_blob(io, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = blob.download_blob() + assert content.read() == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_stage_block(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b"abc" * 512 + data2 = "你好世界" * 10 + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i : i + 500] + + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + blob.stage_block("1", data1, validate_content=a, raw_request_hook=assert_method) + blob.stage_block("2", data2, encoding="utf-8-sig", validate_content=a, raw_request_hook=assert_method) + blob.stage_block("3", generator(), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock("1"), BlobBlock("2"), BlobBlock("3")]) + + # Assert + content = blob.download_blob() + assert content.read() == data1 + data2.encode("utf-8-sig") + data1 + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_stage_block_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + blob.stage_block("1", BytesIO(content), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock("1")]) + + result = blob.download_blob() + assert result.read() == content + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @pytest.mark.live_test_only + def test_stage_block_streaming_large(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + blob.stage_block("1", BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + blob.stage_block("2", BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + blob.stage_block("3", BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + blob.commit_block_list([BlobBlock("1"), BlobBlock("2"), BlobBlock("3")]) + + result = blob.download_blob() + assert result.read() == data1 + data2 + data3 + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_append_block(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b"abc" * 512 + data2 = "你好世界" * 10 + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i : i + 500] + + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + blob.create_append_blob() + blob.append_block(data1, validate_content=a, raw_request_hook=assert_method) + blob.append_block(data2, encoding="utf-16", validate_content=a, raw_request_hook=assert_method) + blob.append_block(generator(), validate_content=a, raw_request_hook=assert_method) + + content = blob.download_blob() + assert content.read() == data1 + data2.encode("utf-16") + data1 + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_append_block_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + blob.create_append_blob() + blob.append_block(BytesIO(content), validate_content=a, raw_request_hook=assert_method) + + result = blob.download_blob() + assert result.read() == content + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @pytest.mark.live_test_only + def test_append_block_streaming_large(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + blob.create_append_blob() + blob.append_block(BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + blob.append_block(BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + blob.append_block(BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + + result = blob.download_blob() + assert result.read() == data1 + data2 + data3 + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_page(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b"abc" * 512 + data2 = "你好世界abcd" * 32 + data2_encoded = data2.encode("utf-8") + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + # Act + blob.create_page_blob(5 * 1024) + blob.upload_page(data1, offset=0, length=len(data1), validate_content=a, raw_request_hook=assert_method) + blob.upload_page( + data2, + offset=len(data1), + length=len(data2_encoded), + encoding="utf-8", + validate_content=a, + raw_request_hook=assert_method, + ) + + # Assert + content = blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) + assert content.read() == data1 + data2_encoded + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_blob(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b"abc" * 512 + blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get + + # Act + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_blob_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b"abc" * 512 + b"abcde" + blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + read_content = b"" + downloader = blob.download_blob(validate_content=a, raw_response_hook=assert_method) + for _ in range(len(data) // 100 + 1): + read_content += downloader.read(100) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_blob_chunks_partial(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b"abc" * 512 + b"abcde" + blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = blob.download_blob(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[10:1010] + + @BlobPreparer() + @pytest.mark.live_test_only + def test_download_blob_large_chunks(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.container._config.max_chunk_get_size = 10 * 1024 * 1024 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 + blob.upload_blob(data, overwrite=True, max_concurrency=5) + + # Act + downloader = blob.download_blob(validate_content="crc64", max_concurrency=5) + content = downloader.read() + + downloader = blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") + partial = downloader.read() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_blob_chars(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + + data = "你好世界" * 256 # 3 KiB + blob = self.container.get_blob_client(self._get_blob_reference()) + blob.upload_blob(data, encoding="utf-8", overwrite=True) + + stream = blob.download_blob(encoding="utf-8", validate_content=a) + assert stream.read() == data + + stream = blob.download_blob(encoding="utf-8", validate_content=a) + assert stream.read(chars=100000) == data + + result = "" + stream = blob.download_blob(encoding="utf-8", validate_content=a) + for _ in range(4): + chunk = stream.read(chars=100) + result += chunk + assert len(chunk) == 100 + + result += stream.readall() + assert result == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_content_validation_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True, + ) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) + try: + self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b"abc" * 512 + + # Determine the appropriate assert methods based on validation mode + upload_assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + download_assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Test upload with retry + upload_call_count = 0 + + def upload_hook_fail_once(response): + nonlocal upload_call_count + upload_call_count += 1 + # Assert content validation headers are present on both attempts + upload_assert_method(response) + if upload_call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + blob.upload_blob(data, validate_content=a, overwrite=True, raw_response_hook=upload_hook_fail_once) + assert upload_call_count == 2 # Original + retry + assert blob.download_blob().read() == data + + # Test download with retry + download_call_count = 0 + + def download_hook_fail_once(response): + nonlocal download_call_count + download_call_count += 1 + # Assert content validation headers are present on both attempts + download_assert_method(response) + if download_call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + downloader = blob.download_blob(validate_content=a, raw_response_hook=download_hook_fail_once) + content = downloader.read() + assert download_call_count == 2 # Original + retry + assert content == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_streaming_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True, + ) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) + try: + self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b"abc" * 512 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + call_count = 0 + + def hook_fail_once(response): + nonlocal call_count + call_count += 1 + # Assert content validation headers are present on both attempts + assert_method(response) + if call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + # Use stage_block to test structured message streaming + blob.stage_block("1", BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) + assert call_count == 2 # Original + retry + + blob.commit_block_list([BlobBlock("1")]) + result = blob.download_blob() + assert result.read() == content diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py new file mode 100644 index 000000000000..e22a03b1fe75 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -0,0 +1,630 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from devtools_testutils import is_live +from devtools_testutils.aio import recorded_by_proxy_async +from devtools_testutils.storage.aio import ( + AsyncStorageRecordedTestCase, + GenericTestProxyParametrize1, + GenericTestProxyParametrize2, +) +from encryption_test_helper import KeyWrapper +from settings.testcase import BlobPreparer +from test_content_validation import ( + assert_content_crc64, + assert_content_md5, + assert_content_md5_get, + assert_structured_message, + assert_structured_message_get, + TestIter, +) + +from azure.core.exceptions import ResourceExistsError +from azure.storage.blob import BlobBlock, BlobType, ContainerClient as SyncContainerClient +from azure.storage.blob.aio import BlobClient, BlobServiceClient, ContainerClient + + +class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): + bsc: BlobServiceClient + container: ContainerClient + + async def _setup(self, account_name): + token_credential = self.get_credential(BlobServiceClient, is_async=True) + self.bsc = BlobServiceClient(self.account_url(account_name, "blob"), token_credential, logging_enable=True) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) + try: + await self.container.create_container() + except ResourceExistsError: + pass + + def teardown_method(self, _): + if not is_live(): + return + # Use sync client as teardown_method must be sync + if self.container: + sync_credential = self.get_credential(SyncContainerClient, is_async=False) + sync_container = SyncContainerClient.from_container_url(self.container.url, credential=sync_credential) + + try: + sync_container.delete_container() + except: + pass + + def _get_blob_reference(self): + return self.get_resource_name("blob") + + @BlobPreparer() + async def test_encryption_blocked_crc64(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + kek = KeyWrapper("key1") + blob = BlobClient( + self.account_url(storage_account_name, "blob"), + "testing", + "testing", + credential=self.get_credential(BlobServiceClient, is_async=True), + require_encryption=True, + encryption_version="2.0", + key_encryption_key=kek, + ) + + with pytest.raises(ValueError): + await blob.upload_blob(b"123", validate_content="crc64") + + # Needed for teardown + self.container = None + + @BlobPreparer() + @pytest.mark.parametrize("a", [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize("b", [True, "auto", "md5", "crc64"]) # b: validate_content + @GenericTestProxyParametrize2() + @recorded_by_proxy_async + async def test_upload_blob(self, a, b, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if b in ("auto", "crc64") else assert_content_md5 + + # Test supported data types + byte_data = b"abc" * 512 + str_data = "你好世界abcd" * 32 + str_data_encoded = str_data.encode("utf-8") + byte_stream = BytesIO(byte_data) + byte_iter = TestIter(byte_data) + str_iter = TestIter(str_data) + + # Act / Assert + await blob.upload_blob( + byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob( + str_data, blob_type=a, encoding="utf-8", validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert await (await blob.download_blob()).read() == str_data_encoded + await blob.upload_blob( + byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob( + byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob( + str_iter, + blob_type=a, + length=len(str_data_encoded), + encoding="utf-8", + validate_content=b, + overwrite=True, + raw_request_hook=assert_method, + ) + assert await (await blob.download_blob()).read() == str_data_encoded + + @BlobPreparer() + @pytest.mark.parametrize("a", [BlobType.BLOCKBLOB, BlobType.PAGEBLOB, BlobType.APPENDBLOB]) # a: blob_type + @pytest.mark.parametrize("b", [True, "md5", "crc64"]) # b: validate_content + @GenericTestProxyParametrize2() + @recorded_by_proxy_async + async def test_upload_blob_chunks(self, a, b, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_put_size = 512 + self.container._config.max_block_size = 512 + self.container._config.max_page_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if b == "crc64" else assert_content_md5 + + # Test supported data types + byte_data = b"abc" * 512 + str_data = "你好世界abcd" * 32 + str_data_encoded = str_data.encode("utf-8") + byte_stream = BytesIO(byte_data) + byte_iter = TestIter(byte_data) + str_iter = TestIter(str_data) + + # Act / Assert + await blob.upload_blob( + byte_data, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob( + str_data, blob_type=a, encoding="utf-8", validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert await (await blob.download_blob()).read() == str_data_encoded + await blob.upload_blob( + byte_stream, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob( + byte_iter, blob_type=a, validate_content=b, overwrite=True, raw_request_hook=assert_method + ) + assert await (await blob.download_blob()).read() == byte_data + await blob.upload_blob( + str_iter, + blob_type=a, + length=len(str_data_encoded), + encoding="utf-8", + validate_content=b, + overwrite=True, + raw_request_hook=assert_method, + ) + assert await (await blob.download_blob()).read() == str_data_encoded + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_blob_substream(self, a, **kwargs): + # Substream is disabled when using content validation so this will behave like regular upload (buffer) + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_put_size = 512 + self.container._config.max_block_size = 512 + self.container._config.min_large_block_upload_threshold = 1 # Set less than block size to enable substream + blob = self.container.get_blob_client(self._get_blob_reference()) + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + + data = b"abc" * 512 + b"abcde" + io = BytesIO(data) + + # Act + await blob.upload_blob(io, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await blob.download_blob() + assert await content.read() == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_stage_block(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b"abc" * 512 + data2 = "你好世界" * 10 + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i : i + 500] + + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + # Act + await blob.stage_block("1", data1, validate_content=a, raw_request_hook=assert_method) + await blob.stage_block("2", data2, encoding="utf-8-sig", validate_content=a, raw_request_hook=assert_method) + await blob.stage_block("3", generator(), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock("1"), BlobBlock("2"), BlobBlock("3")]) + + # Assert + content = await blob.download_blob() + assert await content.read() == data1 + data2.encode("utf-8-sig") + data1 + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_stage_block_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + await blob.stage_block("1", BytesIO(content), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock("1")]) + + # Assert + result = await blob.download_blob() + assert await result.read() == content + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @pytest.mark.live_test_only + async def test_stage_block_streaming_large(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + await blob.stage_block("1", BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + await blob.stage_block("2", BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + await blob.stage_block("3", BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + await blob.commit_block_list([BlobBlock("1"), BlobBlock("2"), BlobBlock("3")]) + + result = await blob.download_blob() + assert await result.read() == data1 + data2 + data3 + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_append_block(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b"abc" * 512 + data2 = "你好世界" * 10 + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i : i + 500] + + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + # Act + await blob.create_append_blob() + await blob.append_block(data1, validate_content=a, raw_request_hook=assert_method) + await blob.append_block(data2, encoding="utf-16", validate_content=a, raw_request_hook=assert_method) + await blob.append_block(generator(), validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await blob.download_blob() + assert await content.readall() == data1 + data2.encode("utf-16") + data1 + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_append_block_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + await blob.create_append_blob() + await blob.append_block(BytesIO(content), validate_content=a, raw_request_hook=assert_method) + + result = await blob.download_blob() + assert await result.read() == content + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @pytest.mark.live_test_only + async def test_append_block_streaming_large(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + await blob.create_append_blob() + await blob.append_block(BytesIO(data1), validate_content=a, raw_request_hook=assert_method) + await blob.append_block(BytesIO(data2), validate_content=a, raw_request_hook=assert_method) + await blob.append_block(BytesIO(data3), validate_content=a, raw_request_hook=assert_method) + + result = await blob.download_blob() + assert await result.read() == data1 + data2 + data3 + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_page(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data1 = b"abc" * 512 + data2 = "你好世界abcd" * 32 + data2_encoded = data2.encode("utf-8") + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + # Act + await blob.create_page_blob(5 * 1024) + await blob.upload_page(data1, offset=0, length=len(data1), validate_content=a, raw_request_hook=assert_method) + await blob.upload_page( + data2, + offset=len(data1), + length=len(data2_encoded), + encoding="utf-8", + validate_content=a, + raw_request_hook=assert_method, + ) + + # Assert + content = await blob.download_blob(offset=0, length=len(data1) + len(data2_encoded)) + assert await content.read() == data1 + data2_encoded + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_blob(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b"abc" * 512 + await blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get + + # Act + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_blob_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b"abc" * 512 + b"abcde" + await blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + read_content = b"" + downloader = await blob.download_blob(validate_content=a, raw_response_hook=assert_method) + for _ in range(len(data) // 100 + 1): + read_content += await downloader.read(100) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_blob_chunks_partial(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b"abc" * 512 + b"abcde" + await blob.upload_blob(data, overwrite=True) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = await blob.download_blob( + offset=10, length=1000, validate_content=a, raw_response_hook=assert_method + ) + content = await downloader.read() + + stream = BytesIO() + downloader = await blob.download_blob( + offset=10, length=1000, validate_content=a, raw_response_hook=assert_method + ) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[10:1010] + + @BlobPreparer() + @pytest.mark.live_test_only + async def test_download_blob_large_chunks(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + blob = self.container.get_blob_client(self._get_blob_reference()) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.container._config.max_chunk_get_size = 10 * 1024 * 1024 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 + await blob.upload_blob(data, overwrite=True, max_concurrency=5) + + # Act + downloader = await blob.download_blob(validate_content="crc64", max_concurrency=5) + content = await downloader.read() + + downloader = await blob.download_blob(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") + partial = await downloader.read() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_blob_chars(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.container._config.max_single_get_size = 512 + self.container._config.max_chunk_get_size = 512 + + data = "你好世界" * 256 # 3 KiB + blob = self.container.get_blob_client(self._get_blob_reference()) + await blob.upload_blob(data, encoding="utf-8", overwrite=True) + + stream = await blob.download_blob(encoding="utf-8", validate_content=a) + assert await stream.read() == data + + stream = await blob.download_blob(encoding="utf-8", validate_content=a) + assert await stream.read(chars=100000) == data + + result = "" + stream = await blob.download_blob(encoding="utf-8", validate_content=a) + for _ in range(4): + chunk = await stream.read(chars=100) + result += chunk + assert len(chunk) == 100 + + result += await stream.readall() + assert result == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_content_validation_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient, is_async=True) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True, + ) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) + try: + await self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + data = b"abc" * 512 + + # Determine the appropriate assert methods based on validation mode + upload_assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + download_assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Test upload with retry + upload_call_count = 0 + + def upload_hook_fail_once(response): + nonlocal upload_call_count + upload_call_count += 1 + # Assert content validation headers are present on both attempts + upload_assert_method(response) + if upload_call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + await blob.upload_blob(data, validate_content=a, overwrite=True, raw_response_hook=upload_hook_fail_once) + assert upload_call_count == 2 # Original + retry + assert await (await blob.download_blob()).read() == data + + # Test download with retry + download_call_count = 0 + + def download_hook_fail_once(response): + nonlocal download_call_count + download_call_count += 1 + # Assert content validation headers are present on both attempts + download_assert_method(response) + if download_call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + downloader = await blob.download_blob(validate_content=a, raw_response_hook=download_hook_fail_once) + content = await downloader.read() + assert download_call_count == 2 # Original + retry + assert content == data + + @BlobPreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_streaming_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient, is_async=True) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True, + ) + self.container = self.bsc.get_container_client(self.get_resource_name("utcontainer")) + try: + await self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b"abc" * 512 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Test stage_block streaming with retry + call_count = 0 + + def hook_fail_once(response): + nonlocal call_count + call_count += 1 + # Assert content validation headers are present on both attempts + assert_method(response) + if call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + # Use stage_block to test structured message streaming + await blob.stage_block("1", BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) + assert call_count == 2 # Original + retry + + await blob.commit_block_list([BlobBlock("1")]) + result = await blob.download_blob() + assert await result.read() == content diff --git a/sdk/storage/azure-storage-blob/tests/test_page_blob.py b/sdk/storage/azure-storage-blob/tests/test_page_blob.py index 8caaf49412f8..f246c400f24c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_page_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_page_blob.py @@ -22,7 +22,7 @@ BlobClient, BlobImmutabilityPolicyMode, BlobProperties, BlobSasPermissions, BlobServiceClient, BlobType, generate_blob_sas, ImmutabilityPolicy, PremiumPageBlobTier, SequenceNumberAction, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 # ------------------------------------------------------------------------------ @@ -639,7 +639,7 @@ def test_upload_pages_from_url_and_validate_content_md5(self, **kwargs): self._setup(bsc) source_blob_data = self.get_random_bytes(SOURCE_BLOB_SIZE) source_blob_client = self._create_source_blob(bsc, source_blob_data, 0, SOURCE_BLOB_SIZE) - src_md5 = StorageContentValidation.get_content_md5(source_blob_data) + src_md5 = calculate_content_md5(source_blob_data) sas = self.generate_sas( generate_blob_sas, source_blob_client.account_name, @@ -669,12 +669,12 @@ def test_upload_pages_from_url_and_validate_content_md5(self, **kwargs): # Act part 2: put block from url with wrong md5 with pytest.raises(HttpResponseError): - destination_blob_client.upload_pages_from_url(source_blob_client.url + "?" + sas, - offset=0, - length=SOURCE_BLOB_SIZE, - source_offset=0, - source_content_md5=StorageContentValidation.get_content_md5( - b"POTATO")) + destination_blob_client.upload_pages_from_url( + source_blob_client.url + "?" + sas, + offset=0, + length=SOURCE_BLOB_SIZE, + source_offset=0, + source_content_md5=calculate_content_md5(b"POTATO")) @BlobPreparer() @recorded_by_proxy diff --git a/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py index bc7b7f6f8fe4..af6ab9c7c46c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_page_blob_async.py @@ -22,7 +22,7 @@ BlobImmutabilityPolicyMode, BlobProperties, BlobSasPermissions, BlobType, generate_blob_sas, ImmutabilityPolicy, PremiumPageBlobTier, SequenceNumberAction, ) -from azure.storage.blob._shared.policies import StorageContentValidation +from azure.storage.blob._shared.validation import calculate_content_md5 from azure.storage.blob.aio import BlobClient, BlobServiceClient @@ -619,7 +619,7 @@ async def test_upload_pages_from_url_and_validate_content_md5(self, **kwargs): await self._setup(bsc) source_blob_data = self.get_random_bytes(SOURCE_BLOB_SIZE) source_blob_client = await self._create_source_blob(bsc, source_blob_data, 0, SOURCE_BLOB_SIZE) - src_md5 = StorageContentValidation.get_content_md5(source_blob_data) + src_md5 = calculate_content_md5(source_blob_data) sas = self.generate_sas( generate_blob_sas, source_blob_client.account_name, @@ -654,7 +654,7 @@ async def test_upload_pages_from_url_and_validate_content_md5(self, **kwargs): 0, SOURCE_BLOB_SIZE, 0, - source_content_md5=StorageContentValidation.get_content_md5(b"POTATO") + source_content_md5=calculate_content_md5(b"POTATO") ) @BlobPreparer() diff --git a/sdk/storage/azure-storage-blob/tests/test_streams.py b/sdk/storage/azure-storage-blob/tests/test_streams.py new file mode 100644 index 000000000000..11f63b3be962 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_streams.py @@ -0,0 +1,580 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import os +import random +from io import BytesIO, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from typing import Iterator, List, Optional, Tuple, Union + +import pytest +from test_helpers import NonSeekableStream + +from azure.storage.blob._shared.streams import ( + StructuredMessageConstants, + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from azure.storage.extensions import crc64 + + +def _iter_bytes(data: bytes, chunk_size: int = 1024) -> Iterator[bytes]: + """Convert bytes to an Iterator[bytes] with the given chunk size.""" + for i in range(0, len(data), chunk_size): + yield data[i:i + chunk_size] + + +def _write_segment( + number: int, + data: bytes, + data_crc: Optional[int], + stream: BytesIO, +) -> None: + stream.write(number.to_bytes(2, 'little')) # Segment number + stream.write(len(data).to_bytes(8, 'little')) # Segment length + stream.write(data) # Segment content + if data_crc is not None: + stream.write(data_crc.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little')) + + +def _build_structured_message( + data: bytes, + segment_size: Union[int, List[int]], + flags: StructuredMessageProperties, + invalidate_crc_segment: Optional[int] = None, +) -> Tuple[BytesIO, int]: + if isinstance(segment_size, list): + segment_count = len(segment_size) + else: + segment_count = math.ceil(len(data) / segment_size) or 1 + segment_footer_length = StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0 + + message_length = ( + StructuredMessageConstants.V1_HEADER_LENGTH + + ((StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + segment_footer_length) * segment_count) + + len(data) + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) + + message = BytesIO() + message_crc = 0 + + # Message Header + message.write(b'\x01') # Version + message.write(message_length.to_bytes(8, 'little')) # Message length + message.write(int(flags).to_bytes(2, 'little')) # Flags + message.write(segment_count.to_bytes(2, 'little')) # Num segments + + # Special case for 0 length content + if len(data) == 0: + crc = 0 if StructuredMessageProperties.CRC64 in flags else None + _write_segment(1, data, crc, message) + else: + # Segments + segment_sizes = segment_size if isinstance(segment_size, list) else [segment_size] * segment_count + offset = 0 + for i in range(1, segment_count + 1): + size = segment_sizes[i - 1] + segment_data = data[offset: offset + size] + offset += size + + segment_crc = None + if StructuredMessageProperties.CRC64 in flags: + segment_crc = crc64.compute(segment_data, 0) # pylint: disable=I1101 + if i == invalidate_crc_segment: + segment_crc += 5 + _write_segment(i, segment_data, segment_crc, message) + + message_crc = crc64.compute(segment_data, message_crc) # pylint: disable=I1101 + + # Message footer + if StructuredMessageProperties.CRC64 in flags: + if invalidate_crc_segment == -1: + message_crc += 5 + message.write(message_crc.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little')) + + message.seek(0, 0) + return message, message_length + + +class TestStructuredMessageEncodeStream: + def test_close(self): + inner = BytesIO() + stream = StructuredMessageEncodeStream(inner, 0, StructuredMessageProperties.NONE) + assert not stream.closed + assert not inner.closed + + stream.close() # no-op + assert not stream.closed + assert not inner.closed + assert stream.read(1) is not None + + inner.close() # closing inner will block reads + with pytest.raises(ValueError): + stream.read(0) + + def test_read_past_end(self): + data = os.urandom(10) + inner_stream = BytesIO(data) + + stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + expected = _build_structured_message(data, len(data), StructuredMessageProperties.CRC64)[0].getvalue() + + result = stream.read(100) + assert result == expected + + result = stream.read(100) + assert result == b'' + + @pytest.mark.parametrize("size, segment_size, flags", [ + (0, 1, StructuredMessageProperties.NONE), + (0, 1, StructuredMessageProperties.CRC64), + (10, 1, StructuredMessageProperties.NONE), + (10, 1, StructuredMessageProperties.CRC64), + (1024, 1024, StructuredMessageProperties.NONE), + (1024, 1024, StructuredMessageProperties.CRC64), + (1024, 512, StructuredMessageProperties.NONE), + (1024, 512, StructuredMessageProperties.CRC64), + (1024, 200, StructuredMessageProperties.NONE), + (1024, 200, StructuredMessageProperties.CRC64), + (123456, 1234, StructuredMessageProperties.NONE), + (123456, 1234, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_read_all(self, size, segment_size, flags): + data = os.urandom(size) + inner_stream = BytesIO(data) + + stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) + actual = stream.read() + + expected = _build_structured_message(data, segment_size, flags)[0].getvalue() + assert actual == expected + + @pytest.mark.parametrize("size, segment_size, chunk_size, flags", [ + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (12345, 678, 90, StructuredMessageProperties.NONE), + (12345, 678, 90, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_read_chunks(self, size, segment_size, chunk_size, flags): + data = os.urandom(size) + inner_stream = BytesIO(data) + + stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) + + read = 0 + content = b'' + while read < len(stream): + chunk = stream.read(chunk_size) + assert len(chunk) == min(chunk_size, len(stream) - read) + + content += chunk + read += chunk_size + + expected = _build_structured_message(data, segment_size, flags)[0].getvalue() + assert content == expected + + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + def test_random_reads(self, data_size): + data = os.urandom(data_size) + inner_stream = BytesIO(data) + + segment_size = data_size // 3 + stream = StructuredMessageEncodeStream( + inner_stream, + len(data), + StructuredMessageProperties.CRC64, + segment_size=segment_size) + + count = 0 + content = b'' + stream_size = len(stream) + while count < stream_size: + read_size = random.randint(10, stream_size // 3) + read_size = min(read_size, stream_size - count) + + content += stream.read(read_size) + count += read_size + + expected = _build_structured_message(data, segment_size, StructuredMessageProperties.CRC64)[0].getvalue() + assert content == expected + + def test_seekable(self): + data = os.urandom(10) + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + + assert sm_stream.seekable() + + def test_not_seekable(self): + data = os.urandom(10) + inner_stream = NonSeekableStream(BytesIO(data)) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + + assert not sm_stream.seekable() + with pytest.raises(UnsupportedOperation): + sm_stream.seek(0) + + def test_seek(self): + data = os.urandom(10) + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + sm_stream.read(25) + + with pytest.raises(UnsupportedOperation): + sm_stream.seek(5) + + with pytest.raises(UnsupportedOperation): + sm_stream.seek(0, SEEK_CUR) + + with pytest.raises(UnsupportedOperation): + sm_stream.seek(0, SEEK_END) + + # Only SEEK_SET to 0 is supported + pos = sm_stream.seek(0, SEEK_SET) + assert pos == 0 + + @pytest.mark.parametrize("initial_read, segment_size, flags", [ + # Single segment + (5000, 2048, StructuredMessageProperties.NONE), # End -> Beginning + (5000, 2048, StructuredMessageProperties.CRC64), # End -> Beginning + (5, 2048, StructuredMessageProperties.NONE), # Message header + (5, 2048, StructuredMessageProperties.CRC64), + (20, 2048, StructuredMessageProperties.NONE), # Segment header + (20, 2048, StructuredMessageProperties.CRC64), + (100, 2048, StructuredMessageProperties.NONE), # First segment content + (100, 2048, StructuredMessageProperties.CRC64), + (1000, 2048, StructuredMessageProperties.NONE), # Second segment content + (1000, 2048, StructuredMessageProperties.CRC64), + (525, 2048, StructuredMessageProperties.CRC64), # Segment footer + (1092, 2048, StructuredMessageProperties.CRC64), # Message footer + # Multiple segments + (5000, 500, StructuredMessageProperties.NONE), # End -> Beginning + (5000, 500, StructuredMessageProperties.CRC64), # End -> Beginning + (5, 500, StructuredMessageProperties.NONE), # Message header + (5, 500, StructuredMessageProperties.CRC64), + (20, 500, StructuredMessageProperties.NONE), # Segment header + (20, 500, StructuredMessageProperties.CRC64), + (100, 500, StructuredMessageProperties.NONE), # First segment content + (100, 500, StructuredMessageProperties.CRC64), + (1000, 500, StructuredMessageProperties.NONE), # Second segment content + (1000, 500, StructuredMessageProperties.CRC64), + (525, 500, StructuredMessageProperties.CRC64), # Segment footer + (1092, 500, StructuredMessageProperties.CRC64), # Message footer + ]) + def test_seek_reverse_beginning(self, initial_read, segment_size, flags): + data = os.urandom(1024) + inner_stream = BytesIO(data) + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) + expected = _build_structured_message(data, segment_size, flags)[0].getvalue() + + initial = sm_stream.read(initial_read) + assert initial == expected[:initial_read] + + sm_stream.seek(0) + result = sm_stream.read() + assert result == expected + + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_partial_stream_read(self, flags): + data = os.urandom(1024) + partial_read = 100 + + inner_stream = BytesIO(data) + inner_stream.seek(partial_read) + expected = _build_structured_message(data[partial_read:], 500, flags)[0].getvalue() + + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data) - partial_read, flags, segment_size=500) + result = sm_stream.read() + assert result == expected + + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_partial_stream_seek_beginning(self, flags): + data = os.urandom(1024) + partial_read = 100 + + inner_stream = BytesIO(data) + inner_stream.seek(partial_read) + expected = _build_structured_message(data[partial_read:], 500, flags)[0].getvalue() + + sm_stream = StructuredMessageEncodeStream(inner_stream, len(data) - partial_read, flags, segment_size=500) + initial = sm_stream.read(101) + assert initial == expected[:101] + + sm_stream.seek(0) + assert inner_stream.tell() == partial_read + + result = sm_stream.read() + assert result == expected + + + + +class TestStructuredMessageDecoder: + + def test_empty_inner_stream(self): + with pytest.raises(ValueError): + StructuredMessageDecoder(iter([b'']), 0) + + def test_read_past_end(self): + data = os.urandom(10) + message_stream, length = _build_structured_message(data, len(data), StructuredMessageProperties.CRC64) + + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) + result = stream.read(100) + assert result == data + + result = stream.read(100) + assert result == b'' + + @pytest.mark.parametrize("size, segment_size, flags", [ + (0, 1, StructuredMessageProperties.NONE), + (0, 1, StructuredMessageProperties.CRC64), + (10, 1, StructuredMessageProperties.NONE), + (10, 1, StructuredMessageProperties.CRC64), + (1024, 1024, StructuredMessageProperties.NONE), + (1024, 1024, StructuredMessageProperties.CRC64), + (1024, 512, StructuredMessageProperties.NONE), + (1024, 512, StructuredMessageProperties.CRC64), + (1024, 200, StructuredMessageProperties.NONE), + (1024, 200, StructuredMessageProperties.CRC64), + (123456, 1234, StructuredMessageProperties.NONE), + (123456, 1234, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_read_all(self, size, segment_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) + content = stream.read() + + assert content == data + + @pytest.mark.parametrize("size, segment_size, chunk_size, flags", [ + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_read_chunks(self, size, segment_size, chunk_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) + read = 0 + content = b'' + while read < len(data): + chunk = stream.read(chunk_size) + content += chunk + read += chunk_size + + assert content == data + + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + def test_random_reads(self, data_size): + data = os.urandom(data_size) + message_stream, length = _build_structured_message(data, data_size // 3, StructuredMessageProperties.CRC64) + + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) + + count = 0 + content = b'' + while count < data_size: + read_size = random.randint(10, data_size // 3) + read_size = min(read_size, data_size - count) + + content += stream.read(read_size) + count += read_size + + assert content == data + + @pytest.mark.parametrize("size, segment_size, block_size, flags", [ + (0, 1, 1, StructuredMessageProperties.NONE), + (0, 1, 1, StructuredMessageProperties.CRC64), + (1, 1, 1, StructuredMessageProperties.NONE), + (1, 1, 1, StructuredMessageProperties.CRC64), + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (1024, 200, 1024, StructuredMessageProperties.NONE), + (1024, 200, 1024, StructuredMessageProperties.CRC64), + (1024, 200, 50, StructuredMessageProperties.NONE), + (1024, 200, 50, StructuredMessageProperties.CRC64), + (100 * 1024, 4 * 1024, 7 * 1024, StructuredMessageProperties.NONE), + (100 * 1024, 4 * 1024, 7 * 1024, StructuredMessageProperties.CRC64), + ]) + def test_iterate(self, size, segment_size, block_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + decoder = StructuredMessageDecoder( + _iter_bytes(message_stream.getvalue()), length, block_size=block_size) + content = b''.join(decoder) + + assert content == data + + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + def test_random_segment_sizes(self, data_size): + segment_sizes = [] + count = 0 + while count < data_size: + size = random.randint(10, data_size // 3) + size = min(size, data_size - count) + segment_sizes.append(size) + count += size + + data = os.urandom(data_size) + message_stream, length = _build_structured_message(data, segment_sizes, StructuredMessageProperties.CRC64) + + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) + content = stream.read() + + assert content == data + + @pytest.mark.parametrize("invalid_segment", [1, 2, 3, -1]) + def test_crc64_mismatch_read_all(self, invalid_segment): + data = os.urandom(3 * 1024) + message_stream, length = _build_structured_message( + data, + 1024, + StructuredMessageProperties.CRC64, + invalid_segment) + + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) + with pytest.raises(ValueError) as e: + stream.read() + assert 'CRC64 mismatch' in str(e.value) + + @pytest.mark.parametrize("invalid_segment", [1, 2, 3, -1]) + def test_crc64_mismatch_read_chunks(self, invalid_segment): + data = os.urandom(3 * 1024) + message_stream, length = _build_structured_message( + data, + 1024, + StructuredMessageProperties.CRC64, + invalid_segment) + + stream = StructuredMessageDecoder(_iter_bytes(message_stream.getvalue()), length) + # Since we only check CRC on segment borders, some reads will succeed, but we test + # to ensure eventually the stream reading will error out. + with pytest.raises(ValueError) as e: + read = 0 + while read < len(data): + stream.read(512) + + assert 'CRC64 mismatch' in str(e.value) + + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_invalid_message_version(self, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 512, flags) + + # Corrupt the version byte + raw = bytearray(message_stream.getvalue()) + raw[0] = 0xFF + + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("message_length", [100, 1234567]) # Correct value: 1057 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_message_length(self, message_length, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 512, flags) + + raw = bytearray(message_stream.getvalue()) + raw[1:9] = int.to_bytes(message_length, 8, 'little') + + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("segment_count", [2, 123]) # Correct value: 4 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_segment_count(self, segment_count, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + raw = bytearray(message_stream.getvalue()) + raw[11:13] = int.to_bytes(segment_count, 2, 'little') + + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("segment_number", [1, 123]) # Correct value: 2 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_segment_number(self, segment_number, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + # Change the second segment to be the incorrect number + position = (StructuredMessageConstants.V1_HEADER_LENGTH + + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + 256 + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) + raw = bytearray(message_stream.getvalue()) + raw[position:position + 2] = int.to_bytes(segment_number, 2, 'little') + + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("segment_size", [123, 345]) # Correct value: 256 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_segment_size(self, segment_size, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + # Change the second segment to be the incorrect size + position = (StructuredMessageConstants.V1_HEADER_LENGTH + + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + 256 + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0) + + 2) + raw = bytearray(message_stream.getvalue()) + raw[position:position + 8] = int.to_bytes(segment_size, 8, 'little') + + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + stream.read() + + @pytest.mark.parametrize("segment_size", [123, 345]) # Correct value: 256 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + def test_incorrect_segment_size_single_segment(self, segment_size, flags): + data = os.urandom(256) + message_stream, length = _build_structured_message(data, 256, flags) + + raw = bytearray(message_stream.getvalue()) + raw[15:23] = int.to_bytes(segment_size, 8, 'little') + + stream = StructuredMessageDecoder(_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + stream.read() diff --git a/sdk/storage/azure-storage-blob/tests/test_streams_async.py b/sdk/storage/azure-storage-blob/tests/test_streams_async.py new file mode 100644 index 000000000000..1d1bb3aff654 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_streams_async.py @@ -0,0 +1,371 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import os +import random +from io import BytesIO +from typing import AsyncIterator, List, Optional, Tuple, Union + +import pytest +from azure.storage.blob._shared.streams import ( + StructuredMessageConstants, + StructuredMessageProperties, +) +from azure.storage.blob._shared.streams_async import AsyncStructuredMessageDecoder +from azure.storage.extensions import crc64 + + +async def _async_iter_bytes(data: bytes, chunk_size: int = 1024) -> AsyncIterator[bytes]: + """Convert bytes to an AsyncIterator[bytes] with the given chunk size.""" + for i in range(0, len(data), chunk_size): + yield data[i:i + chunk_size] + + +def _write_segment( + number: int, + data: bytes, + data_crc: Optional[int], + stream: BytesIO, +) -> None: + stream.write(number.to_bytes(2, 'little')) # Segment number + stream.write(len(data).to_bytes(8, 'little')) # Segment length + stream.write(data) # Segment content + if data_crc is not None: + stream.write(data_crc.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little')) + + +def _build_structured_message( + data: bytes, + segment_size: Union[int, List[int]], + flags: StructuredMessageProperties, + invalidate_crc_segment: Optional[int] = None, +) -> Tuple[BytesIO, int]: + if isinstance(segment_size, list): + segment_count = len(segment_size) + else: + segment_count = math.ceil(len(data) / segment_size) or 1 + segment_footer_length = StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0 + + message_length = ( + StructuredMessageConstants.V1_HEADER_LENGTH + + ((StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + segment_footer_length) * segment_count) + + len(data) + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) + + message = BytesIO() + message_crc = 0 + + # Message Header + message.write(b'\x01') # Version + message.write(message_length.to_bytes(8, 'little')) # Message length + message.write(int(flags).to_bytes(2, 'little')) # Flags + message.write(segment_count.to_bytes(2, 'little')) # Num segments + + # Special case for 0 length content + if len(data) == 0: + crc = 0 if StructuredMessageProperties.CRC64 in flags else None + _write_segment(1, data, crc, message) + else: + # Segments + segment_sizes = segment_size if isinstance(segment_size, list) else [segment_size] * segment_count + offset = 0 + for i in range(1, segment_count + 1): + size = segment_sizes[i - 1] + segment_data = data[offset: offset + size] + offset += size + + segment_crc = None + if StructuredMessageProperties.CRC64 in flags: + segment_crc = crc64.compute(segment_data, 0) # pylint: disable=I1101 + if i == invalidate_crc_segment: + segment_crc += 5 + _write_segment(i, segment_data, segment_crc, message) + + message_crc = crc64.compute(segment_data, message_crc) # pylint: disable=I1101 + + # Message footer + if StructuredMessageProperties.CRC64 in flags: + if invalidate_crc_segment == -1: + message_crc += 5 + message.write(message_crc.to_bytes(StructuredMessageConstants.CRC64_LENGTH, 'little')) + + message.seek(0, 0) + return message, message_length + + +class TestAsyncStructuredMessageDecoder: + + @pytest.mark.asyncio + async def test_empty_inner_stream(self): + with pytest.raises(ValueError): + AsyncStructuredMessageDecoder(_async_iter_bytes(b''), 0) + + @pytest.mark.asyncio + async def test_read_past_end(self): + data = os.urandom(10) + message_stream, length = _build_structured_message(data, len(data), StructuredMessageProperties.CRC64) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + result = await stream.read(100) + assert result == data + + result = await stream.read(100) + assert result == b'' + + @pytest.mark.asyncio + @pytest.mark.parametrize("size, segment_size, flags", [ + (0, 1, StructuredMessageProperties.NONE), + (0, 1, StructuredMessageProperties.CRC64), + (10, 1, StructuredMessageProperties.NONE), + (10, 1, StructuredMessageProperties.CRC64), + (1024, 1024, StructuredMessageProperties.NONE), + (1024, 1024, StructuredMessageProperties.CRC64), + (1024, 512, StructuredMessageProperties.NONE), + (1024, 512, StructuredMessageProperties.CRC64), + (1024, 200, StructuredMessageProperties.NONE), + (1024, 200, StructuredMessageProperties.CRC64), + (123456, 1234, StructuredMessageProperties.NONE), + (123456, 1234, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + async def test_read_all(self, size, segment_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + content = await stream.read() + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("size, segment_size, chunk_size, flags", [ + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.NONE), + (10 * 1024 * 1024, 4 * 1024 * 1024, 1 * 1024 * 1024, StructuredMessageProperties.CRC64), + ]) + async def test_read_chunks(self, size, segment_size, chunk_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + read = 0 + content = b'' + while read < len(data): + chunk = await stream.read(chunk_size) + content += chunk + read += chunk_size + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + async def test_random_reads(self, data_size): + data = os.urandom(data_size) + message_stream, length = _build_structured_message(data, data_size // 3, StructuredMessageProperties.CRC64) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + + count = 0 + content = b'' + while count < data_size: + read_size = random.randint(10, data_size // 3) + read_size = min(read_size, data_size - count) + + content += await stream.read(read_size) + count += read_size + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("size, segment_size, block_size, flags", [ + (0, 1, 1, StructuredMessageProperties.NONE), + (0, 1, 1, StructuredMessageProperties.CRC64), + (1, 1, 1, StructuredMessageProperties.NONE), + (1, 1, 1, StructuredMessageProperties.CRC64), + (10, 10, 1, StructuredMessageProperties.NONE), + (10, 10, 1, StructuredMessageProperties.CRC64), + (1024, 512, 512, StructuredMessageProperties.NONE), + (1024, 512, 512, StructuredMessageProperties.CRC64), + (1024, 512, 123, StructuredMessageProperties.NONE), + (1024, 512, 123, StructuredMessageProperties.CRC64), + (1024, 200, 512, StructuredMessageProperties.NONE), + (1024, 200, 512, StructuredMessageProperties.CRC64), + (1024, 200, 1024, StructuredMessageProperties.NONE), + (1024, 200, 1024, StructuredMessageProperties.CRC64), + (1024, 200, 50, StructuredMessageProperties.NONE), + (1024, 200, 50, StructuredMessageProperties.CRC64), + (100 * 1024, 4 * 1024, 7 * 1024, StructuredMessageProperties.NONE), + (100 * 1024, 4 * 1024, 7 * 1024, StructuredMessageProperties.CRC64), + ]) + async def test_iterate(self, size, segment_size, block_size, flags): + data = os.urandom(size) + message_stream, length = _build_structured_message(data, segment_size, flags) + + decoder = AsyncStructuredMessageDecoder( + _async_iter_bytes(message_stream.getvalue()), length, block_size=block_size) + content = b'' + async for chunk in decoder: + content += chunk + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("data_size", [100, 1024, 10 * 1024, 100 * 1024]) + async def test_random_segment_sizes(self, data_size): + segment_sizes = [] + count = 0 + while count < data_size: + size = random.randint(10, data_size // 3) + size = min(size, data_size - count) + segment_sizes.append(size) + count += size + + data = os.urandom(data_size) + message_stream, length = _build_structured_message(data, segment_sizes, StructuredMessageProperties.CRC64) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + content = await stream.read() + + assert content == data + + @pytest.mark.asyncio + @pytest.mark.parametrize("invalid_segment", [1, 2, 3, -1]) + async def test_crc64_mismatch_read_all(self, invalid_segment): + data = os.urandom(3 * 1024) + message_stream, length = _build_structured_message( + data, + 1024, + StructuredMessageProperties.CRC64, + invalid_segment) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + with pytest.raises(ValueError) as e: + await stream.read() + assert 'CRC64 mismatch' in str(e.value) + + @pytest.mark.asyncio + @pytest.mark.parametrize("invalid_segment", [1, 2, 3, -1]) + async def test_crc64_mismatch_read_chunks(self, invalid_segment): + data = os.urandom(3 * 1024) + message_stream, length = _build_structured_message( + data, + 1024, + StructuredMessageProperties.CRC64, + invalid_segment) + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(message_stream.getvalue()), length) + # Since we only check CRC on segment borders, some reads will succeed, but we test + # to ensure eventually the stream reading will error out. + with pytest.raises(ValueError) as e: + read = 0 + while read < len(data): + await stream.read(512) + + assert 'CRC64 mismatch' in str(e.value) + + @pytest.mark.asyncio + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_invalid_message_version(self, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 512, flags) + + # Corrupt the version byte + raw = bytearray(message_stream.getvalue()) + raw[0] = 0xFF + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("message_length", [100, 1234567]) # Correct value: 1057 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_message_length(self, message_length, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 512, flags) + + raw = bytearray(message_stream.getvalue()) + raw[1:9] = int.to_bytes(message_length, 8, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("segment_count", [2, 123]) # Correct value: 4 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_segment_count(self, segment_count, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + raw = bytearray(message_stream.getvalue()) + raw[11:13] = int.to_bytes(segment_count, 2, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("segment_number", [1, 123]) # Correct value: 2 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_segment_number(self, segment_number, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + # Change the second segment to be the incorrect number + position = (StructuredMessageConstants.V1_HEADER_LENGTH + + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + 256 + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0)) + raw = bytearray(message_stream.getvalue()) + raw[position:position + 2] = int.to_bytes(segment_number, 2, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("segment_size", [123, 345]) # Correct value: 256 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_segment_size(self, segment_size, flags): + data = os.urandom(1024) + message_stream, length = _build_structured_message(data, 256, flags) + + # Change the second segment to be the incorrect size + position = (StructuredMessageConstants.V1_HEADER_LENGTH + + StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + 256 + + (StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in flags else 0) + + 2) + raw = bytearray(message_stream.getvalue()) + raw[position:position + 8] = int.to_bytes(segment_size, 8, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() + + @pytest.mark.asyncio + @pytest.mark.parametrize("segment_size", [123, 345]) # Correct value: 256 + @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) + async def test_incorrect_segment_size_single_segment(self, segment_size, flags): + data = os.urandom(256) + message_stream, length = _build_structured_message(data, 256, flags) + + raw = bytearray(message_stream.getvalue()) + raw[15:23] = int.to_bytes(segment_size, 8, 'little') + + stream = AsyncStructuredMessageDecoder(_async_iter_bytes(bytes(raw)), length) + with pytest.raises(ValueError): + await stream.read() diff --git a/sdk/storage/azure-storage-file-datalake/assets.json b/sdk/storage/azure-storage-file-datalake/assets.json index 73f72bbd8f6d..a0c79aa64e2b 100644 --- a/sdk/storage/azure-storage-file-datalake/assets.json +++ b/sdk/storage/azure-storage-file-datalake/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/storage/azure-storage-file-datalake", - "Tag": "python/storage/azure-storage-file-datalake_c9ea6b56de" + "Tag": "python/storage/azure-storage-file-datalake_34be4c4176" } diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.py index 6641ea74f21f..7991643287f0 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.py @@ -427,15 +427,11 @@ def upload_data( If a date is passed in without timezone info, it is assumed to be UTC. Specify this header to perform the operation only if the resource has not been modified since the specified date/time. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword str etag: An ETag value, or the wildcard character (*). Used to check if the resource has changed, and act according to the condition specified by the `match_condition` parameter. @@ -497,13 +493,11 @@ def append_data( :type length: int or None :keyword bool flush: If true, will commit the data after it is appended. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https as https (the default) - will already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease_action: Used to perform lease operations along with appending data. @@ -714,6 +708,11 @@ def download_file( :paramtype progress_hook: ~typing.Callable[[int, int], None] :keyword bool decompress: If True, any compressed content, identified by the Content-Encoding header, will be decompressed automatically before being returned. Default value is True. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/rest/api/storageservices/setting-timeouts-for-blob-service-operations. diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.pyi b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.pyi index c7e7594894b3..020a89378f46 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.pyi +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client.pyi @@ -149,7 +149,7 @@ class DataLakeFileClient(PathClient): if_unmodified_since: Optional[datetime] = None, etag: Optional[str] = None, match_condition: Optional[MatchConditions] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, max_concurrency: Optional[int] = None, chunk_size: Optional[int] = None, @@ -166,7 +166,7 @@ class DataLakeFileClient(PathClient): length: Optional[int] = None, *, flush: Optional[bool] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease_action: Optional[Literal["acquire", "auto-renew", "release", "acquire-release"]] = None, lease_duration: int = -1, lease: Optional[Union[DataLakeLeaseClient, str]] = None, @@ -207,6 +207,7 @@ class DataLakeFileClient(PathClient): cpk: Optional[CustomerProvidedEncryptionKey] = None, max_concurrency: Optional[int] = None, progress_hook: Optional[Callable[[int, int], None]] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, timeout: Optional[int] = None, **kwargs: Any ) -> StorageStreamDownloader: ... diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client_helpers.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client_helpers.py index 86a0a521d15d..5cf92d7bcfce 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client_helpers.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_data_lake_file_client_helpers.py @@ -24,6 +24,7 @@ from ._shared.response_handlers import return_response_headers from ._shared.uploads import IterStreamer from ._shared.uploads_async import AsyncIterStreamer +from ._shared.validation import parse_validation_option if TYPE_CHECKING: from ._generated.operations import PathOperations @@ -47,6 +48,8 @@ def _append_data_options( if isinstance(data, bytes): data = data[:length] + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) + cpk_info = get_cpk_info(scheme, kwargs) kwargs.update(get_lease_action_properties(kwargs)) @@ -54,7 +57,7 @@ def _append_data_options( 'body': data, 'position': offset, 'content_length': length, - 'validate_content': kwargs.pop('validate_content', False), + 'validate_content': validate_content, 'cpk_info': cpk_info, 'timeout': kwargs.pop('timeout', None), 'cls': return_response_headers @@ -122,7 +125,7 @@ def _upload_options( else: raise TypeError(f"Unsupported data type: {type(data)}") - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) content_settings = kwargs.pop('content_settings', None) metadata = kwargs.pop('metadata', None) max_concurrency = kwargs.pop('max_concurrency', None) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py index 16aba3116029..2e023b1cc8d9 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/base_client_async.py @@ -36,12 +36,15 @@ from .parser import DEVSTORE_ACCOUNT_KEY, _get_development_storage_endpoint from .policies import ( QueueMessagePolicy, - StorageContentValidation, StorageHeadersPolicy, StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook +from .policies_async import ( + AsyncStorageBearerTokenCredentialPolicy, + AsyncContentValidationPolicy, + AsyncStorageResponseHook, +) from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken @@ -130,7 +133,7 @@ def _create_pipeline( QueueMessagePolicy(), config.proxy_policy, config.user_agent_policy, - StorageContentValidation(), + AsyncContentValidationPolicy(), ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=hosts, **kwargs), diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py index 61a4fdb15bdd..b5d0b7d79766 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies.py @@ -5,14 +5,13 @@ # -------------------------------------------------------------------------- import base64 -import hashlib import logging import random import re import uuid -from io import SEEK_SET, UnsupportedOperation +from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +31,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_crc64_validation, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +55,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -55,7 +72,12 @@ def encode_base64(data): # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -101,11 +123,15 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True @@ -237,7 +263,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -352,64 +387,113 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. - - This will overwrite any headers already defined in the request. - """ +def _prepare_content_validation(request: "PipelineRequest") -> None: + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" - - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() + # Download + if request.http_request.method == "GET": + if is_crc64_validation(validate_content): + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif is_crc64_validation(validate_content): + # For crc64-sm, force structured message even for bytes + if validate_content == "crc64-sm" and isinstance(data, bytes): + data = BytesIO(data) + + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) + elif hasattr(data, "read"): + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + request.context["validate_content"] = validate_content + + +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif is_crc64_validation(validate_content): + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}" + ), + response=response.http_response, + ) + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content + _prepare_content_validation(request) def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, - ) + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -456,7 +540,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -496,7 +580,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -570,7 +654,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -584,11 +668,16 @@ def send(self, request): response = self.next.send(request) if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -598,7 +687,12 @@ def send(self, request): raise retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -731,11 +825,11 @@ def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py index 4cb32f23248b..e1d13b1a83fa 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py @@ -11,11 +11,25 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -37,21 +51,52 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -125,11 +170,16 @@ async def send(self, request): response = await self.next.send(request) if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -139,7 +189,12 @@ async def send(self, request): raise retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py new file mode 100644 index 000000000000..cb745693921c --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py @@ -0,0 +1,565 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header(data: bytes, expected_message_length: int) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + @property + def closed(self) -> bool: + return self._inner_stream.closed + + def close(self) -> None: + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return self._inner_stream.seekable() and self._initial_content_position is not None + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") + + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") + + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region(self._current_region, remaining, output) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min(self._segment_size, self.content_length - self._content_offset) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, "little") + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[self._current_region_offset : self._current_region_offset + read_size] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + read_size = min(size, self._current_region_length - self._current_region_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py new file mode 100644 index 000000000000..9fedc055a623 --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams_async.py @@ -0,0 +1,210 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py new file mode 100644 index 000000000000..21d3b081d8cc --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/validation.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + +_VALID_CV_OPTIONS = ("auto", "crc64", "crc64-sm", "md5") + +CV_TYPE = Optional[Union[bool, Literal["auto", "crc64", "md5"]]] +CV_TYPE_PARSED = Optional[Union[bool, Literal["crc64", "crc64-sm", "md5"]]] + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: CV_TYPE, + *, + force_structured_message: bool = False, +) -> CV_TYPE_PARSED: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + parsed = validate_content.lower() + if parsed not in _VALID_CV_OPTIONS: + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if parsed == "auto": + parsed = "crc64" + + if parsed == "crc64": + _verify_extensions("crc64") + if force_structured_message: + parsed = "crc64-sm" + + return cast(CV_TYPE_PARSED, parsed) + + +def is_md5_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == "md5" + + +def is_crc64_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return False + return validate_content in ("crc64", "crc64-sm") + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py index 1edf3832a690..f13042518f98 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.py @@ -443,15 +443,11 @@ async def upload_data( If a date is passed in without timezone info, it is assumed to be UTC. Specify this header to perform the operation only if the resource has not been modified since the specified date/time. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https, as https (the default), will - already validate. Note that this MD5 hash is not stored with the - blob. Also note that if enabled, the memory-efficient upload algorithm - will not be used because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword str etag: An ETag value, or the wildcard character (*). Used to check if the resource has changed, and act according to the condition specified by the `match_condition` parameter. @@ -513,13 +509,11 @@ async def append_data( :type length: int or None :keyword bool flush: If true, will commit the data after it is appended. - :keyword bool validate_content: - If true, calculates an MD5 hash of the block content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https as https (the default) - will already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease_action: Used to perform lease operations along with appending data. @@ -730,6 +724,11 @@ async def download_file( :paramtype progress_hook: ~typing.Callable[[int, Optional[int]], Awaitable[None]] :keyword bool decompress: If True, any compressed content, identified by the Content-Encoding header, will be decompressed automatically before being returned. Default value is True. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the blob. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/rest/api/storageservices/setting-timeouts-for-blob-service-operations. diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.pyi b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.pyi index 8da588a32f51..ff21c42b73e6 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.pyi +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/aio/_data_lake_file_client_async.pyi @@ -148,7 +148,7 @@ class DataLakeFileClient(PathClient): if_unmodified_since: Optional[datetime] = None, etag: Optional[str] = None, match_condition: Optional[MatchConditions] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, cpk: Optional[CustomerProvidedEncryptionKey] = None, max_concurrency: Optional[int] = None, chunk_size: Optional[int] = None, @@ -165,7 +165,7 @@ class DataLakeFileClient(PathClient): length: Optional[int] = None, *, flush: Optional[bool] = None, - validate_content: Optional[bool] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease_action: Optional[Literal["acquire", "auto-renew", "release", "acquire-release"]] = None, lease_duration: int = -1, lease: Optional[Union[DataLakeLeaseClient, str]] = None, @@ -206,6 +206,7 @@ class DataLakeFileClient(PathClient): cpk: Optional[CustomerProvidedEncryptionKey] = None, max_concurrency: Optional[int] = None, progress_hook: Optional[Callable[[int, Optional[int]], Awaitable[None]]] = None, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, timeout: Optional[int] = None, **kwargs: Any ) -> StorageStreamDownloader: ... diff --git a/sdk/storage/azure-storage-file-datalake/dev_requirements.txt b/sdk/storage/azure-storage-file-datalake/dev_requirements.txt index 60d588f5e1e1..b42a7dff1a22 100644 --- a/sdk/storage/azure-storage-file-datalake/dev_requirements.txt +++ b/sdk/storage/azure-storage-file-datalake/dev_requirements.txt @@ -2,4 +2,5 @@ ../../core/azure-core ../../identity/azure-identity ../azure-storage-blob +../azure-storage-extensions aiohttp>=3.13.5 diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py new file mode 100644 index 000000000000..d0668fd75a8d --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation.py @@ -0,0 +1,311 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from devtools_testutils import is_live, recorded_by_proxy +from devtools_testutils.storage import GenericTestProxyParametrize1, StorageRecordedTestCase +from settings.testcase import DataLakePreparer + +from azure.storage.filedatalake import DataLakeServiceClient + + +def assert_content_md5(request): + if request.http_request.query.get("action") == "append": + assert request.http_request.headers.get("Content-MD5") is not None + + +def assert_content_md5_get(response): + assert response.http_request.headers.get("x-ms-range-get-content-md5") == "true" + assert response.http_response.headers.get("Content-MD5") is not None + + +def assert_content_crc64(request): + if request.http_request.query.get("action") == "append": + assert request.http_request.headers.get("x-ms-content-crc64") is not None + + +def assert_structured_message(request): + if request.http_request.query.get("action") == "append": + assert request.http_request.headers.get("x-ms-structured-body") is not None + + +def assert_structured_message_get(response): + assert response.http_request.headers.get("x-ms-structured-body") is not None + assert response.http_response.headers.get("x-ms-structured-body") is not None + + +class TestStorageContentValidation(StorageRecordedTestCase): + def _setup(self, account_name): + token_credential = self.get_credential(DataLakeServiceClient) + self.dsc = DataLakeServiceClient( + self.account_url(account_name, "dfs"), credential=token_credential, logging_enable=True + ) + self.file_system = self.dsc.get_file_system_client(self.get_resource_name("filesystem")) + self.file_system.create_file_system() + + def teardown_method(self, _): + if self.file_system and is_live(): + try: + self.file_system.delete_file_system() + except: + pass + + def _get_file_reference(self): + return self.get_resource_name("file") + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_data(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + # Act + file.upload_data(data, overwrite=True, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_data_chunks(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abcde" * 512 + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + + # Act + file.upload_data(data, overwrite=True, validate_content=a, chunk_size=1024, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_data_substream(self, a, **kwargs): + # Substream is disabled when using content validation so this will behave like regular upload (buffer) + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + self.file_system._config.min_large_chunk_upload_threshold = 1 # Set less than chunk size to enable substream + file = self.file_system.get_file_client(self._get_file_reference()) + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + + data = b"abc" * 512 + b"abcde" + io = BytesIO(data) + + # Act + file.upload_data(io, overwrite=True, validate_content=a, chunk_size=512, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_append_data(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data1 = b"abcde" * 512 + data2 = "你好世界" * 10 + encoded2 = data2.encode("utf-8-sig") + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i : i + 500] + + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + # Act + file.create_file() + file.append_data(data1, 0, validate_content=a, raw_request_hook=assert_method) + file.append_data(data2, len(data1), encoding="utf-8-sig", validate_content=a, raw_request_hook=assert_method) + file.append_data(generator(), len(data1) + len(encoded2), validate_content=a, raw_request_hook=assert_method) + file.flush_data(len(data1) + len(encoded2) + len(data1)) + + # Assert + content = file.download_file() + assert content.read() == data1 + encoded2 + data1 + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_append_data_streaming(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Act + file.create_file() + file.append_data(BytesIO(content), 0, flush=True, validate_content=a, raw_request_hook=assert_method) + + # Assert + result = file.download_file() + assert result.read() == content + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @pytest.mark.live_test_only + def test_append_data_streaming_large(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Act + file.create_file() + file.append_data(BytesIO(data1), 0, flush=True, validate_content=a, raw_request_hook=assert_method) + file.append_data(BytesIO(data2), len(data1), flush=True, validate_content=a, raw_request_hook=assert_method) + file.append_data( + BytesIO(data3), len(data1) + len(data2), flush=True, validate_content=a, raw_request_hook=assert_method + ) + file.flush_data(len(data1) + len(data2) + len(data3)) + + # Assert + result = file.download_file() + assert result.read() == data1 + data2 + data3 + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get + + # Act + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file_chunks(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + self.file_system._config.max_single_get_size = 512 + self.file_system._config.max_chunk_get_size = 512 + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + b"abcde" + file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + read_content = bytearray() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + for _ in range(len(data) // 100 + 1): + read_content.extend(downloader.read(100)) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file_chunks_partial(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + self.file_system._config.max_single_get_size = 512 + self.file_system._config.max_chunk_get_size = 512 + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + b"abcde" + file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = downloader.read() + + stream = BytesIO() + downloader = file.download_file(offset=512, length=1024, validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[512:1536] + + @DataLakePreparer() + @pytest.mark.live_test_only + def test_download_file_large_chunks(self, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.file_system._config.max_chunk_get_size = 10 * 1024 * 1024 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 + file.upload_data(data, overwrite=True, max_concurrency=5) + + # Act + downloader = file.download_file(validate_content="crc64", max_concurrency=5) + content = downloader.read() + + downloader = file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") + partial = downloader.read() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py new file mode 100644 index 000000000000..a9bcf2c2ec50 --- /dev/null +++ b/sdk/storage/azure-storage-file-datalake/tests/test_content_validation_async.py @@ -0,0 +1,317 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from devtools_testutils import is_live +from devtools_testutils.aio import recorded_by_proxy_async +from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase, GenericTestProxyParametrize1 +from settings.testcase import DataLakePreparer +from test_content_validation import ( + assert_content_crc64, + assert_content_md5, + assert_content_md5_get, + assert_structured_message, + assert_structured_message_get, +) + +from azure.storage.filedatalake import FileSystemClient as SyncFileSystemClient +from azure.storage.filedatalake.aio import DataLakeServiceClient + + +class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): + async def _setup(self, account_name): + token_credential = self.get_credential(DataLakeServiceClient, is_async=True) + self.dsc = DataLakeServiceClient( + self.account_url(account_name, "dfs"), credential=token_credential, logging_enable=True + ) + self.file_system = self.dsc.get_file_system_client(self.get_resource_name("filesystem")) + await self.file_system.create_file_system() + + def teardown_method(self, _): + if not is_live(): + return + # Use sync client as teardown_method must be sync + if self.file_system: + sync_credential = self.get_credential(SyncFileSystemClient, is_async=False) + sync_file_system = SyncFileSystemClient( + self.account_url(self.file_system.account_name, "dfs"), + self.file_system.file_system_name, + credential=sync_credential, + ) + + try: + sync_file_system.delete_file_system() + except: + pass + + def _get_file_reference(self): + return self.get_resource_name("file") + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_data(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + # Act + await file.upload_data(data, overwrite=True, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_data_chunks(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abcde" * 512 + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + + # Act + await file.upload_data( + data, overwrite=True, validate_content=a, chunk_size=1024, raw_request_hook=assert_method + ) + + # Assert + content = await file.download_file() + assert await content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_data_substream(self, a, **kwargs): + # Substream is disabled when using content validation so this will behave like regular upload (buffer) + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + self.file_system._config.min_large_chunk_upload_threshold = 1 # Set less than chunk size to enable substream + file = self.file_system.get_file_client(self._get_file_reference()) + assert_method = assert_content_crc64 if a == "crc64" else assert_content_md5 + + data = b"abc" * 512 + b"abcde" + io = BytesIO(data) + + # Act + await file.upload_data(io, overwrite=True, validate_content=a, chunk_size=512, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_append_data(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data1 = b"abcde" * 512 + data2 = "你好世界" * 10 + encoded2 = data2.encode("utf-8-sig") + + # An iterable with no length will be read into bytes and therefore will behave like + # bytes when it comes to testing content validation. + def generator(): + for i in range(0, len(data1), 500): + yield data1[i : i + 500] + + assert_method = assert_content_crc64 if a in ("auto", "crc64") else assert_content_md5 + + # Act + await file.create_file() + await file.append_data(data1, 0, validate_content=a, raw_request_hook=assert_method) + await file.append_data( + data2, len(data1), encoding="utf-8-sig", validate_content=a, raw_request_hook=assert_method + ) + await file.append_data( + generator(), len(data1) + len(encoded2), validate_content=a, raw_request_hook=assert_method + ) + await file.flush_data(len(data1) + len(encoded2) + len(data1)) + + # Assert + content = await file.download_file() + assert await content.read() == data1 + encoded2 + data1 + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_append_data_streaming(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + + content = b"abcde" * 1030 # 5 KiB + 30 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Act + await file.create_file() + await file.append_data(BytesIO(content), 0, flush=True, validate_content=a, raw_request_hook=assert_method) + + # Assert + result = await file.download_file() + assert await result.read() == content + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @pytest.mark.live_test_only + async def test_append_data_streaming_large(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + + data1 = b"abcde" * 1024 * 1024 # 5 MiB + data2 = b"12345" * 2 * 1024 * 1024 + b"abcdefg" # 10 MiB + 7 + data3 = b"12345678" * 8 * 1024 * 1024 # 64 MiB + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Act + await file.create_file() + await file.append_data(BytesIO(data1), 0, flush=True, validate_content=a, raw_request_hook=assert_method) + await file.append_data( + BytesIO(data2), len(data1), flush=True, validate_content=a, raw_request_hook=assert_method + ) + await file.append_data( + BytesIO(data3), len(data1) + len(data2), flush=True, validate_content=a, raw_request_hook=assert_method + ) + await file.flush_data(len(data1) + len(data2) + len(data3)) + + # Assert + result = await file.download_file() + assert await result.read() == data1 + data2 + data3 + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + await file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get + + # Act + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file_chunks(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + self.file_system._config.max_single_get_size = 512 + self.file_system._config.max_chunk_get_size = 512 + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + b"abcde" + await file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + content = await downloader.read() + + stream = BytesIO() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + read_content = bytearray() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + for _ in range(len(data) // 100 + 1): + read_content.extend(await downloader.read(100)) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @DataLakePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file_chunks_partial(self, a, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + self.file_system._config.max_single_get_size = 512 + self.file_system._config.max_chunk_get_size = 512 + file = self.file_system.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + b"abcde" + await file.upload_data(data, overwrite=True) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = await file.download_file( + offset=10, length=1000, validate_content=a, raw_response_hook=assert_method + ) + content = await downloader.read() + + stream = BytesIO() + downloader = await file.download_file( + offset=512, length=1024, validate_content=a, raw_response_hook=assert_method + ) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[512:1536] + + @DataLakePreparer() + @pytest.mark.live_test_only + async def test_download_file_large_chunks(self, **kwargs): + datalake_storage_account_name = kwargs.pop("datalake_storage_account_name") + + await self._setup(datalake_storage_account_name) + file = self.file_system.get_file_client(self._get_file_reference()) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.file_system._config.max_chunk_get_size = 10 * 1024 * 1024 + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 + await file.upload_data(data, overwrite=True, max_concurrency=5) + + # Act + downloader = await file.download_file(validate_content="crc64", max_concurrency=5) + content = await downloader.read() + + downloader = await file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") + partial = await downloader.read() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-share/assets.json b/sdk/storage/azure-storage-file-share/assets.json index f8436a861d3c..79ed1f733678 100644 --- a/sdk/storage/azure-storage-file-share/assets.json +++ b/sdk/storage/azure-storage-file-share/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/storage/azure-storage-file-share", - "Tag": "python/storage/azure-storage-file-share_4afd6de033" + "Tag": "python/storage/azure-storage-file-share_d5f376c42f" } diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py index 935b17ebfde9..c087390eb09b 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_download.py @@ -18,6 +18,7 @@ from ._shared.request_handlers import validate_and_format_range_headers from ._shared.response_handlers import parse_length_from_content_range, process_storage_error from ._shared.constants import DEFAULT_MAX_CONCURRENCY +from ._shared.validation import is_md5_validation, CV_TYPE_PARSED if TYPE_CHECKING: from ._generated.operations import FileOperations @@ -43,7 +44,7 @@ def __init__( current_progress: int, start_range: int, end_range: int, - validate_content: bool, + validate_content: CV_TYPE_PARSED, etag: str, stream: Any = None, parallel: Optional[int] = None, @@ -120,7 +121,9 @@ def _write_to_stream(self, chunk_data: bytes, chunk_start: int) -> None: def _download_chunk(self, chunk_start: int, chunk_end: int) -> bytes: range_header, range_validation = validate_and_format_range_headers( - chunk_start, chunk_end, check_content_md5=self.validate_content + chunk_start, + chunk_end, + check_content_md5=is_md5_validation(self.validate_content) ) try: @@ -217,7 +220,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] path: str = None, # type: ignore [assignment] @@ -247,11 +250,14 @@ def __init__( self._etag = "" # The service only provides transactional MD5s for chunks under 4MB. - # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. self._first_get_size = ( - self._config.max_single_get_size if not self._validate_content else self._config.max_chunk_get_size + self._config.max_single_get_size + if not is_md5_validation(self._validate_content) + else self._config.max_chunk_get_size ) + initial_request_start = self._start_range or 0 if self._end_range is not None and self._end_range - initial_request_start < self._first_get_size: initial_request_end = self._end_range @@ -292,7 +298,7 @@ def _initial_request(self): self._initial_range[1], start_range_required=False, end_range_required=False, - check_content_md5=self._validate_content + check_content_md5=is_md5_validation(self._validate_content) ) try: diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.py index 78bfa03a6945..2dc032a9eb43 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.py @@ -45,6 +45,7 @@ from ._shared.request_handlers import add_metadata_headers, get_length from ._shared.response_handlers import return_response_headers, process_storage_error from ._shared.uploads import IterStreamer, FileChunkUploader, upload_data_chunks +from ._shared.validation import CV_TYPE_PARSED, parse_validation_option if TYPE_CHECKING: from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential @@ -58,7 +59,7 @@ def _upload_file_helper( size: Optional[int], metadata: Optional[Dict[str, str]], content_settings: Optional["ContentSettings"], - validate_content: bool, + validate_content: CV_TYPE_PARSED, timeout: Optional[int], max_concurrency: int, file_settings: "StorageConfiguration", @@ -447,8 +448,13 @@ def create_file( Restore - apply changes without further modification. :paramtype file_property_semantics: Optional[Literal["New", "Restore"]] - :keyword data: Optional initial data to upload, up to 4MB. - :paramtype data: bytes + :keyword bytes data: Optional initial data to upload, up to 4MB. + :keyword validate_content: + Only applicable when `data` is provided. + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[Literal['auto', 'crc64', 'md5']] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/rest/api/storageservices/setting-timeouts-for-file-service-operations. @@ -471,6 +477,10 @@ def create_file( content_settings = kwargs.pop('content_settings', None) metadata = kwargs.pop('metadata', None) timeout = kwargs.pop('timeout', None) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) headers = kwargs.pop('headers', {}) headers.update(add_metadata_headers(metadata)) data = kwargs.pop('data', None) @@ -498,6 +508,7 @@ def create_file( file_permission_key=permission_key, file_http_headers=file_http_headers, optionalbody=data, + validate_content=validate_content, content_length=len(data) if data is not None else None, lease_access_conditions=access_conditions, headers=headers, @@ -562,13 +573,11 @@ def upload_file( :keyword ~azure.storage.fileshare.ContentSettings content_settings: ContentSettings object used to set file properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each range of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https as https (the default) will - already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int max_concurrency: Maximum number of parallel connections to use when transferring the file in chunks. This option does not affect the underlying connection pool, and may @@ -610,7 +619,10 @@ def upload_file( max_concurrency = kwargs.pop('max_concurrency', None) if max_concurrency is None: max_concurrency = DEFAULT_MAX_CONCURRENCY - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) progress_hook = kwargs.pop('progress_hook', None) timeout = kwargs.pop('timeout', None) encoding = kwargs.pop('encoding', 'UTF-8') @@ -877,15 +889,11 @@ def download_file( Maximum number of parallel connections to use when transferring the file in chunks. This option does not affect the underlying connection pool, and may require a separate configuration of the connection pool. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https as https (the default) will - already validate. Note that this MD5 hash is not stored with the - file. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the file has an active lease. Value can be a ShareLeaseClient object or the lease ID as a string. @@ -925,12 +933,14 @@ def download_file( range_end = offset + length - 1 # Service actually uses an end-range inclusive index access_conditions = get_access_conditions(kwargs.pop('lease', None)) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) return StorageStreamDownloader( client=self._client.file, config=self._config, start_range=offset, end_range=range_end, + validate_content=validate_content, name=self.file_name, path='/'.join(self.file_path), share=self.share_name, @@ -1296,13 +1306,11 @@ def upload_range( :param int length: Number of bytes to use for uploading a section of the file. The range can be up to 4 MB in size. - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https as https (the default) - will already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword file_last_write_mode: If the file last write time should be preserved or overwritten. Possible values are "preserve" or "now". If not specified, file last write time will be changed to @@ -1331,7 +1339,10 @@ def upload_range( :returns: File-updated property dict (Etag and last modified). :rtype: Dict[str, Any] """ - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) timeout = kwargs.pop('timeout', None) encoding = kwargs.pop('encoding', 'UTF-8') file_last_write_mode = kwargs.pop('file_last_write_mode', None) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.pyi b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.pyi index 22d645d975cc..a517773f19ca 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.pyi +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_file_client.pyi @@ -134,6 +134,7 @@ class ShareFileClient(StorageAccountHostsMixin): file_mode: Optional[str] = None, file_property_semantics: Optional[Literal["New", "Restore"]] = None, data: Optional[bytes] = None, + validate_content: Optional[Literal['auto', 'crc64', 'md5']] = None, timeout: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: ... @@ -151,7 +152,7 @@ class ShareFileClient(StorageAccountHostsMixin): file_change_time: Optional[Union[str, datetime]] = None, metadata: Optional[Dict[str, str]] = None, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, progress_hook: Optional[Callable[[int, Optional[int]], None]] = None, @@ -199,7 +200,7 @@ class ShareFileClient(StorageAccountHostsMixin): length: Optional[int] = None, *, max_concurrency: Optional[int] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, progress_hook: Optional[Callable[[int, Optional[int]], None]] = None, decompress: Optional[bool] = None, @@ -270,7 +271,7 @@ class ShareFileClient(StorageAccountHostsMixin): offset: int, length: int, *, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, encoding: str = "UTF-8", file_last_write_mode: Optional[Literal["preserve", "now"]] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py index 16aba3116029..2e023b1cc8d9 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/base_client_async.py @@ -36,12 +36,15 @@ from .parser import DEVSTORE_ACCOUNT_KEY, _get_development_storage_endpoint from .policies import ( QueueMessagePolicy, - StorageContentValidation, StorageHeadersPolicy, StorageHosts, StorageRequestHook, ) -from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook +from .policies_async import ( + AsyncStorageBearerTokenCredentialPolicy, + AsyncContentValidationPolicy, + AsyncStorageResponseHook, +) from .response_handlers import PartialBatchErrorException, process_storage_error from .._shared_access_signature import _is_credential_sastoken @@ -130,7 +133,7 @@ def _create_pipeline( QueueMessagePolicy(), config.proxy_policy, config.user_agent_policy, - StorageContentValidation(), + AsyncContentValidationPolicy(), ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=hosts, **kwargs), diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py index 61a4fdb15bdd..b5d0b7d79766 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies.py @@ -5,14 +5,13 @@ # -------------------------------------------------------------------------- import base64 -import hashlib import logging import random import re import uuid -from io import SEEK_SET, UnsupportedOperation +from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +31,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_crc64_validation, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +55,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -55,7 +72,12 @@ def encode_base64(data): # Are we out of retries? def is_exhausted(settings): - retry_counts = (settings["total"], settings["connect"], settings["read"], settings["status"]) + retry_counts = ( + settings["total"], + settings["connect"], + settings["read"], + settings["status"], + ) retry_counts = list(filter(None, retry_counts)) if not retry_counts: return False @@ -101,11 +123,15 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True @@ -237,7 +263,16 @@ def on_request(self, request: "PipelineRequest") -> None: parsed_qs["sig"] = "*****" # the SAS needs to be put back together - value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment)) + value = urlunparse( + ( + scheme, + netloc, + path, + params, + urlencode(parsed_qs), + fragment, + ) + ) _LOGGER.debug(" %r: %r", header, value) _LOGGER.debug("Request body:") @@ -352,64 +387,113 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. - - This will overwrite any headers already defined in the request. - """ +def _prepare_content_validation(request: "PipelineRequest") -> None: + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - header_name = "Content-MD5" - - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() + # Download + if request.http_request.method == "GET": + if is_crc64_validation(validate_content): + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 - return md5.digest() + elif is_crc64_validation(validate_content): + # For crc64-sm, force structured message even for bytes + if validate_content == "crc64-sm" and isinstance(data, bytes): + data = BytesIO(data) + + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) + elif hasattr(data, "read"): + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + request.context["validate_content"] = validate_content + + +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif is_crc64_validation(validate_content): + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}" + ), + response=response.http_response, + ) + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content + _prepare_content_validation(request) def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, - ) + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): @@ -456,7 +540,7 @@ def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRe def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]: """ Configure the retry settings for the request. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A dictionary containing the retry settings. @@ -496,7 +580,7 @@ def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disabl def sleep(self, settings, transport): """Sleep for the backoff time. - + :param Dict[str, Any] settings: The configurable values pertaining to the sleep operation. :param transport: The transport to use for sleeping. :type transport: @@ -570,7 +654,7 @@ def increment( def send(self, request): """Send the request with retry logic. - + :param request: A pipeline request object. :type request: ~azure.core.pipeline.PipelineRequest :return: A pipeline response object. @@ -584,11 +668,16 @@ def send(self, request): response = self.next.send(request) if is_retry(response, retry_settings["mode"]) or is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) self.sleep(retry_settings, request.context.transport) continue @@ -598,7 +687,12 @@ def send(self, request): raise retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - retry_hook(retry_settings, request=request.http_request, response=None, error=err) + retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) self.sleep(retry_settings, request.context.transport) continue raise err @@ -731,11 +825,11 @@ def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: """Handle the challenge from the service and authorize the request. - + :param request: The request object. :type request: ~azure.core.pipeline.PipelineRequest :param response: The response object. - :type response: ~azure.core.pipeline.PipelineResponse + :type response: ~azure.core.pipeline.PipelineResponse :return: True if the request was authorized, False otherwise. :rtype: bool """ diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py index 4cb32f23248b..e1d13b1a83fa 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py @@ -11,11 +11,25 @@ from typing import Any, Dict, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError -from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy +from azure.core.pipeline.policies import ( + AsyncBearerTokenCredentialPolicy, + AsyncHTTPPolicy, +) from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE -from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy +from .policies import ( + _prepare_content_validation, + _validate_content_response, + encode_base64, + is_retry, + StorageRetryPolicy, +) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -37,21 +51,52 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): @@ -125,11 +170,16 @@ async def send(self, request): response = await self.next.send(request) if is_retry(response, retry_settings["mode"]) or await is_checksum_retry(response): retries_remaining = self.increment( - retry_settings, request=request.http_request, response=response.http_response + retry_settings, + request=request.http_request, + response=response.http_response, ) if retries_remaining: await retry_hook( - retry_settings, request=request.http_request, response=response.http_response, error=None + retry_settings, + request=request.http_request, + response=response.http_response, + error=None, ) await self.sleep(retry_settings, request.context.transport) continue @@ -139,7 +189,12 @@ async def send(self, request): raise retries_remaining = self.increment(retry_settings, request=request.http_request, error=err) if retries_remaining: - await retry_hook(retry_settings, request=request.http_request, response=None, error=err) + await retry_hook( + retry_settings, + request=request.http_request, + response=None, + error=err, + ) await self.sleep(retry_settings, request.context.transport) continue raise err diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py new file mode 100644 index 000000000000..cb745693921c --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py @@ -0,0 +1,565 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header(data: bytes, expected_message_length: int) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + @property + def closed(self) -> bool: + return self._inner_stream.closed + + def close(self) -> None: + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return self._inner_stream.seekable() and self._initial_content_position is not None + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") + + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") + + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region(self._current_region, remaining, output) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min(self._segment_size, self.content_length - self._content_offset) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, "little") + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[self._current_region_offset : self._current_region_offset + read_size] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + read_size = min(size, self._current_region_length - self._current_region_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py new file mode 100644 index 000000000000..9fedc055a623 --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams_async.py @@ -0,0 +1,210 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py new file mode 100644 index 000000000000..21d3b081d8cc --- /dev/null +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/validation.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + +_VALID_CV_OPTIONS = ("auto", "crc64", "crc64-sm", "md5") + +CV_TYPE = Optional[Union[bool, Literal["auto", "crc64", "md5"]]] +CV_TYPE_PARSED = Optional[Union[bool, Literal["crc64", "crc64-sm", "md5"]]] + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: CV_TYPE, + *, + force_structured_message: bool = False, +) -> CV_TYPE_PARSED: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + parsed = validate_content.lower() + if parsed not in _VALID_CV_OPTIONS: + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if parsed == "auto": + parsed = "crc64" + + if parsed == "crc64": + _verify_extensions("crc64") + if force_structured_message: + parsed = "crc64-sm" + + return cast(CV_TYPE_PARSED, parsed) + + +def is_md5_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == "md5" + + +def is_crc64_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return False + return validate_content in ("crc64", "crc64-sm") + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py index 731e8c86bd92..b3a2a1750f04 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py @@ -22,6 +22,7 @@ from .._shared.request_handlers import validate_and_format_range_headers from .._shared.response_handlers import parse_length_from_content_range, process_storage_error from .._shared.constants import DEFAULT_MAX_CONCURRENCY +from .._shared.validation import is_md5_validation, CV_TYPE_PARSED if TYPE_CHECKING: from .._generated.aio.operations import FileOperations @@ -82,7 +83,7 @@ async def _download_chunk(self, chunk_start: int, chunk_end: int) -> bytes: range_header, range_validation = validate_and_format_range_headers( chunk_start, chunk_end, - check_content_md5=self.validate_content + check_content_md5=is_md5_validation(self.validate_content) ) try: _, response = await cast(Awaitable[Any], self.client.download( @@ -178,7 +179,7 @@ def __init__( config: "StorageConfiguration" = None, # type: ignore [assignment] start_range: Optional[int] = None, end_range: Optional[int] = None, - validate_content: bool = None, # type: ignore [assignment] + validate_content: CV_TYPE_PARSED = None, max_concurrency: Optional[int] = None, name: str = None, # type: ignore [assignment] path: str = None, # type: ignore [assignment] @@ -208,10 +209,14 @@ def __init__( self._etag = "" # The service only provides transactional MD5s for chunks under 4MB. - # If validate_content is on, get only self.MAX_CHUNK_GET_SIZE for the first + # If validate_content is using MD5, get only self.MAX_CHUNK_GET_SIZE for the first # chunk so a transactional MD5 can be retrieved. - self._first_get_size = self._config.max_single_get_size if not self._validate_content \ + self._first_get_size = ( + self._config.max_single_get_size + if not is_md5_validation(self._validate_content) else self._config.max_chunk_get_size + ) + initial_request_start = self._start_range or 0 if self._end_range is not None and self._end_range - initial_request_start < self._first_get_size: initial_request_end = self._end_range @@ -253,7 +258,7 @@ async def _initial_request(self): self._initial_range[1], start_range_required=False, end_range_required=False, - check_content_md5=self._validate_content) + check_content_md5=is_md5_validation(self._validate_content)) try: location_mode, response = cast(Tuple[Optional[str], Any], await self._client.download( diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.py index 45b1a96fefb9..636228460008 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.py @@ -48,6 +48,7 @@ from .._shared.request_handlers import add_metadata_headers, get_length from .._shared.response_handlers import process_storage_error, return_response_headers from .._shared.uploads_async import AsyncIterStreamer, FileChunkUploader, IterStreamer, upload_data_chunks +from .._shared.validation import CV_TYPE_PARSED, parse_validation_option from ._download_async import StorageStreamDownloader from ._lease_async import ShareLeaseClient from ._models import FileProperties, Handle, HandlesPaged @@ -65,7 +66,7 @@ async def _upload_file_helper( size: Optional[int], metadata: Optional[Dict[str, str]], content_settings: Optional["ContentSettings"], - validate_content: bool, + validate_content: CV_TYPE_PARSED, timeout: Optional[int], max_concurrency: int, file_settings: "StorageConfiguration", @@ -442,8 +443,13 @@ async def create_file( Restore - apply changes without further modification. :paramtype file_property_semantics: Optional[Literal["New", "Restore"]] - :keyword data: Optional initial data to upload, up to 4MB. - :paramtype data: bytes + :keyword bytes data: Optional initial data to upload, up to 4MB. + :keyword validate_content: + Only applicable when `data` is provided. + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[Literal['auto', 'crc64', 'md5']] :keyword int timeout: Sets the server-side timeout for the operation in seconds. For more details see https://learn.microsoft.com/rest/api/storageservices/setting-timeouts-for-file-service-operations. @@ -466,6 +472,10 @@ async def create_file( content_settings = kwargs.pop('content_settings', None) metadata = kwargs.pop('metadata', None) timeout = kwargs.pop('timeout', None) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) headers = kwargs.pop("headers", {}) headers.update(add_metadata_headers(metadata)) data = kwargs.pop('data', None) @@ -493,6 +503,7 @@ async def create_file( file_permission_key=permission_key, file_http_headers=file_http_headers, optionalbody=data, + validate_content=validate_content, content_length=len(data) if data is not None else None, lease_access_conditions=access_conditions, headers=headers, @@ -556,13 +567,11 @@ async def upload_file( :keyword ~azure.storage.fileshare.ContentSettings content_settings: ContentSettings object used to set file properties. Used to set content type, encoding, language, disposition, md5, and cache control. - :keyword bool validate_content: - If true, calculates an MD5 hash for each range of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https as https (the default) will - already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword int max_concurrency: Maximum number of parallel connections to use when transferring the file in chunks. This option does not affect the underlying connection pool, and may @@ -604,7 +613,10 @@ async def upload_file( max_concurrency = kwargs.pop('max_concurrency', None) if max_concurrency is None: max_concurrency = DEFAULT_MAX_CONCURRENCY - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) progress_hook = kwargs.pop('progress_hook', None) timeout = kwargs.pop('timeout', None) encoding = kwargs.pop('encoding', 'UTF-8') @@ -873,15 +885,11 @@ async def download_file( Maximum number of parallel connections to use when transferring the file in chunks. This option does not affect the underlying connection pool, and may require a separate configuration of the connection pool. - :keyword bool validate_content: - If true, calculates an MD5 hash for each chunk of the file. The storage - service checks the hash of the content that has arrived with the hash - that was sent. This is primarily valuable for detecting bitflips on - the wire if using http instead of https as https (the default) will - already validate. Note that this MD5 hash is not stored with the - file. Also note that if enabled, the memory-efficient upload algorithm - will not be used, because computing the MD5 hash requires buffering - entire blocks, and doing so defeats the purpose of the memory-efficient algorithm. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword lease: Required if the file has an active lease. Value can be a ShareLeaseClient object or the lease ID as a string. @@ -924,12 +932,14 @@ async def download_file( range_end = offset + length - 1 # Service actually uses an end-range inclusive index access_conditions = get_access_conditions(kwargs.pop('lease', None)) + validate_content = parse_validation_option(kwargs.pop('validate_content', None)) downloader = StorageStreamDownloader( client=self._client.file, config=self._config, start_range=offset, end_range=range_end, + validate_content=validate_content, name=self.file_name, path='/'.join(self.file_path), share=self.share_name, @@ -1295,13 +1305,11 @@ async def upload_range( :param int length: Number of bytes to use for uploading a section of the file. The range can be up to 4 MB in size. - :keyword bool validate_content: - If true, calculates an MD5 hash of the page content. The storage - service checks the hash of the content that has arrived - with the hash that was sent. This is primarily valuable for detecting - bitflips on the wire if using http instead of https as https (the default) - will already validate. Note that this MD5 hash is not stored with the - file. + :keyword validate_content: + Enables checksum validation for the transfer. Any checksum calculated is NOT stored with the file. + Choose "auto" (let the SDK choose the best algorithm), "crc64", or "md5". The use of bool is deprecated. + NOTE: The use of "auto" or "crc64" requires the `azure-storage-extensions` package to be installed. + :paramtype validate_content: Union[bool, Literal['auto', 'crc64', 'md5']] :keyword file_last_write_mode: If the file last write time should be preserved or overwritten. Possible values are "preserve" or "now". If not specified, file last write time will be changed to @@ -1330,7 +1338,10 @@ async def upload_range( :returns: File-updated property dict (Etag and last modified). :rtype: Dict[str, Any] """ - validate_content = kwargs.pop('validate_content', False) + validate_content = parse_validation_option( + kwargs.pop('validate_content', None), + force_structured_message=True + ) timeout = kwargs.pop('timeout', None) encoding = kwargs.pop('encoding', 'UTF-8') file_last_write_mode = kwargs.pop('file_last_write_mode', None) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.pyi b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.pyi index 3f70def69364..905f5a2bb002 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.pyi +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_file_client_async.pyi @@ -135,6 +135,7 @@ class ShareFileClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin): file_mode: Optional[str] = None, file_property_semantics: Optional[Literal["New", "Restore"]] = None, data: Optional[bytes] = None, + validate_content: Optional[Literal['auto', 'crc64', 'md5']] = None, timeout: Optional[int] = None, **kwargs: Any ) -> Dict[str, Any]: ... @@ -151,7 +152,7 @@ class ShareFileClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin): file_change_time: Optional[Union[str, datetime]] = None, metadata: Optional[Dict[str, str]] = None, content_settings: Optional[ContentSettings] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, max_concurrency: Optional[int] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, progress_hook: Optional[Callable[[int, Optional[int]], Awaitable[None]]] = None, @@ -199,7 +200,7 @@ class ShareFileClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin): length: Optional[int] = None, *, max_concurrency: Optional[int] = None, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, progress_hook: Optional[Callable[[int, Optional[int]], Awaitable[None]]] = None, decompress: Optional[bool] = None, @@ -270,7 +271,7 @@ class ShareFileClient(AsyncStorageAccountHostsMixin, StorageAccountHostsMixin): offset: int, length: int, *, - validate_content: bool = False, + validate_content: Optional[Union[bool, Literal['auto', 'crc64', 'md5']]] = None, file_last_write_mode: Optional[Literal["preserve", "now"]] = None, lease: Optional[Union[ShareLeaseClient, str]] = None, encoding: str = "UTF-8", diff --git a/sdk/storage/azure-storage-file-share/dev_requirements.txt b/sdk/storage/azure-storage-file-share/dev_requirements.txt index 60d588f5e1e1..b42a7dff1a22 100644 --- a/sdk/storage/azure-storage-file-share/dev_requirements.txt +++ b/sdk/storage/azure-storage-file-share/dev_requirements.txt @@ -2,4 +2,5 @@ ../../core/azure-core ../../identity/azure-identity ../azure-storage-blob +../azure-storage-extensions aiohttp>=3.13.5 diff --git a/sdk/storage/azure-storage-file-share/tests/test_content_validation.py b/sdk/storage/azure-storage-file-share/tests/test_content_validation.py new file mode 100644 index 000000000000..6afdfa1194ac --- /dev/null +++ b/sdk/storage/azure-storage-file-share/tests/test_content_validation.py @@ -0,0 +1,277 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from devtools_testutils import is_live, recorded_by_proxy +from devtools_testutils.storage import GenericTestProxyParametrize1, StorageRecordedTestCase +from settings.testcase import FileSharePreparer + +from azure.storage.fileshare import ShareClient, ShareServiceClient + + +def assert_content_md5(request): + if request.http_request.query.get("comp") == "range": + assert request.http_request.headers.get("Content-MD5") is not None + + +def assert_content_md5_get(response): + assert response.http_request.headers.get("x-ms-range-get-content-md5") == "true" + assert response.http_response.headers.get("Content-MD5") is not None + + +def assert_structured_message(request): + if request.http_request.query.get("comp") == "range": + assert request.http_request.headers.get("x-ms-structured-body") is not None + + +def assert_structured_message_get(response): + assert response.http_request.headers.get("x-ms-structured-body") is not None + assert response.http_response.headers.get("x-ms-structured-body") is not None + + +class TestStorageContentValidation(StorageRecordedTestCase): + share_client: ShareClient + + def _setup(self, account_name): + token_credential = self.get_credential(ShareServiceClient) + self.ssc = ShareServiceClient( + self.account_url(account_name, "file"), + credential=token_credential, + token_intent="backup", + logging_enable=True, + ) + self.share_client = self.ssc.get_share_client(self.get_resource_name("utshare")) + self.share_client.create_share() + + def teardown_method(self, _): + if self.share_client and is_live(): + try: + self.share_client.delete_share() + except: + pass + + def _get_file_reference(self): + return self.get_resource_name("file") + + @FileSharePreparer() + @pytest.mark.parametrize("a", ["auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_create_file_with_data(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 + + # Act + file.create_file(len(data), data=data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_file(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 + + # Act + file.upload_file(data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_file_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.share_client._config.max_range_size = 1024 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abcde" * 512 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Act + file.upload_file(data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_range(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data1 = b"abcde" * 512 + data2 = "你好世界" * 10 + encoded2 = data2.encode("utf-16") + + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 + + # Act + file.create_file(len(data1) + len(encoded2)) + file.upload_range(data1, 0, len(data1), validate_content=a, raw_request_hook=assert_method) + file.upload_range( + data2, len(data1), len(encoded2), encoding="utf-16", validate_content=a, raw_request_hook=assert_method + ) + + # Assert + content = file.download_file() + assert content.readall() == data1 + encoded2 + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_upload_range_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + + data = b"abcd" * 1030 # 4 KiB + 24 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Act + file.create_file(len(data)) + # This is not an officially supported data type for upload_range, but we should still be able to handle it + # as there are probably users out there using it. + file.upload_range(BytesIO(data), 0, len(data), validate_content=a, raw_request_hook=assert_method) + + # Assert + content = file.download_file() + assert content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + file.upload_file(data) + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get + + # Act + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + content = downloader.readall() + + stream = BytesIO() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.share_client._config.max_single_get_size = 512 + self.share_client._config.max_chunk_get_size = 512 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + b"abcde" + file.upload_file(data) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + content = downloader.readall() + + stream = BytesIO() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + read_content = bytearray() + downloader = file.download_file(validate_content=a, raw_response_hook=assert_method) + for chunk in downloader.chunks(): + read_content.extend(chunk) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_download_file_chunks_partial(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + self.share_client._config.max_single_get_size = 512 + self.share_client._config.max_chunk_get_size = 512 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + b"abcde" + file.upload_file(data) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = file.download_file(offset=10, length=1000, validate_content=a, raw_response_hook=assert_method) + content = downloader.readall() + + stream = BytesIO() + downloader = file.download_file(offset=512, length=1024, validate_content=a, raw_response_hook=assert_method) + downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[512:1536] + + @FileSharePreparer() + @pytest.mark.live_test_only + def test_download_file_large_chunks(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + self._setup(storage_account_name) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.share_client._config.max_chunk_get_size = 5 * 1024 * 1024 # 5 MiB + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 + file.upload_file(data, max_concurrency=5) + + # Act + downloader = file.download_file(validate_content="crc64", max_concurrency=5) + content = downloader.readall() + + downloader = file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") + partial = downloader.readall() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py b/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py new file mode 100644 index 000000000000..628356cbca07 --- /dev/null +++ b/sdk/storage/azure-storage-file-share/tests/test_content_validation_async.py @@ -0,0 +1,276 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from io import BytesIO + +import pytest +from devtools_testutils import is_live +from devtools_testutils.aio import recorded_by_proxy_async +from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase, GenericTestProxyParametrize1 +from settings.testcase import FileSharePreparer +from test_content_validation import ( + assert_content_md5, + assert_content_md5_get, + assert_structured_message, + assert_structured_message_get, +) + +from azure.storage.fileshare import ShareClient as SyncShareClient +from azure.storage.fileshare.aio import ShareServiceClient + + +class TestStorageContentValidationAsync(AsyncStorageRecordedTestCase): + async def _setup(self, account_name): + token_credential = self.get_credential(ShareServiceClient, is_async=True) + self.ssc = ShareServiceClient( + self.account_url(account_name, "file"), + credential=token_credential, + token_intent="backup", + logging_enable=True, + ) + self.share_client = self.ssc.get_share_client(self.get_resource_name("utshare")) + await self.share_client.create_share() + + def teardown_method(self, _): + if not is_live(): + return + if self.share_client: + sync_credential = self.get_credential(SyncShareClient, is_async=False) + sync_share_client = SyncShareClient( + self.account_url(self.share_client.account_name, "file"), + self.share_client.share_name, + credential=sync_credential, + token_intent="backup", + ) + try: + sync_share_client.delete_share() + except: + pass + + def _get_file_reference(self): + return self.get_resource_name("file") + + @FileSharePreparer() + @pytest.mark.parametrize("a", ["auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_create_file_with_data(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 + + # Act + await file.create_file(len(data), data=data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_file(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 + + # Act + await file.upload_file(data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_file_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.share_client._config.max_range_size = 1024 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abcde" * 512 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Act + await file.upload_file(data, validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_range(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data1 = b"abcde" * 512 + data2 = "你好世界" * 10 + encoded2 = data2.encode("utf-16") + + assert_method = assert_structured_message if a in ("auto", "crc64") else assert_content_md5 + + # Act + await file.create_file(len(data1) + len(encoded2)) + await file.upload_range(data1, 0, len(data1), validate_content=a, raw_request_hook=assert_method) + await file.upload_range( + data2, len(data1), len(encoded2), encoding="utf-16", validate_content=a, raw_request_hook=assert_method + ) + + # Assert + content = await file.download_file() + assert await content.readall() == data1 + encoded2 + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_upload_range_streaming(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + + data = b"abcd" * 1030 # 4 KiB + 24 + assert_method = assert_structured_message if a == "crc64" else assert_content_md5 + + # Act + await file.create_file(len(data)) + # This is not an officially supported data type for upload_range, but we should still be able to handle it + # as there are probably users out there using it. + await file.upload_range(BytesIO(data), 0, len(data), validate_content=a, raw_request_hook=assert_method) + + # Assert + content = await file.download_file() + assert await content.readall() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "auto", "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + await file.upload_file(data) + assert_method = assert_structured_message_get if a in ("auto", "crc64") else assert_content_md5_get + + # Act + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + content = await downloader.readall() + + stream = BytesIO() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data + assert stream.read() == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file_chunks(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.share_client._config.max_single_get_size = 512 + self.share_client._config.max_chunk_get_size = 512 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + b"abcde" + await file.upload_file(data) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + content = await downloader.readall() + + stream = BytesIO() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + await downloader.readinto(stream) + stream.seek(0) + + read_content = bytearray() + downloader = await file.download_file(validate_content=a, raw_response_hook=assert_method) + async for chunk in downloader.chunks(): + read_content.extend(chunk) + + # Assert + assert content == data + assert stream.read() == data + assert read_content == data + + @FileSharePreparer() + @pytest.mark.parametrize("a", [True, "md5", "crc64"]) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_download_file_chunks_partial(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + self.share_client._config.max_single_get_size = 512 + self.share_client._config.max_chunk_get_size = 512 + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abc" * 512 + b"abcde" + await file.upload_file(data) + assert_method = assert_structured_message_get if a == "crc64" else assert_content_md5_get + + # Act + downloader = await file.download_file( + offset=10, length=1000, validate_content=a, raw_response_hook=assert_method + ) + content = await downloader.readall() + + stream = BytesIO() + downloader = await file.download_file( + offset=512, length=1024, validate_content=a, raw_response_hook=assert_method + ) + await downloader.readinto(stream) + stream.seek(0) + + # Assert + assert content == data[10:1010] + assert stream.read() == data[512:1536] + + @FileSharePreparer() + @pytest.mark.live_test_only + async def test_download_file_large_chunks(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + await self._setup(storage_account_name) + # The service will use 4 MiB for structured message chunk size, so make chunk size larger + self.share_client._config.max_chunk_get_size = 5 * 1024 * 1024 # 5 MiB + file = self.share_client.get_file_client(self._get_file_reference()) + data = b"abcde" * 30 * 1024 * 1024 + b"abcde" # 150 MiB + 5 + await file.upload_file(data, max_concurrency=5) + + # Act + downloader = await file.download_file(validate_content="crc64", max_concurrency=5) + content = await downloader.readall() + + downloader = await file.download_file(offset=5 * 1024 * 1024, length=25 * 1024 * 1024, validate_content="crc64") + partial = await downloader.readall() + + # Assert + assert content == data + assert partial == data[5 * 1024 * 1024 : 30 * 1024 * 1024] diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py index 54446f7fc5b4..993c2cfb354f 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/base_client_async.py @@ -36,12 +36,12 @@ from .parser import DEVSTORE_ACCOUNT_KEY, _get_development_storage_endpoint from .policies import ( QueueMessagePolicy, - StorageContentValidation, StorageHeadersPolicy, StorageHosts, StorageRequestHook, ) from .policies_async import ( + AsyncContentValidationPolicy, AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook, ) @@ -161,7 +161,7 @@ def _create_pipeline( QueueMessagePolicy(), config.proxy_policy, config.user_agent_policy, - StorageContentValidation(), + AsyncContentValidationPolicy(), ContentDecodePolicy(response_encoding="utf-8"), AsyncRedirectPolicy(**kwargs), StorageHosts(hosts=hosts, **kwargs), diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py index cc275ae4af61..f4f602d1c669 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies.py @@ -5,14 +5,13 @@ # -------------------------------------------------------------------------- import base64 -import hashlib import logging import random import re import uuid -from io import SEEK_SET, UnsupportedOperation +from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -32,8 +31,20 @@ ) from .authentication import AzureSigningError, StorageHttpChallenge -from .constants import DEFAULT_OAUTH_SCOPE +from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode +from .streams import ( + StructuredMessageDecoder, + StructuredMessageEncodeStream, + StructuredMessageProperties, +) +from .validation import ( + CV_TYPE_ERROR_MSG, + calculate_content_md5, + calculate_crc64_bytes, + is_crc64_validation, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -44,9 +55,15 @@ _LOGGER = logging.getLogger(__name__) +CONTENT_LENGTH_HEADER = "Content-Length" +MD5_HEADER = "Content-MD5" +CRC64_HEADER = "x-ms-content-crc64" +SM_HEADER = "x-ms-structured-body" +SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" +SM_LENGTH_HEADER = "x-ms-structured-content-length" -def encode_base64(data): +def encode_base64(data: Union[bytes, str]) -> str: if isinstance(data, str): data = data.encode("utf-8") encoded = base64.b64encode(data) @@ -106,11 +123,15 @@ def is_retry(response, mode): # pylint: disable=too-many-return-statements return False -def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): +def is_checksum_retry(response) -> bool: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True @@ -372,64 +393,113 @@ def send(self, request: "PipelineRequest") -> "PipelineResponse": return response -class StorageContentValidation(SansIOHTTPPolicy): - """A simple policy that sends the given headers - with the request. +def _prepare_content_validation(request: "PipelineRequest") -> None: + validate_content = request.context.options.pop("validate_content", False) + if not validate_content: + return - This will overwrite any headers already defined in the request. - """ - - header_name = "Content-MD5" - - def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument - super(StorageContentValidation, self).__init__() + # Download + if request.http_request.method == "GET": + if is_crc64_validation(validate_content): + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 - @staticmethod - def get_content_md5(data): + # Upload + else: # Since HTTP does not differentiate between no content and empty content, # we have to perform a None check. - data = data or b"" - md5 = hashlib.md5() # nosec - if isinstance(data, bytes): - md5.update(data) - elif hasattr(data, "read"): - pos = 0 - try: - pos = data.tell() - except: # pylint: disable=bare-except - pass - for chunk in iter(lambda: data.read(4096), b""): - md5.update(chunk) - try: - data.seek(pos, SEEK_SET) - except (AttributeError, IOError) as exc: - raise ValueError("Data should be bytes or a seekable file-like object.") from exc - else: - raise ValueError("Data should be bytes or a seekable file-like object.") + data = request.http_request.data or b"" + if is_md5_validation(validate_content): + computed_md5 = encode_base64(calculate_content_md5(data)) + request.http_request.headers[MD5_HEADER] = computed_md5 + request.context["validate_content_md5"] = computed_md5 + + elif is_crc64_validation(validate_content): + # For crc64-sm, force structured message even for bytes + if validate_content == "crc64-sm" and isinstance(data, bytes): + data = BytesIO(data) + + if isinstance(data, bytes): + request.http_request.headers[CRC64_HEADER] = encode_base64(calculate_crc64_bytes(data)) + elif hasattr(data, "read"): + content_length = int(request.http_request.headers.get(CONTENT_LENGTH_HEADER)) + # Wrap data in structured message stream and adjust HTTP request + sm_stream = StructuredMessageEncodeStream(data, content_length, StructuredMessageProperties.CRC64) + request.http_request.data = sm_stream + request.http_request.headers[CONTENT_LENGTH_HEADER] = str(len(sm_stream)) + request.http_request.headers[SM_LENGTH_HEADER] = str(content_length) + request.http_request.headers[SM_HEADER] = SM_HEADER_V1_CRC64 + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + request.context["validate_content"] = validate_content + - return md5.digest() +def _validate_content_response( + request: "PipelineRequest", + response: "PipelineResponse", + decoder_cls: type, +) -> None: + validate_content = response.context.get("validate_content", False) + if not validate_content: + return + + if is_md5_validation(validate_content) and response.http_response.headers.get("content-md5"): + computed_md5 = request.context.get("validate_content_md5") or encode_base64( + calculate_content_md5(response.http_response.body()) + ) + if response.http_response.headers["content-md5"] != computed_md5: + raise AzureError( + ( + f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " + f"computed value is '{computed_md5}'." + ), + response=response.http_response, + ) + + elif is_crc64_validation(validate_content): + # For upload and download verify structured message header present in response if provided in request. + sm_request = request.http_request.headers.get(SM_HEADER) + sm_response = response.http_response.headers.get(SM_HEADER) + if sm_request != sm_response: + raise AzureError( + ( + f"Expected structured message header in response does not match request. " + f"Request: {sm_request}, Response: {sm_response}" + ), + response=response.http_response, + ) + + if response.http_request.method == "GET": + # Raises exception if missing + content_length = int(response.http_response.headers[CONTENT_LENGTH_HEADER]) + + # Patch response to return response iterator wrapped in structured message decoder + original_stream_download = response.http_response.stream_download + + def wrapped_stream_download(*args, **kwargs): + iterator = original_stream_download(*args, **kwargs) + decoder = decoder_cls(iterator, content_length, block_size=DATA_BLOCK_SIZE) + decoder.request = iterator.request # type: ignore + decoder.response = iterator.response # type: ignore + return decoder + + response.http_response.stream_download = wrapped_stream_download + + +class StorageContentValidation(SansIOHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() def on_request(self, request: "PipelineRequest") -> None: - validate_content = request.context.options.pop("validate_content", False) - if validate_content and request.http_request.method != "GET": - computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data)) - request.http_request.headers[self.header_name] = computed_md5 - request.context["validate_content_md5"] = computed_md5 - request.context["validate_content"] = validate_content + _prepare_content_validation(request) def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None: - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): - computed_md5 = request.context.get("validate_content_md5") or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) - ) - if response.http_response.headers["content-md5"] != computed_md5: - raise AzureError( - ( - f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', " - f"computed value is '{computed_md5}'." - ), - response=response.http_response, - ) + _validate_content_response(request, response, StructuredMessageDecoder) class StorageRetryPolicy(HTTPPolicy): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index f4d235c082d8..e1d13b1a83fa 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -19,11 +19,17 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE from .policies import ( + _prepare_content_validation, + _validate_content_response, encode_base64, is_retry, - StorageContentValidation, StorageRetryPolicy, ) +from .streams_async import AsyncStructuredMessageDecoder +from .validation import ( + calculate_content_md5, + is_md5_validation, +) if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -45,21 +51,52 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): - # retry if invalid content md5 - if response.context.get("validate_content", False) and response.http_response.headers.get("content-md5"): + validate_content = response.context.get("validate_content", False) + if not validate_content: + return False + + # Legacy code - evaluate retry only on validate_content=True + if validate_content is True and response.http_response.headers.get("content-md5"): if hasattr(response.http_response, "load_body"): try: await response.http_response.load_body() # Load the body in memory and close the socket except (StreamClosedError, StreamConsumedError): pass computed_md5 = response.http_request.headers.get("content-md5", None) or encode_base64( - StorageContentValidation.get_content_md5(response.http_response.body()) + calculate_content_md5(response.http_response.body()) ) if response.http_response.headers["content-md5"] != computed_md5: return True return False +class AsyncContentValidationPolicy(AsyncHTTPPolicy): + """A pipeline policy that performs content validation on uploads and downloads when enabled by the user. + This is enabled by setting the "validate_content" key in the request context. When enabled, this policy will + calculate and verify content checksums for uploads and downloads, and raise an exception if a mismatch is detected. + """ + + def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument + super().__init__() + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + _prepare_content_validation(request) + + response = await self.next.send(request) + + validate_content = response.context.get("validate_content", False) + if validate_content and is_md5_validation(validate_content): + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass + + _validate_content_response(request, response, AsyncStructuredMessageDecoder) + + return response + + class AsyncStorageResponseHook(AsyncHTTPPolicy): def __init__(self, **kwargs): diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py new file mode 100644 index 000000000000..cb745693921c --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py @@ -0,0 +1,565 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import sys +from enum import auto, Enum, IntFlag +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET +from typing import IO, Iterator, Optional + +from .validation import calculate_crc64 + +DEFAULT_MESSAGE_VERSION = 1 +DEFAULT_SEGMENT_SIZE = 4 * 1024 * 1024 + + +class StructuredMessageConstants: + V1_HEADER_LENGTH = 13 + V1_SEGMENT_HEADER_LENGTH = 10 + CRC64_LENGTH = 8 + + +class StructuredMessageProperties(IntFlag): + NONE = 0 + CRC64 = auto() + + +class SMRegion(Enum): + MESSAGE_HEADER = 1 + SEGMENT_HEADER = 2 + SEGMENT_CONTENT = 3 + SEGMENT_FOOTER = 4 + MESSAGE_FOOTER = 5 + + +def generate_message_header(version: int, size: int, flags: StructuredMessageProperties, num_segments: int) -> bytes: + return ( + version.to_bytes(1, "little") + + size.to_bytes(8, "little") + + flags.to_bytes(2, "little") + + num_segments.to_bytes(2, "little") + ) + + +def generate_segment_header(number: int, size: int) -> bytes: + return number.to_bytes(2, "little") + size.to_bytes(8, "little") + + +def parse_message_header(data: bytes, expected_message_length: int) -> tuple[int, StructuredMessageProperties, int]: + version = data[0] + if version != 1: + raise ValueError(f"The structured message version is not supported: {version}") + message_length = int.from_bytes(data[1:9], "little") + if message_length != expected_message_length: + raise ValueError( + f"Structured message length {message_length} " f"did not match content length {expected_message_length}" + ) + flags = StructuredMessageProperties(int.from_bytes(data[9:11], "little")) + num_segments = int.from_bytes(data[11:13], "little") + return version, flags, num_segments + + +def parse_segment_header(data: bytes, expected_segment_number: int) -> tuple[int, int]: + segment_number = int.from_bytes(data[0:2], "little") + if segment_number != expected_segment_number: + raise ValueError(f"Structured message segment number invalid or out of order {segment_number}") + segment_content_length = int.from_bytes(data[2:10], "little") + return segment_number, segment_content_length + + +class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instance-attributes + message_version: int + content_length: int + message_length: int + flags: StructuredMessageProperties + + _inner_stream: IO[bytes] + _segment_size: int + _num_segments: int + + _initial_content_position: Optional[int] + """Initial position of the inner stream, None if it did not implement tell()""" + _content_offset: int + _current_segment_number: int + _current_region: SMRegion + _current_region_length: int + _current_region_offset: int + + _message_crc64: int + _segment_crc64s: dict[int, int] + + def __init__( + self, + inner_stream: IO[bytes], + content_length: int, + flags: StructuredMessageProperties, + *, + segment_size: int = DEFAULT_SEGMENT_SIZE, + ) -> None: + if segment_size < 1: + raise ValueError("Segment size must be greater than 0.") + + self.message_version = DEFAULT_MESSAGE_VERSION + self.content_length = content_length + self.flags = flags + + self._inner_stream = inner_stream + self._segment_size = segment_size + self._num_segments = math.ceil(self.content_length / self._segment_size) or 1 + + self.message_length = self._calculate_message_length() + + self._content_offset = 0 + self._current_segment_number = 0 # Will be incremented before first segment + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + + self._message_crc64 = 0 + self._segment_crc64s = {} + + # Attempt to get starting position of inner stream. If we can't, this stream will not be seekable + try: + self._initial_content_position = self._inner_stream.tell() + except (AttributeError, UnsupportedOperation, OSError): + self._initial_content_position = None + super().__init__() + + @property + def _message_header_length(self) -> int: + return StructuredMessageConstants.V1_HEADER_LENGTH + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + def _update_current_region_length(self) -> None: + if self._current_region == SMRegion.MESSAGE_HEADER: + self._current_region_length = self._message_header_length + elif self._current_region == SMRegion.SEGMENT_HEADER: + self._current_region_length = self._segment_header_length + elif self._current_region == SMRegion.SEGMENT_CONTENT: + # Last segment size is remaining content + if self._current_segment_number == self._num_segments: + self._current_region_length = self.content_length - ( + (self._current_segment_number - 1) * self._segment_size + ) + else: + self._current_region_length = self._segment_size + elif self._current_region == SMRegion.SEGMENT_FOOTER: + self._current_region_length = self._segment_footer_length + elif self._current_region == SMRegion.MESSAGE_FOOTER: + self._current_region_length = self._message_footer_length + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def __len__(self): + return self.message_length + + @property + def closed(self) -> bool: + return self._inner_stream.closed + + def close(self) -> None: + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + try: + # Only seekable if the inner stream is and we could get its initial position + return self._inner_stream.seekable() and self._initial_content_position is not None + except (AttributeError, UnsupportedOperation, OSError): + return False + + def tell(self) -> int: + if self._current_region == SMRegion.MESSAGE_HEADER: + return self._current_region_offset + if self._current_region == SMRegion.SEGMENT_HEADER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + if self._current_region == SMRegion.SEGMENT_CONTENT: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + ) + if self._current_region == SMRegion.SEGMENT_FOOTER: + return ( + self._message_header_length + + self._content_offset + + (self._current_segment_number - 1) * (self._segment_header_length + self._segment_footer_length) + + self._segment_header_length + + self._current_region_offset + ) + if self._current_region == SMRegion.MESSAGE_FOOTER: + return ( + self._message_header_length + + self._content_offset + + self._current_segment_number * (self._segment_header_length + self._segment_footer_length) + + self._current_region_offset + ) + + raise ValueError(f"Invalid SMRegion {self._current_region}") + + def seek(self, offset: int, whence: int = SEEK_SET) -> int: + if not self.seekable(): + raise UnsupportedOperation("Inner stream is not seekable.") + + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") + + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") + + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0: + return b"" + if size < 0: + size = sys.maxsize + + count = 0 + output = BytesIO() + + while count < size and self.tell() < self.message_length: + remaining = size - count + if self._current_region in ( + SMRegion.MESSAGE_HEADER, + SMRegion.SEGMENT_HEADER, + SMRegion.SEGMENT_FOOTER, + SMRegion.MESSAGE_FOOTER, + ): + count += self._read_metadata_region(self._current_region, remaining, output) + elif self._current_region == SMRegion.SEGMENT_CONTENT: + count += self._read_content(remaining, output) + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + return output.getvalue() + + def _calculate_message_length(self) -> int: + length = self._message_header_length + length += (self._segment_header_length + self._segment_footer_length) * self._num_segments + length += self.content_length + length += self._message_footer_length + return length + + def _get_metadata_region(self, region: SMRegion) -> bytes: + if region == SMRegion.MESSAGE_HEADER: + return generate_message_header( + self.message_version, + self.message_length, + self.flags, + self._num_segments, + ) + + if region == SMRegion.SEGMENT_HEADER: + segment_size = min(self._segment_size, self.content_length - self._content_offset) + return generate_segment_header(self._current_segment_number, segment_size) + + if region == SMRegion.SEGMENT_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._segment_crc64s[self._current_segment_number].to_bytes( + StructuredMessageConstants.CRC64_LENGTH, "little" + ) + return b"" + + if region == SMRegion.MESSAGE_FOOTER: + if StructuredMessageProperties.CRC64 in self.flags: + return self._message_crc64.to_bytes(StructuredMessageConstants.CRC64_LENGTH, "little") + return b"" + + raise ValueError(f"Invalid metadata SMRegion {self._current_region}") + + def _advance_region(self, current: SMRegion): + self._current_region_offset = 0 + + if current == SMRegion.MESSAGE_HEADER: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + elif current == SMRegion.SEGMENT_HEADER: + self._current_region = SMRegion.SEGMENT_CONTENT + elif current == SMRegion.SEGMENT_CONTENT: + self._current_region = SMRegion.SEGMENT_FOOTER + elif current == SMRegion.SEGMENT_FOOTER: + # If we're at the end of the content + if self._content_offset == self.content_length: + self._current_region = SMRegion.MESSAGE_FOOTER + else: + self._current_region = SMRegion.SEGMENT_HEADER + self._increment_current_segment() + else: + raise ValueError(f"Invalid SMRegion {self._current_region}") + + self._update_current_region_length() + + def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> int: + metadata = self._get_metadata_region(region) + + read_size = min(size, self._current_region_length - self._current_region_offset) + content = metadata[self._current_region_offset : self._current_region_offset + read_size] + output.write(content) + + self._current_region_offset += read_size + if ( + self._current_region_offset == self._current_region_length + and self._current_region != SMRegion.MESSAGE_FOOTER + ): + self._advance_region(region) + + return read_size + + def _read_content(self, size: int, output: BytesIO) -> int: + read_size = min(size, self._current_region_length - self._current_region_offset) + + content = self._inner_stream.read(read_size) + if len(content) != read_size: + raise ValueError("Content ended early when encoding structured message.") + output.write(content) + + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) + + self._content_offset += read_size + self._current_region_offset += read_size + if self._current_region_offset == self._current_region_length: + self._advance_region(SMRegion.SEGMENT_CONTENT) + + return read_size + + def _increment_current_segment(self): + self._current_segment_number += 1 + if StructuredMessageProperties.CRC64 in self.flags: + # If seek was used, we may already have this segment's CRC (could be partial), otherwise initialize to 0 + self._segment_crc64s.setdefault(self._current_segment_number, 0) + + +class StructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: Iterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: Iterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __iter__(self): + return self + + def __next__(self) -> bytes: + data = self.read(self._block_size) + if not data: + raise StopIteration + return data + + def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + self._read_message_header() + self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = next(self._inner_iterator) + except StopIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + def _read_message_header(self) -> None: + header_data = self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + def _read_segment_header(self) -> None: + header_data = self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py new file mode 100644 index 000000000000..9fedc055a623 --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams_async.py @@ -0,0 +1,210 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import sys +from io import BytesIO, IOBase +from typing import AsyncIterator + +from .streams import ( + StructuredMessageConstants, + StructuredMessageProperties, + parse_message_header, + parse_segment_header, +) +from .validation import calculate_crc64 + + +class AsyncStructuredMessageDecoder(IOBase): # pylint: disable=too-many-instance-attributes + + message_version: int + """The version of the structured message.""" + message_length: int + """The total length of the structured message.""" + flags: StructuredMessageProperties + """The properties included in the structured message.""" + num_segments: int + """The number of message segments.""" + + _inner_iterator: AsyncIterator[bytes] + _buffer: bytes + _message_offset: int + _message_crc64: int + _segment_number: int + _segment_crc64: int + _segment_content_length: int + _segment_content_offset: int + _block_size: int + + def __init__( + self, + inner_iterator: AsyncIterator[bytes], + content_length: int, + *, + block_size: int = 4096, + ) -> None: + self.message_length = content_length + # The stream should be at least long enough to hold minimum header length + if self.message_length < StructuredMessageConstants.V1_HEADER_LENGTH: + raise ValueError("Content not long enough to contain a valid message header.") + + self._inner_iterator = inner_iterator + self._buffer = b"" + self._message_offset = 0 + self._message_crc64 = 0 + + self._segment_number = 0 + self._segment_crc64 = 0 + self._segment_content_length = 0 + self._segment_content_offset = 0 + self._block_size = block_size + super().__init__() + + @property + def content_length(self) -> int: + return self.message_length + + @property + def _segment_header_length(self) -> int: + return StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH + + @property + def _segment_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _message_footer_length(self) -> int: + return StructuredMessageConstants.CRC64_LENGTH if StructuredMessageProperties.CRC64 in self.flags else 0 + + @property + def _end_of_segment_content(self) -> bool: + return self._segment_content_offset == self._segment_content_length + + def readable(self) -> bool: + return True + + def seekable(self) -> bool: + return False + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + data = await self.read(self._block_size) + if not data: + raise StopAsyncIteration + return data + + async def read(self, size: int = -1) -> bytes: + if self.closed: # pylint: disable=using-constant-test + raise ValueError("Stream is closed") + + if size == 0 or self._message_offset >= self.message_length: + return b"" + if size < 0: + size = sys.maxsize + + # On the first read, read message header and first segment header + if self._message_offset == 0: + await self._read_message_header() + await self._read_segment_header() + + # Special case for 0 length content + if self._end_of_segment_content: + await self._read_segment_footer() + if self.num_segments > 1: + raise ValueError("First message segment was empty but more segments were detected.") + await self._read_message_footer() + return b"" + + count = 0 + content = BytesIO() + while count < size and not (self._end_of_segment_content and self._message_offset == self.message_length): + if self._end_of_segment_content: + await self._read_segment_header() + + segment_remaining = self._segment_content_length - self._segment_content_offset + read_size = min(segment_remaining, size - count) + + segment_content = await self._read_from_inner(read_size) + content.write(segment_content) + + # Update the running CRC64 for the segment and message + if StructuredMessageProperties.CRC64 in self.flags: + self._segment_crc64 = calculate_crc64(segment_content, self._segment_crc64) + self._message_crc64 = calculate_crc64(segment_content, self._message_crc64) + + self._segment_content_offset += read_size + self._message_offset += read_size + count += read_size + + if self._end_of_segment_content: + await self._read_segment_footer() + # If we are on the last segment, also read the message footer + if self._segment_number == self.num_segments: + await self._read_message_footer() + + # One final check to ensure if we think we've reached the end of the stream + # that the current segment number matches the total. + if self._message_offset == self.message_length and self._segment_number != self.num_segments: + raise ValueError("Invalid structured message data detected.") + + return content.getvalue() + + async def _read_from_inner(self, size: int) -> bytes: + while len(self._buffer) < size: + try: + chunk = await self._inner_iterator.__anext__() + except StopAsyncIteration: + break + self._buffer += chunk + + if len(self._buffer) < size: + raise ValueError("Invalid structured message data detected. Stream content incomplete.") + + data = self._buffer[:size] + self._buffer = self._buffer[size:] + return data + + async def _read_message_header(self) -> None: + header_data = await self._read_from_inner(StructuredMessageConstants.V1_HEADER_LENGTH) + self.message_version, self.flags, self.num_segments = parse_message_header(header_data, self.message_length) + self._message_offset += StructuredMessageConstants.V1_HEADER_LENGTH + + async def _read_message_footer(self) -> None: + # Sanity check: There should only be self._message_footer_length (could be 0) bytes left to consume. + # If not, it is likely the message header contained incorrect info. + if self.message_length - self._message_offset != self._message_footer_length: + raise ValueError("Invalid structured message data detected.") + + if StructuredMessageProperties.CRC64 in self.flags: + message_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._message_crc64 != int.from_bytes(message_crc, "little"): + raise ValueError( + "CRC64 mismatch detected in message trailer. All data read should be considered invalid." + ) + + self._message_offset += self._message_footer_length + + async def _read_segment_header(self) -> None: + header_data = await self._read_from_inner(self._segment_header_length) + self._segment_number, self._segment_content_length = parse_segment_header(header_data, self._segment_number + 1) + self._message_offset += self._segment_header_length + + self._segment_content_offset = 0 + self._segment_crc64 = 0 + + async def _read_segment_footer(self) -> None: + if StructuredMessageProperties.CRC64 in self.flags: + segment_crc = await self._read_from_inner(StructuredMessageConstants.CRC64_LENGTH) + + if self._segment_crc64 != int.from_bytes(segment_crc, "little"): + raise ValueError( + f"CRC64 mismatch detected in segment {self._segment_number}. " + f"All data read should be considered invalid." + ) + + self._message_offset += self._segment_footer_length diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py new file mode 100644 index 000000000000..21d3b081d8cc --- /dev/null +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/validation.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# pylint: disable=c-extension-no-member + +import hashlib +from io import SEEK_SET +from typing import IO, Literal, Optional, Union, cast + +CRC64_LENGTH = 8 +CV_TYPE_ERROR_MSG = "Data should be bytes or seekable IO[bytes] for content validation." + +_VALID_CV_OPTIONS = ("auto", "crc64", "crc64-sm", "md5") + +CV_TYPE = Optional[Union[bool, Literal["auto", "crc64", "md5"]]] +CV_TYPE_PARSED = Optional[Union[bool, Literal["crc64", "crc64-sm", "md5"]]] + + +def _verify_extensions(module: str) -> None: + try: + import azure.storage.extensions # pylint: disable=unused-import + except ImportError as exc: + raise ValueError( + f"The use of {module} requires the azure-storage-extensions package to be installed. " + f"Please install this package and try again." + ) from exc + + +def parse_validation_option( + validate_content: CV_TYPE, + *, + force_structured_message: bool = False, +) -> CV_TYPE_PARSED: + if validate_content is None: + return None + + # Legacy support for bool + if isinstance(validate_content, bool): + return validate_content + + parsed = validate_content.lower() + if parsed not in _VALID_CV_OPTIONS: + raise ValueError("Invalid value for `validate_content` specified.") + + # Resolve auto + if parsed == "auto": + parsed = "crc64" + + if parsed == "crc64": + _verify_extensions("crc64") + if force_structured_message: + parsed = "crc64-sm" + + return cast(CV_TYPE_PARSED, parsed) + + +def is_md5_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return validate_content + return validate_content == "md5" + + +def is_crc64_validation( + validate_content: CV_TYPE_PARSED, +) -> bool: + if validate_content is None: + return False + if isinstance(validate_content, bool): + return False + return validate_content in ("crc64", "crc64-sm") + + +def calculate_content_md5(data: Union[bytes, IO[bytes]]) -> bytes: + md5 = hashlib.md5() # nosec + if isinstance(data, bytes): + md5.update(data) + elif hasattr(data, "read"): + pos = 0 + try: + pos = data.tell() + except: # pylint: disable=bare-except + pass + for chunk in iter(lambda: data.read(4096), b""): + md5.update(chunk) + try: + data.seek(pos, SEEK_SET) + except (AttributeError, IOError) as exc: + raise ValueError(CV_TYPE_ERROR_MSG) from exc + else: + raise ValueError(CV_TYPE_ERROR_MSG) + + return md5.digest() + + +def calculate_crc64(data: bytes, initial_crc: int) -> int: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(int, crc64.compute(data, initial_crc)) + + +def calculate_crc64_bytes(data: bytes) -> bytes: + # Locally import to avoid error if not installed. + from azure.storage.extensions import crc64 + + return cast(bytes, crc64.compute(data, 0).to_bytes(CRC64_LENGTH, "little")) diff --git a/sdk/storage/azure-storage-queue/dev_requirements.txt b/sdk/storage/azure-storage-queue/dev_requirements.txt index c9cf9f09f32d..b8770ca0db3a 100644 --- a/sdk/storage/azure-storage-queue/dev_requirements.txt +++ b/sdk/storage/azure-storage-queue/dev_requirements.txt @@ -1,5 +1,6 @@ -e ../../../eng/tools/azure-sdk-tools ../../core/azure-core ../../identity/azure-identity +../azure-storage-extensions azure-mgmt-storage==20.1.0 aiohttp>=3.13.5