diff --git a/src/replit_river/__init__.py b/src/replit_river/__init__.py index f837f16d..bc166e3e 100644 --- a/src/replit_river/__init__.py +++ b/src/replit_river/__init__.py @@ -1,7 +1,7 @@ from .client import Client from .error_schema import RiverError from .rpc import ( - GenericRpcHandler, + GenericRpcHandlerBuilder, GrpcContext, rpc_method_handler, stream_method_handler, @@ -15,7 +15,7 @@ "Server", "GrpcContext", "RiverError", - "GenericRpcHandler", + "GenericRpcHandlerBuilder", "rpc_method_handler", "subscription_method_handler", "upload_method_handler", diff --git a/src/replit_river/client_session.py b/src/replit_river/client_session.py index 479a9f50..599ac5c3 100644 --- a/src/replit_river/client_session.py +++ b/src/replit_river/client_session.py @@ -2,13 +2,16 @@ import logging from collections.abc import AsyncIterable from datetime import timedelta -from typing import Any, AsyncGenerator, Callable +from typing import Any, AsyncGenerator, Callable, Coroutine import nanoid # type: ignore +import websockets from aiochannel import Channel from aiochannel.errors import ChannelClosed from opentelemetry.trace import Span +from websockets.exceptions import ConnectionClosed +from replit_river.common_session import add_msg_to_stream from replit_river.error_schema import ( ERROR_CODE_CANCEL, ERROR_CODE_STREAM_CLOSED, @@ -17,10 +20,20 @@ StreamClosedRiverServiceException, exception_from_message, ) +from replit_river.messages import ( + FailedSendingMessageException, + parse_transport_msg, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, + OutOfOrderMessageException, +) from replit_river.session import Session -from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions from .rpc import ( + ACK_BIT, STREAM_CLOSED_BIT, STREAM_OPEN_BIT, ErrorType, @@ -33,6 +46,129 @@ class ClientSession(Session): + def __init__( + self, + transport_id: str, + to_id: str, + session_id: str, + websocket: websockets.WebSocketCommonProtocol, + transport_options: TransportOptions, + close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]], + retry_connection_callback: ( + Callable[ + [], + Coroutine[Any, Any, Any], + ] + | None + ) = None, + ) -> None: + super().__init__( + transport_id=transport_id, + to_id=to_id, + session_id=session_id, + websocket=websocket, + transport_options=transport_options, + close_session_callback=close_session_callback, + retry_connection_callback=retry_connection_callback, + ) + + async def do_close_websocket() -> None: + await self.close_websocket( + self._ws_wrapper, + should_retry=True, + ) + await self._begin_close_session_countdown() + + self._setup_heartbeats_task(do_close_websocket) + + async def start_serve_responses(self) -> None: + self._task_manager.create_task(self.serve()) + + async def serve(self) -> None: + """Serve messages from the websocket.""" + self._reset_session_close_countdown() + try: + try: + await self._handle_messages_from_ws() + except ConnectionClosed: + if self._retry_connection_callback: + self._task_manager.create_task(self._retry_connection_callback()) + + await self._begin_close_session_countdown() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + raise ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + + async def _handle_messages_from_ws(self) -> None: + logger.debug( + "%s start handling messages from ws %s", + "client", + self._ws_wrapper.id, + ) + try: + ws_wrapper = self._ws_wrapper + async for message in ws_wrapper.ws: + try: + if not await ws_wrapper.is_open(): + # We should not process messages if the websocket is closed. + break + msg = parse_transport_msg(message, self._transport_options) + + logger.debug(f"{self._transport_id} got a message %r", msg) + + # Update bookkeeping + await self._seq_manager.check_seq_and_update(msg) + await self._buffer.remove_old_messages( + self._seq_manager.receiver_ack, + ) + self._reset_session_close_countdown() + + if msg.controlFlags & ACK_BIT != 0: + continue + async with self._stream_lock: + stream = self._streams.get(msg.streamId, None) + if msg.controlFlags & STREAM_OPEN_BIT == 0: + if not stream: + logger.warning("no stream for %s", msg.streamId) + raise IgnoreMessageException( + "no stream for message, ignoring" + ) + await add_msg_to_stream(msg, stream) + else: + raise InvalidMessageException( + "Client should not receive stream open bit" + ) + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + async with self._stream_lock: + del self._streams[msg.streamId] + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) + continue + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await ws_wrapper.close() + return + except InvalidMessageException: + logger.exception("Got invalid transport message, closing session") + await self.close() + return + except ConnectionClosed as e: + raise e + async def send_rpc( self, service_name: str, diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 79552b58..374fcdcd 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -1,8 +1,9 @@ import asyncio import logging from collections.abc import Awaitable, Callable -from typing import Generic +from typing import Generic, assert_never +import nanoid import websockets from pydantic import ValidationError from websockets import ( @@ -36,7 +37,7 @@ IgnoreMessageException, InvalidMessageException, ) -from replit_river.transport import Transport +from replit_river.session import Session from replit_river.transport_options import ( HandshakeMetadataType, TransportOptions, @@ -46,7 +47,9 @@ logger = logging.getLogger(__name__) -class ClientTransport(Transport, Generic[HandshakeMetadataType]): +class ClientTransport(Generic[HandshakeMetadataType]): + _sessions: dict[str, ClientSession] + def __init__( self, uri_and_metadata_factory: Callable[[], Awaitable[UriAndMetadata]], @@ -54,11 +57,11 @@ def __init__( server_id: str, transport_options: TransportOptions, ): - super().__init__( - transport_id=client_id, - transport_options=transport_options, - is_server=False, - ) + self._sessions = {} + self._transport_id = client_id + self._transport_options = transport_options + self._session_lock = asyncio.Lock() + self._uri_and_metadata_factory = uri_and_metadata_factory self._client_id = client_id self._server_id = server_id @@ -68,6 +71,24 @@ def __init__( # We want to make sure there's only one session creation at a time self._create_session_lock = asyncio.Lock() + async def _close_all_sessions(self) -> None: + sessions = self._sessions.values() + logger.info( + f"start closing sessions {self._transport_id}, number sessions : " + f"{len(sessions)}" + ) + sessions_to_close = list(sessions) + + # closing sessions requires access to the session lock, so we need to close + # them one by one to be safe + for session in sessions_to_close: + await session.close() + + logger.info(f"Transport closed {self._transport_id}") + + def generate_nanoid(self) -> str: + return str(nanoid.generate()) + async def close(self) -> None: self._rate_limiter.close() await self._close_all_sessions() @@ -201,13 +222,11 @@ async def _create_new_session( session_id=hs_request.sessionId, websocket=new_ws, transport_options=self._transport_options, - is_server=False, close_session_callback=self._delete_session, retry_connection_callback=self._retry_connection, - handlers={}, ) - self._set_session(new_session) + self._sessions[new_session._to_id] = new_session await new_session.start_serve_responses() return new_session @@ -297,24 +316,27 @@ async def _establish_handshake( ControlMessageHandshakeResponse, ]: try: + expectedSessionState: ExpectedSessionState + match old_session: + case None: + expectedSessionState = ExpectedSessionState( + nextExpectedSeq=0, + nextSentSeq=0, + ) + case ClientSession(): + expectedSessionState = ExpectedSessionState( + nextExpectedSeq=await old_session.get_next_expected_seq(), + nextSentSeq=await old_session.get_next_sent_seq(), + ) + case other: + assert_never(other) handshake_request = await self._send_handshake_request( transport_id=transport_id, to_id=to_id, session_id=session_id, handshake_metadata=handshake_metadata, websocket=websocket, - expected_session_state=ExpectedSessionState( - nextExpectedSeq=( - await old_session.get_next_expected_seq() - if old_session is not None - else 0 - ), - nextSentSeq=( - await old_session.get_next_sent_seq() - if old_session is not None - else 0 - ), - ), + expected_session_state=expectedSessionState, ) except FailedSendingMessageException as e: raise RiverException( @@ -352,3 +374,8 @@ async def _establish_handshake( + f"{handshake_response.status.reason}", ) return handshake_request, handshake_response + + async def _delete_session(self, session: Session) -> None: + async with self._session_lock: + if session._to_id in self._sessions: + del self._sessions[session._to_id] diff --git a/src/replit_river/codegen/client.py b/src/replit_river/codegen/client.py index 8d2791b9..0c837e12 100644 --- a/src/replit_river/codegen/client.py +++ b/src/replit_river/codegen/client.py @@ -67,7 +67,6 @@ FILE_HEADER = dedent( """\ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/src/replit_river/codegen/server.py b/src/replit_river/codegen/server.py index acee8106..c8e765a9 100644 --- a/src/replit_river/codegen/server.py +++ b/src/replit_river/codegen/server.py @@ -342,7 +342,7 @@ def add_{service.name}Servicer_to_server( ) -> None: rpc_method_handlers: Mapping[ tuple[str, str], - tuple[str, river.GenericRpcHandler] + tuple[str, river.GenericRpcHandlerBuilder] ] = {{ """ ), diff --git a/src/replit_river/common_session.py b/src/replit_river/common_session.py new file mode 100644 index 00000000..7193733d --- /dev/null +++ b/src/replit_river/common_session.py @@ -0,0 +1,135 @@ +import asyncio +import enum +import logging +from typing import Any, Awaitable, Callable, Protocol + +from aiochannel import Channel, ChannelClosed +from opentelemetry.trace import Span + +from replit_river.messages import FailedSendingMessageException +from replit_river.rpc import ACK_BIT, STREAM_CLOSED_BIT, TransportMessage +from replit_river.seq_manager import InvalidMessageException + +logger = logging.getLogger(__name__) + + +class SendMessage(Protocol): + async def __call__( + self, + *, + stream_id: str, + payload: dict[Any, Any] | str, + control_flags: int, + service_name: str | None, + procedure_name: str | None, + span: Span | None, + ) -> None: ... + + +class SessionState(enum.Enum): + """The state a session can be in. + + Can only transition from ACTIVE to CLOSING to CLOSED. + """ + + ACTIVE = 0 + CLOSING = 1 + CLOSED = 2 + + +async def setup_heartbeat( + session_id: str, + heartbeat_ms: float, + heartbeats_until_dead: int, + get_state: Callable[[], SessionState], + get_closing_grace_period: Callable[[], float | None], + close_websocket: Callable[[], Awaitable[None]], + send_message: SendMessage, + increment_and_get_heartbeat_misses: Callable[[], int], +) -> None: + logger.debug("Start heartbeat") + while True: + await asyncio.sleep(heartbeat_ms / 1000) + state = get_state() + if state != SessionState.ACTIVE: + logger.debug( + "Session is closed, no need to send heartbeat, state : " + "%r close_session_after_this: %r", + {state}, + {get_closing_grace_period()}, + ) + # session is closing / closed, no need to send heartbeat anymore + return + try: + await send_message( + stream_id="heartbeat", + # TODO: make this a message class + # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 + payload={ + "ack": 0, + }, + control_flags=ACK_BIT, + procedure_name=None, + service_name=None, + span=None, + ) + + if increment_and_get_heartbeat_misses() > heartbeats_until_dead: + if get_closing_grace_period() is not None: + # already in grace period, no need to set again + continue + logger.info( + "%r closing websocket because of heartbeat misses", + session_id, + ) + await close_websocket() + continue + except FailedSendingMessageException: + # this is expected during websocket closed period + continue + + +async def check_to_close_session( + transport_id: str, + close_session_check_interval_ms: float, + get_state: Callable[[], SessionState], + get_current_time: Callable[[], Awaitable[float]], + get_close_session_after_time_secs: Callable[[], float | None], + do_close: Callable[[], Awaitable[None]], +) -> None: + while True: + await asyncio.sleep(close_session_check_interval_ms / 1000) + if get_state() != SessionState.ACTIVE: + # already closing + return + # calculate the value now before comparing it so that there are no + # await points between the check and the comparison to avoid a TOCTOU + # race. + current_time = await get_current_time() + close_session_after_time_secs = get_close_session_after_time_secs() + if not close_session_after_time_secs: + continue + if current_time > close_session_after_time_secs: + logger.info("Grace period ended for %s, closing session", transport_id) + await do_close() + return + + +async def add_msg_to_stream( + msg: TransportMessage, + stream: Channel, +) -> None: + if ( + msg.controlFlags & STREAM_CLOSED_BIT != 0 + and msg.payload.get("type", None) == "CLOSE" + ): + # close message is not sent to the stream + return + try: + await stream.put(msg.payload) + except ChannelClosed: + # The client is no longer interested in this stream, + # just drop the message. + pass + except RuntimeError as e: + raise InvalidMessageException(e) from e diff --git a/src/replit_river/rpc.py b/src/replit_river/rpc.py index 8415143a..0d1bd4d1 100644 --- a/src/replit_river/rpc.py +++ b/src/replit_river/rpc.py @@ -43,7 +43,7 @@ _MetadataType: TypeAlias = grpc.aio.Metadata | Sequence[tuple[str, str | bytes]] -GenericRpcHandler = Callable[ +GenericRpcHandlerBuilder = Callable[ [str, Channel[Any], Channel[Any]], Coroutine[None, None, None] ] ACK_BIT = 0x0001 @@ -220,7 +220,7 @@ def rpc_method_handler( ], request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], -) -> GenericRpcHandler: +) -> GenericRpcHandlerBuilder: async def wrapped( peer: str, input: Channel[Any], @@ -277,7 +277,7 @@ def subscription_method_handler( ], request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], -) -> GenericRpcHandler: +) -> GenericRpcHandlerBuilder: async def wrapped( peer: str, input: Channel[Any], @@ -336,7 +336,7 @@ def upload_method_handler( ], request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], -) -> GenericRpcHandler: +) -> GenericRpcHandlerBuilder: async def wrapped( peer: str, input: Channel[Any], @@ -414,7 +414,7 @@ def stream_method_handler( ], request_deserializer: Callable[[Any], RequestType], response_serializer: Callable[[ResponseType], Any], -) -> GenericRpcHandler: +) -> GenericRpcHandlerBuilder: async def wrapped( peer: str, input: Channel[Any], diff --git a/src/replit_river/server.py b/src/replit_river/server.py index 2bdf05b9..64974fc3 100644 --- a/src/replit_river/server.py +++ b/src/replit_river/server.py @@ -8,12 +8,12 @@ from replit_river.messages import WebsocketClosedException from replit_river.seq_manager import SessionStateMismatchException +from replit_river.server_session import ServerSession from replit_river.server_transport import ServerTransport -from replit_river.session import Session -from replit_river.transport import TransportOptions +from replit_river.transport_options import TransportOptions from .rpc import ( - GenericRpcHandler, + GenericRpcHandlerBuilder, ) logger = logging.getLogger(__name__) @@ -35,13 +35,13 @@ async def close(self) -> None: def add_rpc_handlers( self, - rpc_handlers: Mapping[tuple[str, str], tuple[str, GenericRpcHandler]], + rpc_handlers: Mapping[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]], ) -> None: self._transport._handlers.update(rpc_handlers) async def _handshake_to_get_session( self, websocket: WebSocketServerProtocol - ) -> Session | None: + ) -> ServerSession | None: """This is a wrapper to make sentry happy, sentry doesn't recognize the exception handling outside of a task or asyncio.wait_for. So we need to catch the errors specifically here. diff --git a/src/replit_river/server_session.py b/src/replit_river/server_session.py new file mode 100644 index 00000000..868ff0fb --- /dev/null +++ b/src/replit_river/server_session.py @@ -0,0 +1,244 @@ +import asyncio +import logging +from typing import Any, Callable, Coroutine + +import websockets +from aiochannel import Channel, ChannelClosed +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from websockets.exceptions import ConnectionClosed + +from replit_river.common_session import add_msg_to_stream +from replit_river.messages import ( + FailedSendingMessageException, + parse_transport_msg, +) +from replit_river.seq_manager import ( + IgnoreMessageException, + InvalidMessageException, + OutOfOrderMessageException, +) +from replit_river.session import Session +from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions + +from .rpc import ( + ACK_BIT, + STREAM_CLOSED_BIT, + STREAM_OPEN_BIT, + GenericRpcHandlerBuilder, + TransportMessage, + TransportMessageTracingSetter, +) + +logger = logging.getLogger(__name__) + + +logger = logging.getLogger(__name__) + +trace_propagator = TraceContextTextMapPropagator() +trace_setter = TransportMessageTracingSetter() + + +class ServerSession(Session): + """A transport object that handles the websocket connection with a client.""" + + handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]] + + def __init__( + self, + transport_id: str, + to_id: str, + session_id: str, + websocket: websockets.WebSocketCommonProtocol, + transport_options: TransportOptions, + handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]], + close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]], + retry_connection_callback: ( + Callable[ + [], + Coroutine[Any, Any, Any], + ] + | None + ) = None, + ) -> None: + super().__init__( + transport_id=transport_id, + to_id=to_id, + session_id=session_id, + websocket=websocket, + transport_options=transport_options, + close_session_callback=close_session_callback, + retry_connection_callback=retry_connection_callback, + ) + self._handlers = handlers + + async def do_close_websocket() -> None: + await self.close_websocket( + self._ws_wrapper, + should_retry=False, + ) + await self._begin_close_session_countdown() + + self._setup_heartbeats_task(do_close_websocket) + + async def serve(self) -> None: + """Serve messages from the websocket.""" + self._reset_session_close_countdown() + try: + async with asyncio.TaskGroup() as tg: + try: + await self._handle_messages_from_ws(tg) + except ConnectionClosed: + if self._retry_connection_callback: + self._task_manager.create_task( + self._retry_connection_callback() + ) + + await self._begin_close_session_countdown() + logger.debug("ConnectionClosed while serving", exc_info=True) + except FailedSendingMessageException: + # Expected error if the connection is closed. + logger.debug( + "FailedSendingMessageException while serving", exc_info=True + ) + except Exception: + logger.exception("caught exception at message iterator") + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + raise ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + + async def _handle_messages_from_ws(self, tg: asyncio.TaskGroup) -> None: + logger.debug( + "%s start handling messages from ws %s", + "server", + self._ws_wrapper.id, + ) + try: + ws_wrapper = self._ws_wrapper + async for message in ws_wrapper.ws: + try: + if not await ws_wrapper.is_open(): + # We should not process messages if the websocket is closed. + break + msg = parse_transport_msg(message, self._transport_options) + + logger.debug(f"{self._transport_id} got a message %r", msg) + + # Update bookkeeping + await self._seq_manager.check_seq_and_update(msg) + await self._buffer.remove_old_messages( + self._seq_manager.receiver_ack, + ) + self._reset_session_close_countdown() + + if msg.controlFlags & ACK_BIT != 0: + continue + async with self._stream_lock: + stream = self._streams.get(msg.streamId, None) + if msg.controlFlags & STREAM_OPEN_BIT == 0: + if not stream: + logger.warning("no stream for %s", msg.streamId) + raise IgnoreMessageException( + "no stream for message, ignoring" + ) + await add_msg_to_stream(msg, stream) + else: + _stream = await self._open_stream_and_call_handler(msg, tg) + if not stream: + async with self._stream_lock: + self._streams[msg.streamId] = _stream + stream = _stream + + if msg.controlFlags & STREAM_CLOSED_BIT != 0: + if stream: + stream.close() + async with self._stream_lock: + del self._streams[msg.streamId] + except IgnoreMessageException: + logger.debug("Ignoring transport message", exc_info=True) + continue + except OutOfOrderMessageException: + logger.exception("Out of order message, closing connection") + await ws_wrapper.close() + return + except InvalidMessageException: + logger.exception("Got invalid transport message, closing session") + await self.close() + return + except ConnectionClosed as e: + raise e + + async def _open_stream_and_call_handler( + self, + msg: TransportMessage, + tg: asyncio.TaskGroup | None, + ) -> Channel: + if not msg.serviceName or not msg.procedureName: + raise IgnoreMessageException( + f"Service name or procedure name is missing in the message {msg}" + ) + key = (msg.serviceName, msg.procedureName) + handler = self._handlers.get(key, None) + if not handler: + raise IgnoreMessageException( + f"No handler for {key} handlers : {self._handlers.keys()}" + ) + method_type, handler_func = handler + is_streaming_output = method_type in ( + "subscription-stream", # subscription + "stream", + ) + is_streaming_input = method_type in ( + "upload-stream", # subscription + "stream", + ) + # New channel pair. + input_stream: Channel[Any] = Channel( + MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1 + ) + output_stream: Channel[Any] = Channel( + MAX_MESSAGE_BUFFER_SIZE if is_streaming_output else 1 + ) + if ( + msg.controlFlags & STREAM_CLOSED_BIT == 0 + or msg.payload.get("type", None) != "CLOSE" + ): + try: + await input_stream.put(msg.payload) + except (RuntimeError, ChannelClosed) as e: + raise InvalidMessageException(e) from e + # Start the handler. + self._task_manager.create_task( + handler_func(msg.from_, input_stream, output_stream), tg + ) + self._task_manager.create_task( + self._send_responses_from_output_stream( + msg.streamId, output_stream, is_streaming_output + ), + tg, + ) + return input_stream + + async def _send_responses_from_output_stream( + self, + stream_id: str, + output: Channel[Any], + is_streaming_output: bool, + ) -> None: + """Send serialized messages to the websockets.""" + try: + async for payload in output: + if not is_streaming_output: + await self.send_message(stream_id, payload, STREAM_CLOSED_BIT) + return + await self.send_message(stream_id, payload) + logger.debug("sent an end of stream %r", stream_id) + await self.send_message(stream_id, {"type": "CLOSE"}, STREAM_CLOSED_BIT) + except FailedSendingMessageException: + logger.exception("Error while sending responses") + except (RuntimeError, ChannelClosed): + logger.exception("Error while sending responses") + except Exception: + logger.exception("Unknown error while river sending responses back") diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 888e0ce3..878eb0b7 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import Any @@ -19,6 +20,7 @@ from replit_river.rpc import ( ControlMessageHandshakeRequest, ControlMessageHandshakeResponse, + GenericRpcHandlerBuilder, HandShakeStatus, TransportMessage, ) @@ -27,29 +29,47 @@ InvalidMessageException, SessionStateMismatchException, ) +from replit_river.server_session import ServerSession from replit_river.session import Session -from replit_river.transport import Transport from replit_river.transport_options import TransportOptions logger = logging.getLogger(__name__) -class ServerTransport(Transport): +class ServerTransport: + _sessions: dict[str, ServerSession] + _handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]] + def __init__( self, transport_id: str, transport_options: TransportOptions, ) -> None: - super().__init__( - transport_id=transport_id, - transport_options=transport_options, - is_server=True, + self._sessions = {} + self._transport_id = transport_id + self._transport_options = transport_options + self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]] = {} + self._session_lock = asyncio.Lock() + + async def _close_all_sessions(self) -> None: + sessions = self._sessions.values() + logger.info( + f"start closing sessions {self._transport_id}, number sessions : " + f"{len(sessions)}" ) + sessions_to_close = list(sessions) + + # closing sessions requires access to the session lock, so we need to close + # them one by one to be safe + for session in sessions_to_close: + await session.close() + + logger.info(f"Transport closed {self._transport_id}") async def handshake_to_get_session( self, websocket: WebSocketServerProtocol, - ) -> Session: + ) -> ServerSession: async for message in websocket: try: msg = parse_transport_msg(message, self._transport_options) @@ -96,21 +116,20 @@ async def _get_or_create_session( to_id: str, session_id: str, websocket: WebSocketCommonProtocol, - ) -> Session: + ) -> ServerSession: async with self._session_lock: session_to_close: Session | None = None - new_session: Session | None = None + new_session: ServerSession | None = None if to_id not in self._sessions: logger.info( 'Creating new session with "%s" using ws: %s', to_id, websocket.id ) - new_session = Session( + new_session = ServerSession( transport_id, to_id, session_id, websocket, self._transport_options, - self._is_server, self._handlers, close_session_callback=self._delete_session, ) @@ -125,13 +144,12 @@ async def _get_or_create_session( old_session.session_id, ) session_to_close = old_session - new_session = Session( + new_session = ServerSession( transport_id, to_id, session_id, websocket, self._transport_options, - self._is_server, self._handlers, close_session_callback=self._delete_session, ) @@ -152,7 +170,7 @@ async def _get_or_create_session( if session_to_close: logger.info("Closing stale session %s", session_to_close.session_id) await session_to_close.close() - self._set_session(new_session) + self._sessions[new_session._to_id] = new_session return new_session async def _send_handshake_response( @@ -293,3 +311,8 @@ async def _establish_handshake( ) return handshake_request, handshake_response + + async def _delete_session(self, session: Session) -> None: + async with self._session_lock: + if session._to_id in self._sessions: + del self._sessions[session._to_id] diff --git a/src/replit_river/session.py b/src/replit_river/session.py index a94156a8..d908bdda 100644 --- a/src/replit_river/session.py +++ b/src/replit_river/session.py @@ -1,37 +1,32 @@ import asyncio -import enum import logging -from typing import Any, Callable, Coroutine +from typing import Any, Awaitable, Callable, Coroutine import nanoid # type: ignore import websockets -from aiochannel import Channel, ChannelClosed +from aiochannel import Channel from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from websockets.exceptions import ConnectionClosed +from replit_river.common_session import ( + SessionState, + check_to_close_session, + setup_heartbeat, +) from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError from replit_river.messages import ( FailedSendingMessageException, WebsocketClosedException, - parse_transport_msg, send_transport_message, ) from replit_river.seq_manager import ( - IgnoreMessageException, - InvalidMessageException, - OutOfOrderMessageException, SeqManager, ) from replit_river.task_manager import BackgroundTaskManager -from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions +from replit_river.transport_options import TransportOptions from replit_river.websocket_wrapper import WebsocketWrapper from .rpc import ( - ACK_BIT, - STREAM_CLOSED_BIT, - STREAM_OPEN_BIT, - GenericRpcHandler, TransportMessage, TransportMessageTracingSetter, ) @@ -42,19 +37,8 @@ trace_setter = TransportMessageTracingSetter() -class SessionState(enum.Enum): - """The state a session can be in. - - Can only transition from ACTIVE to CLOSING to CLOSED. - """ - - ACTIVE = 0 - CLOSING = 1 - CLOSED = 2 - - -class Session(object): - """A transport object that handles the websocket connection with a client.""" +class Session: + """Common functionality shared between client_session and server_session""" def __init__( self, @@ -63,8 +47,6 @@ def __init__( session_id: str, websocket: websockets.WebSocketCommonProtocol, transport_options: TransportOptions, - is_server: bool, - handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]], close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]], retry_connection_callback: ( Callable[ @@ -77,8 +59,6 @@ def __init__( self._transport_id = transport_id self._to_id = to_id self.session_id = session_id - self._handlers = handlers - self._is_server = is_server self._transport_options = transport_options # session state, only modified during closing @@ -103,11 +83,36 @@ def __init__( self._buffer = MessageBuffer(self._transport_options.buffer_size) self._task_manager = BackgroundTaskManager() - self._setup_heartbeats_task() + def _setup_heartbeats_task( + self, + do_close_websocket: Callable[[], Awaitable[None]], + ) -> None: + def increment_and_get_heartbeat_misses() -> int: + self._heartbeat_misses += 1 + return self._heartbeat_misses - def _setup_heartbeats_task(self) -> None: - self._task_manager.create_task(self._heartbeat()) - self._task_manager.create_task(self._check_to_close_session()) + self._task_manager.create_task( + setup_heartbeat( + self.session_id, + self._transport_options.heartbeat_ms, + self._transport_options.heartbeats_until_dead, + lambda: self._state, + lambda: self._close_session_after_time_secs, + close_websocket=do_close_websocket, + send_message=self.send_message, + increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses, + ) + ) + self._task_manager.create_task( + check_to_close_session( + self._transport_id, + self._transport_options.close_session_check_interval_ms, + lambda: self._state, + self._get_current_time, + lambda: self._close_session_after_time_secs, + self.close, + ) + ) async def is_session_open(self) -> bool: async with self._state_lock: @@ -138,100 +143,6 @@ async def _begin_close_session_countdown(self) -> None: ) self._close_session_after_time_secs = close_session_after_time_secs - async def serve(self) -> None: - """Serve messages from the websocket.""" - self._reset_session_close_countdown() - try: - async with asyncio.TaskGroup() as tg: - try: - await self._handle_messages_from_ws(tg) - except ConnectionClosed: - if self._retry_connection_callback: - self._task_manager.create_task( - self._retry_connection_callback() - ) - - await self._begin_close_session_countdown() - logger.debug("ConnectionClosed while serving", exc_info=True) - except FailedSendingMessageException: - # Expected error if the connection is closed. - logger.debug( - "FailedSendingMessageException while serving", exc_info=True - ) - except Exception: - logger.exception("caught exception at message iterator") - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) - if unhandled: - raise ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) - - async def _update_book_keeping(self, msg: TransportMessage) -> None: - await self._seq_manager.check_seq_and_update(msg) - await self._remove_acked_messages_in_buffer() - self._reset_session_close_countdown() - - async def _handle_messages_from_ws( - self, tg: asyncio.TaskGroup | None = None - ) -> None: - logger.debug( - "%s start handling messages from ws %s", - "server" if self._is_server else "client", - self._ws_wrapper.id, - ) - try: - ws_wrapper = self._ws_wrapper - async for message in ws_wrapper.ws: - try: - if not await ws_wrapper.is_open(): - # We should not process messages if the websocket is closed. - break - msg = parse_transport_msg(message, self._transport_options) - - logger.debug(f"{self._transport_id} got a message %r", msg) - - await self._update_book_keeping(msg) - if msg.controlFlags & ACK_BIT != 0: - continue - async with self._stream_lock: - stream = self._streams.get(msg.streamId, None) - if msg.controlFlags & STREAM_OPEN_BIT == 0: - if not stream: - logger.warning("no stream for %s", msg.streamId) - raise IgnoreMessageException( - "no stream for message, ignoring" - ) - await self._add_msg_to_stream(msg, stream) - else: - # TODO(dstewart) This looks like it opens a new call to handler - # on ever ws message, instead of demuxing and - # routing. - _stream = await self._open_stream_and_call_handler(msg, tg) - if not stream: - async with self._stream_lock: - self._streams[msg.streamId] = _stream - stream = _stream - - if msg.controlFlags & STREAM_CLOSED_BIT != 0: - if stream: - stream.close() - async with self._stream_lock: - del self._streams[msg.streamId] - except IgnoreMessageException: - logger.debug("Ignoring transport message", exc_info=True) - continue - except OutOfOrderMessageException: - logger.exception("Out of order message, closing connection") - await ws_wrapper.close() - return - except InvalidMessageException: - logger.exception("Got invalid transport message, closing session") - await self.close() - return - except ConnectionClosed as e: - raise e - async def replace_with_new_websocket( self, new_ws: websockets.WebSocketCommonProtocol ) -> None: @@ -241,94 +152,14 @@ async def replace_with_new_websocket( if new_ws.id != old_ws_id: await old_wrapper.close() self._ws_wrapper = WebsocketWrapper(new_ws) - await self._send_buffered_messages(new_ws) - # Server will call serve itself. - if not self._is_server: - await self.start_serve_responses() - - async def _get_current_time(self) -> float: - return asyncio.get_event_loop().time() - - def _reset_session_close_countdown(self) -> None: - self._heartbeat_misses = 0 - self._close_session_after_time_secs = None - - async def _check_to_close_session(self) -> None: - while True: - await asyncio.sleep( - self._transport_options.close_session_check_interval_ms / 1000 - ) - if self._state != SessionState.ACTIVE: - # already closing - return - # calculate the value now before comparing it so that there are no - # await points between the check and the comparison to avoid a TOCTOU - # race. - current_time = await self._get_current_time() - if not self._close_session_after_time_secs: - continue - if current_time > self._close_session_after_time_secs: - logger.info( - "Grace period ended for %s, closing session", self._transport_id - ) - await self.close() - return - - async def _heartbeat( - self, - ) -> None: - logger.debug("Start heartbeat") - while True: - await asyncio.sleep(self._transport_options.heartbeat_ms / 1000) - if self._state != SessionState.ACTIVE: - logger.debug( - "Session is closed, no need to send heartbeat, state : " - "%r close_session_after_this: %r", - {self._state}, - {self._close_session_after_time_secs}, - ) - # session is closing / closed, no need to send heartbeat anymore - return - try: - await self.send_message( - "heartbeat", - # TODO: make this a message class - # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 - { - "ack": 0, - }, - ACK_BIT, - ) - self._heartbeat_misses += 1 - if ( - self._heartbeat_misses - > self._transport_options.heartbeats_until_dead - ): - if self._close_session_after_time_secs is not None: - # already in grace period, no need to set again - continue - logger.info( - "%r closing websocket because of heartbeat misses", - self.session_id, - ) - await self.close_websocket( - self._ws_wrapper, should_retry=not self._is_server - ) - await self._begin_close_session_countdown() - continue - except FailedSendingMessageException: - # this is expected during websocket closed period - continue - async def _send_buffered_messages( - self, websocket: websockets.WebSocketCommonProtocol - ) -> None: + # Send buffered messages to the new ws buffered_messages = list(self._buffer.buffer) for msg in buffered_messages: try: await self._send_transport_message( msg, - websocket, + new_ws, ) except WebsocketClosedException: logger.info( @@ -339,6 +170,13 @@ async def _send_buffered_messages( logger.exception("Error while sending buffered messages") break + async def _get_current_time(self) -> float: + return asyncio.get_event_loop().time() + + def _reset_session_close_countdown(self) -> None: + self._heartbeat_misses = 0 + self._close_session_after_time_secs = None + async def _send_transport_message( self, msg: TransportMessage, @@ -426,28 +264,6 @@ async def send_message( "Failed sending message, waiting for retry from buffer", exc_info=True ) - async def _send_responses_from_output_stream( - self, - stream_id: str, - output: Channel[Any], - is_streaming_output: bool, - ) -> None: - """Send serialized messages to the websockets.""" - try: - async for payload in output: - if not is_streaming_output: - await self.send_message(stream_id, payload, STREAM_CLOSED_BIT) - return - await self.send_message(stream_id, payload) - logger.debug("sent an end of stream %r", stream_id) - await self.send_message(stream_id, {"type": "CLOSE"}, STREAM_CLOSED_BIT) - except FailedSendingMessageException: - logger.exception("Error while sending responses") - except (RuntimeError, ChannelClosed): - logger.exception("Error while sending responses") - except Exception: - logger.exception("Unknown error while river sending responses back") - async def close_websocket( self, ws_wrapper: WebsocketWrapper, should_retry: bool ) -> None: @@ -460,85 +276,6 @@ async def close_websocket( if should_retry and self._retry_connection_callback: self._task_manager.create_task(self._retry_connection_callback()) - async def _open_stream_and_call_handler( - self, - msg: TransportMessage, - tg: asyncio.TaskGroup | None, - ) -> Channel: - if not self._is_server: - raise InvalidMessageException("Client should not receive stream open bit") - if not msg.serviceName or not msg.procedureName: - raise IgnoreMessageException( - f"Service name or procedure name is missing in the message {msg}" - ) - key = (msg.serviceName, msg.procedureName) - handler = self._handlers.get(key, None) - if not handler: - raise IgnoreMessageException( - f"No handler for {key} handlers : {self._handlers.keys()}" - ) - method_type, handler_func = handler - is_streaming_output = method_type in ( - "subscription-stream", # subscription - "stream", - ) - is_streaming_input = method_type in ( - "upload-stream", # subscription - "stream", - ) - # New channel pair. - input_stream: Channel[Any] = Channel( - MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1 - ) - output_stream: Channel[Any] = Channel( - MAX_MESSAGE_BUFFER_SIZE if is_streaming_output else 1 - ) - if ( - msg.controlFlags & STREAM_CLOSED_BIT == 0 - or msg.payload.get("type", None) != "CLOSE" - ): - try: - await input_stream.put(msg.payload) - except (RuntimeError, ChannelClosed) as e: - raise InvalidMessageException(e) from e - # Start the handler. - self._task_manager.create_task( - handler_func(msg.from_, input_stream, output_stream), tg - ) - self._task_manager.create_task( - self._send_responses_from_output_stream( - msg.streamId, output_stream, is_streaming_output - ), - tg, - ) - return input_stream - - async def _add_msg_to_stream( - self, - msg: TransportMessage, - stream: Channel, - ) -> None: - if ( - msg.controlFlags & STREAM_CLOSED_BIT != 0 - and msg.payload.get("type", None) == "CLOSE" - ): - # close message is not sent to the stream - return - try: - await stream.put(msg.payload) - except ChannelClosed: - # The client is no longer interested in this stream, - # just drop the message. - pass - except RuntimeError as e: - raise InvalidMessageException(e) from e - - async def _remove_acked_messages_in_buffer(self) -> None: - await self._buffer.remove_old_messages(self._seq_manager.receiver_ack) - - async def start_serve_responses(self) -> None: - self._task_manager.create_task(self.serve()) - async def close(self) -> None: """Close the session and all associated streams.""" logger.info( diff --git a/src/replit_river/transport.py b/src/replit_river/transport.py deleted file mode 100644 index f0e2b920..00000000 --- a/src/replit_river/transport.py +++ /dev/null @@ -1,53 +0,0 @@ -import asyncio -import logging - -import nanoid # type: ignore - -from replit_river.rpc import ( - GenericRpcHandler, -) -from replit_river.session import Session -from replit_river.transport_options import TransportOptions - -logger = logging.getLogger(__name__) - - -class Transport: - def __init__( - self, - transport_id: str, - transport_options: TransportOptions, - is_server: bool, - ) -> None: - self._transport_id = transport_id - self._transport_options = transport_options - self._is_server = is_server - self._sessions: dict[str, Session] = {} - self._handlers: dict[tuple[str, str], tuple[str, GenericRpcHandler]] = {} - self._session_lock = asyncio.Lock() - - async def _close_all_sessions(self) -> None: - sessions = self._sessions.values() - logger.info( - f"start closing sessions {self._transport_id}, number sessions : " - f"{len(sessions)}" - ) - sessions_to_close = list(sessions) - - # closing sessions requires access to the session lock, so we need to close - # them one by one to be safe - for session in sessions_to_close: - await session.close() - - logger.info(f"Transport closed {self._transport_id}") - - async def _delete_session(self, session: Session) -> None: - async with self._session_lock: - if session._to_id in self._sessions: - del self._sessions[session._to_id] - - def _set_session(self, session: Session) -> None: - self._sessions[session._to_id] = session - - def generate_nanoid(self) -> str: - return str(nanoid.generate()) diff --git a/tests/codegen/rpc/generated/test_service/rpc_method.py b/tests/codegen/rpc/generated/test_service/rpc_method.py index 91f0e562..f7dff38d 100644 --- a/tests/codegen/rpc/generated/test_service/rpc_method.py +++ b/tests/codegen/rpc/generated/test_service/rpc_method.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py index 23d2ab6d..eff77816 100644 --- a/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py +++ b/tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py index f2bb120f..a9b530f5 100644 --- a/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py +++ b/tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py index dbe6e51e..0204f9c2 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py index 75f00e1c..97559be3 100644 --- a/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py +++ b/tests/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py @@ -1,4 +1,3 @@ -# ruff: noqa # Code generated by river.codegen. DO NOT EDIT. from collections.abc import AsyncIterable, AsyncIterator import datetime diff --git a/tests/conftest.py b/tests/conftest.py index 529ffb23..b9b8cdf6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from replit_river.error_schema import RiverError from replit_river.rpc import ( - GenericRpcHandler, + GenericRpcHandlerBuilder, TransportMessage, ) @@ -17,7 +17,7 @@ pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"] HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"] -HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandler]] +HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandlerBuilder]] def transport_message(