From 32b169a2ffae5f5cb4eafd1bb6fec55e0e5fb120 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 5 May 2025 12:17:37 -0700 Subject: [PATCH 01/14] Adding explicit CancelledError handlers during async waiting loops --- src/replit_river/v2/session.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 9600f4cc..0f0c7c2f 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -751,6 +751,13 @@ async def send_rpc[R, A]( # Block for backpressure and emission errors from the ws await backpressured_waiter() result = await anext(output) + except asyncio.CancelledError: + await self._send_cancel_stream( + stream_id=stream_id, + message="RPC cancelled", + span=span, + ) + raise except asyncio.TimeoutError as e: await self._send_cancel_stream( stream_id=stream_id, @@ -835,6 +842,13 @@ async def send_upload[I, R, A]( payload=payload, span=span, ) + except asyncio.CancelledError: + await self._send_cancel_stream( + stream_id=stream_id, + message="Upload cancelled", + span=span, + ) + raise except Exception as e: # If we get any exception other than WebsocketClosedException, # cancel the stream. @@ -916,6 +930,13 @@ async def send_subscription[I, E, A]( continue yield response_deserializer(item["payload"]) await self._send_close_stream(stream_id, span) + except asyncio.CancelledError: + await self._send_cancel_stream( + stream_id=stream_id, + message="Subscription cancelled", + span=span, + ) + raise except Exception as e: await self._send_cancel_stream( stream_id=stream_id, @@ -1002,6 +1023,13 @@ async def _encode_stream() -> None: # ... block the outer function until the emitter is finished emitting, # possibly raising a terminal exception. await emitter_task + except asyncio.CancelledError: + await self._send_cancel_stream( + stream_id=stream_id, + message="Stream cancelled", + span=span, + ) + raise except Exception as e: await self._send_cancel_stream( stream_id=stream_id, From b2ad51b4db63acce616ba45a7453295ce0669992 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 19:23:46 -0700 Subject: [PATCH 02/14] re-raise should not be "raise e" --- src/replit_river/client.py | 4 ++-- src/replit_river/v2/client.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/replit_river/client.py b/src/replit_river/client.py index db4608ec..273f601b 100644 --- a/src/replit_river/client.py +++ b/src/replit_river/client.py @@ -235,11 +235,11 @@ def _trace_procedure( except RiverException as e: span.record_exception(e, escaped=True) _record_river_error(span_handle, RiverError(code=e.code, message=e.message)) - raise e + raise except BaseException as e: span.record_exception(e, escaped=True) span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") - raise e + raise finally: span.end() diff --git a/src/replit_river/v2/client.py b/src/replit_river/v2/client.py index 1b900d38..f0b8db77 100644 --- a/src/replit_river/v2/client.py +++ b/src/replit_river/v2/client.py @@ -190,11 +190,11 @@ def _trace_procedure( except RiverException as e: span.record_exception(e, escaped=True) _record_river_error(span_handle, RiverError(code=e.code, message=e.message)) - raise e + raise except BaseException as e: span.record_exception(e, escaped=True) span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}") - raise e + raise finally: span.end() From 46f65eeb8c10781c9eaaeabffac5920223da1c55 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 19:27:42 -0700 Subject: [PATCH 03/14] Clear out the "connection" local --- tests/v2/test_v2_session_lifecycle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index bea6d2e0..961fc166 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -102,6 +102,8 @@ async def urimeta() -> UriAndMetadata[None]: yield (urimeta, recv, lambda: connection) + connection = None + try: await anext(server_generator) except StopAsyncIteration: From 250c4786d901834049e52fcd0a68b15b229948c9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 19:58:46 -0700 Subject: [PATCH 04/14] Making room --- tests/conftest.py | 2 +- tests/v2/{fixtures.py => fixtures/bound_client.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/v2/{fixtures.py => fixtures/bound_client.py} (100%) diff --git a/tests/conftest.py b/tests/conftest.py index 3866fdd1..f9d5e393 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ pytest_plugins = [ "tests.v1.river_fixtures.logging", "tests.v1.river_fixtures.clientserver", - "tests.v2.fixtures", + "tests.v2.fixtures.bound_client", ] HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] diff --git a/tests/v2/fixtures.py b/tests/v2/fixtures/bound_client.py similarity index 100% rename from tests/v2/fixtures.py rename to tests/v2/fixtures/bound_client.py From 455ba7f58e0ecd89aaf18d39e17e1fa747f9abed Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 20:10:59 -0700 Subject: [PATCH 05/14] Renaming --- tests/v2/test_v2_session_lifecycle.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index 961fc166..645cbdaf 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -231,7 +231,7 @@ async def handle_server_messages() -> None: stream_close_msg = msgpack.unpackb(await recv.get()) assert stream_close_msg["controlFlags"] == STREAM_CLOSED_BIT - stream_handler = asyncio.create_task(handle_server_messages()) + server_handler = asyncio.create_task(handle_server_messages()) try: async for datagram in client.send_subscription( @@ -245,5 +245,5 @@ async def handle_server_messages() -> None: await connecting # Ensure we're listening to close messages as well - stream_handler.cancel() - await stream_handler + server_handler.cancel() + await server_handler From b7074c6137a53d6ff843c5e8022881705f3262c2 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 19:30:17 -0700 Subject: [PATCH 06/14] Adding a test for cancelling upload --- tests/v2/test_v2_session_lifecycle.py | 120 +++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 2 deletions(-) diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index 645cbdaf..62334d8c 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -1,6 +1,14 @@ import asyncio import logging -from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Literal, + TypeAlias, + TypedDict, +) import msgpack import nanoid @@ -20,7 +28,12 @@ ) from replit_river.transport_options import TransportOptions, UriAndMetadata from replit_river.v2.client import Client -from replit_river.v2.session import STREAM_CLOSED_BIT, Session +from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT, Session + + +class OuterPayload[A](TypedDict): + ok: Literal[True] + payload: A class _PermissiveRateLimiter(RateLimiter): @@ -247,3 +260,106 @@ async def handle_server_messages() -> None: # Ensure we're listening to close messages as well server_handler.cancel() await server_handler + + +async def test_upload_cancel(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + async def handle_server_messages() -> None: + request_msg = parse_transport_msg(await recv.get()) + assert not isinstance(request_msg, str) + + logging.debug("request_msg: %r", repr(request_msg)) + + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + while msg.payload.get("payload", {}).get("hello") == "world": + logging.debug("Found a hello:world %r", repr(msg)) + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + + assert msg.controlFlags == STREAM_CANCEL_BIT + + server_handler = asyncio.create_task(handle_server_messages()) + + sent_waiter = asyncio.Event() + + async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: + count = 0 + while True: + await asyncio.sleep(0.1) + yield { + "ok": True, + "payload": { + "hello": "world", + }, + } + count += 1 + if count > 5: + # We've sent enough messages, interrupt the stream. + sent_waiter.set() + + upload_task = asyncio.create_task( + client.send_upload( + "test", + "bigstream", + {}, + upload_chunks(), + lambda x: x, + lambda x: x, + lambda x: x, + lambda x: x, + ) + ) + + # Wait until we've seen at least a few messages from the upload Task + await sent_waiter.wait() + + upload_task.cancel() + try: + await upload_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + server_handler.cancel() + await server_handler From 6b2b84ba3f809ed05347e3c30969087a9fbc7a3f Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 20:07:51 -0700 Subject: [PATCH 07/14] Moving fixtures out --- tests/conftest.py | 1 + tests/v2/fixtures/raw_ws_server.py | 86 +++++++++++++++++++++++++++ tests/v2/test_v2_session_lifecycle.py | 74 +---------------------- 3 files changed, 88 insertions(+), 73 deletions(-) create mode 100644 tests/v2/fixtures/raw_ws_server.py diff --git a/tests/conftest.py b/tests/conftest.py index f9d5e393..db928bc5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ "tests.v1.river_fixtures.logging", "tests.v1.river_fixtures.clientserver", "tests.v2.fixtures.bound_client", + "tests.v2.fixtures.raw_ws_server", ] HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] diff --git a/tests/v2/fixtures/raw_ws_server.py b/tests/v2/fixtures/raw_ws_server.py new file mode 100644 index 00000000..7ac1bca7 --- /dev/null +++ b/tests/v2/fixtures/raw_ws_server.py @@ -0,0 +1,86 @@ +import asyncio +from typing import ( + AsyncIterator, + Awaitable, + Callable, + TypeAlias, + TypedDict, +) + +import pytest +from websockets import ConnectionClosed, ConnectionClosedOK, Data +from websockets.asyncio.server import ServerConnection, serve + +from replit_river.transport_options import UriAndMetadata + +WsServerFixture: TypeAlias = tuple[ + Callable[[], Awaitable[UriAndMetadata[None]]], + asyncio.Queue[bytes], + Callable[[], ServerConnection | None], +] + + +class _WsServerState(TypedDict): + ipv4_laddr: tuple[str, int] | None + + +async def _ws_server_internal( + recv: asyncio.Queue[bytes], + set_conn: Callable[[ServerConnection], None], + state: _WsServerState, +) -> AsyncIterator[None]: + async def handle(websocket: ServerConnection) -> None: + set_conn(websocket) + datagram: Data + try: + while datagram := await websocket.recv(decode=False): + if isinstance(datagram, str): + continue + await recv.put(datagram) + except ConnectionClosedOK: + pass + except ConnectionClosed: + pass + + port: int | None = None + if state["ipv4_laddr"]: + port = state["ipv4_laddr"][1] + async with serve(handle, "localhost", port=port) as server: + for sock in server.sockets: + if (pair := sock.getsockname())[0] == "127.0.0.1": + if state["ipv4_laddr"] is None: + state["ipv4_laddr"] = pair + serve_forever = asyncio.create_task(server.serve_forever()) + yield None + server.close() + await server.wait_closed() + # "serve_forever" should always be done after wait_closed finishes + assert serve_forever.done() + + +@pytest.fixture +async def ws_server() -> AsyncIterator[WsServerFixture]: + recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1) + connection: ServerConnection | None = None + state: _WsServerState = {"ipv4_laddr": None} + + def set_conn(new_conn: ServerConnection) -> None: + nonlocal connection + connection = new_conn + + server_generator = _ws_server_internal(recv, set_conn, state) + await anext(server_generator) + + async def urimeta() -> UriAndMetadata[None]: + ipv4_laddr = state["ipv4_laddr"] + assert ipv4_laddr + return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None) + + yield (urimeta, recv, lambda: connection) + + connection = None + + try: + await anext(server_generator) + except StopAsyncIteration: + pass diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index 62334d8c..f04dd226 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -29,6 +29,7 @@ from replit_river.transport_options import TransportOptions, UriAndMetadata from replit_river.v2.client import Client from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT, Session +from tests.v2.fixtures.raw_ws_server import WsServerFixture class OuterPayload[A](TypedDict): @@ -50,79 +51,6 @@ def consume_budget(self, user: str) -> None: pass -WsServerFixture: TypeAlias = tuple[ - Callable[[], Awaitable[UriAndMetadata[None]]], - asyncio.Queue[bytes], - Callable[[], ServerConnection | None], -] - - -class _WsServerState(TypedDict): - ipv4_laddr: tuple[str, int] | None - - -async def _ws_server_internal( - recv: asyncio.Queue[bytes], - set_conn: Callable[[ServerConnection], None], - state: _WsServerState, -) -> AsyncIterator[None]: - async def handle(websocket: ServerConnection) -> None: - set_conn(websocket) - datagram: Data - try: - while datagram := await websocket.recv(decode=False): - if isinstance(datagram, str): - continue - await recv.put(datagram) - except ConnectionClosedOK: - pass - except ConnectionClosed: - pass - - port: int | None = None - if state["ipv4_laddr"]: - port = state["ipv4_laddr"][1] - async with serve(handle, "localhost", port=port) as server: - for sock in server.sockets: - if (pair := sock.getsockname())[0] == "127.0.0.1": - if state["ipv4_laddr"] is None: - state["ipv4_laddr"] = pair - serve_forever = asyncio.create_task(server.serve_forever()) - yield None - server.close() - await server.wait_closed() - # "serve_forever" should always be done after wait_closed finishes - assert serve_forever.done() - - -@pytest.fixture -async def ws_server() -> AsyncIterator[WsServerFixture]: - recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1) - connection: ServerConnection | None = None - state: _WsServerState = {"ipv4_laddr": None} - - def set_conn(new_conn: ServerConnection) -> None: - nonlocal connection - connection = new_conn - - server_generator = _ws_server_internal(recv, set_conn, state) - await anext(server_generator) - - async def urimeta() -> UriAndMetadata[None]: - ipv4_laddr = state["ipv4_laddr"] - assert ipv4_laddr - return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None) - - yield (urimeta, recv, lambda: connection) - - connection = None - - try: - await anext(server_generator) - except StopAsyncIteration: - pass - - async def test_connect(ws_server: WsServerFixture) -> None: (urimeta, recv, conn) = ws_server From b518386b7c4c9eb86b32d348d64454a39e7eebfd Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 20:09:29 -0700 Subject: [PATCH 08/14] Moving test_upload_cancel out --- tests/v2/test_v2_cancellation.py | 131 ++++++++++++++++++++++++++ tests/v2/test_v2_session_lifecycle.py | 116 +---------------------- 2 files changed, 133 insertions(+), 114 deletions(-) create mode 100644 tests/v2/test_v2_cancellation.py diff --git a/tests/v2/test_v2_cancellation.py b/tests/v2/test_v2_cancellation.py new file mode 100644 index 00000000..89a39976 --- /dev/null +++ b/tests/v2/test_v2_cancellation.py @@ -0,0 +1,131 @@ +import asyncio +import logging +from typing import ( + Any, + AsyncIterator, + Literal, + TypedDict, +) + +import msgpack +import nanoid + +from replit_river.messages import parse_transport_msg +from replit_river.rpc import ( + ControlMessageHandshakeRequest, + ControlMessageHandshakeResponse, + HandShakeStatus, + TransportMessage, +) +from replit_river.transport_options import TransportOptions +from replit_river.v2.client import Client +from replit_river.v2.session import STREAM_CANCEL_BIT +from tests.v2.fixtures.raw_ws_server import WsServerFixture + + +class OuterPayload[A](TypedDict): + ok: Literal[True] + payload: A + + +async def test_upload_cancel(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + async def handle_server_messages() -> None: + request_msg = parse_transport_msg(await recv.get()) + assert not isinstance(request_msg, str) + + logging.debug("request_msg: %r", repr(request_msg)) + + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + while msg.payload.get("payload", {}).get("hello") == "world": + logging.debug("Found a hello:world %r", repr(msg)) + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + + assert msg.controlFlags == STREAM_CANCEL_BIT + + server_handler = asyncio.create_task(handle_server_messages()) + + sent_waiter = asyncio.Event() + + async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: + count = 0 + while True: + await asyncio.sleep(0.1) + yield { + "ok": True, + "payload": { + "hello": "world", + }, + } + count += 1 + if count > 5: + # We've sent enough messages, interrupt the stream. + sent_waiter.set() + + upload_task = asyncio.create_task( + client.send_upload( + "test", + "bigstream", + {}, + upload_chunks(), + lambda x: x, + lambda x: x, + lambda x: x, + lambda x: x, + ) + ) + + # Wait until we've seen at least a few messages from the upload Task + await sent_waiter.wait() + + upload_task.cancel() + try: + await upload_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + server_handler.cancel() + await server_handler diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index f04dd226..f8d4551b 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -1,21 +1,12 @@ import asyncio import logging from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, Literal, - TypeAlias, TypedDict, ) import msgpack import nanoid -import pytest -from websockets import ConnectionClosed, ConnectionClosedOK -from websockets.asyncio.server import ServerConnection, serve -from websockets.typing import Data from replit_river.common_session import SessionState from replit_river.messages import parse_transport_msg @@ -26,9 +17,9 @@ HandShakeStatus, TransportMessage, ) -from replit_river.transport_options import TransportOptions, UriAndMetadata +from replit_river.transport_options import TransportOptions from replit_river.v2.client import Client -from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT, Session +from replit_river.v2.session import STREAM_CLOSED_BIT, Session from tests.v2.fixtures.raw_ws_server import WsServerFixture @@ -188,106 +179,3 @@ async def handle_server_messages() -> None: # Ensure we're listening to close messages as well server_handler.cancel() await server_handler - - -async def test_upload_cancel(ws_server: WsServerFixture) -> None: - (urimeta, recv, conn) = ws_server - - client = Client( - client_id="CLIENT1", - server_id="SERVER", - transport_options=TransportOptions(), - uri_and_metadata_factory=urimeta, - ) - - connecting = asyncio.create_task(client.ensure_connected()) - request_msg = parse_transport_msg(await recv.get()) - - assert not isinstance(request_msg, str) - assert (serverconn := conn()) - handshake_request: ControlMessageHandshakeRequest[None] = ( - ControlMessageHandshakeRequest(**request_msg.payload) - ) - - handshake_resp = ControlMessageHandshakeResponse( - status=HandShakeStatus( - ok=True, - ), - ) - handshake_request.sessionId - - msg = TransportMessage( - from_=request_msg.from_, - to=request_msg.to, - streamId=request_msg.streamId, - controlFlags=0, - id=nanoid.generate(), - seq=0, - ack=0, - payload=handshake_resp.model_dump(), - ) - packed = msgpack.packb( - msg.model_dump(by_alias=True, exclude_none=True), datetime=True - ) - await serverconn.send(packed) - - async def handle_server_messages() -> None: - request_msg = parse_transport_msg(await recv.get()) - assert not isinstance(request_msg, str) - - logging.debug("request_msg: %r", repr(request_msg)) - - msg = TransportMessage(**msgpack.unpackb(await recv.get())) - while msg.payload.get("payload", {}).get("hello") == "world": - logging.debug("Found a hello:world %r", repr(msg)) - msg = TransportMessage(**msgpack.unpackb(await recv.get())) - - assert msg.controlFlags == STREAM_CANCEL_BIT - - server_handler = asyncio.create_task(handle_server_messages()) - - sent_waiter = asyncio.Event() - - async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: - count = 0 - while True: - await asyncio.sleep(0.1) - yield { - "ok": True, - "payload": { - "hello": "world", - }, - } - count += 1 - if count > 5: - # We've sent enough messages, interrupt the stream. - sent_waiter.set() - - upload_task = asyncio.create_task( - client.send_upload( - "test", - "bigstream", - {}, - upload_chunks(), - lambda x: x, - lambda x: x, - lambda x: x, - lambda x: x, - ) - ) - - # Wait until we've seen at least a few messages from the upload Task - await sent_waiter.wait() - - upload_task.cancel() - try: - await upload_task - except asyncio.CancelledError: - pass - - await client.close() - await connecting - - # Ensure we're listening to close messages as well - server_handler.cancel() - await server_handler From a93b0e98d0eb24839b397b2bdd3c2de117303c64 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 20:14:02 -0700 Subject: [PATCH 09/14] Consolidating --- tests/v2/fixtures/raw_ws_server.py | 6 ++++++ tests/v2/test_v2_cancellation.py | 9 +-------- tests/v2/test_v2_session_lifecycle.py | 9 --------- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/tests/v2/fixtures/raw_ws_server.py b/tests/v2/fixtures/raw_ws_server.py index 7ac1bca7..2353f941 100644 --- a/tests/v2/fixtures/raw_ws_server.py +++ b/tests/v2/fixtures/raw_ws_server.py @@ -3,6 +3,7 @@ AsyncIterator, Awaitable, Callable, + Literal, TypeAlias, TypedDict, ) @@ -20,6 +21,11 @@ ] +class OuterPayload[A](TypedDict): + ok: Literal[True] + payload: A + + class _WsServerState(TypedDict): ipv4_laddr: tuple[str, int] | None diff --git a/tests/v2/test_v2_cancellation.py b/tests/v2/test_v2_cancellation.py index 89a39976..de918f89 100644 --- a/tests/v2/test_v2_cancellation.py +++ b/tests/v2/test_v2_cancellation.py @@ -3,8 +3,6 @@ from typing import ( Any, AsyncIterator, - Literal, - TypedDict, ) import msgpack @@ -20,12 +18,7 @@ from replit_river.transport_options import TransportOptions from replit_river.v2.client import Client from replit_river.v2.session import STREAM_CANCEL_BIT -from tests.v2.fixtures.raw_ws_server import WsServerFixture - - -class OuterPayload[A](TypedDict): - ok: Literal[True] - payload: A +from tests.v2.fixtures.raw_ws_server import OuterPayload, WsServerFixture async def test_upload_cancel(ws_server: WsServerFixture) -> None: diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index f8d4551b..736a35f8 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -1,9 +1,5 @@ import asyncio import logging -from typing import ( - Literal, - TypedDict, -) import msgpack import nanoid @@ -23,11 +19,6 @@ from tests.v2.fixtures.raw_ws_server import WsServerFixture -class OuterPayload[A](TypedDict): - ok: Literal[True] - payload: A - - class _PermissiveRateLimiter(RateLimiter): def start_restoring_budget(self, user: str) -> None: pass From d6cda645f6c996e9d8ec9b74c90a79d3a1bbf715 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 20:31:32 -0700 Subject: [PATCH 10/14] Adding an RPC cancellation test --- tests/v2/test_v2_cancellation.py | 96 +++++++++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 1 deletion(-) diff --git a/tests/v2/test_v2_cancellation.py b/tests/v2/test_v2_cancellation.py index de918f89..c776f682 100644 --- a/tests/v2/test_v2_cancellation.py +++ b/tests/v2/test_v2_cancellation.py @@ -1,5 +1,6 @@ import asyncio import logging +from datetime import timedelta from typing import ( Any, AsyncIterator, @@ -10,6 +11,7 @@ from replit_river.messages import parse_transport_msg from replit_river.rpc import ( + STREAM_OPEN_BIT, ControlMessageHandshakeRequest, ControlMessageHandshakeResponse, HandShakeStatus, @@ -17,10 +19,102 @@ ) from replit_river.transport_options import TransportOptions from replit_river.v2.client import Client -from replit_river.v2.session import STREAM_CANCEL_BIT +from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT from tests.v2.fixtures.raw_ws_server import OuterPayload, WsServerFixture +async def test_rpc_cancel(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + sent_waiter = asyncio.Event() + + async def handle_server_messages() -> None: + request_msg = parse_transport_msg(await recv.get()) + assert not isinstance(request_msg, str) + + logging.debug("request_msg: %r", repr(request_msg)) + + assert request_msg.payload.get("payload", {}).get("hello") == "world" + logging.debug("Found a hello:world %r", repr(request_msg)) + + sent_waiter.set() + + assert request_msg.controlFlags == STREAM_OPEN_BIT | STREAM_CLOSED_BIT + + cancel_msg = parse_transport_msg(await recv.get()) + assert not isinstance(cancel_msg, str) + assert cancel_msg.controlFlags == STREAM_CANCEL_BIT + + server_handler = asyncio.create_task(handle_server_messages()) + + rpc_task = asyncio.create_task( + client.send_rpc( + "test", + "bigstream", + {"ok": True, "payload": {"hello": "world"}}, + lambda x: x, + lambda x: x, + lambda x: x, + timedelta(seconds=2), + ) + ) + + # Wait until we've seen at least a few messages from the upload Task + await sent_waiter.wait() + + rpc_task.cancel() + + try: + await rpc_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + server_handler.cancel() + await server_handler + + async def test_upload_cancel(ws_server: WsServerFixture) -> None: (urimeta, recv, conn) = ws_server From b5971b6a4a6d3bbb89e8daff8f89050234b213df Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 21:04:54 -0700 Subject: [PATCH 11/14] Add a subscription test --- tests/v2/test_v2_cancellation.py | 117 +++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/tests/v2/test_v2_cancellation.py b/tests/v2/test_v2_cancellation.py index c776f682..bf82edce 100644 --- a/tests/v2/test_v2_cancellation.py +++ b/tests/v2/test_v2_cancellation.py @@ -115,6 +115,123 @@ async def handle_server_messages() -> None: await server_handler +async def test_subscription_cancel(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + received_waiter = asyncio.Event() + + async def handle_server_messages() -> None: + request_msg = parse_transport_msg(await recv.get()) + assert not isinstance(request_msg, str) + + logging.debug("request_msg: %r", repr(request_msg)) + seq = 0 + + while True: + try: + cancel_msg = parse_transport_msg(recv.get_nowait()) + break + except asyncio.queues.QueueEmpty: + pass + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=seq, + ack=0, + payload={ + "ok": True, + "payload": { + "hello": "world", + }, + }, + ) + seq += 1 + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + await asyncio.sleep(0.1) + + if seq > 5: + received_waiter.set() + + assert not isinstance(cancel_msg, str) + assert cancel_msg.controlFlags == STREAM_CANCEL_BIT + + server_handler = asyncio.create_task(handle_server_messages()) + + async def receive_chunks() -> None: + async for chunk in client.send_subscription( + "test", + "bigstream", + {}, + lambda x: x, + lambda x: x, + lambda x: x, + ): + print(repr(chunk)) + + receive_task = asyncio.create_task(receive_chunks()) + + # Wait until we've seen at least a few messages from the upload Task + await received_waiter.wait() + + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + server_handler.cancel() + await server_handler + + async def test_upload_cancel(ws_server: WsServerFixture) -> None: (urimeta, recv, conn) = ws_server From 31084f48818ecb8dd4bd8ffc333063961aef260b Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 22:46:55 -0700 Subject: [PATCH 12/14] Propagating stream_id into logs --- src/replit_river/error_schema.py | 2 ++ src/replit_river/task_manager.py | 21 ++++++++++++++++++++- src/replit_river/v2/session.py | 21 +++++++++++++++++---- 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/replit_river/error_schema.py b/src/replit_river/error_schema.py index 5bff801a..94a6789a 100644 --- a/src/replit_river/error_schema.py +++ b/src/replit_river/error_schema.py @@ -86,8 +86,10 @@ class SessionClosedRiverServiceException(RiverException): def __init__( self, message: str, + streamId: str, ) -> None: super().__init__(SYNTHETIC_ERROR_CODE_SESSION_CLOSED, message) + self.streamId = streamId def exception_from_message(code: str) -> type[RiverServiceException]: diff --git a/src/replit_river/task_manager.py b/src/replit_river/task_manager.py index 531292d0..1182bef3 100644 --- a/src/replit_river/task_manager.py +++ b/src/replit_river/task_manager.py @@ -2,7 +2,11 @@ import logging from typing import Coroutine, Set -from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException +from replit_river.error_schema import ( + ERROR_CODE_STREAM_CLOSED, + RiverException, + SessionClosedRiverServiceException, +) logger = logging.getLogger(__name__) @@ -37,6 +41,13 @@ async def cancel_task( # If we cancel the task manager we will get called here as well, # if we want to handle the cancellation differently we can do it here. logger.debug("Task was cancelled %r", task_to_remove) + except SessionClosedRiverServiceException as e: + logger.warning( + "Session was closed", + extra={ + "stream_id": e.streamId, + }, + ) except RiverException as e: if e.code == ERROR_CODE_STREAM_CLOSED: # Task is cancelled @@ -76,6 +87,14 @@ def _task_done_callback( ): # Task is cancelled pass + elif isinstance(exception, SessionClosedRiverServiceException): + # Session is closed, don't bother logging + logger.info( + "Session closed", + extra={ + "stream_id": exception.streamId, + }, + ) else: logger.error( "Exception on cancelling task", diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 0f0c7c2f..aeae6cfd 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -353,6 +353,7 @@ async def _enqueue_message( # session is closing / closed, raise raise SessionClosedRiverServiceException( "river session is closed, dropping message", + stream_id, ) # Begin critical section: Avoid any await between here and _send_buffer.append @@ -448,7 +449,7 @@ async def do_close() -> None: await self._task_manager.cancel_all_tasks() - for stream_meta in self._streams.values(): + for stream_id, stream_meta in self._streams.items(): stream_meta["output"].close() # Wake up backpressured writers try: @@ -456,6 +457,7 @@ async def do_close() -> None: reason or SessionClosedRiverServiceException( "river session is closed", + stream_id, ) ) except ChannelFull: @@ -1023,12 +1025,14 @@ async def _encode_stream() -> None: # ... block the outer function until the emitter is finished emitting, # possibly raising a terminal exception. await emitter_task - except asyncio.CancelledError: + except asyncio.CancelledError as e: await self._send_cancel_stream( stream_id=stream_id, message="Stream cancelled", span=span, ) + if emitter_task.done() and (err := emitter_task.exception()): + raise e from err raise except Exception as e: await self._send_cancel_stream( @@ -1316,6 +1320,7 @@ async def _recv_from_ws( # the outer loop. await transition_no_connection() break + msg: TransportMessage | str | None = None try: msg = parse_transport_msg(message) logger.debug( @@ -1395,9 +1400,13 @@ async def _recv_from_ws( stream_meta["output"].close() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") + stream_id = "unknown" + if isinstance(msg, TransportMessage): + stream_id = msg.streamId close_session( SessionClosedRiverServiceException( - "Out of order message, closing connection" + "Out of order message, closing connection", + stream_id, ) ) continue @@ -1405,9 +1414,13 @@ async def _recv_from_ws( logger.exception( "Got invalid transport message, closing session", ) + stream_id = "unknown" + if isinstance(msg, TransportMessage): + stream_id = msg.streamId close_session( SessionClosedRiverServiceException( - "Out of order message, closing connection" + "Out of order message, closing connection", + stream_id, ) ) continue From 3ce93d65283cbb0763a7164381df8209e2248771 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 22:50:19 -0700 Subject: [PATCH 13/14] Adding stream cancel test --- tests/v2/test_v2_cancellation.py | 147 +++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/tests/v2/test_v2_cancellation.py b/tests/v2/test_v2_cancellation.py index bf82edce..ddac7103 100644 --- a/tests/v2/test_v2_cancellation.py +++ b/tests/v2/test_v2_cancellation.py @@ -4,10 +4,12 @@ from typing import ( Any, AsyncIterator, + Literal, ) import msgpack import nanoid +import pytest from replit_river.messages import parse_transport_msg from replit_river.rpc import ( @@ -115,6 +117,151 @@ async def handle_server_messages() -> None: await server_handler +@pytest.mark.parametrize("direction", ["send", "receive"]) +async def test_stream_cancel( + ws_server: WsServerFixture, direction: Literal["send", "receive"] +) -> None: + (urimeta, recv, conn) = ws_server + + client = Client( + client_id="CLIENT1", + server_id="SERVER", + transport_options=TransportOptions(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(client.ensure_connected()) + request_msg = parse_transport_msg(await recv.get()) + + assert not isinstance(request_msg, str) + assert (serverconn := conn()) + handshake_request: ControlMessageHandshakeRequest[None] = ( + ControlMessageHandshakeRequest(**request_msg.payload) + ) + + handshake_resp = ControlMessageHandshakeResponse( + status=HandShakeStatus( + ok=True, + ), + ) + handshake_request.sessionId + + msg = TransportMessage( + from_=request_msg.from_, + to=request_msg.to, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=0, + ack=0, + payload=handshake_resp.model_dump(), + ) + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + + bidi_waiter = asyncio.Event() + + async def send_server_messages(request_msg: TransportMessage) -> None: + seq = 0 + + while True: + msg = TransportMessage( + from_=request_msg.to, + to=request_msg.from_, + streamId=request_msg.streamId, + controlFlags=0, + id=nanoid.generate(), + seq=seq, + ack=0, + payload={ + "ok": True, + "payload": { + "hello": "world", + }, + }, + ) + seq += 1 + packed = msgpack.packb( + msg.model_dump(by_alias=True, exclude_none=True), datetime=True + ) + await serverconn.send(packed) + await asyncio.sleep(0.1) + + if seq > 5 and direction == "send": + bidi_waiter.set() + + async def handle_server_messages(request_msg: TransportMessage) -> None: + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + while msg.payload.get("payload", {}).get("hello") == "world": + logging.debug("Found a hello:world %r", repr(msg)) + msg = TransportMessage(**msgpack.unpackb(await recv.get())) + + assert msg.controlFlags == STREAM_CANCEL_BIT + + async def receive_chunks() -> None: + async def _upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: + count = 0 + while True: + await asyncio.sleep(0.1) + yield { + "ok": True, + "payload": { + "hello": "world", + }, + } + count += 1 + if count > 5 and direction == "receive": + # We've sent enough messages, interrupt the stream. + bidi_waiter.set() + + async for chunk in client.send_stream( + "test", + "bigstream", + {}, + _upload_chunks(), + lambda x: x, + lambda x: x, + lambda x: x, + lambda x: x, + ): + print(repr(chunk)) + + receive_task = asyncio.create_task(receive_chunks()) + request_msg = parse_transport_msg(await recv.get()) + logging.debug("request_msg: %r", repr(request_msg)) + assert not isinstance(request_msg, str) + + server_sender = asyncio.create_task(send_server_messages(request_msg)) + server_receiver = asyncio.create_task(handle_server_messages(request_msg)) + + # Wait until we've seen at least a few messages from the requisite Task + await bidi_waiter.wait() + + receive_task.cancel() + try: + await receive_task + except asyncio.CancelledError: + pass + + await client.close() + await connecting + + # Ensure we're listening to close messages as well + assert server_sender + server_sender.cancel() + try: + await server_sender + except asyncio.CancelledError: + pass + server_receiver.cancel() + try: + await server_receiver + except Exception: + pass + + async def test_subscription_cancel(ws_server: WsServerFixture) -> None: (urimeta, recv, conn) = ws_server From e9d44cd2573a9ea39c2a26ad4fbbc7b7f9214131 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 6 May 2025 22:53:23 -0700 Subject: [PATCH 14/14] method names --- tests/v2/test_v2_cancellation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/v2/test_v2_cancellation.py b/tests/v2/test_v2_cancellation.py index ddac7103..b259d27e 100644 --- a/tests/v2/test_v2_cancellation.py +++ b/tests/v2/test_v2_cancellation.py @@ -90,7 +90,7 @@ async def handle_server_messages() -> None: rpc_task = asyncio.create_task( client.send_rpc( "test", - "bigstream", + "cancel_rpc", {"ok": True, "payload": {"hello": "world"}}, lambda x: x, lambda x: x, @@ -218,7 +218,7 @@ async def _upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: async for chunk in client.send_stream( "test", - "bigstream", + "cancel_stream", {}, _upload_chunks(), lambda x: x, @@ -352,7 +352,7 @@ async def handle_server_messages() -> None: async def receive_chunks() -> None: async for chunk in client.send_subscription( "test", - "bigstream", + "subscription_cancel", {}, lambda x: x, lambda x: x, @@ -455,7 +455,7 @@ async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]: upload_task = asyncio.create_task( client.send_upload( "test", - "bigstream", + "upload_cancel", {}, upload_chunks(), lambda x: x,