diff --git a/src/replit_river/client.py b/src/replit_river/client.py index db4608ec..273f601b 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -235,11 +235,11 @@ def _trace_procedure( except RiverException as e: span.record_exception(e, escaped=True) _record_river_error(span_handle, RiverError(code=e.code, message=e.message)) - raise e + raise except BaseException as e: span.record_exception(e, escaped=True) span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") - raise e + raise finally: span.end() diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index 5bff801a..94a6789a 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -86,8 +86,10 @@ class SessionClosedRiverServiceException(RiverException): def __init__( self, message: str, + streamId: str, ) -> None: super().__init__(SYNTHETIC_ERROR_CODE_SESSION_CLOSED, message) + self.streamId = streamId def exception_from_message(code: str) -> type[RiverServiceException]: diff --git a/src/replit_river/task_manager.py b/src/replit_river/task_manager.py index 531292d0..1182bef3 100644 --- a/src/replit_river/task_manager.py +++ b/src/replit_river/task_manager.py @@ -2,7 +2,11 @@ import logging from typing import Coroutine, Set -from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException +from replit_river.error_schema import ( + ERROR_CODE_STREAM_CLOSED, + RiverException, + SessionClosedRiverServiceException, +) logger = logging.getLogger(__name__) @@ -37,6 +41,13 @@ async def cancel_task( # If we cancel the task manager we will get called here as well, # if we want to handle the cancellation differently we can do it here. logger.debug("Task was cancelled %r", task_to_remove) + except SessionClosedRiverServiceException as e: + logger.warning( + "Session was closed", + extra={ + "stream_id": e.streamId, + }, + ) except RiverException as e: if e.code == ERROR_CODE_STREAM_CLOSED: # Task is cancelled @@ -76,6 +87,14 @@ def _task_done_callback( ): # Task is cancelled pass + elif isinstance(exception, SessionClosedRiverServiceException): + # Session is closed, don't bother logging + logger.info( + "Session closed", + extra={ + "stream_id": exception.streamId, + }, + ) else: logger.error( "Exception on cancelling task", diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index 1b900d38..f0b8db77 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -190,11 +190,11 @@ def _trace_procedure( except RiverException as e: span.record_exception(e, escaped=True) _record_river_error(span_handle, RiverError(code=e.code, message=e.message)) - raise e + raise except BaseException as e: span.record_exception(e, escaped=True) span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") - raise e + raise finally: span.end() diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 9600f4cc..aeae6cfd 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -353,6 +353,7 @@ async def _enqueue_message( # session is closing / closed, raise raise SessionClosedRiverServiceException( "river session is closed, dropping message", + stream_id, ) # Begin critical section: Avoid any await between here and _send_buffer.append @@ -448,7 +449,7 @@ async def do_close() -> None: await self._task_manager.cancel_all_tasks() - for stream_meta in self._streams.values(): + for stream_id, stream_meta in self._streams.items(): stream_meta["output"].close() # Wake up backpressured writers try: @@ -456,6 +457,7 @@ async def do_close() -> None: reason or SessionClosedRiverServiceException( "river session is closed", + stream_id, ) ) except ChannelFull: @@ -751,6 +753,13 @@ async def send_rpc[R, A]( # Block for backpressure and emission errors from the ws await backpressured_waiter() result = await anext(output) + except asyncio.CancelledError: + await self._send_cancel_stream( + stream_id=stream_id, + message="RPC cancelled", + span=span, + ) + raise except asyncio.TimeoutError as e: await self._send_cancel_stream( stream_id=stream_id, @@ -835,6 +844,13 @@ async def send_upload[I, R, A]( payload=payload, span=span, ) + except asyncio.CancelledError: + await self._send_cancel_stream( + stream_id=stream_id, + message="Upload cancelled", + span=span, + ) + raise except Exception as e: # If we get any exception other than WebsocketClosedException, # cancel the stream. @@ -916,6 +932,13 @@ async def send_subscription[I, E, A]( continue yield response_deserializer(item["payload"]) await self._send_close_stream(stream_id, span) + except asyncio.CancelledError: + await self._send_cancel_stream( + stream_id=stream_id, + message="Subscription cancelled", + span=span, + ) + raise except Exception as e: await self._send_cancel_stream( stream_id=stream_id, @@ -1002,6 +1025,15 @@ async def _encode_stream() -> None: # ... block the outer function until the emitter is finished emitting, # possibly raising a terminal exception. await emitter_task + except asyncio.CancelledError as e: + await self._send_cancel_stream( + stream_id=stream_id, + message="Stream cancelled", + span=span, + ) + if emitter_task.done() and (err := emitter_task.exception()): + raise e from err + raise except Exception as e: await self._send_cancel_stream( stream_id=stream_id, @@ -1288,6 +1320,7 @@ async def _recv_from_ws( # the outer loop. await transition_no_connection() break + msg: TransportMessage | str | None = None try: msg = parse_transport_msg(message) logger.debug( @@ -1367,9 +1400,13 @@ async def _recv_from_ws( stream_meta["output"].close() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") + stream_id = "unknown" + if isinstance(msg, TransportMessage): + stream_id = msg.streamId close_session( SessionClosedRiverServiceException( - "Out of order message, closing connection" + "Out of order message, closing connection", + stream_id, ) ) continue @@ -1377,9 +1414,13 @@ async def _recv_from_ws( logger.exception( "Got invalid transport message, closing session", ) + stream_id = "unknown" + if isinstance(msg, TransportMessage): + stream_id = msg.streamId close_session( SessionClosedRiverServiceException( - "Out of order message, closing connection" + "Out of order message, closing connection", + stream_id, ) ) continue diff --git a/tests/conftest.py b/tests/conftest.py index 3866fdd1..db928bc5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,8 @@ pytest_plugins = [ "tests.v1.river_fixtures.logging", "tests.v1.river_fixtures.clientserver", - "tests.v2.fixtures", + "tests.v2.fixtures.bound_client", + "tests.v2.fixtures.raw_ws_server", ] HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] diff --git a/tests/v2/fixtures.py b/tests/v2/fixtures/bound_client.py similarity index 100% rename from tests/v2/fixtures.py rename to tests/v2/fixtures/bound_client.py diff --git a/tests/v2/fixtures/raw_ws_server.py b/tests/v2/fixtures/raw_ws_server.py new file mode 100644 index 00000000..2353f941 --- /dev/null +++ b/tests/v2/fixtures/raw_ws_server.py @@ -0,0 +1,92 @@ +import asyncio +from typing import ( + AsyncIterator, + Awaitable, + Callable, + Literal, + TypeAlias, + TypedDict, +) + +import pytest +from websockets import ConnectionClosed, ConnectionClosedOK, Data +from websockets.asyncio.server import ServerConnection, serve + +from replit_river.transport_options import UriAndMetadata + +WsServerFixture: TypeAlias = tuple[ + Callable[[], Awaitable[UriAndMetadata[None]]], + asyncio.Queue[bytes], + Callable[[], ServerConnection | None], +] + + +class OuterPayload[A](TypedDict): + ok: Literal[True] + payload: A + + +class _WsServerState(TypedDict): + ipv4_laddr: tuple[str, int] | None + + +async def _ws_server_internal( + recv: asyncio.Queue[bytes], + set_conn: Callable[[ServerConnection], None], + state: _WsServerState, +) -> AsyncIterator[None]: + async def handle(websocket: ServerConnection) -> None: + set_conn(websocket) + datagram: Data + try: + while datagram := await websocket.recv(decode=False): + if isinstance(datagram, str): + continue + await recv.put(datagram) + except ConnectionClosedOK: + pass + except ConnectionClosed: + pass + + port: int | None = None + if state["ipv4_laddr"]: + port = state["ipv4_laddr"][1] + async with serve(handle, "localhost", port=port) as server: + for sock in server.sockets: + if (pair := sock.getsockname())[0] == "127.0.0.1": + if state["ipv4_laddr"] is None: + state["ipv4_laddr"] = pair + serve_forever = asyncio.create_task(server.serve_forever()) + yield None + server.close() + await server.wait_closed() + # "serve_forever" should always be done after wait_closed finishes + assert serve_forever.done() + + +@pytest.fixture +async def ws_server() -> AsyncIterator[WsServerFixture]: + recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1) + connection: ServerConnection | None = None + state: _WsServerState = {"ipv4_laddr": None} + + def set_conn(new_conn: ServerConnection) -> None: + nonlocal connection + connection = new_conn + + server_generator = _ws_server_internal(recv, set_conn, state) + await anext(server_generator) + + async def urimeta() -> UriAndMetadata[None]: + ipv4_laddr = state["ipv4_laddr"] + assert ipv4_laddr + return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None) + + yield (urimeta, recv, lambda: connection) + + connection = None + + try: + await anext(server_generator) + except StopAsyncIteration: + pass diff --git a/tests/v2/test_v2_cancellation.py b/tests/v2/test_v2_cancellation.py new file mode 100644 index 00000000..b259d27e --- /dev/null +++ b/tests/v2/test_v2_cancellation.py @@ -0,0 +1,482 @@ +import asyncio +import logging +from datetime import timedelta +from typing import ( + Any, + AsyncIterator, + Literal, +) + +import msgpack +import nanoid +import pytest + +from replit_river.messages import parse_transport_msg +from replit_river.rpc import ( + STREAM_OPEN_BIT, + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + HandShakeStatus, + TransportMessage, +) +from replit_river.transport_options import TransportOptions +from replit_river.v2.client import Client +from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT +from tests.v2.fixtures.raw_ws_server import OuterPayload, WsServerFixture + + +async def test_rpc_cancel(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + sent_waiter = asyncio.Event() + + async def handle_server_messages() -> None: + request_msg = parse_transport_msg(await recv.get()) + assert not isinstance(request_msg, str) + + logging.debug("request_msg: %r", repr(request_msg)) + + assert request_msg.payload.get("payload", {}).get("hello") == "world" + logging.debug("Found a hello:world %r", repr(request_msg)) + + sent_waiter.set() + + assert request_msg.controlFlags == STREAM_OPEN_BIT | STREAM_CLOSED_BIT + + cancel_msg = parse_transport_msg(await recv.get()) + assert not isinstance(cancel_msg, str) + assert cancel_msg.controlFlags == STREAM_CANCEL_BIT + + server_handler = asyncio.create_task(handle_server_messages()) + + rpc_task = asyncio.create_task( + client.send_rpc( + "test", + "cancel_rpc", + {"ok": True, "payload": {"hello": "world"}}, + lambda x: x, + lambda x: x, + lambda x: x, + timedelta(seconds=2), + ) + ) + + # Wait until we've seen at least a few messages from the upload Task + await sent_waiter.wait() + + rpc_task.cancel() + + try: + await rpc_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + server_handler.cancel() + await server_handler + + +@pytest.mark.parametrize("direction", ["send", "receive"]) +async def test_stream_cancel( + ws_server: WsServerFixture, direction: Literal["send", "receive"] +) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + bidi_waiter = asyncio.Event() + + async def send_server_messages(request_msg: TransportMessage) -> None: + seq = 0 + + while True: + msg = TransportMessage( + from_=request_msg.to, + to=request_msg.from_, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=seq, + ack=0, + payload={ + "ok": True, + "payload": { + "hello": "world", + }, + }, + ) + seq += 1 + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + await asyncio.sleep(0.1) + + if seq > 5 and direction == "send": + bidi_waiter.set() + + async def handle_server_messages(request_msg: TransportMessage) -> None: + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + while msg.payload.get("payload", {}).get("hello") == "world": + logging.debug("Found a hello:world %r", repr(msg)) + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + + assert msg.controlFlags == STREAM_CANCEL_BIT + + async def receive_chunks() -> None: + async def _upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: + count = 0 + while True: + await asyncio.sleep(0.1) + yield { + "ok": True, + "payload": { + "hello": "world", + }, + } + count += 1 + if count > 5 and direction == "receive": + # We've sent enough messages, interrupt the stream. + bidi_waiter.set() + + async for chunk in client.send_stream( + "test", + "cancel_stream", + {}, + _upload_chunks(), + lambda x: x, + lambda x: x, + lambda x: x, + lambda x: x, + ): + print(repr(chunk)) + + receive_task = asyncio.create_task(receive_chunks()) + request_msg = parse_transport_msg(await recv.get()) + logging.debug("request_msg: %r", repr(request_msg)) + assert not isinstance(request_msg, str) + + server_sender = asyncio.create_task(send_server_messages(request_msg)) + server_receiver = asyncio.create_task(handle_server_messages(request_msg)) + + # Wait until we've seen at least a few messages from the requisite Task + await bidi_waiter.wait() + + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + assert server_sender + server_sender.cancel() + try: + await server_sender + except asyncio.CancelledError: + pass + server_receiver.cancel() + try: + await server_receiver + except Exception: + pass + + +async def test_subscription_cancel(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + received_waiter = asyncio.Event() + + async def handle_server_messages() -> None: + request_msg = parse_transport_msg(await recv.get()) + assert not isinstance(request_msg, str) + + logging.debug("request_msg: %r", repr(request_msg)) + seq = 0 + + while True: + try: + cancel_msg = parse_transport_msg(recv.get_nowait()) + break + except asyncio.queues.QueueEmpty: + pass + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=seq, + ack=0, + payload={ + "ok": True, + "payload": { + "hello": "world", + }, + }, + ) + seq += 1 + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + await asyncio.sleep(0.1) + + if seq > 5: + received_waiter.set() + + assert not isinstance(cancel_msg, str) + assert cancel_msg.controlFlags == STREAM_CANCEL_BIT + + server_handler = asyncio.create_task(handle_server_messages()) + + async def receive_chunks() -> None: + async for chunk in client.send_subscription( + "test", + "subscription_cancel", + {}, + lambda x: x, + lambda x: x, + lambda x: x, + ): + print(repr(chunk)) + + receive_task = asyncio.create_task(receive_chunks()) + + # Wait until we've seen at least a few messages from the upload Task + await received_waiter.wait() + + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + server_handler.cancel() + await server_handler + + +async def test_upload_cancel(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + async def handle_server_messages() -> None: + request_msg = parse_transport_msg(await recv.get()) + assert not isinstance(request_msg, str) + + logging.debug("request_msg: %r", repr(request_msg)) + + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + while msg.payload.get("payload", {}).get("hello") == "world": + logging.debug("Found a hello:world %r", repr(msg)) + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + + assert msg.controlFlags == STREAM_CANCEL_BIT + + server_handler = asyncio.create_task(handle_server_messages()) + + sent_waiter = asyncio.Event() + + async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: + count = 0 + while True: + await asyncio.sleep(0.1) + yield { + "ok": True, + "payload": { + "hello": "world", + }, + } + count += 1 + if count > 5: + # We've sent enough messages, interrupt the stream. + sent_waiter.set() + + upload_task = asyncio.create_task( + client.send_upload( + "test", + "upload_cancel", + {}, + upload_chunks(), + lambda x: x, + lambda x: x, + lambda x: x, + lambda x: x, + ) + ) + + # Wait until we've seen at least a few messages from the upload Task + await sent_waiter.wait() + + upload_task.cancel() + try: + await upload_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + server_handler.cancel() + await server_handler diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index bea6d2e0..736a35f8 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -1,13 +1,8 @@ import asyncio import logging -from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict import msgpack import nanoid -import pytest -from websockets import ConnectionClosed, ConnectionClosedOK -from websockets.asyncio.server import ServerConnection, serve -from websockets.typing import Data from replit_river.common_session import SessionState from replit_river.messages import parse_transport_msg @@ -18,9 +13,10 @@ HandShakeStatus, TransportMessage, ) -from replit_river.transport_options import TransportOptions, UriAndMetadata +from replit_river.transport_options import TransportOptions from replit_river.v2.client import Client from replit_river.v2.session import STREAM_CLOSED_BIT, Session +from tests.v2.fixtures.raw_ws_server import WsServerFixture class _PermissiveRateLimiter(RateLimiter): @@ -37,77 +33,6 @@ def consume_budget(self, user: str) -> None: pass -WsServerFixture: TypeAlias = tuple[ - Callable[[], Awaitable[UriAndMetadata[None]]], - asyncio.Queue[bytes], - Callable[[], ServerConnection | None], -] - - -class _WsServerState(TypedDict): - ipv4_laddr: tuple[str, int] | None - - -async def _ws_server_internal( - recv: asyncio.Queue[bytes], - set_conn: Callable[[ServerConnection], None], - state: _WsServerState, -) -> AsyncIterator[None]: - async def handle(websocket: ServerConnection) -> None: - set_conn(websocket) - datagram: Data - try: - while datagram := await websocket.recv(decode=False): - if isinstance(datagram, str): - continue - await recv.put(datagram) - except ConnectionClosedOK: - pass - except ConnectionClosed: - pass - - port: int | None = None - if state["ipv4_laddr"]: - port = state["ipv4_laddr"][1] - async with serve(handle, "localhost", port=port) as server: - for sock in server.sockets: - if (pair := sock.getsockname())[0] == "127.0.0.1": - if state["ipv4_laddr"] is None: - state["ipv4_laddr"] = pair - serve_forever = asyncio.create_task(server.serve_forever()) - yield None - server.close() - await server.wait_closed() - # "serve_forever" should always be done after wait_closed finishes - assert serve_forever.done() - - -@pytest.fixture -async def ws_server() -> AsyncIterator[WsServerFixture]: - recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1) - connection: ServerConnection | None = None - state: _WsServerState = {"ipv4_laddr": None} - - def set_conn(new_conn: ServerConnection) -> None: - nonlocal connection - connection = new_conn - - server_generator = _ws_server_internal(recv, set_conn, state) - await anext(server_generator) - - async def urimeta() -> UriAndMetadata[None]: - ipv4_laddr = state["ipv4_laddr"] - assert ipv4_laddr - return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None) - - yield (urimeta, recv, lambda: connection) - - try: - await anext(server_generator) - except StopAsyncIteration: - pass - - async def test_connect(ws_server: WsServerFixture) -> None: (urimeta, recv, conn) = ws_server @@ -229,7 +154,7 @@ async def handle_server_messages() -> None: stream_close_msg = msgpack.unpackb(await recv.get()) assert stream_close_msg["controlFlags"] == STREAM_CLOSED_BIT - stream_handler = asyncio.create_task(handle_server_messages()) + server_handler = asyncio.create_task(handle_server_messages()) try: async for datagram in client.send_subscription( @@ -243,5 +168,5 @@ async def handle_server_messages() -> None: await connecting # Ensure we're listening to close messages as well - stream_handler.cancel() - await stream_handler + server_handler.cancel() + await server_handler