From 11fcf1733cf0b097677bfb3cba4f1a69cc2d42ea Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 13:47:43 -0700 Subject: [PATCH 01/23] Remove noqa again, it was not necessary --- src/replit_river/codegen/client.py | 1 - tests/codegen/rpc/generated/test_service/rpc_method.py | 1 - .../snapshots/test_basic_stream/test_service/stream_method.py | 1 - .../test_pathological_types/test_service/pathological_method.py | 1 - .../snapshots/test_unknown_enum/enumService/needsEnum.py | 1 - .../snapshots/test_unknown_enum/enumService/needsEnumObject.py | 1 - 6 files changed, 6 deletions(-) diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 8d2791b9..0c837e12 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -67,7 +67,6 @@ FILE_HEADER = dedent( """\ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 91f0e562..f7dff38d 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index 23d2ab6d..eff77816 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py index f2bb120f..a9b530f5 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py +++ b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index dbe6e51e..0204f9c2 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 75f00e1c..97559be3 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime From 6aa87be6d7e32de88907283d14fef3aade04b372 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 14:40:13 -0700 Subject: [PATCH 02/23] Bubble state out of heartbeat --- src/replit_river/session.py | 85 +++++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 22 deletions(-) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index a94156a8..9fb96920 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,7 +1,7 @@ import asyncio import enum import logging -from typing import Any, Callable, Coroutine +from typing import Any, Awaitable, Callable, Coroutine, Protocol import nanoid # type: ignore import websockets @@ -42,6 +42,19 @@ trace_setter = TransportMessageTracingSetter() +class SendMessage(Protocol): + async def __call__( + self, + *, + stream_id: str, + payload: dict[Any, Any] | str, + control_flags: int, + service_name: str | None, + procedure_name: str | None, + span: Span | None, + ) -> None: ... + + class SessionState(enum.Enum): """The state a session can be in. @@ -53,7 +66,7 @@ class SessionState(enum.Enum): CLOSED = 2 -class Session(object): +class Session: """A transport object that handles the websocket connection with a client.""" def __init__( @@ -106,7 +119,29 @@ def __init__( self._setup_heartbeats_task() def _setup_heartbeats_task(self) -> None: - self._task_manager.create_task(self._heartbeat()) + async def do_close_websocket() -> None: + await self.close_websocket( + self._ws_wrapper, + should_retry=not self._is_server, + ) + await self._begin_close_session_countdown() + + def increment_and_get_heartbeat_misses() -> int: + self._heartbeat_misses += 1 + return self._heartbeat_misses + + self._task_manager.create_task( + self._heartbeat( + self.session_id, + self._transport_options.heartbeat_ms, + self._transport_options.heartbeats_until_dead, + lambda: self._state, + lambda: self._close_session_after_time_secs, + close_websocket=do_close_websocket, + send_message=self.send_message, + increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, + ) + ) self._task_manager.create_task(self._check_to_close_session()) async def is_session_open(self) -> bool: @@ -276,45 +311,51 @@ async def _check_to_close_session(self) -> None: async def _heartbeat( self, + session_id: str, + heartbeat_ms: float, + heartbeats_until_dead: int, + get_state: Callable[[], SessionState], + get_closing_grace_period: Callable[[], float | None], + close_websocket: Callable[[], Awaitable[None]], + send_message: SendMessage, + increment_and_get_heartbeat_misses: Callable[[], int], ) -> None: logger.debug("Start heartbeat") while True: - await asyncio.sleep(self._transport_options.heartbeat_ms / 1000) - if self._state != SessionState.ACTIVE: + await asyncio.sleep(heartbeat_ms / 1000) + state = get_state() + if state != SessionState.ACTIVE: logger.debug( "Session is closed, no need to send heartbeat, state : " "%r close_session_after_this: %r", - {self._state}, - {self._close_session_after_time_secs}, + {state}, + {get_closing_grace_period()}, ) # session is closing / closed, no need to send heartbeat anymore return try: - await self.send_message( - "heartbeat", + await send_message( + stream_id="heartbeat", # TODO: make this a message class # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 - { + payload={ "ack": 0, }, - ACK_BIT, + control_flags=ACK_BIT, + procedure_name=None, + service_name=None, + span=None, ) - self._heartbeat_misses += 1 - if ( - self._heartbeat_misses - > self._transport_options.heartbeats_until_dead - ): - if self._close_session_after_time_secs is not None: + + if increment_and_get_heartbeat_misses() > heartbeats_until_dead: + if get_closing_grace_period() is not None: # already in grace period, no need to set again continue logger.info( "%r closing websocket because of heartbeat misses", - self.session_id, + session_id, ) - await self.close_websocket( - self._ws_wrapper, should_retry=not self._is_server - ) - await self._begin_close_session_countdown() + await close_websocket() continue except FailedSendingMessageException: # this is expected during websocket closed period From b0f989b0b680661838a7d56bf2bc025fb5f571c8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 14:46:20 -0700 Subject: [PATCH 03/23] Break out heartbeat lifecycle --- src/replit_river/common_session.py | 61 ++++++++++++++++++++++++++++++ src/replit_river/session.py | 57 ++-------------------------- 2 files changed, 64 insertions(+), 54 deletions(-) create mode 100644 src/replit_river/common_session.py diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py new file mode 100644 index 00000000..05e7a1c8 --- /dev/null +++ b/src/replit_river/common_session.py @@ -0,0 +1,61 @@ +import asyncio +import logging +from typing import Awaitable, Callable + +from replit_river.messages import FailedSendingMessageException +from replit_river.rpc import ACK_BIT +from replit_river.session import SendMessage, SessionState + + +logger = logging.getLogger(__name__) + +async def setup_heartbeat( + session_id: str, + heartbeat_ms: float, + heartbeats_until_dead: int, + get_state: Callable[[], SessionState], + get_closing_grace_period: Callable[[], float | None], + close_websocket: Callable[[], Awaitable[None]], + send_message: SendMessage, + increment_and_get_heartbeat_misses: Callable[[], int], +) -> None: + logger.debug("Start heartbeat") + while True: + await asyncio.sleep(heartbeat_ms / 1000) + state = get_state() + if state != SessionState.ACTIVE: + logger.debug( + "Session is closed, no need to send heartbeat, state : " + "%r close_session_after_this: %r", + {state}, + {get_closing_grace_period()}, + ) + # session is closing / closed, no need to send heartbeat anymore + return + try: + await send_message( + stream_id="heartbeat", + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 + payload={ + "ack": 0, + }, + control_flags=ACK_BIT, + procedure_name=None, + service_name=None, + span=None, + ) + + if increment_and_get_heartbeat_misses() > heartbeats_until_dead: + if get_closing_grace_period() is not None: + # already in grace period, no need to set again + continue + logger.info( + "%r closing websocket because of heartbeat misses", + session_id, + ) + await close_websocket() + continue + except FailedSendingMessageException: + # this is expected during websocket closed period + continue diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 9fb96920..dd6506b3 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,7 +1,7 @@ import asyncio import enum import logging -from typing import Any, Awaitable, Callable, Coroutine, Protocol +from typing import Any, Callable, Coroutine, Protocol import nanoid # type: ignore import websockets @@ -10,6 +10,7 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from websockets.exceptions import ConnectionClosed +from replit_river.common_session import setup_heartbeat from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( FailedSendingMessageException, @@ -131,7 +132,7 @@ def increment_and_get_heartbeat_misses() -> int: return self._heartbeat_misses self._task_manager.create_task( - self._heartbeat( + setup_heartbeat( self.session_id, self._transport_options.heartbeat_ms, self._transport_options.heartbeats_until_dead, @@ -309,58 +310,6 @@ async def _check_to_close_session(self) -> None: await self.close() return - async def _heartbeat( - self, - session_id: str, - heartbeat_ms: float, - heartbeats_until_dead: int, - get_state: Callable[[], SessionState], - get_closing_grace_period: Callable[[], float | None], - close_websocket: Callable[[], Awaitable[None]], - send_message: SendMessage, - increment_and_get_heartbeat_misses: Callable[[], int], - ) -> None: - logger.debug("Start heartbeat") - while True: - await asyncio.sleep(heartbeat_ms / 1000) - state = get_state() - if state != SessionState.ACTIVE: - logger.debug( - "Session is closed, no need to send heartbeat, state : " - "%r close_session_after_this: %r", - {state}, - {get_closing_grace_period()}, - ) - # session is closing / closed, no need to send heartbeat anymore - return - try: - await send_message( - stream_id="heartbeat", - # TODO: make this a message class - # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 - payload={ - "ack": 0, - }, - control_flags=ACK_BIT, - procedure_name=None, - service_name=None, - span=None, - ) - - if increment_and_get_heartbeat_misses() > heartbeats_until_dead: - if get_closing_grace_period() is not None: - # already in grace period, no need to set again - continue - logger.info( - "%r closing websocket because of heartbeat misses", - session_id, - ) - await close_websocket() - continue - except FailedSendingMessageException: - # this is expected during websocket closed period - continue - async def _send_buffered_messages( self, websocket: websockets.WebSocketCommonProtocol ) -> None: From cfc25ac213e686e25bbabe8e3402403e98cfeb9e Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 15:10:40 -0700 Subject: [PATCH 04/23] Break out "ServerSession" type --- src/replit_river/client_transport.py | 20 ++++++++++++++++---- src/replit_river/common_session.py | 2 +- src/replit_river/server.py | 4 ++-- src/replit_river/server_session.py | 20 ++++++++++++++++++++ src/replit_river/server_transport.py | 28 ++++++++++++++++++++-------- src/replit_river/session.py | 5 +---- src/replit_river/transport.py | 17 ++++++----------- 7 files changed, 66 insertions(+), 30 deletions(-) create mode 100644 src/replit_river/server_session.py diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 79552b58..abd8019e 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -1,7 +1,7 @@ import asyncio import logging from collections.abc import Awaitable, Callable -from typing import Generic +from typing import Generic, Mapping import websockets from pydantic import ValidationError @@ -36,6 +36,7 @@ IgnoreMessageException, InvalidMessageException, ) +from replit_river.session import Session from replit_river.transport import Transport from replit_river.transport_options import ( HandshakeMetadataType, @@ -47,6 +48,8 @@ class ClientTransport(Transport, Generic[HandshakeMetadataType]): + _sessions: dict[str, ClientSession] + def __init__( self, uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]], @@ -59,6 +62,7 @@ def __init__( transport_options=transport_options, is_server=False, ) + self._sessions = {} self._uri_and_metadata_factory = uri_and_metadata_factory self._client_id = client_id self._server_id = server_id @@ -70,7 +74,7 @@ def __init__( async def close(self) -> None: self._rate_limiter.close() - await self._close_all_sessions() + await self._close_all_sessions(self._get_all_sessions) async def get_or_create_session(self) -> ClientSession: async with self._create_session_lock: @@ -207,13 +211,13 @@ async def _create_new_session( handlers={}, ) - self._set_session(new_session) + self._sessions[new_session._to_id] = new_session await new_session.start_serve_responses() return new_session async def _retry_connection(self) -> ClientSession: if not self._transport_options.transparent_reconnect: - await self._close_all_sessions() + await self._close_all_sessions(self._get_all_sessions) return await self.get_or_create_session() async def _send_handshake_request( @@ -352,3 +356,11 @@ async def _establish_handshake( + f"{handshake_response.status.reason}", ) return handshake_request, handshake_response + + def _get_all_sessions(self) -> Mapping[str, Session]: + return self._sessions + + async def _delete_session(self, session: Session) -> None: + async with self._session_lock: + if session._to_id in self._sessions: + del self._sessions[session._to_id] diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 05e7a1c8..2388059d 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -6,9 +6,9 @@ from replit_river.rpc import ACK_BIT from replit_river.session import SendMessage, SessionState - logger = logging.getLogger(__name__) + async def setup_heartbeat( session_id: str, heartbeat_ms: float, diff --git a/src/replit_river/server.py b/src/replit_river/server.py index 2bdf05b9..85713415 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -8,8 +8,8 @@ from replit_river.messages import WebsocketClosedException from replit_river.seq_manager import SessionStateMismatchException +from replit_river.server_session import ServerSession from replit_river.server_transport import ServerTransport -from replit_river.session import Session from replit_river.transport import TransportOptions from .rpc import ( @@ -41,7 +41,7 @@ def add_rpc_handlers( async def _handshake_to_get_session( self, websocket: WebSocketServerProtocol - ) -> Session | None: + ) -> ServerSession | None: """This is a wrapper to make sentry happy, sentry doesn't recognize the exception handling outside of a task or asyncio.wait_for. So we need to catch the errors specifically here. diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py new file mode 100644 index 00000000..3c1061b9 --- /dev/null +++ b/src/replit_river/server_session.py @@ -0,0 +1,20 @@ +import logging + +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +from replit_river.session import Session + +from .rpc import ( + TransportMessageTracingSetter, +) + +logger = logging.getLogger(__name__) + +trace_propagator = TraceContextTextMapPropagator() +trace_setter = TransportMessageTracingSetter() + + +class ServerSession(Session): + """A transport object that handles the websocket connection with a client.""" + + pass diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 888e0ce3..df38c896 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, Mapping import nanoid # type: ignore # type: ignore from pydantic import ValidationError @@ -27,6 +27,7 @@ InvalidMessageException, SessionStateMismatchException, ) +from replit_river.server_session import ServerSession from replit_river.session import Session from replit_river.transport import Transport from replit_river.transport_options import TransportOptions @@ -35,6 +36,8 @@ class ServerTransport(Transport): + _sessions: dict[str, ServerSession] + def __init__( self, transport_id: str, @@ -45,11 +48,12 @@ def __init__( transport_options=transport_options, is_server=True, ) + self._sessions = {} async def handshake_to_get_session( self, websocket: WebSocketServerProtocol, - ) -> Session: + ) -> ServerSession: async for message in websocket: try: msg = parse_transport_msg(message, self._transport_options) @@ -88,7 +92,7 @@ async def handshake_to_get_session( raise WebsocketClosedException("No handshake message received") async def close(self) -> None: - await self._close_all_sessions() + await self._close_all_sessions(self._get_all_sessions) async def _get_or_create_session( self, @@ -96,15 +100,15 @@ async def _get_or_create_session( to_id: str, session_id: str, websocket: WebSocketCommonProtocol, - ) -> Session: + ) -> ServerSession: async with self._session_lock: session_to_close: Session | None = None - new_session: Session | None = None + new_session: ServerSession | None = None if to_id not in self._sessions: logger.info( 'Creating new session with "%s" using ws: %s', to_id, websocket.id ) - new_session = Session( + new_session = ServerSession( transport_id, to_id, session_id, @@ -125,7 +129,7 @@ async def _get_or_create_session( old_session.session_id, ) session_to_close = old_session - new_session = Session( + new_session = ServerSession( transport_id, to_id, session_id, @@ -152,7 +156,7 @@ async def _get_or_create_session( if session_to_close: logger.info("Closing stale session %s", session_to_close.session_id) await session_to_close.close() - self._set_session(new_session) + self._sessions[new_session._to_id] = new_session return new_session async def _send_handshake_response( @@ -293,3 +297,11 @@ async def _establish_handshake( ) return handshake_request, handshake_response + + def _get_all_sessions(self) -> Mapping[str, Session]: + return self._sessions + + async def _delete_session(self, session: Session) -> None: + async with self._session_lock: + if session._to_id in self._sessions: + del self._sessions[session._to_id] diff --git a/src/replit_river/session.py b/src/replit_river/session.py index dd6506b3..0702746d 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -68,7 +68,7 @@ class SessionState(enum.Enum): class Session: - """A transport object that handles the websocket connection with a client.""" + """Common functionality shared between client_session and server_session""" def __init__( self, @@ -278,9 +278,6 @@ async def replace_with_new_websocket( await old_wrapper.close() self._ws_wrapper = WebsocketWrapper(new_ws) await self._send_buffered_messages(new_ws) - # Server will call serve itself. - if not self._is_server: - await self.start_serve_responses() async def _get_current_time(self) -> float: return asyncio.get_event_loop().time() diff --git a/src/replit_river/transport.py b/src/replit_river/transport.py index f0e2b920..3c5b50db 100644 --- a/src/replit_river/transport.py +++ b/src/replit_river/transport.py @@ -1,5 +1,6 @@ import asyncio import logging +from typing import Callable, Mapping import nanoid # type: ignore @@ -22,12 +23,14 @@ def __init__( self._transport_id = transport_id self._transport_options = transport_options self._is_server = is_server - self._sessions: dict[str, Session] = {} self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] = {} self._session_lock = asyncio.Lock() - async def _close_all_sessions(self) -> None: - sessions = self._sessions.values() + async def _close_all_sessions( + self, + get_all_sessions: Callable[[], Mapping[str, Session]], + ) -> None: + sessions = get_all_sessions().values() logger.info( f"start closing sessions {self._transport_id}, number sessions : " f"{len(sessions)}" @@ -41,13 +44,5 @@ async def _close_all_sessions(self) -> None: logger.info(f"Transport closed {self._transport_id}") - async def _delete_session(self, session: Session) -> None: - async with self._session_lock: - if session._to_id in self._sessions: - del self._sessions[session._to_id] - - def _set_session(self, session: Session) -> None: - self._sessions[session._to_id] = session - def generate_nanoid(self) -> str: return str(nanoid.generate()) From bd36123900ee7226fa00c14d9d3ea6cbc778040f Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 16:38:20 -0700 Subject: [PATCH 05/23] Flattening Transport into ClientTransport and ServerTransport --- src/replit_river/client_transport.py | 34 ++++++++++++++++---- src/replit_river/server.py | 2 +- src/replit_river/server_transport.py | 39 ++++++++++++++++------ src/replit_river/transport.py | 48 ---------------------------- 4 files changed, 57 insertions(+), 66 deletions(-) delete mode 100644 src/replit_river/transport.py diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index abd8019e..869398e5 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -3,6 +3,7 @@ from collections.abc import Awaitable, Callable from typing import Generic, Mapping +import nanoid import websockets from pydantic import ValidationError from websockets import ( @@ -37,7 +38,6 @@ InvalidMessageException, ) from replit_river.session import Session -from replit_river.transport import Transport from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -47,7 +47,7 @@ logger = logging.getLogger(__name__) -class ClientTransport(Transport, Generic[HandshakeMetadataType]): +class ClientTransport(Generic[HandshakeMetadataType]): _sessions: dict[str, ClientSession] def __init__( @@ -57,12 +57,11 @@ def __init__( server_id: str, transport_options: TransportOptions, ): - super().__init__( - transport_id=client_id, - transport_options=transport_options, - is_server=False, - ) self._sessions = {} + self._transport_id = client_id + self._transport_options = transport_options + self._session_lock = asyncio.Lock() + self._uri_and_metadata_factory = uri_and_metadata_factory self._client_id = client_id self._server_id = server_id @@ -72,6 +71,27 @@ def __init__( # We want to make sure there's only one session creation at a time self._create_session_lock = asyncio.Lock() + async def _close_all_sessions( + self, + get_all_sessions: Callable[[], Mapping[str, Session]], + ) -> None: + sessions = get_all_sessions().values() + logger.info( + f"start closing sessions {self._transport_id}, number sessions : " + f"{len(sessions)}" + ) + sessions_to_close = list(sessions) + + # closing sessions requires access to the session lock, so we need to close + # them one by one to be safe + for session in sessions_to_close: + await session.close() + + logger.info(f"Transport closed {self._transport_id}") + + def generate_nanoid(self) -> str: + return str(nanoid.generate()) + async def close(self) -> None: self._rate_limiter.close() await self._close_all_sessions(self._get_all_sessions) diff --git a/src/replit_river/server.py b/src/replit_river/server.py index 85713415..a186f4c3 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -10,7 +10,7 @@ from replit_river.seq_manager import SessionStateMismatchException from replit_river.server_session import ServerSession from replit_river.server_transport import ServerTransport -from replit_river.transport import TransportOptions +from replit_river.transport_options import TransportOptions from .rpc import ( GenericRpcHandler, diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index df38c896..64fc7fcc 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -1,7 +1,9 @@ +import asyncio import logging -from typing import Any, Mapping +from typing import Any, Callable, Mapping import nanoid # type: ignore # type: ignore +from grpc import GenericRpcHandler from pydantic import ValidationError from websockets import ( WebSocketCommonProtocol, @@ -29,26 +31,43 @@ ) from replit_river.server_session import ServerSession from replit_river.session import Session -from replit_river.transport import Transport from replit_river.transport_options import TransportOptions logger = logging.getLogger(__name__) -class ServerTransport(Transport): +class ServerTransport: _sessions: dict[str, ServerSession] + _handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] def __init__( self, transport_id: str, transport_options: TransportOptions, ) -> None: - super().__init__( - transport_id=transport_id, - transport_options=transport_options, - is_server=True, - ) self._sessions = {} + self._transport_id = transport_id + self._transport_options = transport_options + self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] = {} + self._session_lock = asyncio.Lock() + + async def _close_all_sessions( + self, + get_all_sessions: Callable[[], Mapping[str, Session]], + ) -> None: + sessions = get_all_sessions().values() + logger.info( + f"start closing sessions {self._transport_id}, number sessions : " + f"{len(sessions)}" + ) + sessions_to_close = list(sessions) + + # closing sessions requires access to the session lock, so we need to close + # them one by one to be safe + for session in sessions_to_close: + await session.close() + + logger.info(f"Transport closed {self._transport_id}") async def handshake_to_get_session( self, @@ -114,7 +133,7 @@ async def _get_or_create_session( session_id, websocket, self._transport_options, - self._is_server, + True, self._handlers, close_session_callback=self._delete_session, ) @@ -135,7 +154,7 @@ async def _get_or_create_session( session_id, websocket, self._transport_options, - self._is_server, + True, self._handlers, close_session_callback=self._delete_session, ) diff --git a/src/replit_river/transport.py b/src/replit_river/transport.py deleted file mode 100644 index 3c5b50db..00000000 --- a/src/replit_river/transport.py +++ /dev/null @@ -1,48 +0,0 @@ -import asyncio -import logging -from typing import Callable, Mapping - -import nanoid # type: ignore - -from replit_river.rpc import ( - GenericRpcHandler, -) -from replit_river.session import Session -from replit_river.transport_options import TransportOptions - -logger = logging.getLogger(__name__) - - -class Transport: - def __init__( - self, - transport_id: str, - transport_options: TransportOptions, - is_server: bool, - ) -> None: - self._transport_id = transport_id - self._transport_options = transport_options - self._is_server = is_server - self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] = {} - self._session_lock = asyncio.Lock() - - async def _close_all_sessions( - self, - get_all_sessions: Callable[[], Mapping[str, Session]], - ) -> None: - sessions = get_all_sessions().values() - logger.info( - f"start closing sessions {self._transport_id}, number sessions : " - f"{len(sessions)}" - ) - sessions_to_close = list(sessions) - - # closing sessions requires access to the session lock, so we need to close - # them one by one to be safe - for session in sessions_to_close: - await session.close() - - logger.info(f"Transport closed {self._transport_id}") - - def generate_nanoid(self) -> str: - return str(nanoid.generate()) From 1c9e76d9c99939c8f5effb54da224d56854569f3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 17:02:19 -0700 Subject: [PATCH 06/23] Disambiguate between builtin type --- src/replit_river/__init__.py | 4 ++-- src/replit_river/codegen/server.py | 2 +- src/replit_river/rpc.py | 10 +++++----- src/replit_river/server.py | 4 ++-- src/replit_river/server_transport.py | 6 +++--- src/replit_river/session.py | 4 ++-- tests/conftest.py | 4 ++-- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/replit_river/__init__.py b/src/replit_river/__init__.py index f837f16d..bc166e3e 100644 --- a/src/replit_river/__init__.py +++ b/src/replit_river/__init__.py @@ -1,7 +1,7 @@ from .client import Client from .error_schema import RiverError from .rpc import ( - GenericRpcHandler, + GenericRpcHandlerBuilder, GrpcContext, rpc_method_handler, stream_method_handler, @@ -15,7 +15,7 @@ "Server", "GrpcContext", "RiverError", - "GenericRpcHandler", + "GenericRpcHandlerBuilder", "rpc_method_handler", "subscription_method_handler", "upload_method_handler", diff --git a/src/replit_river/codegen/server.py b/src/replit_river/codegen/server.py index acee8106..c8e765a9 100644 --- a/src/replit_river/codegen/server.py +++ b/src/replit_river/codegen/server.py @@ -342,7 +342,7 @@ def add_{service.name}Servicer_to_server( ) -> None: rpc_method_handlers: Mapping[ tuple[str, str], - tuple[str, river.GenericRpcHandler] + tuple[str, river.GenericRpcHandlerBuilder] ] = {{ """ ), diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index 8415143a..0d1bd4d1 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -43,7 +43,7 @@ _MetadataType: TypeAlias = grpc.aio.Metadata | Sequence[tuple[str, str | bytes]] -GenericRpcHandler = Callable[ +GenericRpcHandlerBuilder = Callable[ [str, Channel[Any], Channel[Any]], Coroutine[None, None, None] ] ACK_BIT = 0x0001 @@ -220,7 +220,7 @@ def rpc_method_handler( ], request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], -) -> GenericRpcHandler: +) -> GenericRpcHandlerBuilder: async def wrapped( peer: str, input: Channel[Any], @@ -277,7 +277,7 @@ def subscription_method_handler( ], request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], -) -> GenericRpcHandler: +) -> GenericRpcHandlerBuilder: async def wrapped( peer: str, input: Channel[Any], @@ -336,7 +336,7 @@ def upload_method_handler( ], request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], -) -> GenericRpcHandler: +) -> GenericRpcHandlerBuilder: async def wrapped( peer: str, input: Channel[Any], @@ -414,7 +414,7 @@ def stream_method_handler( ], request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], -) -> GenericRpcHandler: +) -> GenericRpcHandlerBuilder: async def wrapped( peer: str, input: Channel[Any], diff --git a/src/replit_river/server.py b/src/replit_river/server.py index a186f4c3..64974fc3 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -13,7 +13,7 @@ from replit_river.transport_options import TransportOptions from .rpc import ( - GenericRpcHandler, + GenericRpcHandlerBuilder, ) logger = logging.getLogger(__name__) @@ -35,7 +35,7 @@ async def close(self) -> None: def add_rpc_handlers( self, - rpc_handlers: Mapping[tuple[str, str], tuple[str, GenericRpcHandler]], + rpc_handlers: Mapping[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]], ) -> None: self._transport._handlers.update(rpc_handlers) diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 64fc7fcc..78bee4b0 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Mapping import nanoid # type: ignore # type: ignore -from grpc import GenericRpcHandler from pydantic import ValidationError from websockets import ( WebSocketCommonProtocol, @@ -21,6 +20,7 @@ from replit_river.rpc import ( ControlMessageHandshakeRequest, ControlMessageHandshakeResponse, + GenericRpcHandlerBuilder, HandShakeStatus, TransportMessage, ) @@ -38,7 +38,7 @@ class ServerTransport: _sessions: dict[str, ServerSession] - _handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] + _handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]] def __init__( self, @@ -48,7 +48,7 @@ def __init__( self._sessions = {} self._transport_id = transport_id self._transport_options = transport_options - self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] = {} + self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]] = {} self._session_lock = asyncio.Lock() async def _close_all_sessions( diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 0702746d..828ec89e 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -32,7 +32,7 @@ ACK_BIT, STREAM_CLOSED_BIT, STREAM_OPEN_BIT, - GenericRpcHandler, + GenericRpcHandlerBuilder, TransportMessage, TransportMessageTracingSetter, ) @@ -78,7 +78,7 @@ def __init__( websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, is_server: bool, - handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]], + handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]], close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], retry_connection_callback: ( Callable[ diff --git a/tests/conftest.py b/tests/conftest.py index 529ffb23..b9b8cdf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from replit_river.error_schema import RiverError from replit_river.rpc import ( - GenericRpcHandler, + GenericRpcHandlerBuilder, TransportMessage, ) @@ -17,7 +17,7 @@ pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"] HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] -HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandler]] +HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandlerBuilder]] def transport_message( From 034005f46684db039c9b6f518530de69bf18db39 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 17:19:53 -0700 Subject: [PATCH 07/23] Split serve() functionality between client and server --- src/replit_river/client_session.py | 104 ++++++++++++++++++ src/replit_river/server_session.py | 169 ++++++++++++++++++++++++++++- src/replit_river/session.py | 158 +-------------------------- 3 files changed, 273 insertions(+), 158 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 479a9f50..53d922e8 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -8,6 +8,7 @@ from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span +from websockets.exceptions import ConnectionClosed from replit_river.error_schema import ( ERROR_CODE_CANCEL, @@ -17,22 +18,125 @@ StreamClosedRiverServiceException, exception_from_message, ) +from replit_river.messages import ( + FailedSendingMessageException, + parse_transport_msg, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, + OutOfOrderMessageException, +) from replit_river.session import Session from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE from .rpc import ( + ACK_BIT, STREAM_CLOSED_BIT, STREAM_OPEN_BIT, ErrorType, InitType, RequestType, ResponseType, + TransportMessage, ) logger = logging.getLogger(__name__) class ClientSession(Session): + async def start_serve_responses(self) -> None: + self._task_manager.create_task(self.serve()) + + async def serve(self) -> None: + """Serve messages from the websocket.""" + self._reset_session_close_countdown() + try: + async with asyncio.TaskGroup() as tg: + try: + await self._handle_messages_from_ws(tg) + except ConnectionClosed: + if self._retry_connection_callback: + self._task_manager.create_task( + self._retry_connection_callback() + ) + + await self._begin_close_session_countdown() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + raise ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + + async def _update_book_keeping(self, msg: TransportMessage) -> None: + await self._seq_manager.check_seq_and_update(msg) + await self._remove_acked_messages_in_buffer() + self._reset_session_close_countdown() + + async def _handle_messages_from_ws( + self, tg: asyncio.TaskGroup | None = None + ) -> None: + logger.debug( + "%s start handling messages from ws %s", + "client", + self._ws_wrapper.id, + ) + try: + ws_wrapper = self._ws_wrapper + async for message in ws_wrapper.ws: + try: + if not await ws_wrapper.is_open(): + # We should not process messages if the websocket is closed. + break + msg = parse_transport_msg(message, self._transport_options) + + logger.debug(f"{self._transport_id} got a message %r", msg) + + await self._update_book_keeping(msg) + if msg.controlFlags & ACK_BIT != 0: + continue + async with self._stream_lock: + stream = self._streams.get(msg.streamId, None) + if msg.controlFlags & STREAM_OPEN_BIT == 0: + if not stream: + logger.warning("no stream for %s", msg.streamId) + raise IgnoreMessageException( + "no stream for message, ignoring" + ) + await self._add_msg_to_stream(msg, stream) + else: + raise InvalidMessageException( + "Client should not receive stream open bit" + ) + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + async with self._stream_lock: + del self._streams[msg.streamId] + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) + continue + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await ws_wrapper.close() + return + except InvalidMessageException: + logger.exception("Got invalid transport message, closing session") + await self.close() + return + except ConnectionClosed as e: + raise e + async def send_rpc( self, service_name: str, diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 3c1061b9..d4a04170 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -1,13 +1,34 @@ +import asyncio import logging +from typing import Any +from aiochannel import Channel, ChannelClosed from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from websockets.exceptions import ConnectionClosed +from replit_river.messages import ( + FailedSendingMessageException, + parse_transport_msg, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, + OutOfOrderMessageException, +) from replit_river.session import Session +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE from .rpc import ( + ACK_BIT, + STREAM_CLOSED_BIT, + STREAM_OPEN_BIT, + TransportMessage, TransportMessageTracingSetter, ) +logger = logging.getLogger(__name__) + + logger = logging.getLogger(__name__) trace_propagator = TraceContextTextMapPropagator() @@ -17,4 +38,150 @@ class ServerSession(Session): """A transport object that handles the websocket connection with a client.""" - pass + async def start_serve_responses(self) -> None: + self._task_manager.create_task(self.serve()) + + async def serve(self) -> None: + """Serve messages from the websocket.""" + self._reset_session_close_countdown() + try: + async with asyncio.TaskGroup() as tg: + try: + await self._handle_messages_from_ws(tg) + except ConnectionClosed: + if self._retry_connection_callback: + self._task_manager.create_task( + self._retry_connection_callback() + ) + + await self._begin_close_session_countdown() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + raise ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + + async def _update_book_keeping(self, msg: TransportMessage) -> None: + await self._seq_manager.check_seq_and_update(msg) + await self._remove_acked_messages_in_buffer() + self._reset_session_close_countdown() + + async def _handle_messages_from_ws( + self, tg: asyncio.TaskGroup | None = None + ) -> None: + logger.debug( + "%s start handling messages from ws %s", + "server", + self._ws_wrapper.id, + ) + try: + ws_wrapper = self._ws_wrapper + async for message in ws_wrapper.ws: + try: + if not await ws_wrapper.is_open(): + # We should not process messages if the websocket is closed. + break + msg = parse_transport_msg(message, self._transport_options) + + logger.debug(f"{self._transport_id} got a message %r", msg) + + await self._update_book_keeping(msg) + if msg.controlFlags & ACK_BIT != 0: + continue + async with self._stream_lock: + stream = self._streams.get(msg.streamId, None) + if msg.controlFlags & STREAM_OPEN_BIT == 0: + if not stream: + logger.warning("no stream for %s", msg.streamId) + raise IgnoreMessageException( + "no stream for message, ignoring" + ) + await self._add_msg_to_stream(msg, stream) + else: + # TODO(dstewart) This looks like it opens a new call to handler + # on ever ws message, instead of demuxing and + # routing. + _stream = await self._open_stream_and_call_handler(msg, tg) + if not stream: + async with self._stream_lock: + self._streams[msg.streamId] = _stream + stream = _stream + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + async with self._stream_lock: + del self._streams[msg.streamId] + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) + continue + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await ws_wrapper.close() + return + except InvalidMessageException: + logger.exception("Got invalid transport message, closing session") + await self.close() + return + except ConnectionClosed as e: + raise e + + async def _open_stream_and_call_handler( + self, + msg: TransportMessage, + tg: asyncio.TaskGroup | None, + ) -> Channel: + if not msg.serviceName or not msg.procedureName: + raise IgnoreMessageException( + f"Service name or procedure name is missing in the message {msg}" + ) + key = (msg.serviceName, msg.procedureName) + handler = self._handlers.get(key, None) + if not handler: + raise IgnoreMessageException( + f"No handler for {key} handlers : {self._handlers.keys()}" + ) + method_type, handler_func = handler + is_streaming_output = method_type in ( + "subscription-stream", # subscription + "stream", + ) + is_streaming_input = method_type in ( + "upload-stream", # subscription + "stream", + ) + # New channel pair. + input_stream: Channel[Any] = Channel( + MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1 + ) + output_stream: Channel[Any] = Channel( + MAX_MESSAGE_BUFFER_SIZE if is_streaming_output else 1 + ) + if ( + msg.controlFlags & STREAM_CLOSED_BIT == 0 + or msg.payload.get("type", None) != "CLOSE" + ): + try: + await input_stream.put(msg.payload) + except (RuntimeError, ChannelClosed) as e: + raise InvalidMessageException(e) from e + # Start the handler. + self._task_manager.create_task( + handler_func(msg.from_, input_stream, output_stream), tg + ) + self._task_manager.create_task( + self._send_responses_from_output_stream( + msg.streamId, output_stream, is_streaming_output + ), + tg, + ) + return input_stream diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 828ec89e..f6b5dec4 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -8,30 +8,24 @@ from aiochannel import Channel, ChannelClosed from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from websockets.exceptions import ConnectionClosed from replit_river.common_session import setup_heartbeat from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( FailedSendingMessageException, WebsocketClosedException, - parse_transport_msg, send_transport_message, ) from replit_river.seq_manager import ( - IgnoreMessageException, InvalidMessageException, - OutOfOrderMessageException, SeqManager, ) from replit_river.task_manager import BackgroundTaskManager -from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions +from replit_river.transport_options import TransportOptions from replit_river.websocket_wrapper import WebsocketWrapper from .rpc import ( - ACK_BIT, STREAM_CLOSED_BIT, - STREAM_OPEN_BIT, GenericRpcHandlerBuilder, TransportMessage, TransportMessageTracingSetter, @@ -174,100 +168,6 @@ async def _begin_close_session_countdown(self) -> None: ) self._close_session_after_time_secs = close_session_after_time_secs - async def serve(self) -> None: - """Serve messages from the websocket.""" - self._reset_session_close_countdown() - try: - async with asyncio.TaskGroup() as tg: - try: - await self._handle_messages_from_ws(tg) - except ConnectionClosed: - if self._retry_connection_callback: - self._task_manager.create_task( - self._retry_connection_callback() - ) - - await self._begin_close_session_countdown() - logger.debug("ConnectionClosed while serving", exc_info=True) - except FailedSendingMessageException: - # Expected error if the connection is closed. - logger.debug( - "FailedSendingMessageException while serving", exc_info=True - ) - except Exception: - logger.exception("caught exception at message iterator") - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) - if unhandled: - raise ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) - - async def _update_book_keeping(self, msg: TransportMessage) -> None: - await self._seq_manager.check_seq_and_update(msg) - await self._remove_acked_messages_in_buffer() - self._reset_session_close_countdown() - - async def _handle_messages_from_ws( - self, tg: asyncio.TaskGroup | None = None - ) -> None: - logger.debug( - "%s start handling messages from ws %s", - "server" if self._is_server else "client", - self._ws_wrapper.id, - ) - try: - ws_wrapper = self._ws_wrapper - async for message in ws_wrapper.ws: - try: - if not await ws_wrapper.is_open(): - # We should not process messages if the websocket is closed. - break - msg = parse_transport_msg(message, self._transport_options) - - logger.debug(f"{self._transport_id} got a message %r", msg) - - await self._update_book_keeping(msg) - if msg.controlFlags & ACK_BIT != 0: - continue - async with self._stream_lock: - stream = self._streams.get(msg.streamId, None) - if msg.controlFlags & STREAM_OPEN_BIT == 0: - if not stream: - logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException( - "no stream for message, ignoring" - ) - await self._add_msg_to_stream(msg, stream) - else: - # TODO(dstewart) This looks like it opens a new call to handler - # on ever ws message, instead of demuxing and - # routing. - _stream = await self._open_stream_and_call_handler(msg, tg) - if not stream: - async with self._stream_lock: - self._streams[msg.streamId] = _stream - stream = _stream - - if msg.controlFlags & STREAM_CLOSED_BIT != 0: - if stream: - stream.close() - async with self._stream_lock: - del self._streams[msg.streamId] - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue - except OutOfOrderMessageException: - logger.exception("Out of order message, closing connection") - await ws_wrapper.close() - return - except InvalidMessageException: - logger.exception("Got invalid transport message, closing session") - await self.close() - return - except ConnectionClosed as e: - raise e - async def replace_with_new_websocket( self, new_ws: websockets.WebSocketCommonProtocol ) -> None: @@ -447,59 +347,6 @@ async def close_websocket( if should_retry and self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) - async def _open_stream_and_call_handler( - self, - msg: TransportMessage, - tg: asyncio.TaskGroup | None, - ) -> Channel: - if not self._is_server: - raise InvalidMessageException("Client should not receive stream open bit") - if not msg.serviceName or not msg.procedureName: - raise IgnoreMessageException( - f"Service name or procedure name is missing in the message {msg}" - ) - key = (msg.serviceName, msg.procedureName) - handler = self._handlers.get(key, None) - if not handler: - raise IgnoreMessageException( - f"No handler for {key} handlers : {self._handlers.keys()}" - ) - method_type, handler_func = handler - is_streaming_output = method_type in ( - "subscription-stream", # subscription - "stream", - ) - is_streaming_input = method_type in ( - "upload-stream", # subscription - "stream", - ) - # New channel pair. - input_stream: Channel[Any] = Channel( - MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1 - ) - output_stream: Channel[Any] = Channel( - MAX_MESSAGE_BUFFER_SIZE if is_streaming_output else 1 - ) - if ( - msg.controlFlags & STREAM_CLOSED_BIT == 0 - or msg.payload.get("type", None) != "CLOSE" - ): - try: - await input_stream.put(msg.payload) - except (RuntimeError, ChannelClosed) as e: - raise InvalidMessageException(e) from e - # Start the handler. - self._task_manager.create_task( - handler_func(msg.from_, input_stream, output_stream), tg - ) - self._task_manager.create_task( - self._send_responses_from_output_stream( - msg.streamId, output_stream, is_streaming_output - ), - tg, - ) - return input_stream - async def _add_msg_to_stream( self, msg: TransportMessage, @@ -523,9 +370,6 @@ async def _add_msg_to_stream( async def _remove_acked_messages_in_buffer(self) -> None: await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) - async def start_serve_responses(self) -> None: - self._task_manager.create_task(self.serve()) - async def close(self) -> None: """Close the session and all associated streams.""" logger.info( From 6e1b781cdb12e5590d3be454c92a38705be979d0 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 17:27:46 -0700 Subject: [PATCH 08/23] Remove handlers from Client* --- src/replit_river/client_transport.py | 1 - src/replit_river/server_session.py | 38 ++++++++++++++++++++++++++-- src/replit_river/session.py | 3 --- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 869398e5..dfed9c31 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -228,7 +228,6 @@ async def _create_new_session( is_server=False, close_session_callback=self._delete_session, retry_connection_callback=self._retry_connection, - handlers={}, ) self._sessions[new_session._to_id] = new_session diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index d4a04170..650f60e3 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -1,7 +1,8 @@ import asyncio import logging -from typing import Any +from typing import Any, Callable, Coroutine +import websockets from aiochannel import Channel, ChannelClosed from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from websockets.exceptions import ConnectionClosed @@ -16,12 +17,13 @@ OutOfOrderMessageException, ) from replit_river.session import Session -from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions from .rpc import ( ACK_BIT, STREAM_CLOSED_BIT, STREAM_OPEN_BIT, + GenericRpcHandlerBuilder, TransportMessage, TransportMessageTracingSetter, ) @@ -38,6 +40,38 @@ class ServerSession(Session): """A transport object that handles the websocket connection with a client.""" + handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]] + + def __init__( + self, + transport_id: str, + to_id: str, + session_id: str, + websocket: websockets.WebSocketCommonProtocol, + transport_options: TransportOptions, + is_server: bool, + handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]], + close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], + retry_connection_callback: ( + Callable[ + [], + Coroutine[Any, Any, Any], + ] + | None + ) = None, + ) -> None: + super().__init__( + transport_id=transport_id, + to_id=to_id, + session_id=session_id, + websocket=websocket, + transport_options=transport_options, + is_server=is_server, + close_session_callback=close_session_callback, + retry_connection_callback=retry_connection_callback, + ) + self._handlers = handlers + async def start_serve_responses(self) -> None: self._task_manager.create_task(self.serve()) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index f6b5dec4..ee8da373 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -26,7 +26,6 @@ from .rpc import ( STREAM_CLOSED_BIT, - GenericRpcHandlerBuilder, TransportMessage, TransportMessageTracingSetter, ) @@ -72,7 +71,6 @@ def __init__( websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, is_server: bool, - handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]], close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], retry_connection_callback: ( Callable[ @@ -85,7 +83,6 @@ def __init__( self._transport_id = transport_id self._to_id = to_id self.session_id = session_id - self._handlers = handlers self._is_server = is_server self._transport_options = transport_options From 42ead4197cd0a68fa44dab74c999595cefa0cb23 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 17:29:20 -0700 Subject: [PATCH 09/23] Strip is_server from Session __init__ --- src/replit_river/client_transport.py | 1 - src/replit_river/server_session.py | 3 +-- src/replit_river/server_transport.py | 2 -- src/replit_river/session.py | 4 ++-- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index dfed9c31..059aa6b0 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -225,7 +225,6 @@ async def _create_new_session( session_id=hs_request.sessionId, websocket=new_ws, transport_options=self._transport_options, - is_server=False, close_session_callback=self._delete_session, retry_connection_callback=self._retry_connection, ) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 650f60e3..873f3452 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -49,7 +49,6 @@ def __init__( session_id: str, websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, - is_server: bool, handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]], close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], retry_connection_callback: ( @@ -66,10 +65,10 @@ def __init__( session_id=session_id, websocket=websocket, transport_options=transport_options, - is_server=is_server, close_session_callback=close_session_callback, retry_connection_callback=retry_connection_callback, ) + self._is_server = True self._handlers = handlers async def start_serve_responses(self) -> None: diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 78bee4b0..3facc587 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -133,7 +133,6 @@ async def _get_or_create_session( session_id, websocket, self._transport_options, - True, self._handlers, close_session_callback=self._delete_session, ) @@ -154,7 +153,6 @@ async def _get_or_create_session( session_id, websocket, self._transport_options, - True, self._handlers, close_session_callback=self._delete_session, ) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index ee8da373..fb6359e5 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -62,6 +62,7 @@ class SessionState(enum.Enum): class Session: """Common functionality shared between client_session and server_session""" + _is_server: bool def __init__( self, @@ -70,7 +71,6 @@ def __init__( session_id: str, websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, - is_server: bool, close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], retry_connection_callback: ( Callable[ @@ -83,7 +83,7 @@ def __init__( self._transport_id = transport_id self._to_id = to_id self.session_id = session_id - self._is_server = is_server + self._is_server = False self._transport_options = transport_options # session state, only modified during closing From c8ace77ff7a189b920cb4316445e2c52908005e7 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 17:44:00 -0700 Subject: [PATCH 10/23] Adding __init__ to ClientSession --- src/replit_river/client_session.py | 31 ++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 53d922e8..8eea0a00 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -2,12 +2,13 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable +from typing import Any, AsyncGenerator, Callable, Coroutine import nanoid # type: ignore from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span +import websockets from websockets.exceptions import ConnectionClosed from replit_river.error_schema import ( @@ -28,7 +29,7 @@ OutOfOrderMessageException, ) from replit_river.session import Session -from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions from .rpc import ( ACK_BIT, @@ -45,6 +46,32 @@ class ClientSession(Session): + def __init__( + self, + transport_id: str, + to_id: str, + session_id: str, + websocket: websockets.WebSocketCommonProtocol, + transport_options: TransportOptions, + close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]], + retry_connection_callback: ( + Callable[ + [], + Coroutine[Any, Any, Any], + ] + | None + ) = None, + ) -> None: + super().__init__( + transport_id=transport_id, + to_id=to_id, + session_id=session_id, + websocket=websocket, + transport_options=transport_options, + close_session_callback=close_session_callback, + retry_connection_callback=retry_connection_callback, + ) + async def start_serve_responses(self) -> None: self._task_manager.create_task(self.serve()) From 7f0c323833c8d83c58d17e49f29f457b64304e43 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 18:14:31 -0700 Subject: [PATCH 11/23] Remove is_server --- src/replit_river/client_session.py | 11 ++++++++++- src/replit_river/server_session.py | 12 ++++++++++-- src/replit_river/session.py | 18 +++++------------- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 8eea0a00..482534a1 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -5,10 +5,10 @@ from typing import Any, AsyncGenerator, Callable, Coroutine import nanoid # type: ignore +import websockets from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span -import websockets from websockets.exceptions import ConnectionClosed from replit_river.error_schema import ( @@ -72,6 +72,15 @@ def __init__( retry_connection_callback=retry_connection_callback, ) + async def do_close_websocket() -> None: + await self.close_websocket( + self._ws_wrapper, + should_retry=True, + ) + await self._begin_close_session_countdown() + + self._setup_heartbeats_task(do_close_websocket) + async def start_serve_responses(self) -> None: self._task_manager.create_task(self.serve()) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 873f3452..01ec233b 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -50,7 +50,7 @@ def __init__( websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]], - close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], + close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]], retry_connection_callback: ( Callable[ [], @@ -68,9 +68,17 @@ def __init__( close_session_callback=close_session_callback, retry_connection_callback=retry_connection_callback, ) - self._is_server = True self._handlers = handlers + async def do_close_websocket() -> None: + await self.close_websocket( + self._ws_wrapper, + should_retry=False, + ) + await self._begin_close_session_countdown() + + self._setup_heartbeats_task(do_close_websocket) + async def start_serve_responses(self) -> None: self._task_manager.create_task(self.serve()) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index fb6359e5..514af272 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,7 +1,7 @@ import asyncio import enum import logging -from typing import Any, Callable, Coroutine, Protocol +from typing import Any, Awaitable, Callable, Coroutine, Protocol import nanoid # type: ignore import websockets @@ -62,7 +62,6 @@ class SessionState(enum.Enum): class Session: """Common functionality shared between client_session and server_session""" - _is_server: bool def __init__( self, @@ -83,7 +82,6 @@ def __init__( self._transport_id = transport_id self._to_id = to_id self.session_id = session_id - self._is_server = False self._transport_options = transport_options # session state, only modified during closing @@ -108,16 +106,10 @@ def __init__( self._buffer = MessageBuffer(self._transport_options.buffer_size) self._task_manager = BackgroundTaskManager() - self._setup_heartbeats_task() - - def _setup_heartbeats_task(self) -> None: - async def do_close_websocket() -> None: - await self.close_websocket( - self._ws_wrapper, - should_retry=not self._is_server, - ) - await self._begin_close_session_countdown() - + def _setup_heartbeats_task( + self, + do_close_websocket: Callable[[], Awaitable[None]], + ) -> None: def increment_and_get_heartbeat_misses() -> int: self._heartbeat_misses += 1 return self._heartbeat_misses From 2db0827f92470a8f162786f625924e275aaf1bc9 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 19:03:56 -0700 Subject: [PATCH 12/23] Moving more fields from session to the specialized classes --- src/replit_river/client_session.py | 40 ++++++++++++++---------------- src/replit_river/server_session.py | 7 +++--- src/replit_river/session.py | 3 --- 3 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 482534a1..27b7ba80 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -88,24 +88,21 @@ async def serve(self) -> None: """Serve messages from the websocket.""" self._reset_session_close_countdown() try: - async with asyncio.TaskGroup() as tg: - try: - await self._handle_messages_from_ws(tg) - except ConnectionClosed: - if self._retry_connection_callback: - self._task_manager.create_task( - self._retry_connection_callback() - ) - - await self._begin_close_session_countdown() - logger.debug("ConnectionClosed while serving", exc_info=True) - except FailedSendingMessageException: - # Expected error if the connection is closed. - logger.debug( - "FailedSendingMessageException while serving", exc_info=True - ) - except Exception: - logger.exception("caught exception at message iterator") + try: + await self._handle_messages_from_ws() + except ConnectionClosed: + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + + await self._begin_close_session_countdown() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") except ExceptionGroup as eg: _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) if unhandled: @@ -118,9 +115,10 @@ async def _update_book_keeping(self, msg: TransportMessage) -> None: await self._remove_acked_messages_in_buffer() self._reset_session_close_countdown() - async def _handle_messages_from_ws( - self, tg: asyncio.TaskGroup | None = None - ) -> None: + async def _remove_acked_messages_in_buffer(self) -> None: + await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) + + async def _handle_messages_from_ws(self) -> None: logger.debug( "%s start handling messages from ws %s", "client", diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 01ec233b..7c9239e1 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -116,9 +116,10 @@ async def _update_book_keeping(self, msg: TransportMessage) -> None: await self._remove_acked_messages_in_buffer() self._reset_session_close_countdown() - async def _handle_messages_from_ws( - self, tg: asyncio.TaskGroup | None = None - ) -> None: + async def _remove_acked_messages_in_buffer(self) -> None: + await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) + + async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: logger.debug( "%s start handling messages from ws %s", "server", diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 514af272..a8cb3e68 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -356,9 +356,6 @@ async def _add_msg_to_stream( except RuntimeError as e: raise InvalidMessageException(e) from e - async def _remove_acked_messages_in_buffer(self) -> None: - await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) - async def close(self) -> None: """Close the session and all associated streams.""" logger.info( From a68b8ffc14e500d0cc74c1c27d7218e4a56218b8 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 19:16:35 -0700 Subject: [PATCH 13/23] Resolving circular import --- src/replit_river/common_session.py | 30 ++++++++++++++++++++++++++++-- src/replit_river/session.py | 29 ++--------------------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 2388059d..2f6df56b 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -1,14 +1,40 @@ import asyncio +import enum import logging -from typing import Awaitable, Callable +from typing import Any, Awaitable, Callable, Protocol + +from opentelemetry.trace import Span from replit_river.messages import FailedSendingMessageException from replit_river.rpc import ACK_BIT -from replit_river.session import SendMessage, SessionState logger = logging.getLogger(__name__) +class SendMessage(Protocol): + async def __call__( + self, + *, + stream_id: str, + payload: dict[Any, Any] | str, + control_flags: int, + service_name: str | None, + procedure_name: str | None, + span: Span | None, + ) -> None: ... + + +class SessionState(enum.Enum): + """The state a session can be in. + + Can only transition from ACTIVE to CLOSING to CLOSED. + """ + + ACTIVE = 0 + CLOSING = 1 + CLOSED = 2 + + async def setup_heartbeat( session_id: str, heartbeat_ms: float, diff --git a/src/replit_river/session.py b/src/replit_river/session.py index a8cb3e68..b19f7f33 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,7 +1,6 @@ import asyncio -import enum import logging -from typing import Any, Awaitable, Callable, Coroutine, Protocol +from typing import Any, Awaitable, Callable, Coroutine import nanoid # type: ignore import websockets @@ -9,7 +8,7 @@ from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from replit_river.common_session import setup_heartbeat +from replit_river.common_session import SessionState, setup_heartbeat from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( FailedSendingMessageException, @@ -36,30 +35,6 @@ trace_setter = TransportMessageTracingSetter() -class SendMessage(Protocol): - async def __call__( - self, - *, - stream_id: str, - payload: dict[Any, Any] | str, - control_flags: int, - service_name: str | None, - procedure_name: str | None, - span: Span | None, - ) -> None: ... - - -class SessionState(enum.Enum): - """The state a session can be in. - - Can only transition from ACTIVE to CLOSING to CLOSED. - """ - - ACTIVE = 0 - CLOSING = 1 - CLOSED = 2 - - class Session: """Common functionality shared between client_session and server_session""" From 403a4470ca23ae800aba196cff12912bdbf65230 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 20:10:36 -0700 Subject: [PATCH 14/23] Bubble state out of check_to_close_session --- src/replit_river/common_session.py | 26 ++++++++++++++++++++ src/replit_river/session.py | 38 ++++++++++++------------------ 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 2f6df56b..52271004 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -85,3 +85,29 @@ async def setup_heartbeat( except FailedSendingMessageException: # this is expected during websocket closed period continue + + +async def check_to_close_session( + transport_id: str, + close_session_check_interval_ms: float, + get_state: Callable[[], SessionState], + get_current_time: Callable[[], Awaitable[float]], + get_close_session_after_time_secs: Callable[[], float | None], + do_close: Callable[[], Awaitable[None]], +) -> None: + while True: + await asyncio.sleep(close_session_check_interval_ms / 1000) + if get_state() != SessionState.ACTIVE: + # already closing + return + # calculate the value now before comparing it so that there are no + # await points between the check and the comparison to avoid a TOCTOU + # race. + current_time = await get_current_time() + close_session_after_time_secs = get_close_session_after_time_secs() + if not close_session_after_time_secs: + continue + if current_time > close_session_after_time_secs: + logger.info("Grace period ended for %s, closing session", transport_id) + await do_close() + return diff --git a/src/replit_river/session.py b/src/replit_river/session.py index b19f7f33..3021f865 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -8,7 +8,11 @@ from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from replit_river.common_session import SessionState, setup_heartbeat +from replit_river.common_session import ( + SessionState, + check_to_close_session, + setup_heartbeat, +) from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( FailedSendingMessageException, @@ -101,7 +105,16 @@ def increment_and_get_heartbeat_misses() -> int: increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, ) ) - self._task_manager.create_task(self._check_to_close_session()) + self._task_manager.create_task( + check_to_close_session( + self._transport_id, + self._transport_options.close_session_check_interval_ms, + lambda: self._state, + self._get_current_time, + lambda: self._close_session_after_time_secs, + self.close, + ) + ) async def is_session_open(self) -> bool: async with self._state_lock: @@ -150,27 +163,6 @@ def _reset_session_close_countdown(self) -> None: self._heartbeat_misses = 0 self._close_session_after_time_secs = None - async def _check_to_close_session(self) -> None: - while True: - await asyncio.sleep( - self._transport_options.close_session_check_interval_ms / 1000 - ) - if self._state != SessionState.ACTIVE: - # already closing - return - # calculate the value now before comparing it so that there are no - # await points between the check and the comparison to avoid a TOCTOU - # race. - current_time = await self._get_current_time() - if not self._close_session_after_time_secs: - continue - if current_time > self._close_session_after_time_secs: - logger.info( - "Grace period ended for %s, closing session", self._transport_id - ) - await self.close() - return - async def _send_buffered_messages( self, websocket: websockets.WebSocketCommonProtocol ) -> None: From b3952d22fd26cdb3f0fcf4d4b2ffa2020f6e1e63 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 20:14:46 -0700 Subject: [PATCH 15/23] Moving add_msg_to_stream out --- src/replit_river/client_session.py | 3 ++- src/replit_river/common_session.py | 24 +++++++++++++++++++++++- src/replit_river/server_session.py | 3 ++- src/replit_river/session.py | 21 --------------------- 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 27b7ba80..2c2b9a8d 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -11,6 +11,7 @@ from opentelemetry.trace import Span from websockets.exceptions import ConnectionClosed +from replit_river.common_session import add_msg_to_stream from replit_river.error_schema import ( ERROR_CODE_CANCEL, ERROR_CODE_STREAM_CLOSED, @@ -146,7 +147,7 @@ async def _handle_messages_from_ws(self) -> None: raise IgnoreMessageException( "no stream for message, ignoring" ) - await self._add_msg_to_stream(msg, stream) + await add_msg_to_stream(msg, stream) else: raise InvalidMessageException( "Client should not receive stream open bit" diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py index 52271004..7193733d 100644 --- a/src/replit_river/common_session.py +++ b/src/replit_river/common_session.py @@ -3,10 +3,12 @@ import logging from typing import Any, Awaitable, Callable, Protocol +from aiochannel import Channel, ChannelClosed from opentelemetry.trace import Span from replit_river.messages import FailedSendingMessageException -from replit_river.rpc import ACK_BIT +from replit_river.rpc import ACK_BIT, STREAM_CLOSED_BIT, TransportMessage +from replit_river.seq_manager import InvalidMessageException logger = logging.getLogger(__name__) @@ -111,3 +113,23 @@ async def check_to_close_session( logger.info("Grace period ended for %s, closing session", transport_id) await do_close() return + + +async def add_msg_to_stream( + msg: TransportMessage, + stream: Channel, +) -> None: + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + return + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 7c9239e1..9fc81f7b 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -7,6 +7,7 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from websockets.exceptions import ConnectionClosed +from replit_river.common_session import add_msg_to_stream from replit_river.messages import ( FailedSendingMessageException, parse_transport_msg, @@ -147,7 +148,7 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: raise IgnoreMessageException( "no stream for message, ignoring" ) - await self._add_msg_to_stream(msg, stream) + await add_msg_to_stream(msg, stream) else: # TODO(dstewart) This looks like it opens a new call to handler # on ever ws message, instead of demuxing and diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 3021f865..2ae38cf2 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -20,7 +20,6 @@ send_transport_message, ) from replit_river.seq_manager import ( - InvalidMessageException, SeqManager, ) from replit_river.task_manager import BackgroundTaskManager @@ -303,26 +302,6 @@ async def close_websocket( if should_retry and self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) - async def _add_msg_to_stream( - self, - msg: TransportMessage, - stream: Channel, - ) -> None: - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - return - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e - async def close(self) -> None: """Close the session and all associated streams.""" logger.info( From 8d9161e3ec14f5cd2fb87e96ddc2da0308489595 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 20:17:30 -0700 Subject: [PATCH 16/23] Moving send_responses_from_output_stream to server_session --- src/replit_river/server_session.py | 22 ++++++++++++++++++++++ src/replit_river/session.py | 25 +------------------------ 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 9fc81f7b..741fb592 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -228,3 +228,25 @@ async def _open_stream_and_call_handler( tg, ) return input_stream + + async def _send_responses_from_output_stream( + self, + stream_id: str, + output: Channel[Any], + is_streaming_output: bool, + ) -> None: + """Send serialized messages to the websockets.""" + try: + async for payload in output: + if not is_streaming_output: + await self.send_message(stream_id, payload, STREAM_CLOSED_BIT) + return + await self.send_message(stream_id, payload) + logger.debug("sent an end of stream %r", stream_id) + await self.send_message(stream_id, {"type": "CLOSE"}, STREAM_CLOSED_BIT) + except FailedSendingMessageException: + logger.exception("Error while sending responses") + except (RuntimeError, ChannelClosed): + logger.exception("Error while sending responses") + except Exception: + logger.exception("Unknown error while river sending responses back") diff --git a/src/replit_river/session.py b/src/replit_river/session.py index 2ae38cf2..a8b7cee1 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -4,7 +4,7 @@ import nanoid # type: ignore import websockets -from aiochannel import Channel, ChannelClosed +from aiochannel import Channel from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -27,7 +27,6 @@ from replit_river.websocket_wrapper import WebsocketWrapper from .rpc import ( - STREAM_CLOSED_BIT, TransportMessage, TransportMessageTracingSetter, ) @@ -268,28 +267,6 @@ async def send_message( "Failed sending message, waiting for retry from buffer", exc_info=True ) - async def _send_responses_from_output_stream( - self, - stream_id: str, - output: Channel[Any], - is_streaming_output: bool, - ) -> None: - """Send serialized messages to the websockets.""" - try: - async for payload in output: - if not is_streaming_output: - await self.send_message(stream_id, payload, STREAM_CLOSED_BIT) - return - await self.send_message(stream_id, payload) - logger.debug("sent an end of stream %r", stream_id) - await self.send_message(stream_id, {"type": "CLOSE"}, STREAM_CLOSED_BIT) - except FailedSendingMessageException: - logger.exception("Error while sending responses") - except (RuntimeError, ChannelClosed): - logger.exception("Error while sending responses") - except Exception: - logger.exception("Unknown error while river sending responses back") - async def close_websocket( self, ws_wrapper: WebsocketWrapper, should_retry: bool ) -> None: From 36aee31f905d5b679633cc08728e95638cc29edf Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 20:20:05 -0700 Subject: [PATCH 17/23] Turns out _send_buffered_messages was only used in one place --- src/replit_river/session.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/replit_river/session.py b/src/replit_river/session.py index a8b7cee1..d908bdda 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -152,24 +152,14 @@ async def replace_with_new_websocket( if new_ws.id != old_ws_id: await old_wrapper.close() self._ws_wrapper = WebsocketWrapper(new_ws) - await self._send_buffered_messages(new_ws) - async def _get_current_time(self) -> float: - return asyncio.get_event_loop().time() - - def _reset_session_close_countdown(self) -> None: - self._heartbeat_misses = 0 - self._close_session_after_time_secs = None - - async def _send_buffered_messages( - self, websocket: websockets.WebSocketCommonProtocol - ) -> None: + # Send buffered messages to the new ws buffered_messages = list(self._buffer.buffer) for msg in buffered_messages: try: await self._send_transport_message( msg, - websocket, + new_ws, ) except WebsocketClosedException: logger.info( @@ -180,6 +170,13 @@ async def _send_buffered_messages( logger.exception("Error while sending buffered messages") break + async def _get_current_time(self) -> float: + return asyncio.get_event_loop().time() + + def _reset_session_close_countdown(self) -> None: + self._heartbeat_misses = 0 + self._close_session_after_time_secs = None + async def _send_transport_message( self, msg: TransportMessage, From 628a8ae9f825b89a1ff15df1475526b7134a80f0 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 20:22:55 -0700 Subject: [PATCH 18/23] Unused --- src/replit_river/server_session.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 741fb592..bf1e0865 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -80,9 +80,6 @@ async def do_close_websocket() -> None: self._setup_heartbeats_task(do_close_websocket) - async def start_serve_responses(self) -> None: - self._task_manager.create_task(self.serve()) - async def serve(self) -> None: """Serve messages from the websocket.""" self._reset_session_close_countdown() From a4263534e1e6b8f3cb472dfd8d7ef3b8c035456a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 20:26:18 -0700 Subject: [PATCH 19/23] Inline --- src/replit_river/client_session.py | 5 +---- src/replit_river/server_session.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 2c2b9a8d..dbb661d3 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -113,11 +113,8 @@ async def serve(self) -> None: async def _update_book_keeping(self, msg: TransportMessage) -> None: await self._seq_manager.check_seq_and_update(msg) - await self._remove_acked_messages_in_buffer() - self._reset_session_close_countdown() - - async def _remove_acked_messages_in_buffer(self) -> None: await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) + self._reset_session_close_countdown() async def _handle_messages_from_ws(self) -> None: logger.debug( diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index bf1e0865..54fd6d0e 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -111,11 +111,8 @@ async def serve(self) -> None: async def _update_book_keeping(self, msg: TransportMessage) -> None: await self._seq_manager.check_seq_and_update(msg) - await self._remove_acked_messages_in_buffer() - self._reset_session_close_countdown() - - async def _remove_acked_messages_in_buffer(self) -> None: await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) + self._reset_session_close_countdown() async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: logger.debug( From f9a33f3b9550d7c268488e38c685fd1ad9379606 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 20:28:00 -0700 Subject: [PATCH 20/23] Inline update_bookkeeping --- src/replit_river/client_session.py | 14 +++++++------- src/replit_river/server_session.py | 13 +++++++------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index dbb661d3..599ac5c3 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -40,7 +40,6 @@ InitType, RequestType, ResponseType, - TransportMessage, ) logger = logging.getLogger(__name__) @@ -111,11 +110,6 @@ async def serve(self) -> None: "Unhandled exceptions on River server", unhandled.exceptions ) - async def _update_book_keeping(self, msg: TransportMessage) -> None: - await self._seq_manager.check_seq_and_update(msg) - await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) - self._reset_session_close_countdown() - async def _handle_messages_from_ws(self) -> None: logger.debug( "%s start handling messages from ws %s", @@ -133,7 +127,13 @@ async def _handle_messages_from_ws(self) -> None: logger.debug(f"{self._transport_id} got a message %r", msg) - await self._update_book_keeping(msg) + # Update bookkeeping + await self._seq_manager.check_seq_and_update(msg) + await self._buffer.remove_old_messages( + self._seq_manager.receiver_ack, + ) + self._reset_session_close_countdown() + if msg.controlFlags & ACK_BIT != 0: continue async with self._stream_lock: diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 54fd6d0e..45cc54fa 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -109,11 +109,6 @@ async def serve(self) -> None: "Unhandled exceptions on River server", unhandled.exceptions ) - async def _update_book_keeping(self, msg: TransportMessage) -> None: - await self._seq_manager.check_seq_and_update(msg) - await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) - self._reset_session_close_countdown() - async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: logger.debug( "%s start handling messages from ws %s", @@ -131,7 +126,13 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: logger.debug(f"{self._transport_id} got a message %r", msg) - await self._update_book_keeping(msg) + # Update bookkeeping + await self._seq_manager.check_seq_and_update(msg) + await self._buffer.remove_old_messages( + self._seq_manager.receiver_ack, + ) + self._reset_session_close_countdown() + if msg.controlFlags & ACK_BIT != 0: continue async with self._stream_lock: From 46afc33f34314de5e9929c09392e488e3230defd Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 18 Mar 2025 20:35:54 -0700 Subject: [PATCH 21/23] Unnest ExpectedSessionState constructor --- src/replit_river/client_transport.py | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 059aa6b0..4140afc4 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -1,7 +1,7 @@ import asyncio import logging from collections.abc import Awaitable, Callable -from typing import Generic, Mapping +from typing import Generic, Mapping, assert_never import nanoid import websockets @@ -319,24 +319,27 @@ async def _establish_handshake( ControlMessageHandshakeResponse, ]: try: + expectedSessionState: ExpectedSessionState + match old_session: + case None: + expectedSessionState = ExpectedSessionState( + nextExpectedSeq=0, + nextSentSeq=0, + ) + case ClientSession(): + expectedSessionState = ExpectedSessionState( + nextExpectedSeq=await old_session.get_next_expected_seq(), + nextSentSeq=await old_session.get_next_sent_seq(), + ) + case other: + assert_never(other) handshake_request = await self._send_handshake_request( transport_id=transport_id, to_id=to_id, session_id=session_id, handshake_metadata=handshake_metadata, websocket=websocket, - expected_session_state=ExpectedSessionState( - nextExpectedSeq=( - await old_session.get_next_expected_seq() - if old_session is not None - else 0 - ), - nextSentSeq=( - await old_session.get_next_sent_seq() - if old_session is not None - else 0 - ), - ), + expected_session_state=expectedSessionState, ) except FailedSendingMessageException as e: raise RiverException( From 3bcfefc58316b304f0434df5b9fae7886557e5d4 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 19 Mar 2025 16:17:17 -0700 Subject: [PATCH 22/23] Inlining no-longer-invariant _sessions access --- src/replit_river/client_transport.py | 16 +++++----------- src/replit_river/server_transport.py | 14 ++++---------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 4140afc4..374fcdcd 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -1,7 +1,7 @@ import asyncio import logging from collections.abc import Awaitable, Callable -from typing import Generic, Mapping, assert_never +from typing import Generic, assert_never import nanoid import websockets @@ -71,11 +71,8 @@ def __init__( # We want to make sure there's only one session creation at a time self._create_session_lock = asyncio.Lock() - async def _close_all_sessions( - self, - get_all_sessions: Callable[[], Mapping[str, Session]], - ) -> None: - sessions = get_all_sessions().values() + async def _close_all_sessions(self) -> None: + sessions = self._sessions.values() logger.info( f"start closing sessions {self._transport_id}, number sessions : " f"{len(sessions)}" @@ -94,7 +91,7 @@ def generate_nanoid(self) -> str: async def close(self) -> None: self._rate_limiter.close() - await self._close_all_sessions(self._get_all_sessions) + await self._close_all_sessions() async def get_or_create_session(self) -> ClientSession: async with self._create_session_lock: @@ -235,7 +232,7 @@ async def _create_new_session( async def _retry_connection(self) -> ClientSession: if not self._transport_options.transparent_reconnect: - await self._close_all_sessions(self._get_all_sessions) + await self._close_all_sessions() return await self.get_or_create_session() async def _send_handshake_request( @@ -378,9 +375,6 @@ async def _establish_handshake( ) return handshake_request, handshake_response - def _get_all_sessions(self) -> Mapping[str, Session]: - return self._sessions - async def _delete_session(self, session: Session) -> None: async with self._session_lock: if session._to_id in self._sessions: diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 3facc587..878eb0b7 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, Callable, Mapping +from typing import Any import nanoid # type: ignore # type: ignore from pydantic import ValidationError @@ -51,11 +51,8 @@ def __init__( self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]] = {} self._session_lock = asyncio.Lock() - async def _close_all_sessions( - self, - get_all_sessions: Callable[[], Mapping[str, Session]], - ) -> None: - sessions = get_all_sessions().values() + async def _close_all_sessions(self) -> None: + sessions = self._sessions.values() logger.info( f"start closing sessions {self._transport_id}, number sessions : " f"{len(sessions)}" @@ -111,7 +108,7 @@ async def handshake_to_get_session( raise WebsocketClosedException("No handshake message received") async def close(self) -> None: - await self._close_all_sessions(self._get_all_sessions) + await self._close_all_sessions() async def _get_or_create_session( self, @@ -315,9 +312,6 @@ async def _establish_handshake( return handshake_request, handshake_response - def _get_all_sessions(self) -> Mapping[str, Session]: - return self._sessions - async def _delete_session(self, session: Session) -> None: async with self._session_lock: if session._to_id in self._sessions: From 68224c7488c8665266da616b80d1f01bffe11d7c Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Wed, 19 Mar 2025 17:17:31 -0700 Subject: [PATCH 23/23] Thank you for your service --- src/replit_river/server_session.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py index 45cc54fa..868ff0fb 100644 --- a/src/replit_river/server_session.py +++ b/src/replit_river/server_session.py @@ -145,9 +145,6 @@ async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: ) await add_msg_to_stream(msg, stream) else: - # TODO(dstewart) This looks like it opens a new call to handler - # on ever ws message, instead of demuxing and - # routing. _stream = await self._open_stream_and_call_handler(msg, tg) if not stream: async with self._stream_lock: