From 64706256ab9767e1505d5d2c326d155702c83be3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 15 Apr 2025 15:25:23 -0700 Subject: [PATCH 01/16] More logging --- src/replit_river/v2/session.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 46ab151d..d40851ef 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -293,11 +293,11 @@ def unbind_connecting_task() -> None: # This is safe because each individual function that is waiting on this # function completeing already has a reference, so we'll last a few ticks # before GC. - # - # Let's do our best to avoid clobbering other tasks by comparing the .name current_task = asyncio.current_task() if self._connecting_task is current_task: self._connecting_task = None + else: + logger.debug("unbind_connecting_task failed, id did not match") if not self._connecting_task: self._connecting_task = asyncio.create_task( @@ -1150,21 +1150,23 @@ async def websocket_closed_callback() -> None: raise err + logger.debug("Connected") # We did it! We're connected! last_error = None rate_limiter.start_restoring_budget(client_id) transition_connected(ws) break except Exception as e: - if ws: - close_ws_in_background(ws) - ws = None - last_error = e backoff_time = rate_limiter.get_backoff_ms(client_id) logger.exception( f"Error connecting, retrying with {backoff_time}ms backoff" ) + if ws: + close_ws_in_background(ws) + ws = None + last_error = e await asyncio.sleep(backoff_time / 1000) + logger.debug("Here, about to retry") unbind_connecting_task() if last_error is not None: From 6f73ebfa77233cd2698fd6e748feff4fdbbd91fb Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 15 Apr 2025 15:34:15 -0700 Subject: [PATCH 02/16] Waiting for terminating_task to be finished --- src/replit_river/v2/session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index d40851ef..e12bb777 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -321,6 +321,8 @@ def unbind_connecting_task() -> None: ) await self._connecting_task + if self._terminating_task: + await self._terminating_task def is_closed(self) -> bool: """ From 7b374f49c97d4c5a3df1c749d2e62c2a00d8afff Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 15 Apr 2025 15:44:51 -0700 Subject: [PATCH 03/16] More grace for graceful cleanup --- src/replit_river/v2/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index e12bb777..4aa6b262 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -418,8 +418,8 @@ async def close( # ... message processor so it can exit cleanly self._process_messages.set() - # Wait a tick to permit the waiting tasks to shut down gracefully - await asyncio.sleep(0.01) + # Wait to permit the waiting tasks to shut down gracefully + await asyncio.sleep(0.25) await self._task_manager.cancel_all_tasks() From 8e28338cd4d13e3a099d0f951657ee6a24ae3cb4 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 15 Apr 2025 15:47:18 -0700 Subject: [PATCH 04/16] This is not async --- src/replit_river/v2/client_transport.py | 2 +- src/replit_river/v2/session.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 3dc96522..4c55d07b 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -80,7 +80,7 @@ async def _retry_connection(self) -> Session: logger.debug("Triggering get_or_create_session") return await self.get_or_create_session() - async def _delete_session(self, session: Session) -> None: + def _delete_session(self, session: Session) -> None: if self._session is session: self._session = None else: diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 4aa6b262..fd7bdf53 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -109,7 +109,7 @@ class ResultError(TypedDict): trace_propagator = TraceContextTextMapPropagator() trace_setter = TransportMessageTracingSetter() -CloseSessionCallback: TypeAlias = Callable[["Session"], Coroutine[Any, Any, Any]] +CloseSessionCallback: TypeAlias = Callable[["Session"], None] RetryConnectionCallback: TypeAlias = Callable[ [], Coroutine[Any, Any, Any], @@ -472,7 +472,7 @@ async def close( # Clear the session in transports # This will get us GC'd, so this should be the last thing. - await self._close_session_callback(self) + self._close_session_callback(self) def _start_buffered_message_sender( self, From e87db623c73bacb56b5f9ddf601c08b9c107998a Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Tue, 15 Apr 2025 15:54:30 -0700 Subject: [PATCH 05/16] Differentiating between closing and closed --- src/replit_river/v2/client_transport.py | 4 +++- src/replit_river/v2/session.py | 11 +++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/replit_river/v2/client_transport.py b/src/replit_river/v2/client_transport.py index 4c55d07b..2bfc60ee 100644 --- a/src/replit_river/v2/client_transport.py +++ b/src/replit_river/v2/client_transport.py @@ -54,8 +54,10 @@ async def get_or_create_session(self) -> Session: call ensure_connected on whatever session is active. """ existing_session = self._session - if not existing_session or existing_session.is_closed(): + if not existing_session or existing_session.is_terminal(): logger.info("Creating new session") + if existing_session: + await existing_session.close() new_session = Session( client_id=self._client_id, server_id=self._server_id, diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index fd7bdf53..7f17ec9a 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -324,7 +324,7 @@ def unbind_connecting_task() -> None: if self._terminating_task: await self._terminating_task - def is_closed(self) -> bool: + def is_terminal(self) -> bool: """ If the session is in a terminal state. Do not send messages, do not expect any more messages to be emitted, @@ -402,12 +402,15 @@ async def close( self, reason: Exception | None = None, current_state: SessionState | None = None ) -> None: """Close the session and all associated streams.""" - logger.info( - f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" - ) if (current_state or self._state) in TerminalStates: + while (current_state or self._state) != SessionState.CLOSED: + logger.debug("Session already closing, waiting...") + await asyncio.sleep(0.2) # already closing return + logger.info( + f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" + ) self._state = SessionState.CLOSING # We're closing, so we need to wake up... From 72a005d267e5bc34bb28cf43362282755759a553 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 17 Apr 2025 13:19:14 -0700 Subject: [PATCH 06/16] Prevent runaway attempts --- src/replit_river/v2/session.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 7f17ec9a..90a75137 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -4,7 +4,7 @@ from collections.abc import AsyncIterable from contextlib import asynccontextmanager from dataclasses import dataclass -from datetime import timedelta +from datetime import datetime, timedelta from typing import ( Any, AsyncGenerator, @@ -83,6 +83,8 @@ STREAM_CLOSED_BIT: STREAM_CLOSED_BIT_TYPE = 0b01000 +SESSION_CLOSE_TIMEOUT_SEC = 2 + _BackpressuredWaiter: TypeAlias = Callable[[], Awaitable[None]] @@ -403,7 +405,15 @@ async def close( ) -> None: """Close the session and all associated streams.""" if (current_state or self._state) in TerminalStates: + start = datetime.now() while (current_state or self._state) != SessionState.CLOSED: + elapsed = (datetime.now() - start).total_seconds() + if elapsed >= SESSION_CLOSE_TIMEOUT_SEC: + logger.warning( + f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} " + "seconds to close, leaking", + ) + break logger.debug("Session already closing, waiting...") await asyncio.sleep(0.2) # already closing From ae69cdc71736fd4a30a071bc9a07541e7969250d Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 17 Apr 2025 15:07:44 -0700 Subject: [PATCH 07/16] Clarifying termination semantics --- src/replit_river/v2/session.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 90a75137..52cdad6a 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -494,7 +494,8 @@ def _start_buffered_message_sender( Building on buffered_message_sender's documentation, we implement backpressure per-stream by way of self._streams' - error_channel: Channel[Exception | None] + error_channel: Channel[Exception] + backpressured_waiter: Callable[[], Awaitable[None]] This is accomplished via the following strategy: - If buffered_message_sender encounters an error, we transition back to @@ -506,8 +507,11 @@ def _start_buffered_message_sender( - Alternately, if buffered_message_sender successfully writes back to the - Finally, if _recv_from_ws encounters an error (transport or deserialization), - we emit an informative error to close_session which gets emitted to all - backpressured client methods. + it transitions to NO_CONNECTION and defers to the client_transport to + reestablish a connection. + + The in-flight messages are still valid, as if we can reconnect to the server + in time, those responses can be marshalled to their respective callbacks. """ async def commit(msg: TransportMessage) -> None: @@ -789,7 +793,7 @@ async def send_upload[I, R, A]( # If this request is not closed and the session is killed, we should # throw exception here async for item in request: - # Block for backpressure and emission errors from the ws + # Block for backpressure await backpressured_waiter() try: payload = request_serializer(item) @@ -950,9 +954,9 @@ async def _encode_stream() -> None: assert request_serializer, "send_stream missing request_serializer" async for item in request: - # Block for backpressure (or errors) + # Block for backpressure await backpressured_waiter() - # If there are any errors so far, raise them + await self._enqueue_message( stream_id=stream_id, control_flags=0, From 99e161cbe5a84f1078ac6f0076251d7b2c1868d1 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Thu, 17 Apr 2025 15:08:52 -0700 Subject: [PATCH 08/16] Swapping outer try: for outer while: to avoid terminating recv_from_ws --- src/replit_river/v2/session.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 52cdad6a..2b4e07b1 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1221,8 +1221,8 @@ async def _recv_from_ws( """ our_task = asyncio.current_task() connection_attempts = 0 - try: - while our_task and not our_task.cancelling() and not our_task.cancelled(): + while our_task and not our_task.cancelling() and not our_task.cancelled(): + try: logger.debug(f"_recv_from_ws loop count={connection_attempts}") connection_attempts += 1 ws = None @@ -1359,21 +1359,22 @@ async def _recv_from_ws( logger.debug( "FailedSendingMessageException while serving", exc_info=True ) - break + break # Inner loop except Exception: logger.exception("caught exception at message iterator") - break + await transition_no_connection() + break # Inner loop logger.debug("_handle_messages_from_ws exiting") - except ExceptionGroup as eg: - _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) - if unhandled: - # We're in a task, there's not that much that can be done. - unhandled = ExceptionGroup( - "Unhandled exceptions on River server", unhandled.exceptions - ) - logger.exception( - "caught exception at message iterator", - exc_info=unhandled, - ) - raise unhandled + except ExceptionGroup as eg: + _, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed)) + if unhandled: + # We're in a task, there's not that much that can be done. + unhandled = ExceptionGroup( + "Unhandled exceptions on River server", unhandled.exceptions + ) + logger.exception( + "caught exception at message iterator", + exc_info=unhandled, + ) + raise unhandled logger.debug(f"_recv_from_ws exiting normally after {connection_attempts} loops") From 14a8d8793bc33995bd19f994d15dfce8d1a141a1 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 26 Apr 2025 11:00:07 -0700 Subject: [PATCH 09/16] Break out rate limiter protocol for testing --- src/replit_river/rate_limiter.py | 8 ++++++++ src/replit_river/v2/session.py | 8 ++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/replit_river/rate_limiter.py b/src/replit_river/rate_limiter.py index 5e742ce9..e7593e19 100644 --- a/src/replit_river/rate_limiter.py +++ b/src/replit_river/rate_limiter.py @@ -2,6 +2,7 @@ import logging import random from contextvars import Context +from typing import Protocol from replit_river.error_schema import RiverException from replit_river.transport_options import ConnectionRetryOptions @@ -15,6 +16,13 @@ def __init__(self, code: str, message: str, client_id: str) -> None: self.client_id = client_id +class RateLimiter(Protocol): + def start_restoring_budget(self, user: str) -> None: ... + def get_backoff_ms(self, user: str) -> float: ... + def has_budget(self, user: str) -> bool: ... + def consume_budget(self, user: str) -> None: ... + + class LeakyBucketRateLimit: """Asynchronous leaky bucket rate limiter. diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 2b4e07b1..70bed958 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -54,7 +54,7 @@ parse_transport_msg, send_transport_message, ) -from replit_river.rate_limiter import LeakyBucketRateLimit +from replit_river.rate_limiter import RateLimiter from replit_river.rpc import ( ACK_BIT, STREAM_OPEN_BIT, @@ -143,7 +143,7 @@ class Session[HandshakeMetadata]: _wait_for_connected: asyncio.Event _client_id: str - _rate_limiter: LeakyBucketRateLimit + _rate_limiter: RateLimiter _uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ] @@ -176,7 +176,7 @@ def __init__( transport_options: TransportOptions, close_session_callback: CloseSessionCallback, client_id: str, - rate_limiter: LeakyBucketRateLimit, + rate_limiter: RateLimiter, uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], @@ -1034,7 +1034,7 @@ async def _do_ensure_connected[HandshakeMetadata]( client_id: str, session_id: str, server_id: str, - rate_limiter: LeakyBucketRateLimit, + rate_limiter: RateLimiter, uri_and_metadata_factory: Callable[ [], Awaitable[UriAndMetadata[HandshakeMetadata]] ], From ea5f5cdac9ee348de592ddda6a627149f68c9f56 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 26 Apr 2025 12:06:31 -0700 Subject: [PATCH 10/16] Permitting graceful termination when we close intentionally --- src/replit_river/v2/session.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 70bed958..bb9fd7c6 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -26,6 +26,7 @@ from opentelemetry.trace import Span, use_span from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import ValidationError +from websockets import ConnectionClosedOK from websockets.asyncio.client import ClientConnection from websockets.exceptions import ConnectionClosed @@ -1119,6 +1120,9 @@ async def websocket_closed_callback() -> None: try: data = await ws.recv(decode=False) + except ConnectionClosedOK as e: + close_session(e) + continue except ConnectionClosed as e: logger.debug( "_do_ensure_connected: Connection closed during waiting " From 39e6f2f9fbdd6c6ce3d5100a72330359ded759f5 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 26 Apr 2025 12:06:58 -0700 Subject: [PATCH 11/16] Expose in-flight ws to the session state to permit close() --- src/replit_river/v2/session.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index bb9fd7c6..cd7413ac 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -268,7 +268,7 @@ def close_session(reason: Exception | None) -> None: self.close(reason, current_state=current_state), ) - def transition_connecting() -> None: + def transition_connecting(ws: ClientConnection) -> None: if self._state in TerminalStates: return logger.debug("transition_connecting") @@ -276,6 +276,9 @@ def transition_connecting() -> None: # "Clear" here means observers should wait until we are connected. self._wait_for_connected.clear() + # Expose the current ws to be collected by close() + self._ws = ws + def transition_connected(ws: ClientConnection) -> None: if self._state in TerminalStates: return @@ -1043,7 +1046,7 @@ async def _do_ensure_connected[HandshakeMetadata]( get_next_sent_seq: Callable[[], int], get_current_ack: Callable[[], int], get_state: Callable[[], SessionState], - transition_connecting: Callable[[], None], + transition_connecting: Callable[[ClientConnection], None], close_ws_in_background: Callable[[ClientConnection], None], transition_connected: Callable[[ClientConnection], None], unbind_connecting_task: Callable[[], None], @@ -1063,12 +1066,12 @@ async def _do_ensure_connected[HandshakeMetadata]( attempt_count += 1 rate_limiter.consume_budget(client_id) - transition_connecting() ws: ClientConnection | None = None try: uri_and_metadata = await uri_and_metadata_factory() ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"]) + transition_connecting(ws) try: handshake_request = ControlMessageHandshakeRequest[HandshakeMetadata]( From 810ed9f60adbd4dcb52072efff2e0df10c0943a3 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 26 Apr 2025 12:27:15 -0700 Subject: [PATCH 12/16] Short-circuit to prevent internal deadlocks --- src/replit_river/v2/session.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index cd7413ac..533b3cef 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -265,7 +265,11 @@ def close_session(reason: Exception | None) -> None: # during the cleanup procedure. self._terminating_task = asyncio.create_task( - self.close(reason, current_state=current_state), + self.close( + reason, + current_state=current_state, + _wait_for_closed=False, + ), ) def transition_connecting(ws: ClientConnection) -> None: @@ -405,12 +409,18 @@ async def _enqueue_message( self._process_messages.set() async def close( - self, reason: Exception | None = None, current_state: SessionState | None = None + self, + reason: Exception | None = None, + current_state: SessionState | None = None, + _wait_for_closed: bool = True, ) -> None: """Close the session and all associated streams.""" if (current_state or self._state) in TerminalStates: start = datetime.now() - while (current_state or self._state) != SessionState.CLOSED: + while ( + _wait_for_closed + and (current_state or self._state) != SessionState.CLOSED + ): elapsed = (datetime.now() - start).total_seconds() if elapsed >= SESSION_CLOSE_TIMEOUT_SEC: logger.warning( @@ -632,7 +642,7 @@ async def block_until_connected() -> None: get_state=lambda: self._state, get_ws=lambda: self._ws, transition_no_connection=transition_no_connection, - close_session=self.close, + close_session=lambda err: self.close(err, _wait_for_closed=False), assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), enqueue_message=self._enqueue_message, From bf02c0277ae87397b910ecdea7bc43a1d8145285 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 27 Apr 2025 12:03:59 -0700 Subject: [PATCH 13/16] Replacing "while" with an asyncio.Event --- src/replit_river/v2/session.py | 36 +++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 533b3cef..3a2eeff0 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -4,7 +4,7 @@ from collections.abc import AsyncIterable from contextlib import asynccontextmanager from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import timedelta from typing import ( Any, AsyncGenerator, @@ -169,6 +169,7 @@ class Session[HandshakeMetadata]: # Terminating _terminating_task: asyncio.Task[None] | None + _closing_waiter: asyncio.Event | None def __init__( self, @@ -228,6 +229,7 @@ def __init__( # Terminating self._terminating_task = None + self._closing_waiter = None self._start_recv_from_ws() self._start_buffered_message_sender() @@ -415,27 +417,25 @@ async def close( _wait_for_closed: bool = True, ) -> None: """Close the session and all associated streams.""" - if (current_state or self._state) in TerminalStates: - start = datetime.now() - while ( - _wait_for_closed - and (current_state or self._state) != SessionState.CLOSED - ): - elapsed = (datetime.now() - start).total_seconds() - if elapsed >= SESSION_CLOSE_TIMEOUT_SEC: - logger.warning( - f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} " - "seconds to close, leaking", - ) - break + if self._closing_waiter: + # Break early for internal callers + if not _wait_for_closed: + return + try: logger.debug("Session already closing, waiting...") - await asyncio.sleep(0.2) - # already closing + async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC): + await self._closing_waiter.wait() + except asyncio.TimeoutError: + logger.warning( + f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} " + "seconds to close, leaking", + ) return logger.info( f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" ) self._state = SessionState.CLOSING + self._closing_waiter = asyncio.Event() # We're closing, so we need to wake up... # ... tasks waiting for connection to be established @@ -501,6 +501,10 @@ async def close( # This will get us GC'd, so this should be the last thing. self._close_session_callback(self) + # Release waiters, then release the event + self._closing_waiter.set() + self._closing_waiter = None + def _start_buffered_message_sender( self, ) -> None: From 2617a0bc395d5ad0aaec89c8b7cf99e88a6c2eaf Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sat, 26 Apr 2025 12:27:58 -0700 Subject: [PATCH 14/16] Force close_session on ConnectionClosedOK --- src/replit_river/v2/session.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 3a2eeff0..37062c7d 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -1274,6 +1274,9 @@ async def _recv_from_ws( # is no @overrides in `websockets` to hint this. try: message = await ws.recv(decode=False) + except ConnectionClosedOK as e: + close_session(e) + continue except ConnectionClosed: # This triggers a break in the inner loop so we can get back to # the outer loop. From 313467eba92f44d4723d53c1a566facf30ead8ef Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Sun, 27 Apr 2025 12:53:17 -0700 Subject: [PATCH 15/16] Adding a raw connection test --- tests/v2/test_v2_session_lifecycle.py | 135 ++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tests/v2/test_v2_session_lifecycle.py diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py new file mode 100644 index 00000000..fb5982cc --- /dev/null +++ b/tests/v2/test_v2_session_lifecycle.py @@ -0,0 +1,135 @@ +import asyncio +from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict + +import pytest +from websockets import ConnectionClosedOK +from websockets.asyncio.server import ServerConnection, serve +from websockets.typing import Data + +from replit_river.messages import parse_transport_msg +from replit_river.rate_limiter import RateLimiter +from replit_river.rpc import TransportMessage +from replit_river.transport_options import TransportOptions, UriAndMetadata +from replit_river.v2.session import Session + + +class _PermissiveRateLimiter(RateLimiter): + def start_restoring_budget(self, user: str) -> None: + pass + + def get_backoff_ms(self, user: str) -> float: + return 0 + + def has_budget(self, user: str) -> bool: + return True + + def consume_budget(self, user: str) -> None: + pass + + +WsServerFixture: TypeAlias = tuple[ + Callable[[], Awaitable[UriAndMetadata[None]]], + asyncio.Queue[bytes], + Callable[[], ServerConnection | None], +] + + +class _WsServerState(TypedDict): + ipv4_laddr: tuple[str, int] | None + + +async def _ws_server_internal( + recv: asyncio.Queue[bytes], + set_conn: Callable[[ServerConnection], None], + state: _WsServerState, +) -> AsyncIterator[None]: + async def handle(websocket: ServerConnection) -> None: + set_conn(websocket) + datagram: Data + try: + while datagram := await websocket.recv(decode=False): + if isinstance(datagram, str): + continue + await recv.put(datagram) + except ConnectionClosedOK: + pass + + port: int | None = None + if state["ipv4_laddr"]: + port = state["ipv4_laddr"][1] + async with serve(handle, "localhost", port=port) as server: + for sock in server.sockets: + if (pair := sock.getsockname())[0] == "127.0.0.1": + if state["ipv4_laddr"] is None: + state["ipv4_laddr"] = pair + serve_forever = asyncio.create_task(server.serve_forever()) + yield None + serve_forever.cancel() + + +@pytest.fixture +async def ws_server() -> AsyncIterator[WsServerFixture]: + recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1) + connection: ServerConnection | None = None + state: _WsServerState = {"ipv4_laddr": None} + + def set_conn(new_conn: ServerConnection) -> None: + nonlocal connection + connection = new_conn + + server_generator = _ws_server_internal(recv, set_conn, state) + await anext(server_generator) + + async def urimeta() -> UriAndMetadata[None]: + ipv4_laddr = state["ipv4_laddr"] + assert ipv4_laddr + return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None) + + yield (urimeta, recv, lambda: connection) + + try: + await anext(server_generator) + except StopAsyncIteration: + pass + + +async def test_connect(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + session = Session( + server_id="SERVER", + session_id="SESSION1", + transport_options=TransportOptions(), + close_session_callback=lambda _: None, + client_id="CLIENT1", + rate_limiter=_PermissiveRateLimiter(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(session.ensure_connected()) + msg = parse_transport_msg(await recv.get()) + assert isinstance(msg, TransportMessage) + assert msg.payload["type"] == "HANDSHAKE_REQ" + await session.close() + await connecting + + +async def test_reconnect(ws_server: WsServerFixture) -> None: + (urimeta, recv, conn) = ws_server + + session = Session( + server_id="SERVER", + session_id="SESSION1", + transport_options=TransportOptions(), + close_session_callback=lambda _: None, + client_id="CLIENT1", + rate_limiter=_PermissiveRateLimiter(), + uri_and_metadata_factory=urimeta, + ) + + connecting = asyncio.create_task(session.ensure_connected()) + msg = parse_transport_msg(await recv.get()) + assert isinstance(msg, TransportMessage) + assert msg.payload["type"] == "HANDSHAKE_REQ" + await session.close() + await connecting From 41272a8af50236842e6224f6171b853e114fdee6 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 28 Apr 2025 12:46:39 -0700 Subject: [PATCH 16/16] Centralizing confusion around blocking vs non-blocking "close()" codepaths --- src/replit_river/v2/session.py | 199 +++++++++++++++++---------------- 1 file changed, 104 insertions(+), 95 deletions(-) diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 37062c7d..3fe948e8 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -252,28 +252,6 @@ def get_next_sent_seq() -> int: return self._send_buffer[0].seq return self.seq - def close_session(reason: Exception | None) -> None: - # If we're already closing, just let whoever's currently doing it handle it. - if self._state in TerminalStates: - return - - # Avoid closing twice - if self._terminating_task is None: - current_state = self._state - self._state = SessionState.CLOSING - - # We can't just call self.close() directly because - # we're inside a thread that will eventually be awaited - # during the cleanup procedure. - - self._terminating_task = asyncio.create_task( - self.close( - reason, - current_state=current_state, - _wait_for_closed=False, - ), - ) - def transition_connecting(ws: ClientConnection) -> None: if self._state in TerminalStates: return @@ -328,7 +306,7 @@ def unbind_connecting_task() -> None: close_ws_in_background=close_ws_in_background, transition_connected=transition_connected, unbind_connecting_task=unbind_connecting_task, - close_session=close_session, + close_session=self._close_internal_nowait, ) ) @@ -413,14 +391,9 @@ async def _enqueue_message( async def close( self, reason: Exception | None = None, - current_state: SessionState | None = None, - _wait_for_closed: bool = True, ) -> None: """Close the session and all associated streams.""" if self._closing_waiter: - # Break early for internal callers - if not _wait_for_closed: - return try: logger.debug("Session already closing, waiting...") async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC): @@ -431,79 +404,112 @@ async def close( "seconds to close, leaking", ) return - logger.info( - f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" - ) - self._state = SessionState.CLOSING - self._closing_waiter = asyncio.Event() + await self._close_internal(reason) - # We're closing, so we need to wake up... - # ... tasks waiting for connection to be established - self._wait_for_connected.set() - # ... consumers waiting to enqueue messages - self._space_available.set() - # ... message processor so it can exit cleanly - self._process_messages.set() + def _close_internal_nowait(self, reason: Exception | None = None) -> None: + """ + When calling close() from asyncio Tasks, we must not block. + + This function does so, deferring to the underlying infrastructure for + creating self._terminating_task. + """ + self._close_internal(reason) + + def _close_internal(self, reason: Exception | None = None) -> asyncio.Task[None]: + """ + Internal close method. Subsequent calls past the first do not block. + + This is intended to be the primary driver of a session being torn down + and returned to its initial state. + + NB: This function is intended to be the sole lifecycle manager of + self._terminating_task. Waiting on the completion of that task is optional, + but the population of that property is critical. + + NB: We must not await the task returned from this function from chained tasks + inside this session, otherwise we will create a thread loop. + """ + + async def do_close() -> None: + logger.info( + f"{self.session_id} closing session to {self._server_id}, " + f"ws: {self._ws}" + ) + self._state = SessionState.CLOSING + self._closing_waiter = asyncio.Event() - # Wait to permit the waiting tasks to shut down gracefully - await asyncio.sleep(0.25) + # We're closing, so we need to wake up... + # ... tasks waiting for connection to be established + self._wait_for_connected.set() + # ... consumers waiting to enqueue messages + self._space_available.set() + # ... message processor so it can exit cleanly + self._process_messages.set() + + # Wait to permit the waiting tasks to shut down gracefully + await asyncio.sleep(0.25) - await self._task_manager.cancel_all_tasks() + await self._task_manager.cancel_all_tasks() - for stream_meta in self._streams.values(): - stream_meta["output"].close() - # Wake up backpressured writers + for stream_meta in self._streams.values(): + stream_meta["output"].close() + # Wake up backpressured writers + try: + stream_meta["error_channel"].put_nowait( + reason + or SessionClosedRiverServiceException( + "river session is closed", + ) + ) + except ChannelFull: + logger.exception( + "Unable to tell the caller that the session is going away", + ) + stream_meta["release_backpressured_waiter"]() + # Before we GC the streams, let's wait for all tasks to be closed gracefully try: - stream_meta["error_channel"].put_nowait( - reason - or SessionClosedRiverServiceException( - "river session is closed", + async with asyncio.timeout( + self._transport_options.shutdown_all_streams_timeout_ms + ): + # Block for backpressure and emission errors from the ws + await asyncio.gather( + *[ + stream_meta["output"].join() + for stream_meta in self._streams.values() + ] ) - ) - except ChannelFull: + except asyncio.TimeoutError: + spans: list[Span] = [ + stream_meta["span"] + for stream_meta in self._streams.values() + if not stream_meta["output"].closed() + ] + span_ids = [span.get_span_context().span_id for span in spans] logger.exception( - "Unable to tell the caller that the session is going away", + "Timeout waiting for output streams to finallize", + extra={"span_ids": span_ids}, ) - stream_meta["release_backpressured_waiter"]() - # Before we GC the streams, let's wait for all tasks to be closed gracefully. - try: - async with asyncio.timeout( - self._transport_options.shutdown_all_streams_timeout_ms - ): - # Block for backpressure and emission errors from the ws - await asyncio.gather( - *[ - stream_meta["output"].join() - for stream_meta in self._streams.values() - ] - ) - except asyncio.TimeoutError: - spans: list[Span] = [ - stream_meta["span"] - for stream_meta in self._streams.values() - if not stream_meta["output"].closed() - ] - span_ids = [span.get_span_context().span_id for span in spans] - logger.exception( - "Timeout waiting for output streams to finallize", - extra={"span_ids": span_ids}, - ) - self._streams.clear() + self._streams.clear() - if self._ws: - # The Session isn't guaranteed to live much longer than this close() - # invocation, so let's await this close to avoid dropping the socket. - await self._ws.close() + if self._ws: + # The Session isn't guaranteed to live much longer than this close() + # invocation, so let's await this close to avoid dropping the socket. + await self._ws.close() - self._state = SessionState.CLOSED + self._state = SessionState.CLOSED - # Clear the session in transports - # This will get us GC'd, so this should be the last thing. - self._close_session_callback(self) + # Clear the session in transports + # This will get us GC'd, so this should be the last thing. + self._close_session_callback(self) - # Release waiters, then release the event - self._closing_waiter.set() - self._closing_waiter = None + # Release waiters, then release the event + self._closing_waiter.set() + self._closing_waiter = None + + if self._terminating_task: + return self._terminating_task + + return asyncio.create_task(do_close()) def _start_buffered_message_sender( self, @@ -646,7 +652,7 @@ async def block_until_connected() -> None: get_state=lambda: self._state, get_ws=lambda: self._ws, transition_no_connection=transition_no_connection, - close_session=lambda err: self.close(err, _wait_for_closed=False), + close_session=self._close_internal_nowait, assert_incoming_seq_bookkeeping=assert_incoming_seq_bookkeeping, get_stream=lambda stream_id: self._streams.get(stream_id), enqueue_message=self._enqueue_message, @@ -1137,8 +1143,11 @@ async def websocket_closed_callback() -> None: try: data = await ws.recv(decode=False) - except ConnectionClosedOK as e: - close_session(e) + except ConnectionClosedOK: + # In the case of a normal connection closure, we defer to + # the outer loop to determine next steps. + # A call to close(...) should set the SessionState to a terminal one, + # otherwise we should try again. continue except ConnectionClosed as e: logger.debug( @@ -1226,7 +1235,7 @@ async def _recv_from_ws( get_state: Callable[[], SessionState], get_ws: Callable[[], ClientConnection | None], transition_no_connection: Callable[[], Awaitable[None]], - close_session: Callable[[Exception | None], Awaitable[None]], + close_session: Callable[[Exception | None], None], assert_incoming_seq_bookkeeping: Callable[ [str, int, int], Literal[True] | _IgnoreMessage ], @@ -1361,7 +1370,7 @@ async def _recv_from_ws( stream_meta["output"].close() except OutOfOrderMessageException: logger.exception("Out of order message, closing connection") - await close_session( + close_session( SessionClosedRiverServiceException( "Out of order message, closing connection" ) @@ -1371,7 +1380,7 @@ async def _recv_from_ws( logger.exception( "Got invalid transport message, closing session", ) - await close_session( + close_session( SessionClosedRiverServiceException( "Out of order message, closing connection" )