Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/connectrpc/_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@

CODEC_NAME_PROTO = "proto"
CODEC_NAME_JSON = "json"
# Follow connect-go's hacky approach to handling charset parameter
# https://github.com/connectrpc/connect-go/blob/fe4915717d32438c40a24a50e3895271d4c24751/codec.go#L31
CODEC_NAME_JSON_CHARSET_UTF8 = "json; charset=utf-8"


T_contra = TypeVar("T_contra", contravariant=True)
Expand Down Expand Up @@ -68,11 +65,7 @@ def decode(self, data: bytes | bytearray, message: V) -> V:

_proto_binary_codec = ProtoBinaryCodec()
_proto_json_codec = ProtoJSONCodec()
_default_codecs = [
_proto_binary_codec,
_proto_json_codec,
ProtoJSONCodec(name=CODEC_NAME_JSON_CHARSET_UTF8),
]
_default_codecs = [_proto_binary_codec, _proto_json_codec]


def get_default_codecs() -> list[Codec]:
Expand Down
30 changes: 15 additions & 15 deletions src/connectrpc/_protocol_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, TypeVar

from ._codec import CODEC_NAME_JSON, CODEC_NAME_JSON_CHARSET_UTF8, Codec
from ._codec import CODEC_NAME_JSON, Codec
from ._compression import IdentityCompression, negotiate_compression
from ._envelope import EnvelopeReader, EnvelopeWriter
from ._protocol import (
Expand Down Expand Up @@ -36,6 +36,9 @@
CONNECT_PROTOCOL_VERSION = "1"
CONNECT_HEADER_TIMEOUT = "connect-timeout-ms"
CONNECT_UNARY_CONTENT_TYPE_PREFIX = "application/"
CONNECT_UNARY_CONTENT_TYPE_JSON = (
f"{CONNECT_UNARY_CONTENT_TYPE_PREFIX}{CODEC_NAME_JSON}"
)
CONNECT_STREAMING_CONTENT_TYPE_PREFIX = "application/connect+"
CONNECT_UNARY_HEADER_COMPRESSION = "content-encoding"
CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION = "accept-encoding"
Expand All @@ -46,7 +49,16 @@
_DEFAULT_CONNECT_USER_AGENT = f"connectrpc/{__version__}"


def _normalize_content_type(content_type: str) -> str:
# content-type can have parameters, most commonly charset. Our supported codecs,
# binary and JSON are always either non-text or utf-8 and the parameters are not
# important for matching to a codec. A custom codec could conceivably need to
# match on parameters, but we will reconsider that if it is ever asked for.
return content_type.partition(";")[0].strip().lower()


def codec_name_from_content_type(content_type: str, *, stream: bool) -> str:
content_type = _normalize_content_type(content_type)
prefix = (
CONNECT_STREAMING_CONTENT_TYPE_PREFIX
if stream
Expand Down Expand Up @@ -226,12 +238,10 @@ def create_request_context(
def validate_response(
self, request_codec_name: str, status_code: int, response_content_type: str
) -> None:
response_content_type = _normalize_content_type(response_content_type)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit that we're double-normalizing the content type between here and codec_name_from_content_type below, but no real suggestions on a refactor — probably fine to leave.

if status_code != HTTPStatus.OK:
# Error responses must be JSON-encoded
if response_content_type in (
f"{CONNECT_UNARY_CONTENT_TYPE_PREFIX}{CODEC_NAME_JSON}",
f"{CONNECT_UNARY_CONTENT_TYPE_PREFIX}{CODEC_NAME_JSON_CHARSET_UTF8}",
):
if response_content_type == CONNECT_UNARY_CONTENT_TYPE_JSON:
return
raise ConnectWireError.from_http_status(status_code).to_exception()

Expand All @@ -247,16 +257,6 @@ def validate_response(
if response_codec_name == request_codec_name:
return

if (
response_codec_name == CODEC_NAME_JSON
and request_codec_name == CODEC_NAME_JSON_CHARSET_UTF8
) or (
response_codec_name == CODEC_NAME_JSON_CHARSET_UTF8
and request_codec_name == CODEC_NAME_JSON
):
# Both are JSON
return

raise ConnectError(
Code.INTERNAL,
f"invalid content-type: '{response_content_type}'; expecting '{CONNECT_UNARY_CONTENT_TYPE_PREFIX}{request_codec_name}'",
Expand Down
2 changes: 1 addition & 1 deletion src/connectrpc/_server_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def __call__(
codec_name = protocol.codec_name_from_content_type(
headers.get("content-type", ""), stream=not is_unary
)
codec = self._codecs.get(codec_name.lower())
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed a couple of spots of random lowercasing and consolidated them, with some cleanup to WSGI

codec = self._codecs.get(codec_name)
if not codec:
raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
Expand Down
22 changes: 5 additions & 17 deletions src/connectrpc/_server_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,12 @@
)


def _normalize_wsgi_headers(environ: WSGIEnvironment) -> dict:
"""Extract and normalize HTTP headers from WSGI environment."""
headers = {}
def _process_headers(environ: WSGIEnvironment) -> Headers:
headers = Headers()
if "CONTENT_TYPE" in environ:
headers["content-type"] = environ["CONTENT_TYPE"].lower()
headers["content-type"] = environ["CONTENT_TYPE"]
if "CONTENT_LENGTH" in environ:
headers["content-length"] = environ["CONTENT_LENGTH"].lower()
headers["content-length"] = environ["CONTENT_LENGTH"]

for key, value in environ.items():
if key.startswith("HTTP_"):
Expand All @@ -85,17 +84,6 @@ def _normalize_wsgi_headers(environ: WSGIEnvironment) -> dict:
return headers


def _process_headers(headers: dict) -> Headers:
result = Headers()
for key, value in headers.items():
if isinstance(value, list | tuple):
for v in value:
result.add(key, v)
else:
result.add(key, str(value))
return result


def prepare_response_headers(
base_headers: dict[str, list[str]], selected_encoding: str
) -> dict[str, list[str]]:
Expand Down Expand Up @@ -220,7 +208,7 @@ def __call__(

http_method = environ["REQUEST_METHOD"]
http_scheme = environ.get("wsgi.url_scheme", "http")
headers = _process_headers(_normalize_wsgi_headers(environ))
headers = _process_headers(environ)
if ra := environ.get("REMOTE_ADDR"):
port = environ.get("REMOTE_PORT", "0")
client_address = f"{ra}:{port}"
Expand Down
62 changes: 62 additions & 0 deletions test/test_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import pytest
from pyqwest import Client, SyncClient
from pyqwest.testing import ASGITransport, WSGITransport

from .haberdasher_connect import (
Haberdasher,
HaberdasherASGIApplication,
HaberdasherSync,
HaberdasherWSGIApplication,
)
from .haberdasher_pb2 import Hat

_charset_content_type_cases = [
"application/json",
"application/json; charset=utf-8",
"application/json; charset=UTF-8",
"application/json;charset=utf-8",
"application/json; charset=utf-8",
"application/json; charset=utf-8; version=1",
"application/JSON",
]


@pytest.mark.parametrize("header", _charset_content_type_cases)
def test_json_charset_content_type(header: str) -> None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it also be good to flex streaming in these tests? I see we have validate_stream_response; presumably a well-behaved connect server wouldn't have an issue, though...

class HeadersHaberdasherSync(HaberdasherSync):
def make_hat(self, request, ctx):
return Hat(size=2)

transport = WSGITransport(HaberdasherWSGIApplication(HeadersHaberdasherSync()))

client = SyncClient(transport=transport)

res = client.post(
"http://localhost/connectrpc.example.Haberdasher/MakeHat",
content=b"{}",
headers={"content-type": header},
)
assert res.status == 200
assert res.json() == {"size": 2}


@pytest.mark.asyncio
@pytest.mark.parametrize("header", _charset_content_type_cases)
async def test_json_charset_content_type_async(header: str) -> None:
class HeadersHaberdasher(Haberdasher):
async def make_hat(self, request, ctx):
return Hat(size=2)

transport = ASGITransport(HaberdasherASGIApplication(HeadersHaberdasher()))

client = Client(transport=transport)

res = await client.post(
"http://localhost/connectrpc.example.Haberdasher/MakeHat",
content=b"{}",
headers={"content-type": header},
)
assert res.status == 200
assert res.json() == {"size": 2}
Loading