diff --git a/src/connectrpc/_codec.py b/src/connectrpc/_codec.py index 06190f4..864cdb8 100644 --- a/src/connectrpc/_codec.py +++ b/src/connectrpc/_codec.py @@ -8,9 +8,6 @@ CODEC_NAME_PROTO = "proto" CODEC_NAME_JSON = "json" -# Follow connect-go's hacky approach to handling charset parameter -# https://github.com/connectrpc/connect-go/blob/fe4915717d32438c40a24a50e3895271d4c24751/codec.go#L31 -CODEC_NAME_JSON_CHARSET_UTF8 = "json; charset=utf-8" T_contra = TypeVar("T_contra", contravariant=True) @@ -68,11 +65,7 @@ def decode(self, data: bytes | bytearray, message: V) -> V: _proto_binary_codec = ProtoBinaryCodec() _proto_json_codec = ProtoJSONCodec() -_default_codecs = [ - _proto_binary_codec, - _proto_json_codec, - ProtoJSONCodec(name=CODEC_NAME_JSON_CHARSET_UTF8), -] +_default_codecs = [_proto_binary_codec, _proto_json_codec] def get_default_codecs() -> list[Codec]: diff --git a/src/connectrpc/_protocol_connect.py b/src/connectrpc/_protocol_connect.py index d61fd6f..4997daf 100644 --- a/src/connectrpc/_protocol_connect.py +++ b/src/connectrpc/_protocol_connect.py @@ -5,7 +5,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Any, TypeVar -from ._codec import CODEC_NAME_JSON, CODEC_NAME_JSON_CHARSET_UTF8, Codec +from ._codec import CODEC_NAME_JSON, Codec from ._compression import IdentityCompression, negotiate_compression from ._envelope import EnvelopeReader, EnvelopeWriter from ._protocol import ( @@ -36,6 +36,9 @@ CONNECT_PROTOCOL_VERSION = "1" CONNECT_HEADER_TIMEOUT = "connect-timeout-ms" CONNECT_UNARY_CONTENT_TYPE_PREFIX = "application/" +CONNECT_UNARY_CONTENT_TYPE_JSON = ( + f"{CONNECT_UNARY_CONTENT_TYPE_PREFIX}{CODEC_NAME_JSON}" +) CONNECT_STREAMING_CONTENT_TYPE_PREFIX = "application/connect+" CONNECT_UNARY_HEADER_COMPRESSION = "content-encoding" CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION = "accept-encoding" @@ -46,7 +49,16 @@ _DEFAULT_CONNECT_USER_AGENT = f"connectrpc/{__version__}" +def _normalize_content_type(content_type: str) -> str: + # content-type can have parameters, most commonly charset. Our supported codecs, + # binary and JSON are always either non-text or utf-8 and the parameters are not + # important for matching to a codec. A custom codec could conceivably need to + # match on parameters, but we will reconsider that if it is ever asked for. + return content_type.partition(";")[0].strip().lower() + + def codec_name_from_content_type(content_type: str, *, stream: bool) -> str: + content_type = _normalize_content_type(content_type) prefix = ( CONNECT_STREAMING_CONTENT_TYPE_PREFIX if stream @@ -226,12 +238,10 @@ def create_request_context( def validate_response( self, request_codec_name: str, status_code: int, response_content_type: str ) -> None: + response_content_type = _normalize_content_type(response_content_type) if status_code != HTTPStatus.OK: # Error responses must be JSON-encoded - if response_content_type in ( - f"{CONNECT_UNARY_CONTENT_TYPE_PREFIX}{CODEC_NAME_JSON}", - f"{CONNECT_UNARY_CONTENT_TYPE_PREFIX}{CODEC_NAME_JSON_CHARSET_UTF8}", - ): + if response_content_type == CONNECT_UNARY_CONTENT_TYPE_JSON: return raise ConnectWireError.from_http_status(status_code).to_exception() @@ -247,16 +257,6 @@ def validate_response( if response_codec_name == request_codec_name: return - if ( - response_codec_name == CODEC_NAME_JSON - and request_codec_name == CODEC_NAME_JSON_CHARSET_UTF8 - ) or ( - response_codec_name == CODEC_NAME_JSON_CHARSET_UTF8 - and request_codec_name == CODEC_NAME_JSON - ): - # Both are JSON - return - raise ConnectError( Code.INTERNAL, f"invalid content-type: '{response_content_type}'; expecting '{CONNECT_UNARY_CONTENT_TYPE_PREFIX}{request_codec_name}'", diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 46162c9..b2b87c9 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -213,7 +213,7 @@ async def __call__( codec_name = protocol.codec_name_from_content_type( headers.get("content-type", ""), stream=not is_unary ) - codec = self._codecs.get(codec_name.lower()) + codec = self._codecs.get(codec_name) if not codec: raise HTTPException( HTTPStatus.UNSUPPORTED_MEDIA_TYPE, diff --git a/src/connectrpc/_server_sync.py b/src/connectrpc/_server_sync.py index d3650e2..6655edd 100644 --- a/src/connectrpc/_server_sync.py +++ b/src/connectrpc/_server_sync.py @@ -70,13 +70,12 @@ ) -def _normalize_wsgi_headers(environ: WSGIEnvironment) -> dict: - """Extract and normalize HTTP headers from WSGI environment.""" - headers = {} +def _process_headers(environ: WSGIEnvironment) -> Headers: + headers = Headers() if "CONTENT_TYPE" in environ: - headers["content-type"] = environ["CONTENT_TYPE"].lower() + headers["content-type"] = environ["CONTENT_TYPE"] if "CONTENT_LENGTH" in environ: - headers["content-length"] = environ["CONTENT_LENGTH"].lower() + headers["content-length"] = environ["CONTENT_LENGTH"] for key, value in environ.items(): if key.startswith("HTTP_"): @@ -85,17 +84,6 @@ def _normalize_wsgi_headers(environ: WSGIEnvironment) -> dict: return headers -def _process_headers(headers: dict) -> Headers: - result = Headers() - for key, value in headers.items(): - if isinstance(value, list | tuple): - for v in value: - result.add(key, v) - else: - result.add(key, str(value)) - return result - - def prepare_response_headers( base_headers: dict[str, list[str]], selected_encoding: str ) -> dict[str, list[str]]: @@ -220,7 +208,7 @@ def __call__( http_method = environ["REQUEST_METHOD"] http_scheme = environ.get("wsgi.url_scheme", "http") - headers = _process_headers(_normalize_wsgi_headers(environ)) + headers = _process_headers(environ) if ra := environ.get("REMOTE_ADDR"): port = environ.get("REMOTE_PORT", "0") client_address = f"{ra}:{port}" diff --git a/test/test_http.py b/test/test_http.py new file mode 100644 index 0000000..e543cdc --- /dev/null +++ b/test/test_http.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from pyqwest import Client, Request, SyncClient, SyncRequest, SyncTransport, Transport +from pyqwest.testing import ASGITransport, WSGITransport + +from connectrpc.codec import proto_json_codec + +from .haberdasher_connect import ( + Haberdasher, + HaberdasherASGIApplication, + HaberdasherClient, + HaberdasherClientSync, + HaberdasherSync, + HaberdasherWSGIApplication, +) +from .haberdasher_pb2 import Hat, Size + +if TYPE_CHECKING: + from pyqwest._pyqwest import Response, SyncResponse + +_charset_content_type_cases = [ + "application/json", + "application/json; charset=utf-8", + "application/json; charset=UTF-8", + "application/json;charset=utf-8", + "application/json; charset=utf-8", + "application/json; charset=utf-8; version=1", + "application/JSON", +] + + +@pytest.mark.parametrize("header", _charset_content_type_cases) +def test_json_charset_content_type(header: str) -> None: + class HeadersHaberdasherSync(HaberdasherSync): + def make_hat(self, request, ctx): + return Hat(size=2) + + transport = WSGITransport(HaberdasherWSGIApplication(HeadersHaberdasherSync())) + + client = SyncClient(transport=transport) + + res = client.post( + "http://localhost/connectrpc.example.Haberdasher/MakeHat", + content=b"{}", + headers={"content-type": header}, + ) + assert res.status == 200 + assert res.json() == {"size": 2} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("header", _charset_content_type_cases) +async def test_json_charset_content_type_async(header: str) -> None: + class HeadersHaberdasher(Haberdasher): + async def make_hat(self, request, ctx): + return Hat(size=2) + + transport = ASGITransport(HaberdasherASGIApplication(HeadersHaberdasher())) + + client = Client(transport=transport) + + res = await client.post( + "http://localhost/connectrpc.example.Haberdasher/MakeHat", + content=b"{}", + headers={"content-type": header}, + ) + assert res.status == 200 + assert res.json() == {"size": 2} + + +_streaming_charset_content_type_cases = [ + "application/connect+" + h.split("/")[1] for h in _charset_content_type_cases +] + + +@pytest.mark.parametrize("header", _streaming_charset_content_type_cases) +def test_json_charset_content_type_stream(header: str) -> None: + class HeadersHaberdasherSync(HaberdasherSync): + def make_similar_hats(self, request, ctx): + yield Hat(size=2) + yield Hat(size=3) + + # Difficult to parse an HTTP streaming response so override the header + # with a real client's transport instead. + class HeaderTransport(SyncTransport): + def __init__(self, delegate: SyncTransport): + self._delegate = delegate + + def execute_sync(self, request: SyncRequest) -> SyncResponse: + request.headers["content-type"] = header + return self._delegate.execute_sync(request) + + transport = HeaderTransport( + WSGITransport(HaberdasherWSGIApplication(HeadersHaberdasherSync())) + ) + + client = HaberdasherClientSync( + address="http://localhost", + codec=proto_json_codec(), + http_client=SyncClient(transport=transport), + ) + + hats = list(client.make_similar_hats(Size(inches=2))) + assert hats == [Hat(size=2), Hat(size=3)] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("header", _streaming_charset_content_type_cases) +async def test_json_charset_content_type_stream_async(header: str) -> None: + class HeadersHaberdasher(Haberdasher): + async def make_similar_hats(self, request, ctx): + yield Hat(size=2) + yield Hat(size=3) + + # Difficult to parse an HTTP streaming response so override the header + # with a real client's transport instead. + class HeaderTransport(Transport): + def __init__(self, delegate: Transport): + self._delegate = delegate + + async def execute(self, request: Request) -> Response: + request.headers["content-type"] = header + return await self._delegate.execute(request) + + transport = HeaderTransport( + ASGITransport(HaberdasherASGIApplication(HeadersHaberdasher())) + ) + + client = HaberdasherClient( + address="http://localhost", + codec=proto_json_codec(), + http_client=Client(transport=transport), + ) + + hats = [] + async for hat in client.make_similar_hats(Size(inches=2)): + hats.append(hat) + assert hats == [Hat(size=2), Hat(size=3)]