Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 56 additions & 53 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ class connect:
compression: The "permessage-deflate" extension is enabled by default.
Set ``compression`` to :obj:`None` to disable it. See the
:doc:`compression guide <../../topics/compression>` for details.
additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
to the handshake request.
additional_headers: Arbitrary HTTP headers to add to the handshake
request.
user_agent_header: Value of the ``User-Agent`` request header.
It defaults to ``"Python/x.y.z websockets/X.Y"``.
Setting it to :obj:`None` removes the header.
Expand Down Expand Up @@ -328,6 +328,9 @@ def __init__(
**kwargs: Any,
) -> None:
self.uri = uri
self.ws_uri = parse_uri(uri)
if not self.ws_uri.secure and kwargs.get("ssl") is not None:
raise ValueError("ssl argument is incompatible with a ws:// URI")

if subprotocols is not None:
validate_subprotocols(subprotocols)
Expand All @@ -343,7 +346,7 @@ def __init__(
if create_connection is None:
create_connection = ClientConnection

def protocol_factory(uri: WebSocketURI) -> ClientConnection:
def factory(uri: WebSocketURI) -> ClientConnection:
# This is a protocol in the Sans-I/O implementation of websockets.
protocol = ClientProtocol(
uri,
Expand All @@ -365,40 +368,35 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
return connection

self.proxy = proxy
self.protocol_factory = protocol_factory
self.factory = factory
self.additional_headers = additional_headers
self.user_agent_header = user_agent_header
self.process_exception = process_exception
self.open_timeout = open_timeout
self.logger = logger
self.connection_kwargs = kwargs
self.create_connection_kwargs = kwargs

async def create_connection(self) -> ClientConnection:
"""Create TCP or Unix connection."""
async def open_tcp_connection(self) -> ClientConnection:
"""Create TCP or Unix connection to the server, possibly through a proxy."""
loop = asyncio.get_running_loop()
kwargs = self.connection_kwargs.copy()

ws_uri = parse_uri(self.uri)
kwargs = self.create_connection_kwargs.copy()

proxy = self.proxy
if kwargs.get("unix", False):
proxy = None
if kwargs.get("sock") is not None:
proxy = None
if proxy is True:
proxy = get_proxy(ws_uri)
proxy = get_proxy(self.ws_uri)

def factory() -> ClientConnection:
return self.protocol_factory(ws_uri)
return self.factory(self.ws_uri)

if ws_uri.secure:
if self.ws_uri.secure:
kwargs.setdefault("ssl", True)
kwargs.setdefault("server_hostname", ws_uri.host)
if kwargs.get("ssl") is None:
raise ValueError("ssl=None is incompatible with a wss:// URI")
else:
if kwargs.get("ssl") is not None:
raise ValueError("ssl argument is incompatible with a ws:// URI")
kwargs.setdefault("server_hostname", self.ws_uri.host)

if kwargs.pop("unix", False):
_, connection = await loop.create_unix_connection(factory, **kwargs)
Expand All @@ -408,7 +406,7 @@ def factory() -> ClientConnection:
# Connect to the server through the proxy.
sock = await connect_socks_proxy(
proxy_parsed,
ws_uri,
self.ws_uri,
local_addr=kwargs.pop("local_addr", None),
)
# Initialize WebSocket connection via the proxy.
Expand Down Expand Up @@ -442,7 +440,7 @@ def factory() -> ClientConnection:
# Connect to the server through the proxy.
transport = await connect_http_proxy(
proxy_parsed,
ws_uri,
self.ws_uri,
user_agent_header=self.user_agent_header,
**proxy_kwargs,
)
Expand All @@ -459,18 +457,18 @@ def factory() -> ClientConnection:
assert new_transport is not None # help mypy
transport = new_transport
connection.connection_made(transport)
else:
raise AssertionError("unsupported proxy")
else: # pragma: no cover
raise NotImplementedError(f"unsupported proxy: {proxy}")
else:
# Connect to the server directly.
if kwargs.get("sock") is None:
kwargs.setdefault("host", ws_uri.host)
kwargs.setdefault("port", ws_uri.port)
kwargs.setdefault("host", self.ws_uri.host)
kwargs.setdefault("port", self.ws_uri.port)
# Initialize WebSocket connection.
_, connection = await loop.create_connection(factory, **kwargs)
return connection

def process_redirect(self, exc: Exception) -> Exception | str:
def process_redirect(self, exc: Exception) -> Exception | tuple[str, WebSocketURI]:
"""
Determine whether a connection error is a redirect that can be followed.

Expand All @@ -492,12 +490,12 @@ def process_redirect(self, exc: Exception) -> Exception | str:
):
return exc

old_ws_uri = parse_uri(self.uri)
old_ws_uri = self.ws_uri
new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
new_ws_uri = parse_uri(new_uri)

# If connect() received a socket, it is closed and cannot be reused.
if self.connection_kwargs.get("sock") is not None:
if self.create_connection_kwargs.get("sock") is not None:
return ValueError(
f"cannot follow redirect to {new_uri} with a preexisting socket"
)
Expand All @@ -513,23 +511,23 @@ def process_redirect(self, exc: Exception) -> Exception | str:
or old_ws_uri.port != new_ws_uri.port
):
# Cross-origin redirects on Unix sockets don't quite make sense.
if self.connection_kwargs.get("unix", False):
if self.create_connection_kwargs.get("unix", False):
return ValueError(
f"cannot follow cross-origin redirect to {new_uri} "
f"with a Unix socket"
)

# Cross-origin redirects when host and port are overridden are ill-defined.
if (
self.connection_kwargs.get("host") is not None
or self.connection_kwargs.get("port") is not None
self.create_connection_kwargs.get("host") is not None
or self.create_connection_kwargs.get("port") is not None
):
return ValueError(
f"cannot follow cross-origin redirect to {new_uri} "
f"with an explicit host or port"
)

return new_uri
return new_uri, new_ws_uri

# ... = await connect(...)

Expand All @@ -541,14 +539,14 @@ async def __await_impl__(self) -> ClientConnection:
try:
async with asyncio_timeout(self.open_timeout):
for _ in range(MAX_REDIRECTS):
self.connection = await self.create_connection()
connection = await self.open_tcp_connection()
try:
await self.connection.handshake(
await connection.handshake(
self.additional_headers,
self.user_agent_header,
)
except asyncio.CancelledError:
self.connection.transport.abort()
connection.transport.abort()
raise
except Exception as exc:
# Always close the connection even though keep-alive is
Expand All @@ -557,22 +555,23 @@ async def __await_impl__(self) -> ClientConnection:
# protocol. In the current design of connect(), there is
# no easy way to reuse the network connection that works
# in every case nor to reinitialize the protocol.
self.connection.transport.abort()
connection.transport.abort()

uri_or_exc = self.process_redirect(exc)
# Response is a valid redirect; follow it.
if isinstance(uri_or_exc, str):
self.uri = uri_or_exc
continue
exc_or_uri = self.process_redirect(exc)
# Response isn't a valid redirect; raise the exception.
if uri_or_exc is exc:
raise
if isinstance(exc_or_uri, Exception):
if exc_or_uri is exc:
raise
else:
raise exc_or_uri from exc
# Response is a valid redirect; follow it.
else:
raise uri_or_exc from exc
self.uri, self.ws_uri = exc_or_uri
continue

else:
self.connection.start_keepalive()
return self.connection
connection.start_keepalive()
return connection
else:
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")

Expand All @@ -587,24 +586,30 @@ async def __await_impl__(self) -> ClientConnection:
# async with connect(...) as ...: ...

async def __aenter__(self) -> ClientConnection:
return await self
if hasattr(self, "connection"):
raise RuntimeError("connect() isn't reentrant")
self.connection = await self
return self.connection

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
await self.connection.close()
try:
await self.connection.close()
finally:
del self.connection

# async for ... in connect(...):

async def __aiter__(self) -> AsyncIterator[ClientConnection]:
delays: Generator[float] | None = None
while True:
try:
async with self as protocol:
yield protocol
async with self as connection:
yield connection
except Exception as exc:
# Determine whether the exception is retryable or fatal.
# The API of process_exception is "return an exception or None";
Expand Down Expand Up @@ -633,7 +638,6 @@ async def __aiter__(self) -> AsyncIterator[ClientConnection]:
traceback.format_exception_only(exc)[0].strip(),
)
await asyncio.sleep(delay)
continue

else:
# The connection succeeded. Reset backoff.
Expand Down Expand Up @@ -777,8 +781,7 @@ def eof_received(self) -> None:

def connection_lost(self, exc: Exception | None) -> None:
self.reader.feed_eof()
if exc is not None:
self.response.set_exception(exc)
self.run_parser()


async def connect_http_proxy(
Expand All @@ -797,8 +800,8 @@ async def connect_http_proxy(
try:
# This raises exceptions if the connection to the proxy fails.
await protocol.response
except Exception:
transport.close()
except (asyncio.CancelledError, Exception):
transport.abort()
raise

return transport
37 changes: 19 additions & 18 deletions src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def handshake(
assert isinstance(response, Response) # help mypy
self.response = response

if server_header:
if server_header is not None:
self.response.headers["Server"] = server_header

response = None
Expand Down Expand Up @@ -231,12 +231,9 @@ class Server:

This class mirrors the API of :class:`asyncio.Server`.

It keeps track of WebSocket connections in order to close them properly
when shutting down.

Args:
handler: Connection handler. It receives the WebSocket connection,
which is a :class:`ServerConnection`, in argument.
which is a :class:`ServerConnection`.
process_request: Intercept the request during the opening handshake.
Return an HTTP response to force the response. Return :obj:`None` to
continue normally. When you force an HTTP 101 Continue response, the
Expand Down Expand Up @@ -310,7 +307,11 @@ def connections(self) -> set[ServerConnection]:
It can be useful in combination with :func:`~broadcast`.

"""
return {connection for connection in self.handlers if connection.state is OPEN}
return {
connection
for connection in self.handlers
if connection.protocol.state is OPEN
}

def wrap(self, server: asyncio.Server) -> None:
"""
Expand Down Expand Up @@ -351,6 +352,8 @@ async def conn_handler(self, connection: ServerConnection) -> None:

"""
try:
# Apply open_timeout to the WebSocket handshake.
# Use ssl_handshake_timeout for the TLS handshake.
async with asyncio_timeout(self.open_timeout):
try:
await connection.handshake(
Expand Down Expand Up @@ -425,7 +428,7 @@ def close(
``code`` and ``reason`` can be customized, for example to use code
1012 (service restart).

* Wait until all connection handlers terminate.
* Wait until all connection handlers have returned.

:meth:`close` is idempotent.

Expand All @@ -452,22 +455,20 @@ async def _close(
self.logger.info("server closing")

# Stop accepting new connections.
# Reject OPENING connections with HTTP 503 -- see handshake().
self.server.close()

# Wait until all accepted connections reach connection_made() and call
# register(). See https://github.com/python/cpython/issues/79033 for
# details. This workaround can be removed when dropping Python < 3.11.
await asyncio.sleep(0)

# After server.close(), handshake() closes OPENING connections with an
# HTTP 503 error.

# Close OPEN connections.
if close_connections:
# Close OPEN connections with code 1001 by default.
close_tasks = [
asyncio.create_task(connection.close(code, reason))
for connection in self.handlers
if connection.protocol.state is not CONNECTING
if connection.protocol.state is OPEN
]
# asyncio.wait doesn't accept an empty first argument.
if close_tasks:
Expand All @@ -476,7 +477,7 @@ async def _close(
# Wait until all TCP connections are closed.
await self.server.wait_closed()

# Wait until all connection handlers terminate.
# Wait until all connection handlers have returned.
# asyncio.wait doesn't accept an empty first argument.
if self.handlers:
await asyncio.wait(self.handlers.values())
Expand Down Expand Up @@ -590,18 +591,18 @@ class serve:

This coroutine returns a :class:`Server` whose API mirrors
:class:`asyncio.Server`. Treat it as an asynchronous context manager to
ensure that the server will be closed::
ensure that the server will be closed gracefully::

from websockets.asyncio.server import serve

def handler(websocket):
async def handler(websocket):
...

# set this future to exit the server
stop = asyncio.get_running_loop().create_future()
# set this event to exit the server
stop = asyncio.Event()

async with serve(handler, host, port):
await stop
await stop.wait()

Alternatively, call :meth:`~Server.serve_forever` to serve requests and
cancel it to stop the server::
Expand Down
Loading
Loading