Skip to content

Commit 43d4a87

Browse files
authored
fix header casing for Twisted WebSockets (#26)
1 parent 886a8f2 commit 43d4a87

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "rolo"
77
authors = [
88
{ name = "LocalStack Contributors", email = "info@localstack.cloud" }
99
]
10-
version = "0.7.4"
10+
version = "0.7.5"
1111
description = "A Python framework for building HTTP-based server applications"
1212
dependencies = [
1313
"requests>=2.20",

rolo/testing/pytest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from rolo.gateway.asgi import AsgiGateway
1818
from rolo.gateway.wsgi import WsgiGateway
1919
from rolo.routing import handler_dispatcher
20-
from rolo.serving.twisted import TwistedGateway
20+
from rolo.serving.twisted import HeaderPreservingHTTPChannel, TwistedGateway
2121
from rolo.websocket.adapter import WebSocketListener
2222

2323
if typing.TYPE_CHECKING:
@@ -237,6 +237,12 @@ def _create(gateway):
237237

238238
@pytest.fixture
239239
def serve_twisted_websocket_listener(twisted_reactor, serve_twisted_tcp_server):
240+
"""
241+
This fixture creates a Twisted Site, without the need to serve a fully-fledged rolo Gateway.
242+
This is inspired by `rolo.serving.twisted.TwistedGateway`, but directly uses `WebsocketResourceDecorator` to
243+
pass the `WebSocketListener` instead of `gateway.accept` to the `websocketListener` parameter.
244+
It allows us to test the low-level behavior of WebSockets without being dependent on the Gateway implementation.
245+
"""
240246
from twisted.web.server import Site
241247

242248
from rolo.serving.twisted import HeaderPreservingWSGIResource, WebsocketResourceDecorator
@@ -250,6 +256,7 @@ def _create(websocket_listener: WebSocketListener):
250256
websocketListener=websocket_listener,
251257
)
252258
)
259+
site.protocol = HeaderPreservingHTTPChannel.protocol_factory
253260
return serve_twisted_tcp_server(site)
254261

255262
return _create

rolo/websocket/request.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from werkzeug import Response
55
from werkzeug._internal import _wsgi_decoding_dance
6-
from werkzeug.datastructures import EnvironHeaders, Headers
6+
from werkzeug.datastructures import EnvironHeaders, Headers, MultiDict
77
from werkzeug.sansio.request import Request as _SansIORequest
88
from werkzeug.wsgi import _get_server
99

@@ -115,6 +115,17 @@ def __init__(self, environ: WebSocketEnvironment):
115115
116116
:param environ: the WebSocketEnvironment
117117
"""
118+
raw_headers = environ.get("rolo.headers")
119+
if raw_headers:
120+
# restores raw headers from the server scope, to have proper casing or dashes. This can depend on server
121+
# behavior, but we want a unified way to keep header casing/formatting.
122+
# This is similar to what we do in wsgi.py
123+
headers = Headers(
124+
MultiDict([(k.decode("latin-1"), v.decode("latin-1")) for (k, v) in raw_headers])
125+
)
126+
else:
127+
headers = Headers(EnvironHeaders(environ))
128+
118129
# copied from werkzeug.wrappers.request
119130
super().__init__(
120131
method=environ.get("REQUEST_METHOD", "WEBSOCKET"),
@@ -123,7 +134,7 @@ def __init__(self, environ: WebSocketEnvironment):
123134
root_path=_wsgi_decoding_dance(environ.get("SCRIPT_NAME") or ""),
124135
path=_wsgi_decoding_dance(environ.get("PATH_INFO") or ""),
125136
query_string=environ.get("QUERY_STRING", "").encode("latin1"),
126-
headers=Headers(EnvironHeaders(environ)),
137+
headers=headers,
127138
remote_addr=environ.get("REMOTE_ADDR"),
128139
)
129140
self.environ = environ

tests/websocket/test_websockets.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def echo_headers(request: WebSocketRequest):
9292

9393
client = websocket.WebSocket()
9494
client.connect(
95-
server.url.replace("http://", "ws://"), header=["Authorization: Basic let-me-in"]
95+
server.url.replace("http://", "ws://"),
96+
header=["Authorization: Basic let-me-in", "CasedHeader: hello"],
9697
)
9798

9899
assert client.handshake_response.status == 101
@@ -101,6 +102,7 @@ def echo_headers(request: WebSocketRequest):
101102
headers = json.loads(doc)
102103
assert headers["Connection"] == "Upgrade"
103104
assert headers["Authorization"] == "Basic let-me-in"
105+
assert headers["CasedHeader"] == "hello"
104106

105107

106108
def test_websocket_reject(serve_websocket_listener):
@@ -174,11 +176,15 @@ def _handler(request: WebSocketRequest, request_args: dict):
174176
with request.accept() as ws:
175177
ws.send("foo")
176178
ws.send(f"id={request_args['id']}")
179+
ws.send(json.dumps(dict(request.headers)))
177180

178181
router.add("/foo/<id>", _handler)
179182

180183
server = serve_websocket_listener(WebSocketRequest.listener(router.dispatch))
181184
client = websocket.WebSocket()
182-
client.connect(server.url.replace("http://", "ws://") + "/foo/bar")
185+
client.connect(
186+
server.url.replace("http://", "ws://") + "/foo/bar", header=["CasedHeader: hello"]
187+
)
183188
assert client.recv() == "foo"
184189
assert client.recv() == "id=bar"
190+
assert "CasedHeader" in json.loads(client.recv())

0 commit comments

Comments
 (0)