diff --git a/Makefile b/Makefile index 41959647..537d6337 100644 --- a/Makefile +++ b/Makefile @@ -8,3 +8,6 @@ lint: format: uv run ruff format src tests uv run ruff check src tests --fix + +test: + uv run pytest tests diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 7ae62b23..b9d3a331 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -344,7 +344,6 @@ async def _establish_handshake( # If the session status is mismatched, we should close the old session # and let the retry logic to create a new session. await old_session.close() - await self._delete_session(old_session) raise RiverException( ERROR_HANDSHAKE, diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index c2e30657..b2850721 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -90,6 +90,10 @@ async def handshake_to_get_session( async def close(self) -> None: await self._close_all_sessions() + async def _get_existing_session(self, to_id: str) -> Optional[Session]: + async with self._session_lock: + return self._sessions.get(to_id) + async def _get_or_create_session( self, transport_id: str, @@ -97,12 +101,30 @@ async def _get_or_create_session( session_id: str, websocket: WebSocketCommonProtocol, ) -> Session: - async with self._session_lock: - session_to_close: Optional[Session] = None - new_session: Optional[Session] = None - if to_id not in self._sessions: + new_session: Optional[Session] = None + old_session: Optional[Session] = await self._get_existing_session(to_id) + if not old_session: + logger.info( + 'Creating new session with "%s" using ws: %s', to_id, websocket.id + ) + new_session = Session( + transport_id, + to_id, + session_id, + websocket, + self._transport_options, + self._is_server, + self._handlers, + close_session_callback=self._delete_session, + ) + else: + if old_session.session_id != session_id: logger.info( - 'Creating new session with "%s" using ws: %s', to_id, websocket.id + 'Create new session with "%s" for session id %s' + " and close old session %s", + to_id, + session_id, + old_session.session_id, ) new_session = Session( transport_id, @@ -115,44 +137,26 @@ async def _get_or_create_session( close_session_callback=self._delete_session, ) else: - old_session = self._sessions[to_id] - if old_session.session_id != session_id: - logger.info( - 'Create new session with "%s" for session id %s' - " and close old session %s", - to_id, - session_id, - old_session.session_id, - ) - session_to_close = old_session - new_session = Session( - transport_id, - to_id, - session_id, - websocket, - self._transport_options, - self._is_server, - self._handlers, - close_session_callback=self._delete_session, - ) - else: - # If the instance id is the same, we reuse the session and assign - # a new websocket to it. - logger.debug( - 'Reuse old session with "%s" using new ws: %s', - to_id, - websocket.id, - ) - try: - await old_session.replace_with_new_websocket(websocket) - new_session = old_session - except FailedSendingMessageException as e: - raise e + # If the instance id is the same, we reuse the session and assign + # a new websocket to it. + logger.debug( + 'Reuse old session with "%s" using new ws: %s', + to_id, + websocket.id, + ) + try: + await old_session.replace_with_new_websocket(websocket) + new_session = old_session + except FailedSendingMessageException as e: + raise e - if session_to_close: - logger.info("Closing stale session %s", session_to_close.session_id) - await session_to_close.close() + if old_session and new_session != old_session: + logger.info("Closing stale session %s", old_session.session_id) + await old_session.close() + + async with self._session_lock: self._set_session(new_session) + return new_session async def _send_handshake_response( @@ -228,68 +232,64 @@ async def _establish_handshake( ) raise InvalidMessageException("handshake request to wrong server") - async with self._session_lock: - old_session = self._sessions.get(request_message.from_, None) - client_next_expected_seq = ( - handshake_request.expectedSessionState.nextExpectedSeq - ) - client_next_sent_seq = ( - handshake_request.expectedSessionState.nextSentSeq or 0 - ) - if old_session and old_session.session_id == handshake_request.sessionId: - # check invariants - # ordering must be correct - our_next_seq = await old_session.get_next_sent_seq() - our_ack = await old_session.get_next_expected_seq() - - if client_next_sent_seq > our_ack: - message = ( - "client is in the future: " - f"server wanted {our_ack} but client has {client_next_sent_seq}" - ) - await self._send_handshake_response( - request_message, - HandShakeStatus(ok=False, reason=message), - websocket, - ) - raise SessionStateMismatchException(message) + old_session = await self._get_existing_session(request_message.from_) + client_next_expected_seq = ( + handshake_request.expectedSessionState.nextExpectedSeq + ) + client_next_sent_seq = handshake_request.expectedSessionState.nextSentSeq or 0 + if old_session and old_session.session_id == handshake_request.sessionId: + # check invariants + # ordering must be correct + our_next_seq = await old_session.get_next_sent_seq() + our_ack = await old_session.get_next_expected_seq() - if our_next_seq > client_next_expected_seq: - message = ( - "server is in the future: " - f"client wanted {client_next_expected_seq} " - f"but server has {our_next_seq}" - ) - await self._send_handshake_response( - request_message, - HandShakeStatus(ok=False, reason=message), - websocket, - ) - raise SessionStateMismatchException(message) - elif old_session: - # we have an old session but the session id is different - # just delete the old session - await old_session.close() - await self._delete_session(old_session) - old_session = None + if client_next_sent_seq > our_ack: + message = ( + "client is in the future: " + f"server wanted {our_ack} but client has {client_next_sent_seq}" + ) + await self._send_handshake_response( + request_message, + HandShakeStatus(ok=False, reason=message), + websocket, + ) + raise SessionStateMismatchException(message) - if not old_session and ( - client_next_sent_seq > 0 or client_next_expected_seq > 0 - ): - message = "client is trying to resume a session but we don't have it" + if our_next_seq > client_next_expected_seq: + message = ( + "server is in the future: " + f"client wanted {client_next_expected_seq} " + f"but server has {our_next_seq}" + ) await self._send_handshake_response( request_message, HandShakeStatus(ok=False, reason=message), websocket, ) raise SessionStateMismatchException(message) + elif old_session: + # we have an old session but the session id is different + # just delete the old session + await old_session.close() + old_session = None - # from this point on, we're committed to connecting - session_id = handshake_request.sessionId - handshake_response = await self._send_handshake_response( + if not old_session and ( + client_next_sent_seq > 0 or client_next_expected_seq > 0 + ): + message = "client is trying to resume a session but we don't have it" + await self._send_handshake_response( request_message, - HandShakeStatus(ok=True, sessionId=session_id), + HandShakeStatus(ok=False, reason=message), websocket, ) + raise SessionStateMismatchException(message) + + # from this point on, we're committed to connecting + session_id = handshake_request.sessionId + handshake_response = await self._send_handshake_response( + request_message, + HandShakeStatus(ok=True, sessionId=session_id), + websocket, + ) - return handshake_request, handshake_response + return handshake_request, handshake_response diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index 7c9d3172..74092152 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import AsyncGenerator, Literal @@ -62,7 +61,6 @@ async def websocket_uri_factory() -> UriAndMetadata[None]: logging.debug("Start closing test client : %s", "test_client") await client.close() finally: - await asyncio.sleep(1) logging.debug("Start closing test server") await server.close() # Server should close normally diff --git a/tests/test_communication.py b/tests/test_communication.py index 488e4b81..2c88b0b7 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -268,3 +268,31 @@ async def test_ignore_flood_subscription(client: Client) -> None: timedelta(seconds=20), ) assert response == "Hello, Alice!" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) +async def test_rpc_method_reconnect(client: Client) -> None: + response = await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + assert response == "Hello, Alice!" + + await client._transport._close_all_sessions() + response = await client.send_rpc( + "test_service", + "rpc_method", + "Bob", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + + assert response == "Hello, Bob!"