From 4537d336751f82c83cb343b1e895368d8e2d7d3c Mon Sep 17 00:00:00 2001 From: Connor Brewster Date: Wed, 8 Jan 2025 13:56:12 -0600 Subject: [PATCH] Drop messages on the client for closed streams/subscriptions --- replit_river/client_session.py | 4 +++ replit_river/session.py | 6 +++- tests/common_handlers.py | 4 +-- tests/conftest.py | 5 +-- tests/test_communication.py | 57 +++++++++++++++++++++++++++++++++- 5 files changed, 70 insertions(+), 6 deletions(-) diff --git a/replit_river/client_session.py b/replit_river/client_session.py index ecd43370..a5a0c44e 100644 --- a/replit_river/client_session.py +++ b/replit_river/client_session.py @@ -234,6 +234,8 @@ async def send_subscription( ) from e except Exception as e: raise e + finally: + output.close() async def send_stream( self, @@ -335,6 +337,8 @@ async def _encode_stream() -> None: ) from e except Exception as e: raise e + finally: + output.close() async def send_close_stream( self, diff --git a/replit_river/session.py b/replit_river/session.py index 2be2b816..eb34d9ef 100644 --- a/replit_river/session.py +++ b/replit_river/session.py @@ -525,7 +525,11 @@ async def _add_msg_to_stream( return try: await stream.put(msg.payload) - except (RuntimeError, ChannelClosed) as e: + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: raise InvalidMessageException(e) from e async def _remove_acked_messages_in_buffer(self) -> None: diff --git a/tests/common_handlers.py b/tests/common_handlers.py index 3ddc4825..19a5e2fa 100644 --- a/tests/common_handlers.py +++ b/tests/common_handlers.py @@ -39,7 +39,7 @@ async def upload_handler( basic_upload: HandlerMapping = { ("test_service", "upload_method"): ( - "upload", + "upload-stream", upload_method_handler(upload_handler, deserialize_request, serialize_response), ), } @@ -54,7 +54,7 @@ async def subscription_handler( basic_subscription: HandlerMapping = { ("test_service", "subscription_method"): ( - "subscription", + "subscription-stream", subscription_method_handler( subscription_handler, deserialize_request, serialize_response ), diff --git a/tests/conftest.py b/tests/conftest.py index 16ee8270..529ffb23 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping +from typing import Any, Literal, Mapping import nanoid import pytest @@ -16,7 +16,8 @@ # Modular fixtures pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"] -HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]] +HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] +HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandler]] def transport_message( diff --git a/tests/test_communication.py b/tests/test_communication.py index fc86d270..488e4b81 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -3,9 +3,11 @@ from typing import AsyncGenerator import pytest +from grpc.aio import grpc from replit_river.client import Client from replit_river.error_schema import RiverError +from replit_river.rpc import subscription_method_handler from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE from tests.common_handlers import ( basic_rpc_method, @@ -14,9 +16,12 @@ basic_upload, ) from tests.conftest import ( + HandlerMapping, deserialize_error, + deserialize_request, deserialize_response, serialize_request, + serialize_response, ) @@ -101,6 +106,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]: @pytest.mark.asyncio @pytest.mark.parametrize("handlers", [{**basic_subscription}]) async def test_subscription_method(client: Client) -> None: + messages = [] async for response in client.send_subscription( "test_service", "subscription_method", @@ -110,7 +116,8 @@ async def test_subscription_method(client: Client) -> None: deserialize_error, ): assert isinstance(response, str) - assert "Subscription message" in response + messages.append(response) + assert messages == [f"Subscription message {i} for Bob" for i in range(5)] @pytest.mark.asyncio @@ -213,3 +220,51 @@ async def stream_data() -> AsyncGenerator[str, None]: "Stream response for Stream Data 1", "Stream response for Stream Data 2", ] + + +async def flood_subscription_handler( + request: str, context: grpc.aio.ServicerContext +) -> AsyncGenerator[str, None]: + for i in range(1024): + yield f"Subscription message {i} for {request}" + + +flood_subscription: HandlerMapping = { + ("test_service", "flood_subscription_method"): ( + "subscription-stream", + subscription_method_handler( + flood_subscription_handler, deserialize_request, serialize_response + ), + ), +} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_rpc_method, **flood_subscription}]) +async def test_ignore_flood_subscription(client: Client) -> None: + sub = client.send_subscription( + "test_service", + "flood_subscription_method", + "Initial Subscription Data", + serialize_request, + deserialize_response, + deserialize_error, + ) + + # read one entry to start the subscription + await sub.__anext__() + # close the subscription so we can signal that we're not + # interested in the rest of the subscription. + await sub.aclose() + + # ensure that subsequent RPCs still work + response = await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + assert response == "Hello, Alice!"