Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#### Breaking Changes

#### Bugs Fixed
* Fixed bug where the `Content-Length` HTTP request header was computed from the character count of the request body instead of its UTF-8 byte count. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008)
* Added an opt-in fallback for invalid UTF-8 in response bodies. Default behavior is unchanged (strict decode). Setting `AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT` to `REPLACE` or `IGNORE` at process start enables a permissive decode so reads, queries, and change-feed iteration can make progress past corrupt payloads. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008)
* Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243)

#### Other Changes
Expand Down
6 changes: 6 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class _Constants:
TIMEOUT_ERROR_THRESHOLD_PPAF_DEFAULT: int = 10
# -------------------------------------------------------------------------

# Controls how the SDK handles invalid UTF-8 bytes in HTTP response bodies.
# Accepted values: "REPLACE", "IGNORE". Anything else (including unset)
# leaves strict decoding in effect, which is the historical default.
CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT: str = \
"AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT"

# Error code translations
ERROR_TRANSLATIONS: dict[int, str] = {
400: "BAD_REQUEST - Request being sent is invalid.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ._cosmos_http_logging_policy import CosmosHttpLoggingPolicy
from ._cosmos_responses import CosmosDict
from ._inference_auth_policy import InferenceServiceBearerTokenPolicy
from ._response_decoding import decode_response_body
from ._retry_utility import ConnectionRetryPolicy
from .http_constants import HttpHeaders

Expand Down Expand Up @@ -202,7 +203,7 @@ def rerank(

data = response.body()
if data:
data = data.decode("utf-8")
data = decode_response_body(data, "inference_request")

if response.status_code >= 400:
raise exceptions.CosmosHttpResponseError(message=data, response=response)
Expand Down
122 changes: 122 additions & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_response_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.

"""UTF-8 decoding for HTTP response bodies, with an opt-in fallback for
payloads containing bytes that are not valid UTF-8.

By default this module preserves the historical SDK behavior: strict
decode, ``UnicodeDecodeError`` raised on the first invalid byte.
Operators who need to read past corrupt payloads (for example, to
unblock a stuck change-feed processor) can opt in to a permissive
fallback by setting an environment variable at process start.

The recognized environment variable is
``AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT``:

* ``REPLACE`` -> Python ``errors="replace"`` (substitute U+FFFD)
* ``IGNORE`` -> Python ``errors="ignore"`` (drop the bad bytes)
* anything else, including unset -> strict (raise on bad bytes)

The value is read once at module import. Tests can call
``_reset_for_tests()`` to re-snapshot.
"""
import logging
import os
from typing import Optional

from ._constants import _Constants

__all__ = ["decode_response_body", "_reset_for_tests"]

_MALFORMED_INPUT_ENV_VAR = _Constants.CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT

# Mapping from the recognized env var values to Python's bytes.decode
# `errors=` argument. Anything not in this mapping (including the env var
# being unset) resolves to strict decoding, which is the historical default.
_ENV_VALUE_TO_DECODE_ERRORS_MODE = {
"REPLACE": "replace",
"IGNORE": "ignore",
}

_logger = logging.getLogger(__name__)


def _resolve_fallback_mode_from_env() -> Optional[str]:
"""Reads the malformed-input env var and returns the Python decode
``errors=`` mode to use as a fallback, or ``None`` if the operator
has not opted in (in which case strict decoding stays in effect)."""
raw_value = os.environ.get(_MALFORMED_INPUT_ENV_VAR)
if raw_value is None:
return None
return _ENV_VALUE_TO_DECODE_ERRORS_MODE.get(raw_value.strip().upper())


# Snapshot at module import. The value is immutable after import unless
# `_reset_for_tests` is called. Reading a module-level string in CPython is
# atomic, so no lock is needed on the per-call read path.
_fallback_errors_mode: Optional[str] = _resolve_fallback_mode_from_env()


def _reset_for_tests() -> None:
"""Re-reads the env var and refreshes the cached fallback mode. Tests
should call this after mutating ``os.environ`` so the next call to
``decode_response_body`` sees the new value."""
global _fallback_errors_mode # pylint: disable=global-statement
_fallback_errors_mode = _resolve_fallback_mode_from_env()


def decode_response_body(data: bytes, operation_context: Optional[str] = None) -> str:
"""Decode an HTTP response body as UTF-8.

The healthy path is strict decoding, identical in behavior and cost
to ``data.decode("utf-8")``. The slow path is taken only when the
payload contains bytes that are not valid UTF-8:

* If the operator has opted in via the malformed-input env var, the
decode is retried in the configured permissive mode (``replace`` or
``ignore``) and a WARNING is logged with the byte offset, the
decoder's reason, and the supplied operation context.
* Otherwise a ``UnicodeDecodeError`` is raised whose ``reason`` field
carries an actionable hint pointing the operator at the env var.
The original exception is preserved as ``__cause__``.

:param data: Response body bytes.
:param operation_context: Optional short string identifying the call
site (for example, ``"read_item"`` or ``"query_items page"``);
included in the WARNING log line when permissive fallback fires.
:returns: The decoded string.
:raises UnicodeDecodeError: If the body contains invalid UTF-8 and
the operator has not opted in to a permissive fallback.
"""
try:
return data.decode("utf-8")
except UnicodeDecodeError as strict_error:
fallback_mode = _fallback_errors_mode
if fallback_mode is None:
hint = (
"{original}; set environment variable "
"{env_var}=REPLACE (or IGNORE) to tolerate invalid UTF-8 "
"in Cosmos response bodies"
).format(
original=strict_error.reason,
env_var=_MALFORMED_INPUT_ENV_VAR,
)
raise UnicodeDecodeError(
strict_error.encoding,
strict_error.object,
strict_error.start,
strict_error.end,
hint,
) from strict_error

_logger.warning(
"Cosmos response body contained invalid UTF-8 at byte offset %d "
"(reason: %s); decoding with errors=%r per %s%s.",
strict_error.start,
strict_error.reason,
fallback_mode,
_MALFORMED_INPUT_ENV_VAR,
" (operation: {0})".format(operation_context) if operation_context else "",
)
return data.decode("utf-8", errors=fallback_mode)

Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ._availability_strategy_config import CrossRegionHedgingStrategy
from ._availability_strategy_handler import execute_with_hedging
from ._constants import _Constants
from ._response_decoding import decode_response_body
from ._request_object import RequestObject
from .documents import _OperationType

Expand Down Expand Up @@ -177,7 +178,7 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin

data = response.body()
if data:
data = data.decode("utf-8")
data = decode_response_body(data, request_params.operation_type)

if response.status_code == 404:
raise exceptions.CosmosResourceNotFoundError(message=data, response=response)
Expand Down Expand Up @@ -257,7 +258,10 @@ def SynchronizedRequest(
"""
request.data = _request_body_from_data(request_data)
if request.data and isinstance(request.data, str):
request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data)
# Use UTF-8 byte length, not str length (code-point count), so the
# header matches the bytes the transport actually writes for any
# non-ASCII payload.
request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data.encode("utf-8"))
elif request.data is None:
request.headers[http_constants.HttpHeaders.ContentLength] = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .._availability_strategy_config import CrossRegionHedgingStrategy
from .._constants import _Constants
from .._request_object import RequestObject
from .._response_decoding import decode_response_body
from .._synchronized_request import _request_body_from_data, _replace_url_prefix
from ..documents import _OperationType

Expand Down Expand Up @@ -141,7 +142,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p

data = response.body()
if data:
data = data.decode("utf-8")
data = decode_response_body(data, request_params.operation_type)

if response.status_code == 404:
raise exceptions.CosmosResourceNotFoundError(message=data, response=response)
Expand Down Expand Up @@ -210,7 +211,10 @@ async def AsynchronousRequest(
"""
request.data = _request_body_from_data(request_data)
if request.data and isinstance(request.data, str):
request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data)
# Use UTF-8 byte length, not str length (code-point count), so the
# header matches the bytes the transport actually writes for any
# non-ASCII payload.
request.headers[http_constants.HttpHeaders.ContentLength] = len(request.data.encode("utf-8"))
elif request.data is None:
request.headers[http_constants.HttpHeaders.ContentLength] = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .._constants import _Constants as Constants
from .._cosmos_http_logging_policy import CosmosHttpLoggingPolicy
from .._cosmos_responses import CosmosDict
from .._response_decoding import decode_response_body
from ..http_constants import HttpHeaders


Expand Down Expand Up @@ -235,7 +236,7 @@ async def rerank(

data = response.body()
if data:
data = data.decode("utf-8")
data = decode_response_body(data, "inference_request")

if response.status_code >= 400:
raise exceptions.CosmosHttpResponseError(message=data, response=response)
Expand Down
Loading
Loading