From 661591f855928a7db42de4a918c0a1e2bcbee44a Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Tue, 19 May 2026 18:54:15 -0500 Subject: [PATCH 1/3] utf8 bug fix - initial commit --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 + .../azure/cosmos/_response_decoding.py | 120 ++++++++++++ .../azure/cosmos/_synchronized_request.py | 8 +- .../azure/cosmos/aio/_asynchronous_request.py | 8 +- .../tests/test_content_length_encoding.py | 169 +++++++++++++++++ .../tests/test_response_decoding.py | 174 ++++++++++++++++++ 6 files changed, 477 insertions(+), 4 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_response_decoding.py diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 187cd853cb27..e5c5332a6d8e 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -7,6 +7,8 @@ #### Breaking Changes #### Bugs Fixed +* Fixed bug where the `Content-Length` HTTP request header was computed from the character count of the request body instead of its UTF-8 byte count. +* Added an opt-in fallback for invalid UTF-8 in response bodies. Default behavior is unchanged (strict decode). Setting `COSMOS.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT` to `REPLACE` or `IGNORE` at process start enables a permissive decode so reads, queries, and change-feed iteration can make progress past corrupt payloads. * Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py new file mode 100644 index 000000000000..dfd25fccda49 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py @@ -0,0 +1,120 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""UTF-8 decoding for HTTP response bodies, with an opt-in fallback for +payloads containing bytes that are not valid UTF-8. + +By default this module preserves the historical SDK behavior: strict +decode, ``UnicodeDecodeError`` raised on the first invalid byte. +Operators who need to read past corrupt payloads (for example, to +unblock a stuck change-feed processor) can opt in to a permissive +fallback by setting an environment variable at process start. + +The recognized environment variable is +``COSMOS.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT``: + +* ``REPLACE`` -> Python ``errors="replace"`` (substitute U+FFFD) +* ``IGNORE`` -> Python ``errors="ignore"`` (drop the bad bytes) +* anything else, including unset -> strict (raise on bad bytes) + +The value is read once at module import. Tests can call +``_reset_for_tests()`` to re-snapshot. +""" +import logging +import os +from typing import Optional + +__all__ = ["decode_response_body", "_reset_for_tests"] + +_MALFORMED_INPUT_ENV_VAR = "COSMOS.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" + +# Mapping from the recognized env var values to Python's bytes.decode +# `errors=` argument. Anything not in this mapping (including the env var +# being unset) resolves to strict decoding, which is the historical default. +_ENV_VALUE_TO_DECODE_ERRORS_MODE = { + "REPLACE": "replace", + "IGNORE": "ignore", +} + +_logger = logging.getLogger(__name__) + + +def _resolve_fallback_mode_from_env() -> Optional[str]: + """Reads the malformed-input env var and returns the Python decode + ``errors=`` mode to use as a fallback, or ``None`` if the operator + has not opted in (in which case strict decoding stays in effect).""" + raw_value = os.environ.get(_MALFORMED_INPUT_ENV_VAR) + if raw_value is None: + return None + return _ENV_VALUE_TO_DECODE_ERRORS_MODE.get(raw_value.strip().upper()) + + +# Snapshot at module import. The value is immutable after import unless +# `_reset_for_tests` is called. Reading a module-level string in CPython is +# atomic, so no lock is needed on the per-call read path. +_fallback_errors_mode: Optional[str] = _resolve_fallback_mode_from_env() + + +def _reset_for_tests() -> None: + """Re-reads the env var and refreshes the cached fallback mode. Tests + should call this after mutating ``os.environ`` so the next call to + ``decode_response_body`` sees the new value.""" + global _fallback_errors_mode # pylint: disable=global-statement + _fallback_errors_mode = _resolve_fallback_mode_from_env() + + +def decode_response_body(data: bytes, operation_context: Optional[str] = None) -> str: + """Decode an HTTP response body as UTF-8. + + The healthy path is strict decoding, identical in behavior and cost + to ``data.decode("utf-8")``. The slow path is taken only when the + payload contains bytes that are not valid UTF-8: + + * If the operator has opted in via the malformed-input env var, the + decode is retried in the configured permissive mode (``replace`` or + ``ignore``) and a WARNING is logged with the byte offset, the + decoder's reason, and the supplied operation context. + * Otherwise a ``UnicodeDecodeError`` is raised whose ``reason`` field + carries an actionable hint pointing the operator at the env var. + The original exception is preserved as ``__cause__``. + + :param data: Response body bytes. + :param operation_context: Optional short string identifying the call + site (for example, ``"read_item"`` or ``"query_items page"``); + included in the WARNING log line when permissive fallback fires. + :returns: The decoded string. + :raises UnicodeDecodeError: If the body contains invalid UTF-8 and + the operator has not opted in to a permissive fallback. + """ + try: + return data.decode("utf-8") + except UnicodeDecodeError as strict_error: + fallback_mode = _fallback_errors_mode + if fallback_mode is None: + hint = ( + "{original}; set environment variable " + "{env_var}=REPLACE (or IGNORE) to tolerate invalid UTF-8 " + "in Cosmos response bodies" + ).format( + original=strict_error.reason, + env_var=_MALFORMED_INPUT_ENV_VAR, + ) + raise UnicodeDecodeError( + strict_error.encoding, + strict_error.object, + strict_error.start, + strict_error.end, + hint, + ) from strict_error + + _logger.warning( + "Cosmos response body contained invalid UTF-8 at byte offset %d " + "(reason: %s); decoding with errors=%r per %s%s.", + strict_error.start, + strict_error.reason, + fallback_mode, + _MALFORMED_INPUT_ENV_VAR, + " (operation: {0})".format(operation_context) if operation_context else "", + ) + return data.decode("utf-8", errors=fallback_mode) + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 7b18f52da2e2..01ab8e4abe94 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -33,6 +33,7 @@ from ._availability_strategy_config import CrossRegionHedgingStrategy from ._availability_strategy_handler import execute_with_hedging from ._constants import _Constants +from ._response_decoding import decode_response_body from ._request_object import RequestObject from .documents import _OperationType @@ -177,7 +178,7 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin data = response.body() if data: - data = data.decode("utf-8") + data = decode_response_body(data, request_params.operation_type) if response.status_code == 404: raise exceptions.CosmosResourceNotFoundError(message=data, response=response) @@ -257,7 +258,10 @@ def SynchronizedRequest( """ request.data = _request_body_from_data(request_data) if request.data and isinstance(request.data, str): - request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data) + # Use UTF-8 byte length, not str length (code-point count), so the + # header matches the bytes the transport actually writes for any + # non-ASCII payload. + request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data.encode("utf-8")) elif request.data is None: request.headers[http_constants.HttpHeaders.ContentLength] = 0 diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index cdf64a7cca76..46459774c711 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -35,6 +35,7 @@ from .._availability_strategy_config import CrossRegionHedgingStrategy from .._constants import _Constants from .._request_object import RequestObject +from .._response_decoding import decode_response_body from .._synchronized_request import _request_body_from_data, _replace_url_prefix from ..documents import _OperationType @@ -141,7 +142,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p data = response.body() if data: - data = data.decode("utf-8") + data = decode_response_body(data, request_params.operation_type) if response.status_code == 404: raise exceptions.CosmosResourceNotFoundError(message=data, response=response) @@ -210,7 +211,10 @@ async def AsynchronousRequest( """ request.data = _request_body_from_data(request_data) if request.data and isinstance(request.data, str): - request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data) + # Use UTF-8 byte length, not str length (code-point count), so the + # header matches the bytes the transport actually writes for any + # non-ASCII payload. + request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data.encode("utf-8")) elif request.data is None: request.headers[http_constants.HttpHeaders.ContentLength] = 0 diff --git a/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py b/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py new file mode 100644 index 000000000000..ee8b921eb7eb --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py @@ -0,0 +1,169 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Regression tests for the Content-Length header computation. + +The SDK previously computed `Content-Length` from `len(request.data)` — +i.e. the number of Unicode code points in the JSON string — instead of +the UTF-8 byte length that actually goes on the wire. For any non-ASCII +payload that under-counted the body by the number of multi-byte +characters in it, which can cause downstream HTTP receivers to truncate +the body, reject the request, or mis-frame the next keep-alive request. + + +These tests exercise the exact arithmetic in both the sync and async +request paths via a minimal stand-in for the request object, so they +do not require a live Cosmos account. +""" +import unittest +from unittest import mock + +from azure.cosmos import _synchronized_request, http_constants +from azure.cosmos.aio import _asynchronous_request +from azure.cosmos.documents import _OperationType +from azure.cosmos.http_constants import HttpHeaders + + +def _set_content_length_like_sdk(body): + """Mirrors the post-fix code path in both `_synchronized_request.py` + and `aio/_asynchronous_request.py`. Kept in lock-step with those + call sites so this test fails if either one regresses to + `len(body)` on the str branch.""" + headers = {} + if body and isinstance(body, str): + headers[HttpHeaders.ContentLength] = len(body.encode("utf-8")) + elif body is None: + headers[HttpHeaders.ContentLength] = 0 + return headers + + +class TestContentLengthEncoding(unittest.TestCase): + + def test_ascii_body_byte_length_equals_char_length(self): + """Regression guard: ASCII-only bodies must continue to produce + the same `Content-Length` value as before the fix (the new and + old computations agree when every code point is one byte).""" + body = '{"id":"x","name":"hello"}' + headers = _set_content_length_like_sdk(body) + self.assertEqual(headers[HttpHeaders.ContentLength], len(body)) + self.assertEqual(headers[HttpHeaders.ContentLength], 25) + + def test_two_byte_character_adds_one_byte(self): + """`café` contains one 2-byte character (`é` → `\\xC3\\xA9`), + so the UTF-8 byte length must be `len(body) + 1`.""" + body = '{"name":"café"}' + headers = _set_content_length_like_sdk(body) + self.assertEqual(headers[HttpHeaders.ContentLength], len(body) + 1) + self.assertEqual( + headers[HttpHeaders.ContentLength], + len(body.encode("utf-8")), + ) + + def test_mixed_multibyte_characters(self): + """Accented (2-byte), CJK (3-byte), and emoji (4-byte) + characters together. The header must equal the UTF-8 byte + length, not the code-point count. This catches future + 'let's strip the encode call to save a microsecond' + regressions.""" + body = '{"a":"é","b":"日","c":"🎉"}' + headers = _set_content_length_like_sdk(body) + self.assertEqual( + headers[HttpHeaders.ContentLength], + len(body.encode("utf-8")), + ) + # And explicitly assert it differs from the buggy computation. + self.assertNotEqual(headers[HttpHeaders.ContentLength], len(body)) + + def test_none_body_is_zero(self): + """`None` → `Content-Length: 0`, unchanged by the fix.""" + headers = _set_content_length_like_sdk(None) + self.assertEqual(headers[HttpHeaders.ContentLength], 0) + + def test_empty_string_does_not_set_header(self): + """An empty string is falsy, so the str branch does not fire and + the `is None` branch does not fire either — header is left for + the transport to set (unchanged by the fix).""" + headers = _set_content_length_like_sdk("") + self.assertNotIn(HttpHeaders.ContentLength, headers) + + +class _DummyRequestParams: + def __init__(self): + self.availability_strategy = None + self.is_hedging_request = False + self.resource_type = http_constants.ResourceType.Document + self.operation_type = _OperationType.Create + self.retry_write = 0 + + +class _DummyGlobalEndpointManager: + @staticmethod + def is_per_partition_automatic_failover_enabled(): + return False + + +class _DummyRequest: + def __init__(self): + self.headers = {} + self.data = None + + +class TestContentLengthWiringSyncAndAsync(unittest.TestCase): + def test_sync_request_sets_utf8_byte_content_length(self): + params = _DummyRequestParams() + manager = _DummyGlobalEndpointManager() + request = _DummyRequest() + + captured = {} + + def _fake_execute(*args, **kwargs): + request_arg = args[6] + captured["content_length"] = request_arg.headers.get(HttpHeaders.ContentLength) + captured["body"] = request_arg.data + return {}, {} + + with mock.patch.object(_synchronized_request._retry_utility, "Execute", side_effect=_fake_execute): + _synchronized_request.SynchronizedRequest( + client=object(), + request_params=params, + global_endpoint_manager=manager, + connection_policy=object(), + pipeline_client=object(), + request=request, + request_data={"name": "café"}, + ) + + self.assertEqual(captured["content_length"], len(captured["body"].encode("utf-8"))) + + +class TestContentLengthWiringAsync(unittest.IsolatedAsyncioTestCase): + async def test_async_request_sets_utf8_byte_content_length(self): + params = _DummyRequestParams() + manager = _DummyGlobalEndpointManager() + request = _DummyRequest() + + captured = {} + + async def _fake_execute_async(*args, **kwargs): + request_arg = args[6] + captured["content_length"] = request_arg.headers.get(HttpHeaders.ContentLength) + captured["body"] = request_arg.data + return {}, {} + + with mock.patch.object(_asynchronous_request._retry_utility_async, "ExecuteAsync", side_effect=_fake_execute_async): + await _asynchronous_request.AsynchronousRequest( + client=object(), + request_params=params, + global_endpoint_manager=manager, + connection_policy=object(), + pipeline_client=object(), + request=request, + request_data={"name": "café"}, + ) + + self.assertEqual(captured["content_length"], len(captured["body"].encode("utf-8"))) + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py new file mode 100644 index 000000000000..e0866892895e --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py @@ -0,0 +1,174 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Tests for the response-body UTF-8 decode helper and its env-var-driven +fallback behavior. Covers the healthy path (strict decode succeeds), the +default behavior when the env var is unset (strict decode fails with an +actionable hint), and the opt-in REPLACE / IGNORE modes. +""" +import os +import unittest +from unittest import mock + +from azure.cosmos import _response_decoding + + +# A small payload containing one valid 2-byte UTF-8 sequence followed by a +# byte (0xC3 followed by 0x28) that is not a valid UTF-8 continuation byte. +# `\xC3\x28` is the textbook example of an invalid UTF-8 sequence. +_INVALID_UTF8 = b'{"note":"hello \xc3\x28 world"}' +_VALID_UTF8 = b'{"note":"hello world"}' + +_MALFORMED_INPUT_ENV_VAR = "COSMOS.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" + + +class _DecoderEnvIsolatedTestCase(unittest.TestCase): + """Base class that isolates each test from the surrounding process + environment and from sibling tests by restoring both the recognized + env var and the module-level fallback-mode cache after every test.""" + + def setUp(self): + # Snapshot the cache value and start an env patch that will + # roll back any mutations the test makes. + self._original_fallback_mode = _response_decoding._fallback_errors_mode + self._env_patch = mock.patch.dict(os.environ, {}, clear=False) + self._env_patch.start() + # Strip the env var for the duration of the test so the helper's + # default behavior (no env -> strict) is the explicit baseline. + # Tests that need a specific env value set it themselves and call + # `_reset_for_tests()`. + os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) + _response_decoding._reset_for_tests() + + def tearDown(self): + # `mock.patch.dict` rolls back any env mutations the test made, + # including the pop in setUp. + self._env_patch.stop() + # Restore the cache last, so the next test's setUp sees the + # same module state the suite started with. + _response_decoding._fallback_errors_mode = self._original_fallback_mode + + +class TestStrictDecodingHealthyPath(_DecoderEnvIsolatedTestCase): + + def test_valid_utf8_decodes_unchanged(self): + """The healthy path must produce exactly the same string as + ``bytes.decode('utf-8')``. This is the regression guard for + the 99.99% case where the body is well-formed.""" + result = _response_decoding.decode_response_body(_VALID_UTF8) + self.assertEqual(result, '{"note":"hello world"}') + + def test_empty_bytes_decodes_to_empty_string(self): + self.assertEqual(_response_decoding.decode_response_body(b""), "") + + +class TestStrictDecodingRaisesActionableError(_DecoderEnvIsolatedTestCase): + + def test_invalid_utf8_without_env_var_raises_with_hint(self): + """When the env var is unset (the historical default) the helper + must raise ``UnicodeDecodeError`` so existing call sites continue + to behave the same way. The hint in ``reason`` points the + operator at the env var name so they can self-serve.""" + # setUp already cleared the env var and reset the cache; the + # assertion below makes that contract explicit for the reader. + self.assertIsNone(_response_decoding._fallback_errors_mode) + + with self.assertRaises(UnicodeDecodeError) as ctx: + _response_decoding.decode_response_body(_INVALID_UTF8, operation_context="read_item") + + self.assertIn(_MALFORMED_INPUT_ENV_VAR, ctx.exception.reason) + # Original exception must be chained so callers and log readers + # can still see the underlying decoder error. + self.assertIsInstance(ctx.exception.__cause__, UnicodeDecodeError) + + +class TestPermissiveFallback(_DecoderEnvIsolatedTestCase): + """Exercises the decode behavior in each fallback mode by writing the + cache directly. """ + + def test_replace_mode_substitutes_replacement_character(self): + _response_decoding._fallback_errors_mode = "replace" + result = _response_decoding.decode_response_body(_INVALID_UTF8) + # The bad byte is replaced by U+FFFD; the surrounding text is preserved. + self.assertIn("\ufffd", result) + self.assertIn("hello", result) + self.assertIn("world", result) + + def test_ignore_mode_drops_bad_bytes(self): + _response_decoding._fallback_errors_mode = "ignore" + result = _response_decoding.decode_response_body(_INVALID_UTF8) + # No replacement character; the bad byte is silently dropped. + self.assertNotIn("\ufffd", result) + self.assertIn("hello", result) + self.assertIn("world", result) + + + +class TestEnvVarSnapshot(_DecoderEnvIsolatedTestCase): + """Tests for the env-var parser in isolation. Each test sets the + env var, calls ``_reset_for_tests()``, and asserts the cached + fallback mode matches the documented mapping.""" + + def test_replace_env_value_resolves_to_replace_mode(self): + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + _response_decoding._reset_for_tests() + self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + + def test_ignore_env_value_resolves_to_ignore_mode(self): + os.environ[_MALFORMED_INPUT_ENV_VAR] = "IGNORE" + _response_decoding._reset_for_tests() + self.assertEqual(_response_decoding._fallback_errors_mode, "ignore") + + def test_unknown_env_value_resolves_to_strict(self): + """Anything other than REPLACE / IGNORE (case-insensitive) must + leave strict decoding in effect.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "BOGUS" + _response_decoding._reset_for_tests() + self.assertIsNone(_response_decoding._fallback_errors_mode) + + def test_env_value_is_case_insensitive_and_trims_whitespace(self): + os.environ[_MALFORMED_INPUT_ENV_VAR] = " replace " + _response_decoding._reset_for_tests() + self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + + def test_unset_env_resolves_to_strict(self): + # setUp already pops the var; this makes the contract explicit. + self.assertNotIn(_MALFORMED_INPUT_ENV_VAR, os.environ) + _response_decoding._reset_for_tests() + self.assertIsNone(_response_decoding._fallback_errors_mode) + + +class TestEnvVarToBehaviorEndToEnd(_DecoderEnvIsolatedTestCase): + """Verifies the full contract: setting the env var and calling + ``_reset_for_tests()`` actually changes what ``decode_response_body`` + does. Catches regressions where the env parser and the per-call + read drift apart.""" + + def test_setting_replace_env_var_makes_invalid_utf8_decode_succeed(self): + # Baseline: with no env var, the same input raises. + with self.assertRaises(UnicodeDecodeError): + _response_decoding.decode_response_body(_INVALID_UTF8) + + # Opt in via the env var, refresh, and prove the same input now + # decodes to a replacement-character-bearing string instead of raising. + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + _response_decoding._reset_for_tests() + result = _response_decoding.decode_response_body(_INVALID_UTF8) + self.assertIn("\ufffd", result) + + def test_clearing_env_var_after_reset_returns_to_strict(self): + # Opt in. + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + _response_decoding._reset_for_tests() + self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + + # Opt out by removing the var and re-snapshotting. + del os.environ[_MALFORMED_INPUT_ENV_VAR] + _response_decoding._reset_for_tests() + with self.assertRaises(UnicodeDecodeError): + _response_decoding.decode_response_body(_INVALID_UTF8) + + +if __name__ == "__main__": + unittest.main() + From baef4893011df84904c0414fb2d37f8f4b58850c Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Tue, 19 May 2026 20:15:17 -0500 Subject: [PATCH 2/3] addressing copilot comments --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 4 +- .../azure-cosmos/azure/cosmos/_constants.py | 6 ++ .../azure/cosmos/_inference_service.py | 3 +- .../azure/cosmos/_response_decoding.py | 6 +- .../cosmos/aio/_inference_service_async.py | 3 +- .../tests/test_response_decoding.py | 3 +- .../tests/test_semantic_reranker_unit.py | 56 +++++++++++++++++++ 7 files changed, 74 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index e5c5332a6d8e..3269a2e1ecce 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -7,8 +7,8 @@ #### Breaking Changes #### Bugs Fixed -* Fixed bug where the `Content-Length` HTTP request header was computed from the character count of the request body instead of its UTF-8 byte count. -* Added an opt-in fallback for invalid UTF-8 in response bodies. Default behavior is unchanged (strict decode). Setting `COSMOS.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT` to `REPLACE` or `IGNORE` at process start enables a permissive decode so reads, queries, and change-feed iteration can make progress past corrupt payloads. +* Fixed bug where the `Content-Length` HTTP request header was computed from the character count of the request body instead of its UTF-8 byte count. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008) +* Added an opt-in fallback for invalid UTF-8 in response bodies. Default behavior is unchanged (strict decode). Setting `AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT` to `REPLACE` or `IGNORE` at process start enables a permissive decode so reads, queries, and change-feed iteration can make progress past corrupt payloads. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008) * Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 5338ea116340..73ba0649a859 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -90,6 +90,12 @@ class _Constants: TIMEOUT_ERROR_THRESHOLD_PPAF_DEFAULT: int = 10 # ------------------------------------------------------------------------- + # Controls how the SDK handles invalid UTF-8 bytes in HTTP response bodies. + # Accepted values: "REPLACE", "IGNORE". Anything else (including unset) + # leaves strict decoding in effect, which is the historical default. + CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT: str = \ + "AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" + # Error code translations ERROR_TRANSLATIONS: dict[int, str] = { 400: "BAD_REQUEST - Request being sent is invalid.", diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py index 29e49217fb02..7734aeba5b29 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py @@ -38,6 +38,7 @@ from ._cosmos_http_logging_policy import CosmosHttpLoggingPolicy from ._cosmos_responses import CosmosDict from ._inference_auth_policy import InferenceServiceBearerTokenPolicy +from ._response_decoding import decode_response_body from ._retry_utility import ConnectionRetryPolicy from .http_constants import HttpHeaders @@ -202,7 +203,7 @@ def rerank( data = response.body() if data: - data = data.decode("utf-8") + data = decode_response_body(data, "inference_request") if response.status_code >= 400: raise exceptions.CosmosHttpResponseError(message=data, response=response) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py index dfd25fccda49..1299e2875fd8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py @@ -11,7 +11,7 @@ fallback by setting an environment variable at process start. The recognized environment variable is -``COSMOS.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT``: +``AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT``: * ``REPLACE`` -> Python ``errors="replace"`` (substitute U+FFFD) * ``IGNORE`` -> Python ``errors="ignore"`` (drop the bad bytes) @@ -24,9 +24,11 @@ import os from typing import Optional +from ._constants import _Constants + __all__ = ["decode_response_body", "_reset_for_tests"] -_MALFORMED_INPUT_ENV_VAR = "COSMOS.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" +_MALFORMED_INPUT_ENV_VAR = _Constants.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT # Mapping from the recognized env var values to Python's bytes.decode # `errors=` argument. Anything not in this mapping (including the env var diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py index 05ec30bec755..a6ea3a095bb2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py @@ -40,6 +40,7 @@ from .._constants import _Constants as Constants from .._cosmos_http_logging_policy import CosmosHttpLoggingPolicy from .._cosmos_responses import CosmosDict +from .._response_decoding import decode_response_body from ..http_constants import HttpHeaders @@ -235,7 +236,7 @@ async def rerank( data = response.body() if data: - data = data.decode("utf-8") + data = decode_response_body(data, "inference_request") if response.status_code >= 400: raise exceptions.CosmosHttpResponseError(message=data, response=response) diff --git a/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py index e0866892895e..f2cc6aecad82 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py +++ b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py @@ -6,6 +6,7 @@ default behavior when the env var is unset (strict decode fails with an actionable hint), and the opt-in REPLACE / IGNORE modes. """ +# cspell:ignore ufffd import os import unittest from unittest import mock @@ -19,7 +20,7 @@ _INVALID_UTF8 = b'{"note":"hello \xc3\x28 world"}' _VALID_UTF8 = b'{"note":"hello world"}' -_MALFORMED_INPUT_ENV_VAR = "COSMOS.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" +_MALFORMED_INPUT_ENV_VAR = "AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" class _DecoderEnvIsolatedTestCase(unittest.TestCase): diff --git a/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py b/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py index 9c2a2e034d01..ce3bcb6f78a1 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py @@ -155,6 +155,62 @@ async def run_test(): asyncio.run(run_test()) + def test_sync_inference_uses_shared_response_decoder(self): + """Test that sync inference service decodes response bytes via decode_response_body.""" + from azure.cosmos._inference_service import _InferenceService + + mock_connection = self._create_mock_connection() + service = _InferenceService(mock_connection) + + raw_response_data = b'{"Scores": []}' + mock_response = MagicMock() + mock_response.http_response.status_code = 200 + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = raw_response_data + + with patch.object( + service._inference_pipeline_client._pipeline, "run", + return_value=mock_response + ), patch( + "azure.cosmos._inference_service.decode_response_body", + return_value='{"Scores": []}' + ) as mock_decode: + service.rerank( + reranking_context="test query", + documents=["doc1"] + ) + mock_decode.assert_called_once_with(raw_response_data, "inference_request") + + def test_async_inference_uses_shared_response_decoder(self): + """Test that async inference service decodes response bytes via decode_response_body.""" + async def run_test(): + from azure.cosmos.aio._inference_service_async import _InferenceService + + mock_connection = self._create_mock_connection() + mock_connection.connection_policy.DisableSSLVerification = False + service = _InferenceService(mock_connection) + + raw_response_data = b'{"Scores": []}' + mock_response = MagicMock() + mock_response.http_response.status_code = 200 + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = raw_response_data + + with patch.object( + service._inference_pipeline_client._pipeline, "run", + return_value=mock_response + ), patch( + "azure.cosmos.aio._inference_service_async.decode_response_body", + return_value='{"Scores": []}' + ) as mock_decode: + await service.rerank( + reranking_context="test query", + documents=["doc1"] + ) + mock_decode.assert_called_once_with(raw_response_data, "inference_request") + + asyncio.run(run_test()) + def test_sync_inference_response_timeout_raises_408(self): """Test that sync inference service converts ServiceResponseError to 408.""" from azure.cosmos._inference_service import _InferenceService From 175153b3d207e235e6345fd6ad22483df19f3917 Mon Sep 17 00:00:00 2001 From: dibahlfi <106994927+dibahlfi@users.noreply.github.com> Date: Wed, 20 May 2026 18:22:07 -0500 Subject: [PATCH 3/3] refactoring tests --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- .../azure/cosmos/_inference_service.py | 6 +- .../azure/cosmos/_response_decoding.py | 76 ++++-- .../azure/cosmos/_synchronized_request.py | 6 +- .../azure/cosmos/aio/_asynchronous_request.py | 6 +- .../cosmos/aio/_inference_service_async.py | 6 +- .../tests/test_content_length_encoding.py | 180 +++++++------ .../tests/test_request_response_decoding.py | 189 ++++++++++++++ .../tests/test_response_decoding.py | 240 ++++++++++++++---- .../tests/test_semantic_reranker_unit.py | 15 +- 10 files changed, 564 insertions(+), 162 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_request_response_decoding.py diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 3269a2e1ecce..7147220b9280 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -8,7 +8,7 @@ #### Bugs Fixed * Fixed bug where the `Content-Length` HTTP request header was computed from the character count of the request body instead of its UTF-8 byte count. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008) -* Added an opt-in fallback for invalid UTF-8 in response bodies. Default behavior is unchanged (strict decode). Setting `AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT` to `REPLACE` or `IGNORE` at process start enables a permissive decode so reads, queries, and change-feed iteration can make progress past corrupt payloads. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008) +* Added an opt-in fallback for invalid UTF-8 in response bodies. Default behavior is unchanged (strict decode). Setting `AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT` to `REPLACE` or `IGNORE` enables a permissive decode so reads, queries, and change-feed iteration can make progress past corrupt payloads. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008) * Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py index 7734aeba5b29..44d23e62a2b8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py @@ -38,7 +38,7 @@ from ._cosmos_http_logging_policy import CosmosHttpLoggingPolicy from ._cosmos_responses import CosmosDict from ._inference_auth_policy import InferenceServiceBearerTokenPolicy -from ._response_decoding import decode_response_body +from ._response_decoding import decode_response_body_for_status from ._retry_utility import ConnectionRetryPolicy from .http_constants import HttpHeaders @@ -203,7 +203,9 @@ def rerank( data = response.body() if data: - data = decode_response_body(data, "inference_request") + data = decode_response_body_for_status( + data, response.status_code, "inference_request" + ) if response.status_code >= 400: raise exceptions.CosmosHttpResponseError(message=data, response=response) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py index 1299e2875fd8..58979a1f2bb1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py @@ -8,7 +8,7 @@ decode, ``UnicodeDecodeError`` raised on the first invalid byte. Operators who need to read past corrupt payloads (for example, to unblock a stuck change-feed processor) can opt in to a permissive -fallback by setting an environment variable at process start. +fallback by setting an environment variable. The recognized environment variable is ``AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT``: @@ -17,8 +17,10 @@ * ``IGNORE`` -> Python ``errors="ignore"`` (drop the bad bytes) * anything else, including unset -> strict (raise on bad bytes) -The value is read once at module import. Tests can call -``_reset_for_tests()`` to re-snapshot. +The env var is consulted only on the decode-failure path, so operators +can set or change it at any point during process lifetime and the next +malformed payload will pick up the new value. This follows the Cosmos +SDK's runtime-read pattern for environment-based controls. """ import logging import os @@ -26,7 +28,6 @@ from ._constants import _Constants -__all__ = ["decode_response_body", "_reset_for_tests"] _MALFORMED_INPUT_ENV_VAR = _Constants.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT @@ -51,19 +52,6 @@ def _resolve_fallback_mode_from_env() -> Optional[str]: return _ENV_VALUE_TO_DECODE_ERRORS_MODE.get(raw_value.strip().upper()) -# Snapshot at module import. The value is immutable after import unless -# `_reset_for_tests` is called. Reading a module-level string in CPython is -# atomic, so no lock is needed on the per-call read path. -_fallback_errors_mode: Optional[str] = _resolve_fallback_mode_from_env() - - -def _reset_for_tests() -> None: - """Re-reads the env var and refreshes the cached fallback mode. Tests - should call this after mutating ``os.environ`` so the next call to - ``decode_response_body`` sees the new value.""" - global _fallback_errors_mode # pylint: disable=global-statement - _fallback_errors_mode = _resolve_fallback_mode_from_env() - def decode_response_body(data: bytes, operation_context: Optional[str] = None) -> str: """Decode an HTTP response body as UTF-8. @@ -81,17 +69,20 @@ def decode_response_body(data: bytes, operation_context: Optional[str] = None) - The original exception is preserved as ``__cause__``. :param data: Response body bytes. + :type data: bytes :param operation_context: Optional short string identifying the call site (for example, ``"read_item"`` or ``"query_items page"``); included in the WARNING log line when permissive fallback fires. + :type operation_context: Optional[str] :returns: The decoded string. + :rtype: str :raises UnicodeDecodeError: If the body contains invalid UTF-8 and the operator has not opted in to a permissive fallback. """ try: return data.decode("utf-8") except UnicodeDecodeError as strict_error: - fallback_mode = _fallback_errors_mode + fallback_mode = _resolve_fallback_mode_from_env() if fallback_mode is None: hint = ( "{original}; set environment variable " @@ -111,12 +102,57 @@ def decode_response_body(data: bytes, operation_context: Optional[str] = None) - _logger.warning( "Cosmos response body contained invalid UTF-8 at byte offset %d " - "(reason: %s); decoding with errors=%r per %s%s.", + "(reason: %s); decoding with errors=%r per %s (operation: %s).", strict_error.start, strict_error.reason, fallback_mode, _MALFORMED_INPUT_ENV_VAR, - " (operation: {0})".format(operation_context) if operation_context else "", + operation_context or "-", ) return data.decode("utf-8", errors=fallback_mode) + +def decode_response_body_for_status( + data: bytes, + status_code: int, + operation_context: Optional[str] = None, +) -> str: + """Decode an HTTP response body, with a best-effort fallback for HTTP + error responses whose body happens to contain invalid UTF-8. + + Behaves exactly like :func:`decode_response_body` on success and on + 2xx responses with malformed UTF-8. The difference is the error path: + if strict decoding fails AND the response is an HTTP error + (``status_code >= 400``), the body is decoded with ``errors="replace"`` + so the caller can still construct the real status-code exception + (``CosmosResourceNotFoundError``, ``CosmosHttpResponseError``, etc.). + + The reason: the SDK's retry/refresh logic and customer error handlers + branch on status code, not on message contents. Masking a 404, 410 + (partition split), 429 (throttle), or 503 with a ``UnicodeDecodeError`` + breaks recovery paths that would otherwise have worked. ``U+FFFD`` in + an error message is acceptable; a wrong exception class is not. + + For 2xx responses with malformed UTF-8 the exception is still raised — + a successful response carrying corrupt bytes is a real data-integrity + problem the caller needs to see. + + :param data: Response body bytes. + :type data: bytes + :param status_code: The HTTP status code of the response. + :type status_code: int + :param operation_context: Optional short string identifying the call + site; forwarded to :func:`decode_response_body`. + :type operation_context: Optional[str] + :returns: The decoded string. + :rtype: str + :raises UnicodeDecodeError: If the body contains invalid UTF-8, the + operator has not opted in to a permissive fallback, and the + response is a success (2xx/3xx) rather than an HTTP error. + """ + try: + return decode_response_body(data, operation_context) + except UnicodeDecodeError: + if status_code >= 400: + return data.decode("utf-8", errors="replace") + raise diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 01ab8e4abe94..456323d18876 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -33,7 +33,7 @@ from ._availability_strategy_config import CrossRegionHedgingStrategy from ._availability_strategy_handler import execute_with_hedging from ._constants import _Constants -from ._response_decoding import decode_response_body +from ._response_decoding import decode_response_body_for_status from ._request_object import RequestObject from .documents import _OperationType @@ -178,7 +178,9 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin data = response.body() if data: - data = decode_response_body(data, request_params.operation_type) + data = decode_response_body_for_status( + data, response.status_code, request_params.operation_type + ) if response.status_code == 404: raise exceptions.CosmosResourceNotFoundError(message=data, response=response) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 46459774c711..2a712c245ebf 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -35,7 +35,7 @@ from .._availability_strategy_config import CrossRegionHedgingStrategy from .._constants import _Constants from .._request_object import RequestObject -from .._response_decoding import decode_response_body +from .._response_decoding import decode_response_body_for_status from .._synchronized_request import _request_body_from_data, _replace_url_prefix from ..documents import _OperationType @@ -142,7 +142,9 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p data = response.body() if data: - data = decode_response_body(data, request_params.operation_type) + data = decode_response_body_for_status( + data, response.status_code, request_params.operation_type + ) if response.status_code == 404: raise exceptions.CosmosResourceNotFoundError(message=data, response=response) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py index a6ea3a095bb2..7e62fa8177bb 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py @@ -40,7 +40,7 @@ from .._constants import _Constants as Constants from .._cosmos_http_logging_policy import CosmosHttpLoggingPolicy from .._cosmos_responses import CosmosDict -from .._response_decoding import decode_response_body +from .._response_decoding import decode_response_body_for_status from ..http_constants import HttpHeaders @@ -236,7 +236,9 @@ async def rerank( data = response.body() if data: - data = decode_response_body(data, "inference_request") + data = decode_response_body_for_status( + data, response.status_code, "inference_request" + ) if response.status_code >= 400: raise exceptions.CosmosHttpResponseError(message=data, response=response) diff --git a/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py b/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py index ee8b921eb7eb..569eec222246 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py +++ b/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py @@ -3,17 +3,22 @@ """Regression tests for the Content-Length header computation. -The SDK previously computed `Content-Length` from `len(request.data)` — -i.e. the number of Unicode code points in the JSON string — instead of -the UTF-8 byte length that actually goes on the wire. For any non-ASCII +The SDK previously computed ``Content-Length`` from ``len(request.data)`` — +the number of Unicode code points in the JSON string — instead of the +UTF-8 byte length that actually goes on the wire. For any non-ASCII payload that under-counted the body by the number of multi-byte -characters in it, which can cause downstream HTTP receivers to truncate -the body, reject the request, or mis-frame the next keep-alive request. - - -These tests exercise the exact arithmetic in both the sync and async -request paths via a minimal stand-in for the request object, so they -do not require a live Cosmos account. +characters, which can cause downstream HTTP receivers to truncate the +body, reject the request, or mis-frame the next keep-alive request. + +Every assertion in this file exercises the actual production code path +in ``_synchronized_request.SynchronizedRequest`` or +``_asynchronous_request.AsynchronousRequest`` by patching the retry +layer and inspecting the request object the SDK would have put on the +wire. A previous iteration of this file also contained "mirror" tests +that re-implemented the production formula locally — those have been +removed because they could not catch a production regression (they +only verified that ``len(s.encode("utf-8"))`` works, which is a Python +built-in). """ import unittest from unittest import mock @@ -24,67 +29,28 @@ from azure.cosmos.http_constants import HttpHeaders -def _set_content_length_like_sdk(body): - """Mirrors the post-fix code path in both `_synchronized_request.py` - and `aio/_asynchronous_request.py`. Kept in lock-step with those - call sites so this test fails if either one regresses to - `len(body)` on the str branch.""" - headers = {} - if body and isinstance(body, str): - headers[HttpHeaders.ContentLength] = len(body.encode("utf-8")) - elif body is None: - headers[HttpHeaders.ContentLength] = 0 - return headers - - -class TestContentLengthEncoding(unittest.TestCase): - - def test_ascii_body_byte_length_equals_char_length(self): - """Regression guard: ASCII-only bodies must continue to produce - the same `Content-Length` value as before the fix (the new and - old computations agree when every code point is one byte).""" - body = '{"id":"x","name":"hello"}' - headers = _set_content_length_like_sdk(body) - self.assertEqual(headers[HttpHeaders.ContentLength], len(body)) - self.assertEqual(headers[HttpHeaders.ContentLength], 25) - - def test_two_byte_character_adds_one_byte(self): - """`café` contains one 2-byte character (`é` → `\\xC3\\xA9`), - so the UTF-8 byte length must be `len(body) + 1`.""" - body = '{"name":"café"}' - headers = _set_content_length_like_sdk(body) - self.assertEqual(headers[HttpHeaders.ContentLength], len(body) + 1) - self.assertEqual( - headers[HttpHeaders.ContentLength], - len(body.encode("utf-8")), - ) - - def test_mixed_multibyte_characters(self): - """Accented (2-byte), CJK (3-byte), and emoji (4-byte) - characters together. The header must equal the UTF-8 byte - length, not the code-point count. This catches future - 'let's strip the encode call to save a microsecond' - regressions.""" - body = '{"a":"é","b":"日","c":"🎉"}' - headers = _set_content_length_like_sdk(body) - self.assertEqual( - headers[HttpHeaders.ContentLength], - len(body.encode("utf-8")), - ) - # And explicitly assert it differs from the buggy computation. - self.assertNotEqual(headers[HttpHeaders.ContentLength], len(body)) - - def test_none_body_is_zero(self): - """`None` → `Content-Length: 0`, unchanged by the fix.""" - headers = _set_content_length_like_sdk(None) - self.assertEqual(headers[HttpHeaders.ContentLength], 0) - - def test_empty_string_does_not_set_header(self): - """An empty string is falsy, so the str branch does not fire and - the `is None` branch does not fire either — header is left for - the transport to set (unchanged by the fix).""" - headers = _set_content_length_like_sdk("") - self.assertNotIn(HttpHeaders.ContentLength, headers) +# Payload matrix covering the four interesting char-vs-byte cases. Each +# str payload is named by the most divergent character it contains. +# The 4-byte emoji case maximizes the difference between ``len(s)`` +# (the old, buggy formula) and ``len(s.encode("utf-8"))`` (the new, +# correct formula), so a regression that reverts the fix will fail +# loudest on that case. +# +# Subtlety on why the payloads are pre-serialized JSON strings rather +# than dicts: the SDK's ``_request_body_from_data`` uses +# ``json.dumps(data, separators=(",", ":"))`` with default +# ``ensure_ascii=True``. That means dicts containing multi-byte chars +# get escaped to pure ASCII (e.g. ``"é"`` -> ``"\\u00e9"``) *before* +# Content-Length is computed — the byte-length code path is never +# actually exercised. The path that matters is when a customer passes +# a pre-serialized string, which ``_request_body_from_data`` returns +# unchanged. That is the path these tests exercise. +_STR_PAYLOADS = [ + ("ascii_baseline", '{"name":"hello"}'), # 1 byte per char + ("two_byte_latin", '{"name":"café"}'), # 2-byte 'é' + ("three_byte_cjk", '{"name":"日本"}'), # 3-byte CJK + ("four_byte_emoji", '{"name":"🎉🎊"}'), # 4-byte emoji +] class _DummyRequestParams: @@ -108,8 +74,12 @@ def __init__(self): self.data = None -class TestContentLengthWiringSyncAndAsync(unittest.TestCase): - def test_sync_request_sets_utf8_byte_content_length(self): +class TestContentLengthWiringSync(unittest.TestCase): + """Sync path: ``SynchronizedRequest`` → ``Execute`` should produce a + request whose ``Content-Length`` header equals the UTF-8 byte count + of the serialized body (not the code-point count).""" + + def _capture_outgoing_request(self, request_data): params = _DummyRequestParams() manager = _DummyGlobalEndpointManager() request = _DummyRequest() @@ -122,7 +92,9 @@ def _fake_execute(*args, **kwargs): captured["body"] = request_arg.data return {}, {} - with mock.patch.object(_synchronized_request._retry_utility, "Execute", side_effect=_fake_execute): + with mock.patch.object( + _synchronized_request._retry_utility, "Execute", side_effect=_fake_execute + ): _synchronized_request.SynchronizedRequest( client=object(), request_params=params, @@ -130,14 +102,40 @@ def _fake_execute(*args, **kwargs): connection_policy=object(), pipeline_client=object(), request=request, - request_data={"name": "café"}, + request_data=request_data, ) - - self.assertEqual(captured["content_length"], len(captured["body"].encode("utf-8"))) + return captured + + def test_str_bodies_set_utf8_byte_content_length(self): + """For each payload in the byte-divergence matrix, the + ``Content-Length`` header the SDK puts on the wire must equal + the UTF-8 byte count of the JSON-serialized body. For the emoji + case in particular this exceeds the code-point count by 3×.""" + for label, payload in _STR_PAYLOADS: + with self.subTest(payload=label): + captured = self._capture_outgoing_request(payload) + body = captured["body"] + self.assertIsInstance(body, str) + expected_bytes = len(body.encode("utf-8")) + self.assertEqual(captured["content_length"], expected_bytes) + # Explicitly assert the value differs from the buggy + # formula for the multi-byte cases. ASCII is excluded + # because for ASCII both formulas agree. + if label != "ascii_baseline": + self.assertNotEqual(captured["content_length"], len(body)) + + def test_none_body_sets_content_length_zero(self): + """Covers the ``elif body is None`` branch in production: a + request with no body should still get ``Content-Length: 0``.""" + captured = self._capture_outgoing_request(None) + self.assertEqual(captured["content_length"], 0) class TestContentLengthWiringAsync(unittest.IsolatedAsyncioTestCase): - async def test_async_request_sets_utf8_byte_content_length(self): + """Async path: same contract as the sync test class, routed through + ``AsynchronousRequest`` → ``ExecuteAsync``.""" + + async def _capture_outgoing_request(self, request_data): params = _DummyRequestParams() manager = _DummyGlobalEndpointManager() request = _DummyRequest() @@ -150,7 +148,11 @@ async def _fake_execute_async(*args, **kwargs): captured["body"] = request_arg.data return {}, {} - with mock.patch.object(_asynchronous_request._retry_utility_async, "ExecuteAsync", side_effect=_fake_execute_async): + with mock.patch.object( + _asynchronous_request._retry_utility_async, + "ExecuteAsync", + side_effect=_fake_execute_async, + ): await _asynchronous_request.AsynchronousRequest( client=object(), request_params=params, @@ -158,10 +160,24 @@ async def _fake_execute_async(*args, **kwargs): connection_policy=object(), pipeline_client=object(), request=request, - request_data={"name": "café"}, + request_data=request_data, ) - - self.assertEqual(captured["content_length"], len(captured["body"].encode("utf-8"))) + return captured + + async def test_str_bodies_set_utf8_byte_content_length(self): + for label, payload in _STR_PAYLOADS: + with self.subTest(payload=label): + captured = await self._capture_outgoing_request(payload) + body = captured["body"] + self.assertIsInstance(body, str) + expected_bytes = len(body.encode("utf-8")) + self.assertEqual(captured["content_length"], expected_bytes) + if label != "ascii_baseline": + self.assertNotEqual(captured["content_length"], len(body)) + + async def test_none_body_sets_content_length_zero(self): + captured = await self._capture_outgoing_request(None) + self.assertEqual(captured["content_length"], 0) if __name__ == "__main__": diff --git a/sdk/cosmos/azure-cosmos/tests/test_request_response_decoding.py b/sdk/cosmos/azure-cosmos/tests/test_request_response_decoding.py new file mode 100644 index 000000000000..5389c8025dc1 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_request_response_decoding.py @@ -0,0 +1,189 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Wiring tests for the response-body decode call in the core sync/async +``_Request()`` paths. + +Background: ``_synchronized_request._Request`` and +``_asynchronous_request._Request`` are the highest-traffic code paths in +the SDK — every CRUD, query, and change-feed read flows through them. +Both call ``decode_response_body_for_status`` to decode the HTTP +response body before the status-code branching that builds typed +``CosmosResourceNotFoundError`` / ``CosmosHttpResponseError`` exceptions. + +These tests lock in two contracts: + +1. **Wiring** — the call sites actually invoke the shared decoder. If + someone reverts the call back to ``data.decode("utf-8")`` we want a + unit test to fail immediately. Mirrors the wiring tests already in + place for the inference service in ``test_semantic_reranker_unit``. + +2. **Behavior** — when an HTTP error response body contains invalid + UTF-8, ``_Request`` surfaces the real typed exception + (``CosmosResourceNotFoundError`` etc.) instead of letting a + ``UnicodeDecodeError`` escape. This is the property the + ``decode_response_body_for_status`` helper was introduced to + guarantee; without an end-to-end test on ``_Request``, the helper + could quietly fall out of use and the regression would not be + caught at unit-test time. +""" +import asyncio +import unittest +from unittest.mock import MagicMock, patch, AsyncMock + +from azure.cosmos import _synchronized_request, exceptions +from azure.cosmos.aio import _asynchronous_request +from azure.cosmos.http_constants import ResourceType + + +# Same invalid UTF-8 used in test_response_decoding.py +_INVALID_UTF8 = b'{"note":"hello \xc3\x28 world"}' +_VALID_UTF8 = b'{"ok":true}' + +_FAKE_ENDPOINT = "https://example.documents.azure.com:443/" + + +def _build_request_args(status_code: int, body: bytes): + """Build the minimal set of mocked dependencies ``_Request`` needs to + reach the decode call site. Returns (args_tuple, mock_response).""" + # ``endpoint_override`` short-circuits endpoint resolution so we do + # not need a real GlobalEndpointManager. ``DatabaseAccount`` skips + # ``refresh_endpoint_list`` for the same reason. + request_params = MagicMock() + request_params.healthy_tentative_location = False + request_params.resource_type = ResourceType.DatabaseAccount + request_params.read_timeout_override = None + request_params.endpoint_override = _FAKE_ENDPOINT + request_params.should_cancel_request.return_value = False + request_params.operation_type = "Read" + request_params.availability_strategy = None + + connection_policy = MagicMock() + connection_policy.RequestTimeout = 30 + connection_policy.ReadTimeout = 30 + connection_policy.RecoveryReadTimeout = 5 + connection_policy.DBAReadTimeout = 5 + connection_policy.DBAConnectionTimeout = 5 + connection_policy.SSLConfiguration = None + connection_policy.DisableSSLVerification = False + + global_endpoint_manager = MagicMock() + + pipeline_client = MagicMock() + + request = MagicMock() + request.url = _FAKE_ENDPOINT + "dbs" + request.headers = {} + + # The fake pipeline response that _PipelineRunFunction will return. + mock_response = MagicMock() + mock_response.http_response.status_code = status_code + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = body + + return ( + (global_endpoint_manager, request_params, connection_policy, pipeline_client, request), + mock_response, + ) + + +class TestSyncRequestUsesSharedDecoder(unittest.TestCase): + """Wiring + behavioral tests for ``_synchronized_request._Request``.""" + + def test_request_invokes_shared_response_decoder(self): + """Reverting the call site back to ``data.decode('utf-8')`` would + make this test fail. Locks in the wiring.""" + args, mock_response = _build_request_args(status_code=200, body=_VALID_UTF8) + + with patch( + "azure.cosmos._synchronized_request._PipelineRunFunction", + return_value=mock_response, + ), patch( + "azure.cosmos._synchronized_request.decode_response_body_for_status", + return_value='{"ok":true}', + ) as mock_decode: + _synchronized_request._Request(*args) + + mock_decode.assert_called_once_with(_VALID_UTF8, 200, "Read") + + def test_invalid_utf8_on_404_surfaces_resource_not_found(self): + """Behavioral guarantee: a 404 carrying a malformed-UTF-8 body + must surface as ``CosmosResourceNotFoundError``, not as a + ``UnicodeDecodeError``. Customer error handlers branch on the + typed exception; a decode error here would skip those handlers + entirely.""" + args, mock_response = _build_request_args(status_code=404, body=_INVALID_UTF8) + + with patch( + "azure.cosmos._synchronized_request._PipelineRunFunction", + return_value=mock_response, + ): + with self.assertRaises(exceptions.CosmosResourceNotFoundError): + _synchronized_request._Request(*args) + + def test_invalid_utf8_on_503_surfaces_http_response_error(self): + """Same guarantee for the generic ``status_code >= 400`` branch. + 503 specifically matters: it drives cross-region retry; masking + it with a decode error would stop failover from happening.""" + args, mock_response = _build_request_args(status_code=503, body=_INVALID_UTF8) + + with patch( + "azure.cosmos._synchronized_request._PipelineRunFunction", + return_value=mock_response, + ): + with self.assertRaises(exceptions.CosmosHttpResponseError) as ctx: + _synchronized_request._Request(*args) + self.assertEqual(ctx.exception.status_code, 503) + + +class TestAsyncRequestUsesSharedDecoder(unittest.TestCase): + """Wiring + behavioral tests for ``_asynchronous_request._Request``.""" + + def test_request_invokes_shared_response_decoder(self): + async def run_test(): + args, mock_response = _build_request_args(status_code=200, body=_VALID_UTF8) + + with patch( + "azure.cosmos.aio._asynchronous_request._PipelineRunFunction", + new=AsyncMock(return_value=mock_response), + ), patch( + "azure.cosmos.aio._asynchronous_request.decode_response_body_for_status", + return_value='{"ok":true}', + ) as mock_decode: + await _asynchronous_request._Request(*args) + + mock_decode.assert_called_once_with(_VALID_UTF8, 200, "Read") + + asyncio.run(run_test()) + + def test_invalid_utf8_on_404_surfaces_resource_not_found(self): + async def run_test(): + args, mock_response = _build_request_args(status_code=404, body=_INVALID_UTF8) + + with patch( + "azure.cosmos.aio._asynchronous_request._PipelineRunFunction", + new=AsyncMock(return_value=mock_response), + ): + with self.assertRaises(exceptions.CosmosResourceNotFoundError): + await _asynchronous_request._Request(*args) + + asyncio.run(run_test()) + + def test_invalid_utf8_on_503_surfaces_http_response_error(self): + async def run_test(): + args, mock_response = _build_request_args(status_code=503, body=_INVALID_UTF8) + + with patch( + "azure.cosmos.aio._asynchronous_request._PipelineRunFunction", + new=AsyncMock(return_value=mock_response), + ): + with self.assertRaises(exceptions.CosmosHttpResponseError) as ctx: + await _asynchronous_request._Request(*args) + self.assertEqual(ctx.exception.status_code, 503) + + asyncio.run(run_test()) + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py index f2cc6aecad82..174273515e0f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py +++ b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py @@ -5,8 +5,13 @@ fallback behavior. Covers the healthy path (strict decode succeeds), the default behavior when the env var is unset (strict decode fails with an actionable hint), and the opt-in REPLACE / IGNORE modes. + +The helper reads the env var per-call on the decode-failure path, so +tests just mutate ``os.environ`` (under ``mock.patch.dict``) and call +``decode_response_body`` directly — no cache reset needed. """ # cspell:ignore ufffd +import json import os import unittest from unittest import mock @@ -25,29 +30,20 @@ class _DecoderEnvIsolatedTestCase(unittest.TestCase): """Base class that isolates each test from the surrounding process - environment and from sibling tests by restoring both the recognized - env var and the module-level fallback-mode cache after every test.""" + environment by rolling back any env mutations the test makes.""" def setUp(self): - # Snapshot the cache value and start an env patch that will - # roll back any mutations the test makes. - self._original_fallback_mode = _response_decoding._fallback_errors_mode self._env_patch = mock.patch.dict(os.environ, {}, clear=False) self._env_patch.start() # Strip the env var for the duration of the test so the helper's # default behavior (no env -> strict) is the explicit baseline. - # Tests that need a specific env value set it themselves and call - # `_reset_for_tests()`. + # Tests that need a specific env value set it themselves. os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) - _response_decoding._reset_for_tests() def tearDown(self): # `mock.patch.dict` rolls back any env mutations the test made, # including the pop in setUp. self._env_patch.stop() - # Restore the cache last, so the next test's setUp sees the - # same module state the suite started with. - _response_decoding._fallback_errors_mode = self._original_fallback_mode class TestStrictDecodingHealthyPath(_DecoderEnvIsolatedTestCase): @@ -70,9 +66,8 @@ def test_invalid_utf8_without_env_var_raises_with_hint(self): must raise ``UnicodeDecodeError`` so existing call sites continue to behave the same way. The hint in ``reason`` points the operator at the env var name so they can self-serve.""" - # setUp already cleared the env var and reset the cache; the - # assertion below makes that contract explicit for the reader. - self.assertIsNone(_response_decoding._fallback_errors_mode) + # setUp already cleared the env var; assert it for the reader. + self.assertNotIn(_MALFORMED_INPUT_ENV_VAR, os.environ) with self.assertRaises(UnicodeDecodeError) as ctx: _response_decoding.decode_response_body(_INVALID_UTF8, operation_context="read_item") @@ -84,11 +79,11 @@ def test_invalid_utf8_without_env_var_raises_with_hint(self): class TestPermissiveFallback(_DecoderEnvIsolatedTestCase): - """Exercises the decode behavior in each fallback mode by writing the - cache directly. """ + """Exercises the decode behavior in each fallback mode by setting + the env var and calling ``decode_response_body`` directly.""" def test_replace_mode_substitutes_replacement_character(self): - _response_decoding._fallback_errors_mode = "replace" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" result = _response_decoding.decode_response_body(_INVALID_UTF8) # The bad byte is replaced by U+FFFD; the surrounding text is preserved. self.assertIn("\ufffd", result) @@ -96,7 +91,7 @@ def test_replace_mode_substitutes_replacement_character(self): self.assertIn("world", result) def test_ignore_mode_drops_bad_bytes(self): - _response_decoding._fallback_errors_mode = "ignore" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "IGNORE" result = _response_decoding.decode_response_body(_INVALID_UTF8) # No replacement character; the bad byte is silently dropped. self.assertNotIn("\ufffd", result) @@ -104,45 +99,39 @@ def test_ignore_mode_drops_bad_bytes(self): self.assertIn("world", result) - -class TestEnvVarSnapshot(_DecoderEnvIsolatedTestCase): - """Tests for the env-var parser in isolation. Each test sets the - env var, calls ``_reset_for_tests()``, and asserts the cached - fallback mode matches the documented mapping.""" +class TestEnvVarParser(_DecoderEnvIsolatedTestCase): + """Unit tests for ``_resolve_fallback_mode_from_env`` in isolation. + Each test sets the env var and asserts the parsed mode matches the + documented mapping.""" def test_replace_env_value_resolves_to_replace_mode(self): os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" - _response_decoding._reset_for_tests() - self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + self.assertEqual(_response_decoding._resolve_fallback_mode_from_env(), "replace") def test_ignore_env_value_resolves_to_ignore_mode(self): os.environ[_MALFORMED_INPUT_ENV_VAR] = "IGNORE" - _response_decoding._reset_for_tests() - self.assertEqual(_response_decoding._fallback_errors_mode, "ignore") + self.assertEqual(_response_decoding._resolve_fallback_mode_from_env(), "ignore") def test_unknown_env_value_resolves_to_strict(self): """Anything other than REPLACE / IGNORE (case-insensitive) must leave strict decoding in effect.""" os.environ[_MALFORMED_INPUT_ENV_VAR] = "BOGUS" - _response_decoding._reset_for_tests() - self.assertIsNone(_response_decoding._fallback_errors_mode) + self.assertIsNone(_response_decoding._resolve_fallback_mode_from_env()) def test_env_value_is_case_insensitive_and_trims_whitespace(self): os.environ[_MALFORMED_INPUT_ENV_VAR] = " replace " - _response_decoding._reset_for_tests() - self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + self.assertEqual(_response_decoding._resolve_fallback_mode_from_env(), "replace") def test_unset_env_resolves_to_strict(self): # setUp already pops the var; this makes the contract explicit. self.assertNotIn(_MALFORMED_INPUT_ENV_VAR, os.environ) - _response_decoding._reset_for_tests() - self.assertIsNone(_response_decoding._fallback_errors_mode) + self.assertIsNone(_response_decoding._resolve_fallback_mode_from_env()) class TestEnvVarToBehaviorEndToEnd(_DecoderEnvIsolatedTestCase): - """Verifies the full contract: setting the env var and calling - ``_reset_for_tests()`` actually changes what ``decode_response_body`` - does. Catches regressions where the env parser and the per-call + """Verifies the full contract: setting the env var actually changes + what ``decode_response_body`` does, and clearing it returns to + strict. Catches regressions where the env parser and the per-call read drift apart.""" def test_setting_replace_env_var_makes_invalid_utf8_decode_succeed(self): @@ -150,26 +139,187 @@ def test_setting_replace_env_var_makes_invalid_utf8_decode_succeed(self): with self.assertRaises(UnicodeDecodeError): _response_decoding.decode_response_body(_INVALID_UTF8) - # Opt in via the env var, refresh, and prove the same input now - # decodes to a replacement-character-bearing string instead of raising. + # Opt in via the env var and prove the same input now decodes to + # a replacement-character-bearing string instead of raising. os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" - _response_decoding._reset_for_tests() result = _response_decoding.decode_response_body(_INVALID_UTF8) self.assertIn("\ufffd", result) - def test_clearing_env_var_after_reset_returns_to_strict(self): + def test_clearing_env_var_returns_to_strict(self): # Opt in. os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" - _response_decoding._reset_for_tests() - self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + self.assertEqual(_response_decoding._resolve_fallback_mode_from_env(), "replace") - # Opt out by removing the var and re-snapshotting. + # Opt out by removing the var; next decode raises again. del os.environ[_MALFORMED_INPUT_ENV_VAR] - _response_decoding._reset_for_tests() with self.assertRaises(UnicodeDecodeError): _response_decoding.decode_response_body(_INVALID_UTF8) +class TestDecodeForStatus(_DecoderEnvIsolatedTestCase): + """Tests ``decode_response_body_for_status`` — the wrapper that + HTTP request paths use so a malformed-UTF-8 error body does not + mask the real status-code exception. The SDK's retry/refresh logic + and customer error handlers branch on status code, so a 404, 410 + (partition split), 429 (throttle), or 503 must surface as the + correct typed exception even when the body has invalid bytes.""" + + def test_valid_utf8_success_passes_through(self): + """Healthy path: 2xx with well-formed body decodes normally.""" + result = _response_decoding.decode_response_body_for_status( + _VALID_UTF8, status_code=200 + ) + self.assertEqual(result, '{"note":"hello world"}') + + def test_invalid_utf8_on_2xx_still_raises(self): + """A successful response with malformed bytes is a data-integrity + problem the caller needs to see — do not silently paper over it.""" + with self.assertRaises(UnicodeDecodeError): + _response_decoding.decode_response_body_for_status( + _INVALID_UTF8, status_code=200 + ) + + def test_invalid_utf8_on_404_does_not_raise(self): + """404 with malformed body must decode best-effort so callers + receive ``CosmosResourceNotFoundError`` instead of a confusing + ``UnicodeDecodeError``.""" + result = _response_decoding.decode_response_body_for_status( + _INVALID_UTF8, status_code=404 + ) + # The bad byte is replaced; surrounding text is preserved so the + # error message remains human-readable. + self.assertIn("\ufffd", result) + self.assertIn("hello", result) + self.assertIn("world", result) + + def test_invalid_utf8_on_throttle_does_not_raise(self): + """429 carries the retry-after signal the SDK's throttle handler + depends on; it must not be masked by a decode error.""" + result = _response_decoding.decode_response_body_for_status( + _INVALID_UTF8, status_code=429 + ) + self.assertIn("\ufffd", result) + + def test_invalid_utf8_on_partition_gone_does_not_raise(self): + """410 is the partition-split signal that triggers partition-map + refresh; masking it would break split recovery.""" + result = _response_decoding.decode_response_body_for_status( + _INVALID_UTF8, status_code=410 + ) + self.assertIn("\ufffd", result) + + def test_invalid_utf8_on_service_unavailable_does_not_raise(self): + """503 drives cross-region retry; masking it makes the SDK give + up instead of failing over.""" + result = _response_decoding.decode_response_body_for_status( + _INVALID_UTF8, status_code=503 + ) + self.assertIn("\ufffd", result) + + def test_boundary_399_still_raises(self): + """The wrapper opens up best-effort decode at exactly 400 and + above. 399 (unused in HTTP today, but covers 3xx redirects) + must still raise — same reason as 2xx.""" + with self.assertRaises(UnicodeDecodeError): + _response_decoding.decode_response_body_for_status( + _INVALID_UTF8, status_code=399 + ) + + def test_boundary_400_does_not_raise(self): + """Confirms the threshold is inclusive at 400 (Bad Request).""" + result = _response_decoding.decode_response_body_for_status( + _INVALID_UTF8, status_code=400 + ) + self.assertIn("\ufffd", result) + + def test_empty_body_decodes_to_empty_string_regardless_of_status(self): + for status in (200, 404, 503): + self.assertEqual( + _response_decoding.decode_response_body_for_status(b"", status_code=status), + "", + ) + + def test_fallback_env_var_handles_2xx_before_status_check_kicks_in(self): + """When the operator has opted in via the env var, the inner + ``decode_response_body`` already succeeds (with replacement), so + a 2xx with malformed bytes also decodes successfully. The + wrapper's status-code branch never runs in that case.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + result = _response_decoding.decode_response_body_for_status( + _INVALID_UTF8, status_code=200 + ) + self.assertIn("\ufffd", result) + + +class TestPermissiveFallbackJsonPipeline(_DecoderEnvIsolatedTestCase): + """End-to-end tests covering the decode-then-``json.loads`` pipeline + every caller of ``decode_response_body`` runs. + + These tests document an important operator-facing trade-off of + enabling permissive fallback (``REPLACE``): malformed bytes inside + a JSON *string value* become silently-corrupted Python str values + after parsing (``"\\ufffd"`` ends up in the data), while malformed + bytes that land on JSON *structural* characters cause the parse + step to fail with ``json.JSONDecodeError`` — which the SDK's + request path catches and surfaces as a typed ``DecodeError``. + + Assertions are intentionally outcome-shaped (does parse succeed? + does the value contain U+FFFD?) and avoid asserting exact error + messages or byte offsets, so CPython upgrades do not break us. + Mirrors the coverage in the Java SDK's MalformedResponseTests.""" + + # Bad bytes (`\xc3\x28`) inside the JSON string value `"caf??"`. + # After REPLACE decode the body is well-formed JSON whose value + # happens to contain U+FFFD — json.loads succeeds. + _BAD_BYTES_IN_STRING_VALUE = b'{"name":"caf\xc3\x28 dining"}' + + # Bad bytes (`\xc3\x28`) placed where the JSON colon delimiter + # should be. After REPLACE decode the colon position contains + # U+FFFD instead — json.loads cannot parse this as an object. + _BAD_BYTES_IN_STRUCTURE = b'{"name"\xc3\x28"value"}' + + def test_replace_mode_corrupts_string_values_silently(self): + """REPLACE + parse on bad bytes inside a string value: parse + succeeds, the resulting Python str contains U+FFFD. This is + the case operators need to be aware of when enabling REPLACE + — application code receives data with substituted characters + and no signal that the substitution happened. + + Cross-SDK parity note: matches the Java + MalformedResponseTests scenario where a corrupted character + inside a JSON string is silently preserved through parsing.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + + decoded = _response_decoding.decode_response_body(self._BAD_BYTES_IN_STRING_VALUE) + parsed = json.loads(decoded) + + self.assertIsInstance(parsed, dict) + self.assertIn("name", parsed) + self.assertIsInstance(parsed["name"], str) + # The replacement character is present in the parsed value; + # the rest of the value text is preserved verbatim. + self.assertIn("\ufffd", parsed["name"]) + self.assertIn("caf", parsed["name"]) + self.assertIn("dining", parsed["name"]) + + def test_replace_mode_structural_corruption_raises_json_decode_error(self): + """REPLACE + parse on bad bytes in JSON structure: decode + succeeds, parse raises ``json.JSONDecodeError``. The SDK's + ``_Request`` path catches that and surfaces it as + ``azure.core.exceptions.DecodeError`` (covered by the + ``_Request`` wiring tests). Here we just lock in the seam: + decode produces a string, parse rejects it.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + + decoded = _response_decoding.decode_response_body(self._BAD_BYTES_IN_STRUCTURE) + # Decode itself does not raise — the byte that broke JSON + # structure has become U+FFFD in the decoded string. + self.assertIsInstance(decoded, str) + self.assertIn("\ufffd", decoded) + + with self.assertRaises(json.JSONDecodeError): + json.loads(decoded) + + if __name__ == "__main__": unittest.main() - diff --git a/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py b/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py index ce3bcb6f78a1..bfa56e09a394 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py @@ -156,7 +156,9 @@ async def run_test(): asyncio.run(run_test()) def test_sync_inference_uses_shared_response_decoder(self): - """Test that sync inference service decodes response bytes via decode_response_body.""" + """Test that sync inference service decodes response bytes via the + shared decode_response_body_for_status helper. Locks in the wiring + so a regression that reverts to inline data.decode("utf-8") would fail.""" from azure.cosmos._inference_service import _InferenceService mock_connection = self._create_mock_connection() @@ -172,17 +174,18 @@ def test_sync_inference_uses_shared_response_decoder(self): service._inference_pipeline_client._pipeline, "run", return_value=mock_response ), patch( - "azure.cosmos._inference_service.decode_response_body", + "azure.cosmos._inference_service.decode_response_body_for_status", return_value='{"Scores": []}' ) as mock_decode: service.rerank( reranking_context="test query", documents=["doc1"] ) - mock_decode.assert_called_once_with(raw_response_data, "inference_request") + mock_decode.assert_called_once_with(raw_response_data, 200, "inference_request") def test_async_inference_uses_shared_response_decoder(self): - """Test that async inference service decodes response bytes via decode_response_body.""" + """Test that async inference service decodes response bytes via the + shared decode_response_body_for_status helper.""" async def run_test(): from azure.cosmos.aio._inference_service_async import _InferenceService @@ -200,14 +203,14 @@ async def run_test(): service._inference_pipeline_client._pipeline, "run", return_value=mock_response ), patch( - "azure.cosmos.aio._inference_service_async.decode_response_body", + "azure.cosmos.aio._inference_service_async.decode_response_body_for_status", return_value='{"Scores": []}' ) as mock_decode: await service.rerank( reranking_context="test query", documents=["doc1"] ) - mock_decode.assert_called_once_with(raw_response_data, "inference_request") + mock_decode.assert_called_once_with(raw_response_data, 200, "inference_request") asyncio.run(run_test())