diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 187cd853cb27..3269a2e1ecce 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -7,6 +7,8 @@ #### Breaking Changes #### Bugs Fixed +* Fixed bug where the `Content-Length` HTTP request header was computed from the character count of the request body instead of its UTF-8 byte count. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008) +* Added an opt-in fallback for invalid UTF-8 in response bodies. Default behavior is unchanged (strict decode). Setting `AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT` to `REPLACE` or `IGNORE` at process start enables a permissive decode so reads, queries, and change-feed iteration can make progress past corrupt payloads. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008) * Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 5338ea116340..73ba0649a859 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -90,6 +90,12 @@ class _Constants: TIMEOUT_ERROR_THRESHOLD_PPAF_DEFAULT: int = 10 # ------------------------------------------------------------------------- + # Controls how the SDK handles invalid UTF-8 bytes in HTTP response bodies. + # Accepted values: "REPLACE", "IGNORE". Anything else (including unset) + # leaves strict decoding in effect, which is the historical default. + CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT: str = \ + "AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" + # Error code translations ERROR_TRANSLATIONS: dict[int, str] = { 400: "BAD_REQUEST - Request being sent is invalid.", diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py index 29e49217fb02..7734aeba5b29 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_inference_service.py @@ -38,6 +38,7 @@ from ._cosmos_http_logging_policy import CosmosHttpLoggingPolicy from ._cosmos_responses import CosmosDict from ._inference_auth_policy import InferenceServiceBearerTokenPolicy +from ._response_decoding import decode_response_body from ._retry_utility import ConnectionRetryPolicy from .http_constants import HttpHeaders @@ -202,7 +203,7 @@ def rerank( data = response.body() if data: - data = data.decode("utf-8") + data = decode_response_body(data, "inference_request") if response.status_code >= 400: raise exceptions.CosmosHttpResponseError(message=data, response=response) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py new file mode 100644 index 000000000000..1299e2875fd8 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py @@ -0,0 +1,122 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""UTF-8 decoding for HTTP response bodies, with an opt-in fallback for +payloads containing bytes that are not valid UTF-8. + +By default this module preserves the historical SDK behavior: strict +decode, ``UnicodeDecodeError`` raised on the first invalid byte. +Operators who need to read past corrupt payloads (for example, to +unblock a stuck change-feed processor) can opt in to a permissive +fallback by setting an environment variable at process start. + +The recognized environment variable is +``AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT``: + +* ``REPLACE`` -> Python ``errors="replace"`` (substitute U+FFFD) +* ``IGNORE`` -> Python ``errors="ignore"`` (drop the bad bytes) +* anything else, including unset -> strict (raise on bad bytes) + +The value is read once at module import. Tests can call +``_reset_for_tests()`` to re-snapshot. +""" +import logging +import os +from typing import Optional + +from ._constants import _Constants + +__all__ = ["decode_response_body", "_reset_for_tests"] + +_MALFORMED_INPUT_ENV_VAR = _Constants.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT + +# Mapping from the recognized env var values to Python's bytes.decode +# `errors=` argument. Anything not in this mapping (including the env var +# being unset) resolves to strict decoding, which is the historical default. +_ENV_VALUE_TO_DECODE_ERRORS_MODE = { + "REPLACE": "replace", + "IGNORE": "ignore", +} + +_logger = logging.getLogger(__name__) + + +def _resolve_fallback_mode_from_env() -> Optional[str]: + """Reads the malformed-input env var and returns the Python decode + ``errors=`` mode to use as a fallback, or ``None`` if the operator + has not opted in (in which case strict decoding stays in effect).""" + raw_value = os.environ.get(_MALFORMED_INPUT_ENV_VAR) + if raw_value is None: + return None + return _ENV_VALUE_TO_DECODE_ERRORS_MODE.get(raw_value.strip().upper()) + + +# Snapshot at module import. The value is immutable after import unless +# `_reset_for_tests` is called. Reading a module-level string in CPython is +# atomic, so no lock is needed on the per-call read path. +_fallback_errors_mode: Optional[str] = _resolve_fallback_mode_from_env() + + +def _reset_for_tests() -> None: + """Re-reads the env var and refreshes the cached fallback mode. Tests + should call this after mutating ``os.environ`` so the next call to + ``decode_response_body`` sees the new value.""" + global _fallback_errors_mode # pylint: disable=global-statement + _fallback_errors_mode = _resolve_fallback_mode_from_env() + + +def decode_response_body(data: bytes, operation_context: Optional[str] = None) -> str: + """Decode an HTTP response body as UTF-8. + + The healthy path is strict decoding, identical in behavior and cost + to ``data.decode("utf-8")``. The slow path is taken only when the + payload contains bytes that are not valid UTF-8: + + * If the operator has opted in via the malformed-input env var, the + decode is retried in the configured permissive mode (``replace`` or + ``ignore``) and a WARNING is logged with the byte offset, the + decoder's reason, and the supplied operation context. + * Otherwise a ``UnicodeDecodeError`` is raised whose ``reason`` field + carries an actionable hint pointing the operator at the env var. + The original exception is preserved as ``__cause__``. + + :param data: Response body bytes. + :param operation_context: Optional short string identifying the call + site (for example, ``"read_item"`` or ``"query_items page"``); + included in the WARNING log line when permissive fallback fires. + :returns: The decoded string. + :raises UnicodeDecodeError: If the body contains invalid UTF-8 and + the operator has not opted in to a permissive fallback. + """ + try: + return data.decode("utf-8") + except UnicodeDecodeError as strict_error: + fallback_mode = _fallback_errors_mode + if fallback_mode is None: + hint = ( + "{original}; set environment variable " + "{env_var}=REPLACE (or IGNORE) to tolerate invalid UTF-8 " + "in Cosmos response bodies" + ).format( + original=strict_error.reason, + env_var=_MALFORMED_INPUT_ENV_VAR, + ) + raise UnicodeDecodeError( + strict_error.encoding, + strict_error.object, + strict_error.start, + strict_error.end, + hint, + ) from strict_error + + _logger.warning( + "Cosmos response body contained invalid UTF-8 at byte offset %d " + "(reason: %s); decoding with errors=%r per %s%s.", + strict_error.start, + strict_error.reason, + fallback_mode, + _MALFORMED_INPUT_ENV_VAR, + " (operation: {0})".format(operation_context) if operation_context else "", + ) + return data.decode("utf-8", errors=fallback_mode) + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 7b18f52da2e2..01ab8e4abe94 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -33,6 +33,7 @@ from ._availability_strategy_config import CrossRegionHedgingStrategy from ._availability_strategy_handler import execute_with_hedging from ._constants import _Constants +from ._response_decoding import decode_response_body from ._request_object import RequestObject from .documents import _OperationType @@ -177,7 +178,7 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin data = response.body() if data: - data = data.decode("utf-8") + data = decode_response_body(data, request_params.operation_type) if response.status_code == 404: raise exceptions.CosmosResourceNotFoundError(message=data, response=response) @@ -257,7 +258,10 @@ def SynchronizedRequest( """ request.data = _request_body_from_data(request_data) if request.data and isinstance(request.data, str): - request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data) + # Use UTF-8 byte length, not str length (code-point count), so the + # header matches the bytes the transport actually writes for any + # non-ASCII payload. + request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data.encode("utf-8")) elif request.data is None: request.headers[http_constants.HttpHeaders.ContentLength] = 0 diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index cdf64a7cca76..46459774c711 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -35,6 +35,7 @@ from .._availability_strategy_config import CrossRegionHedgingStrategy from .._constants import _Constants from .._request_object import RequestObject +from .._response_decoding import decode_response_body from .._synchronized_request import _request_body_from_data, _replace_url_prefix from ..documents import _OperationType @@ -141,7 +142,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p data = response.body() if data: - data = data.decode("utf-8") + data = decode_response_body(data, request_params.operation_type) if response.status_code == 404: raise exceptions.CosmosResourceNotFoundError(message=data, response=response) @@ -210,7 +211,10 @@ async def AsynchronousRequest( """ request.data = _request_body_from_data(request_data) if request.data and isinstance(request.data, str): - request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data) + # Use UTF-8 byte length, not str length (code-point count), so the + # header matches the bytes the transport actually writes for any + # non-ASCII payload. + request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data.encode("utf-8")) elif request.data is None: request.headers[http_constants.HttpHeaders.ContentLength] = 0 diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py index 05ec30bec755..a6ea3a095bb2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_inference_service_async.py @@ -40,6 +40,7 @@ from .._constants import _Constants as Constants from .._cosmos_http_logging_policy import CosmosHttpLoggingPolicy from .._cosmos_responses import CosmosDict +from .._response_decoding import decode_response_body from ..http_constants import HttpHeaders @@ -235,7 +236,7 @@ async def rerank( data = response.body() if data: - data = data.decode("utf-8") + data = decode_response_body(data, "inference_request") if response.status_code >= 400: raise exceptions.CosmosHttpResponseError(message=data, response=response) diff --git a/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py b/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py new file mode 100644 index 000000000000..ee8b921eb7eb --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py @@ -0,0 +1,169 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Regression tests for the Content-Length header computation. + +The SDK previously computed `Content-Length` from `len(request.data)` — +i.e. the number of Unicode code points in the JSON string — instead of +the UTF-8 byte length that actually goes on the wire. For any non-ASCII +payload that under-counted the body by the number of multi-byte +characters in it, which can cause downstream HTTP receivers to truncate +the body, reject the request, or mis-frame the next keep-alive request. + + +These tests exercise the exact arithmetic in both the sync and async +request paths via a minimal stand-in for the request object, so they +do not require a live Cosmos account. +""" +import unittest +from unittest import mock + +from azure.cosmos import _synchronized_request, http_constants +from azure.cosmos.aio import _asynchronous_request +from azure.cosmos.documents import _OperationType +from azure.cosmos.http_constants import HttpHeaders + + +def _set_content_length_like_sdk(body): + """Mirrors the post-fix code path in both `_synchronized_request.py` + and `aio/_asynchronous_request.py`. Kept in lock-step with those + call sites so this test fails if either one regresses to + `len(body)` on the str branch.""" + headers = {} + if body and isinstance(body, str): + headers[HttpHeaders.ContentLength] = len(body.encode("utf-8")) + elif body is None: + headers[HttpHeaders.ContentLength] = 0 + return headers + + +class TestContentLengthEncoding(unittest.TestCase): + + def test_ascii_body_byte_length_equals_char_length(self): + """Regression guard: ASCII-only bodies must continue to produce + the same `Content-Length` value as before the fix (the new and + old computations agree when every code point is one byte).""" + body = '{"id":"x","name":"hello"}' + headers = _set_content_length_like_sdk(body) + self.assertEqual(headers[HttpHeaders.ContentLength], len(body)) + self.assertEqual(headers[HttpHeaders.ContentLength], 25) + + def test_two_byte_character_adds_one_byte(self): + """`café` contains one 2-byte character (`é` → `\\xC3\\xA9`), + so the UTF-8 byte length must be `len(body) + 1`.""" + body = '{"name":"café"}' + headers = _set_content_length_like_sdk(body) + self.assertEqual(headers[HttpHeaders.ContentLength], len(body) + 1) + self.assertEqual( + headers[HttpHeaders.ContentLength], + len(body.encode("utf-8")), + ) + + def test_mixed_multibyte_characters(self): + """Accented (2-byte), CJK (3-byte), and emoji (4-byte) + characters together. The header must equal the UTF-8 byte + length, not the code-point count. This catches future + 'let's strip the encode call to save a microsecond' + regressions.""" + body = '{"a":"é","b":"日","c":"🎉"}' + headers = _set_content_length_like_sdk(body) + self.assertEqual( + headers[HttpHeaders.ContentLength], + len(body.encode("utf-8")), + ) + # And explicitly assert it differs from the buggy computation. + self.assertNotEqual(headers[HttpHeaders.ContentLength], len(body)) + + def test_none_body_is_zero(self): + """`None` → `Content-Length: 0`, unchanged by the fix.""" + headers = _set_content_length_like_sdk(None) + self.assertEqual(headers[HttpHeaders.ContentLength], 0) + + def test_empty_string_does_not_set_header(self): + """An empty string is falsy, so the str branch does not fire and + the `is None` branch does not fire either — header is left for + the transport to set (unchanged by the fix).""" + headers = _set_content_length_like_sdk("") + self.assertNotIn(HttpHeaders.ContentLength, headers) + + +class _DummyRequestParams: + def __init__(self): + self.availability_strategy = None + self.is_hedging_request = False + self.resource_type = http_constants.ResourceType.Document + self.operation_type = _OperationType.Create + self.retry_write = 0 + + +class _DummyGlobalEndpointManager: + @staticmethod + def is_per_partition_automatic_failover_enabled(): + return False + + +class _DummyRequest: + def __init__(self): + self.headers = {} + self.data = None + + +class TestContentLengthWiringSyncAndAsync(unittest.TestCase): + def test_sync_request_sets_utf8_byte_content_length(self): + params = _DummyRequestParams() + manager = _DummyGlobalEndpointManager() + request = _DummyRequest() + + captured = {} + + def _fake_execute(*args, **kwargs): + request_arg = args[6] + captured["content_length"] = request_arg.headers.get(HttpHeaders.ContentLength) + captured["body"] = request_arg.data + return {}, {} + + with mock.patch.object(_synchronized_request._retry_utility, "Execute", side_effect=_fake_execute): + _synchronized_request.SynchronizedRequest( + client=object(), + request_params=params, + global_endpoint_manager=manager, + connection_policy=object(), + pipeline_client=object(), + request=request, + request_data={"name": "café"}, + ) + + self.assertEqual(captured["content_length"], len(captured["body"].encode("utf-8"))) + + +class TestContentLengthWiringAsync(unittest.IsolatedAsyncioTestCase): + async def test_async_request_sets_utf8_byte_content_length(self): + params = _DummyRequestParams() + manager = _DummyGlobalEndpointManager() + request = _DummyRequest() + + captured = {} + + async def _fake_execute_async(*args, **kwargs): + request_arg = args[6] + captured["content_length"] = request_arg.headers.get(HttpHeaders.ContentLength) + captured["body"] = request_arg.data + return {}, {} + + with mock.patch.object(_asynchronous_request._retry_utility_async, "ExecuteAsync", side_effect=_fake_execute_async): + await _asynchronous_request.AsynchronousRequest( + client=object(), + request_params=params, + global_endpoint_manager=manager, + connection_policy=object(), + pipeline_client=object(), + request=request, + request_data={"name": "café"}, + ) + + self.assertEqual(captured["content_length"], len(captured["body"].encode("utf-8"))) + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py new file mode 100644 index 000000000000..f2cc6aecad82 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py @@ -0,0 +1,175 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Tests for the response-body UTF-8 decode helper and its env-var-driven +fallback behavior. Covers the healthy path (strict decode succeeds), the +default behavior when the env var is unset (strict decode fails with an +actionable hint), and the opt-in REPLACE / IGNORE modes. +""" +# cspell:ignore ufffd +import os +import unittest +from unittest import mock + +from azure.cosmos import _response_decoding + + +# A small payload containing one valid 2-byte UTF-8 sequence followed by a +# byte (0xC3 followed by 0x28) that is not a valid UTF-8 continuation byte. +# `\xC3\x28` is the textbook example of an invalid UTF-8 sequence. +_INVALID_UTF8 = b'{"note":"hello \xc3\x28 world"}' +_VALID_UTF8 = b'{"note":"hello world"}' + +_MALFORMED_INPUT_ENV_VAR = "AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" + + +class _DecoderEnvIsolatedTestCase(unittest.TestCase): + """Base class that isolates each test from the surrounding process + environment and from sibling tests by restoring both the recognized + env var and the module-level fallback-mode cache after every test.""" + + def setUp(self): + # Snapshot the cache value and start an env patch that will + # roll back any mutations the test makes. + self._original_fallback_mode = _response_decoding._fallback_errors_mode + self._env_patch = mock.patch.dict(os.environ, {}, clear=False) + self._env_patch.start() + # Strip the env var for the duration of the test so the helper's + # default behavior (no env -> strict) is the explicit baseline. + # Tests that need a specific env value set it themselves and call + # `_reset_for_tests()`. + os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) + _response_decoding._reset_for_tests() + + def tearDown(self): + # `mock.patch.dict` rolls back any env mutations the test made, + # including the pop in setUp. + self._env_patch.stop() + # Restore the cache last, so the next test's setUp sees the + # same module state the suite started with. + _response_decoding._fallback_errors_mode = self._original_fallback_mode + + +class TestStrictDecodingHealthyPath(_DecoderEnvIsolatedTestCase): + + def test_valid_utf8_decodes_unchanged(self): + """The healthy path must produce exactly the same string as + ``bytes.decode('utf-8')``. This is the regression guard for + the 99.99% case where the body is well-formed.""" + result = _response_decoding.decode_response_body(_VALID_UTF8) + self.assertEqual(result, '{"note":"hello world"}') + + def test_empty_bytes_decodes_to_empty_string(self): + self.assertEqual(_response_decoding.decode_response_body(b""), "") + + +class TestStrictDecodingRaisesActionableError(_DecoderEnvIsolatedTestCase): + + def test_invalid_utf8_without_env_var_raises_with_hint(self): + """When the env var is unset (the historical default) the helper + must raise ``UnicodeDecodeError`` so existing call sites continue + to behave the same way. The hint in ``reason`` points the + operator at the env var name so they can self-serve.""" + # setUp already cleared the env var and reset the cache; the + # assertion below makes that contract explicit for the reader. + self.assertIsNone(_response_decoding._fallback_errors_mode) + + with self.assertRaises(UnicodeDecodeError) as ctx: + _response_decoding.decode_response_body(_INVALID_UTF8, operation_context="read_item") + + self.assertIn(_MALFORMED_INPUT_ENV_VAR, ctx.exception.reason) + # Original exception must be chained so callers and log readers + # can still see the underlying decoder error. + self.assertIsInstance(ctx.exception.__cause__, UnicodeDecodeError) + + +class TestPermissiveFallback(_DecoderEnvIsolatedTestCase): + """Exercises the decode behavior in each fallback mode by writing the + cache directly. """ + + def test_replace_mode_substitutes_replacement_character(self): + _response_decoding._fallback_errors_mode = "replace" + result = _response_decoding.decode_response_body(_INVALID_UTF8) + # The bad byte is replaced by U+FFFD; the surrounding text is preserved. + self.assertIn("\ufffd", result) + self.assertIn("hello", result) + self.assertIn("world", result) + + def test_ignore_mode_drops_bad_bytes(self): + _response_decoding._fallback_errors_mode = "ignore" + result = _response_decoding.decode_response_body(_INVALID_UTF8) + # No replacement character; the bad byte is silently dropped. + self.assertNotIn("\ufffd", result) + self.assertIn("hello", result) + self.assertIn("world", result) + + + +class TestEnvVarSnapshot(_DecoderEnvIsolatedTestCase): + """Tests for the env-var parser in isolation. Each test sets the + env var, calls ``_reset_for_tests()``, and asserts the cached + fallback mode matches the documented mapping.""" + + def test_replace_env_value_resolves_to_replace_mode(self): + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + _response_decoding._reset_for_tests() + self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + + def test_ignore_env_value_resolves_to_ignore_mode(self): + os.environ[_MALFORMED_INPUT_ENV_VAR] = "IGNORE" + _response_decoding._reset_for_tests() + self.assertEqual(_response_decoding._fallback_errors_mode, "ignore") + + def test_unknown_env_value_resolves_to_strict(self): + """Anything other than REPLACE / IGNORE (case-insensitive) must + leave strict decoding in effect.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "BOGUS" + _response_decoding._reset_for_tests() + self.assertIsNone(_response_decoding._fallback_errors_mode) + + def test_env_value_is_case_insensitive_and_trims_whitespace(self): + os.environ[_MALFORMED_INPUT_ENV_VAR] = " replace " + _response_decoding._reset_for_tests() + self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + + def test_unset_env_resolves_to_strict(self): + # setUp already pops the var; this makes the contract explicit. + self.assertNotIn(_MALFORMED_INPUT_ENV_VAR, os.environ) + _response_decoding._reset_for_tests() + self.assertIsNone(_response_decoding._fallback_errors_mode) + + +class TestEnvVarToBehaviorEndToEnd(_DecoderEnvIsolatedTestCase): + """Verifies the full contract: setting the env var and calling + ``_reset_for_tests()`` actually changes what ``decode_response_body`` + does. Catches regressions where the env parser and the per-call + read drift apart.""" + + def test_setting_replace_env_var_makes_invalid_utf8_decode_succeed(self): + # Baseline: with no env var, the same input raises. + with self.assertRaises(UnicodeDecodeError): + _response_decoding.decode_response_body(_INVALID_UTF8) + + # Opt in via the env var, refresh, and prove the same input now + # decodes to a replacement-character-bearing string instead of raising. + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + _response_decoding._reset_for_tests() + result = _response_decoding.decode_response_body(_INVALID_UTF8) + self.assertIn("\ufffd", result) + + def test_clearing_env_var_after_reset_returns_to_strict(self): + # Opt in. + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + _response_decoding._reset_for_tests() + self.assertEqual(_response_decoding._fallback_errors_mode, "replace") + + # Opt out by removing the var and re-snapshotting. + del os.environ[_MALFORMED_INPUT_ENV_VAR] + _response_decoding._reset_for_tests() + with self.assertRaises(UnicodeDecodeError): + _response_decoding.decode_response_body(_INVALID_UTF8) + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py b/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py index 9c2a2e034d01..ce3bcb6f78a1 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py @@ -155,6 +155,62 @@ async def run_test(): asyncio.run(run_test()) + def test_sync_inference_uses_shared_response_decoder(self): + """Test that sync inference service decodes response bytes via decode_response_body.""" + from azure.cosmos._inference_service import _InferenceService + + mock_connection = self._create_mock_connection() + service = _InferenceService(mock_connection) + + raw_response_data = b'{"Scores": []}' + mock_response = MagicMock() + mock_response.http_response.status_code = 200 + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = raw_response_data + + with patch.object( + service._inference_pipeline_client._pipeline, "run", + return_value=mock_response + ), patch( + "azure.cosmos._inference_service.decode_response_body", + return_value='{"Scores": []}' + ) as mock_decode: + service.rerank( + reranking_context="test query", + documents=["doc1"] + ) + mock_decode.assert_called_once_with(raw_response_data, "inference_request") + + def test_async_inference_uses_shared_response_decoder(self): + """Test that async inference service decodes response bytes via decode_response_body.""" + async def run_test(): + from azure.cosmos.aio._inference_service_async import _InferenceService + + mock_connection = self._create_mock_connection() + mock_connection.connection_policy.DisableSSLVerification = False + service = _InferenceService(mock_connection) + + raw_response_data = b'{"Scores": []}' + mock_response = MagicMock() + mock_response.http_response.status_code = 200 + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = raw_response_data + + with patch.object( + service._inference_pipeline_client._pipeline, "run", + return_value=mock_response + ), patch( + "azure.cosmos.aio._inference_service_async.decode_response_body", + return_value='{"Scores": []}' + ) as mock_decode: + await service.rerank( + reranking_context="test query", + documents=["doc1"] + ) + mock_decode.assert_called_once_with(raw_response_data, "inference_request") + + asyncio.run(run_test()) + def test_sync_inference_response_timeout_raises_408(self): """Test that sync inference service converts ServiceResponseError to 408.""" from azure.cosmos._inference_service import _InferenceService