diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 374fcdcd..56158fcf 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -366,7 +366,6 @@ async def _establish_handshake( # If the session status is mismatched, we should close the old session # and let the retry logic to create a new session. await old_session.close() - await self._delete_session(old_session) raise RiverException( ERROR_HANDSHAKE, diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 878eb0b7..88e848a7 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -117,10 +117,11 @@ async def _get_or_create_session( session_id: str, websocket: WebSocketCommonProtocol, ) -> ServerSession: + new_session: ServerSession | None = None + old_session: ServerSession | None = None async with self._session_lock: - session_to_close: Session | None = None - new_session: ServerSession | None = None - if to_id not in self._sessions: + old_session = self._sessions.get(to_id) + if not old_session: logger.info( 'Creating new session with "%s" using ws: %s', to_id, websocket.id ) @@ -134,7 +135,6 @@ async def _get_or_create_session( close_session_callback=self._delete_session, ) else: - old_session = self._sessions[to_id] if old_session.session_id != session_id: logger.info( 'Create new session with "%s" for session id %s' @@ -143,7 +143,6 @@ async def _get_or_create_session( session_id, old_session.session_id, ) - session_to_close = old_session new_session = ServerSession( transport_id, to_id, @@ -167,10 +166,12 @@ async def _get_or_create_session( except FailedSendingMessageException as e: raise e - if session_to_close: - logger.info("Closing stale session %s", session_to_close.session_id) - await session_to_close.close() self._sessions[new_session._to_id] = new_session + + if old_session and new_session != old_session: + logger.info("Closing stale session %s", old_session.session_id) + await old_session.close() + return new_session async def _send_handshake_response( @@ -247,7 +248,7 @@ async def _establish_handshake( raise InvalidMessageException("handshake request to wrong server") async with self._session_lock: - old_session = self._sessions.get(request_message.from_, None) + old_session = self._sessions.get(request_message.from_) client_next_expected_seq = ( handshake_request.expectedSessionState.nextExpectedSeq ) @@ -285,10 +286,6 @@ async def _establish_handshake( ) raise SessionStateMismatchException(message) elif old_session: - # we have an old session but the session id is different - # just delete the old session - await old_session.close() - await self._delete_session(old_session) old_session = None if not old_session and ( diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index 3f4ff446..cf1e1e29 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import AsyncGenerator, Literal @@ -64,7 +63,6 @@ async def websocket_uri_factory() -> UriAndMetadata[None]: await client.close() finally: - await asyncio.sleep(1) logging.debug("Start closing test server") if binding: binding.close() diff --git a/tests/test_communication.py b/tests/test_communication.py index 488e4b81..87de1811 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -268,3 +268,32 @@ async def test_ignore_flood_subscription(client: Client) -> None: timedelta(seconds=20), ) assert response == "Hello, Alice!" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) +async def test_rpc_method_reconnect(client: Client) -> None: + response = await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + assert response == "Hello, Alice!" + + await client._transport._close_all_sessions() + + response = await client.send_rpc( + "test_service", + "rpc_method", + "Bob", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + + assert response == "Hello, Bob!"