From f967744f4c22aa4f6ad88d493f209c54528003ba Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Mon, 17 Mar 2025 20:31:51 -0700 Subject: [PATCH 1/8] test removing extraneous delete --- src/replit_river/client_transport.py | 1 - src/replit_river/server_transport.py | 1 - tests/river_fixtures/clientserver.py | 1 - 3 files changed, 3 deletions(-) diff --git a/src/replit_river/client_transport.py b/src/replit_river/client_transport.py index 7ae62b23..b9d3a331 100644 --- a/src/replit_river/client_transport.py +++ b/src/replit_river/client_transport.py @@ -344,7 +344,6 @@ async def _establish_handshake( # If the session status is mismatched, we should close the old session # and let the retry logic to create a new session. await old_session.close() - await self._delete_session(old_session) raise RiverException( ERROR_HANDSHAKE, diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index c2e30657..05e7082a 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -270,7 +270,6 @@ async def _establish_handshake( # we have an old session but the session id is different # just delete the old session await old_session.close() - await self._delete_session(old_session) old_session = None if not old_session and ( diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index 7c9d3172..45df396b 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -62,7 +62,6 @@ async def websocket_uri_factory() -> UriAndMetadata[None]: logging.debug("Start closing test client : %s", "test_client") await client.close() finally: - await asyncio.sleep(1) logging.debug("Start closing test server") await server.close() # Server should close normally From 8e69ac6db0572e18a66f8205ba0a48e3a4f5034a Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Mon, 17 Mar 2025 20:32:49 -0700 Subject: [PATCH 2/8] remove asyncio --- tests/river_fixtures/clientserver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index 45df396b..74092152 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import AsyncGenerator, Literal From cb5eb527aa8c907ed1ef8030cc22176fbc08a1f5 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Mon, 17 Mar 2025 20:37:07 -0700 Subject: [PATCH 3/8] add potential deadlock test --- Makefile | 3 +++ tests/test_communication.py | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/Makefile b/Makefile index 41959647..537d6337 100644 --- a/Makefile +++ b/Makefile @@ -8,3 +8,6 @@ lint: format: uv run ruff format src tests uv run ruff check src tests --fix + +test: + uv run pytest tests diff --git a/tests/test_communication.py b/tests/test_communication.py index 488e4b81..38b59691 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -268,3 +268,30 @@ async def test_ignore_flood_subscription(client: Client) -> None: timedelta(seconds=20), ) assert response == "Hello, Alice!" + +@pytest.mark.asyncio +@pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) +async def test_rpc_method_reconnect(client: Client) -> None: + response = await client.send_rpc( + "test_service", + "rpc_method", + "Alice", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + assert response == "Hello, Alice!" + + await client._transport._close_all_sessions() + response = await client.send_rpc( + "test_service", + "rpc_method", + "Bob", + serialize_request, + deserialize_response, + deserialize_error, + timedelta(seconds=20), + ) + + assert response == "Hello, Bob!" From c54fc7a053534b4a396cd11d73b3b7179571f789 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Mon, 17 Mar 2025 20:39:17 -0700 Subject: [PATCH 4/8] fmt --- tests/test_communication.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_communication.py b/tests/test_communication.py index 38b59691..2c88b0b7 100644 --- a/tests/test_communication.py +++ b/tests/test_communication.py @@ -269,6 +269,7 @@ async def test_ignore_flood_subscription(client: Client) -> None: ) assert response == "Hello, Alice!" + @pytest.mark.asyncio @pytest.mark.parametrize("handlers", [{**basic_rpc_method}]) async def test_rpc_method_reconnect(client: Client) -> None: From 9c08b02e496cd073e15a6fafca5c8c13289f4849 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Mon, 17 Mar 2025 21:22:46 -0700 Subject: [PATCH 5/8] easiest deadlock fix? --- src/replit_river/server_transport.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 05e7082a..8438518b 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -97,9 +97,9 @@ async def _get_or_create_session( session_id: str, websocket: WebSocketCommonProtocol, ) -> Session: + session_to_close: Optional[Session] = None + new_session: Optional[Session] = None async with self._session_lock: - session_to_close: Optional[Session] = None - new_session: Optional[Session] = None if to_id not in self._sessions: logger.info( 'Creating new session with "%s" using ws: %s', to_id, websocket.id @@ -149,10 +149,12 @@ async def _get_or_create_session( except FailedSendingMessageException as e: raise e - if session_to_close: - logger.info("Closing stale session %s", session_to_close.session_id) - await session_to_close.close() self._set_session(new_session) + + if session_to_close: + logger.info("Closing stale session %s", session_to_close.session_id) + await session_to_close.close() + return new_session async def _send_handshake_response( From 1b027f93729a569e8183d8dffbefb131124ca2e9 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Mon, 17 Mar 2025 21:47:09 -0700 Subject: [PATCH 6/8] move some locks around --- src/replit_river/server_transport.py | 186 ++++++++++++++------------- tests/river_fixtures/clientserver.py | 49 ++++--- 2 files changed, 122 insertions(+), 113 deletions(-) diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 8438518b..1da49bb0 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -90,6 +90,10 @@ async def handshake_to_get_session( async def close(self) -> None: await self._close_all_sessions() + async def _get_existing_session(self, to_id: str) -> Optional[Session]: + async with self._session_lock: + return self._sessions.get(to_id) + async def _get_or_create_session( self, transport_id: str, @@ -97,12 +101,30 @@ async def _get_or_create_session( session_id: str, websocket: WebSocketCommonProtocol, ) -> Session: - session_to_close: Optional[Session] = None new_session: Optional[Session] = None - async with self._session_lock: - if to_id not in self._sessions: + old_session: Optional[Session] = await self._get_existing_session(to_id) + if not old_session: + logger.info( + 'Creating new session with "%s" using ws: %s', to_id, websocket.id + ) + new_session = Session( + transport_id, + to_id, + session_id, + websocket, + self._transport_options, + self._is_server, + self._handlers, + close_session_callback=self._delete_session, + ) + else: + if old_session.session_id != session_id: logger.info( - 'Creating new session with "%s" using ws: %s', to_id, websocket.id + 'Create new session with "%s" for session id %s' + " and close old session %s", + to_id, + session_id, + old_session.session_id, ) new_session = Session( transport_id, @@ -115,45 +137,25 @@ async def _get_or_create_session( close_session_callback=self._delete_session, ) else: - old_session = self._sessions[to_id] - if old_session.session_id != session_id: - logger.info( - 'Create new session with "%s" for session id %s' - " and close old session %s", - to_id, - session_id, - old_session.session_id, - ) - session_to_close = old_session - new_session = Session( - transport_id, - to_id, - session_id, - websocket, - self._transport_options, - self._is_server, - self._handlers, - close_session_callback=self._delete_session, - ) - else: - # If the instance id is the same, we reuse the session and assign - # a new websocket to it. - logger.debug( - 'Reuse old session with "%s" using new ws: %s', - to_id, - websocket.id, - ) - try: - await old_session.replace_with_new_websocket(websocket) - new_session = old_session - except FailedSendingMessageException as e: - raise e + # If the instance id is the same, we reuse the session and assign + # a new websocket to it. + logger.debug( + 'Reuse old session with "%s" using new ws: %s', + to_id, + websocket.id, + ) + try: + await old_session.replace_with_new_websocket(websocket) + new_session = old_session + except FailedSendingMessageException as e: + raise e - self._set_session(new_session) + if old_session and new_session != old_session: + logger.info("Closing stale session %s", old_session.session_id) + await old_session.close() - if session_to_close: - logger.info("Closing stale session %s", session_to_close.session_id) - await session_to_close.close() + async with self._session_lock: + self._set_session(new_session) return new_session @@ -230,67 +232,67 @@ async def _establish_handshake( ) raise InvalidMessageException("handshake request to wrong server") - async with self._session_lock: - old_session = self._sessions.get(request_message.from_, None) - client_next_expected_seq = ( - handshake_request.expectedSessionState.nextExpectedSeq - ) - client_next_sent_seq = ( - handshake_request.expectedSessionState.nextSentSeq or 0 - ) - if old_session and old_session.session_id == handshake_request.sessionId: - # check invariants - # ordering must be correct - our_next_seq = await old_session.get_next_sent_seq() - our_ack = await old_session.get_next_expected_seq() - - if client_next_sent_seq > our_ack: - message = ( - "client is in the future: " - f"server wanted {our_ack} but client has {client_next_sent_seq}" - ) - await self._send_handshake_response( - request_message, - HandShakeStatus(ok=False, reason=message), - websocket, - ) - raise SessionStateMismatchException(message) + old_session = await self._get_existing_session(request_message.from_) + client_next_expected_seq = ( + handshake_request.expectedSessionState.nextExpectedSeq + ) + client_next_sent_seq = ( + handshake_request.expectedSessionState.nextSentSeq or 0 + ) + if old_session and old_session.session_id == handshake_request.sessionId: + # check invariants + # ordering must be correct + our_next_seq = await old_session.get_next_sent_seq() + our_ack = await old_session.get_next_expected_seq() - if our_next_seq > client_next_expected_seq: - message = ( - "server is in the future: " - f"client wanted {client_next_expected_seq} " - f"but server has {our_next_seq}" - ) - await self._send_handshake_response( - request_message, - HandShakeStatus(ok=False, reason=message), - websocket, - ) - raise SessionStateMismatchException(message) - elif old_session: - # we have an old session but the session id is different - # just delete the old session - await old_session.close() - old_session = None + if client_next_sent_seq > our_ack: + message = ( + "client is in the future: " + f"server wanted {our_ack} but client has {client_next_sent_seq}" + ) + await self._send_handshake_response( + request_message, + HandShakeStatus(ok=False, reason=message), + websocket, + ) + raise SessionStateMismatchException(message) - if not old_session and ( - client_next_sent_seq > 0 or client_next_expected_seq > 0 - ): - message = "client is trying to resume a session but we don't have it" + if our_next_seq > client_next_expected_seq: + message = ( + "server is in the future: " + f"client wanted {client_next_expected_seq} " + f"but server has {our_next_seq}" + ) await self._send_handshake_response( request_message, HandShakeStatus(ok=False, reason=message), websocket, ) raise SessionStateMismatchException(message) + elif old_session: + # we have an old session but the session id is different + # just delete the old session + await old_session.close() + old_session = None + - # from this point on, we're committed to connecting - session_id = handshake_request.sessionId - handshake_response = await self._send_handshake_response( + if not old_session and ( + client_next_sent_seq > 0 or client_next_expected_seq > 0 + ): + message = "client is trying to resume a session but we don't have it" + await self._send_handshake_response( request_message, - HandShakeStatus(ok=True, sessionId=session_id), + HandShakeStatus(ok=False, reason=message), websocket, ) + raise SessionStateMismatchException(message) + + # from this point on, we're committed to connecting + session_id = handshake_request.sessionId + handshake_response = await self._send_handshake_response( + request_message, + HandShakeStatus(ok=True, sessionId=session_id), + websocket, + ) - return handshake_request, handshake_response + return handshake_request, handshake_response diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index 74092152..8dee6bad 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import AsyncGenerator, Literal @@ -37,31 +38,37 @@ async def client( transport_options: TransportOptions, no_logging_error: NoErrors, ) -> AsyncGenerator[Client, None]: + binding = None try: - async with serve(server.serve, "127.0.0.1") as binding: - sockets = list(binding.sockets) - assert len(sockets) == 1, "Too many sockets!" - socket = sockets[0] + binding = await serve(server.serve, "127.0.0.1") + sockets = list(binding.sockets) + assert len(sockets) == 1, "Too many sockets!" + socket = sockets[0] - async def websocket_uri_factory() -> UriAndMetadata[None]: - return { - "uri": "ws://%s:%d" % socket.getsockname(), - "metadata": None, - } + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + yield client + finally: + logging.debug("Start closing test client : %s", "test_client") + await client.close() - client: Client[Literal[None]] = Client[None]( - uri_and_metadata_factory=websocket_uri_factory, - client_id="test_client", - server_id="test_server", - transport_options=transport_options, - ) - try: - yield client - finally: - logging.debug("Start closing test client : %s", "test_client") - await client.close() finally: logging.debug("Start closing test server") + if binding: + binding.close() await server.close() + if binding: + await binding.wait_closed() # Server should close normally - no_logging_error() + no_logging_error() \ No newline at end of file From 0d6f2c093d5110bd124360a3ac54e2abb60e20f2 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Mon, 17 Mar 2025 21:47:24 -0700 Subject: [PATCH 7/8] undo local fixture change --- tests/river_fixtures/clientserver.py | 49 ++++++++++++---------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/tests/river_fixtures/clientserver.py b/tests/river_fixtures/clientserver.py index 8dee6bad..74092152 100644 --- a/tests/river_fixtures/clientserver.py +++ b/tests/river_fixtures/clientserver.py @@ -1,4 +1,3 @@ -import asyncio import logging from typing import AsyncGenerator, Literal @@ -38,37 +37,31 @@ async def client( transport_options: TransportOptions, no_logging_error: NoErrors, ) -> AsyncGenerator[Client, None]: - binding = None try: - binding = await serve(server.serve, "127.0.0.1") - sockets = list(binding.sockets) - assert len(sockets) == 1, "Too many sockets!" - socket = sockets[0] + async with serve(server.serve, "127.0.0.1") as binding: + sockets = list(binding.sockets) + assert len(sockets) == 1, "Too many sockets!" + socket = sockets[0] - async def websocket_uri_factory() -> UriAndMetadata[None]: - return { - "uri": "ws://%s:%d" % socket.getsockname(), - "metadata": None, - } - - client: Client[Literal[None]] = Client[None]( - uri_and_metadata_factory=websocket_uri_factory, - client_id="test_client", - server_id="test_server", - transport_options=transport_options, - ) - try: - yield client - finally: - logging.debug("Start closing test client : %s", "test_client") - await client.close() + async def websocket_uri_factory() -> UriAndMetadata[None]: + return { + "uri": "ws://%s:%d" % socket.getsockname(), + "metadata": None, + } + client: Client[Literal[None]] = Client[None]( + uri_and_metadata_factory=websocket_uri_factory, + client_id="test_client", + server_id="test_server", + transport_options=transport_options, + ) + try: + yield client + finally: + logging.debug("Start closing test client : %s", "test_client") + await client.close() finally: logging.debug("Start closing test server") - if binding: - binding.close() await server.close() - if binding: - await binding.wait_closed() # Server should close normally - no_logging_error() \ No newline at end of file + no_logging_error() From 3dccfc5d1adcf65dabf8e55765f12119c53d3647 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Mon, 17 Mar 2025 21:48:38 -0700 Subject: [PATCH 8/8] fmt --- src/replit_river/server_transport.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/replit_river/server_transport.py b/src/replit_river/server_transport.py index 1da49bb0..b2850721 100644 --- a/src/replit_river/server_transport.py +++ b/src/replit_river/server_transport.py @@ -236,9 +236,7 @@ async def _establish_handshake( client_next_expected_seq = ( handshake_request.expectedSessionState.nextExpectedSeq ) - client_next_sent_seq = ( - handshake_request.expectedSessionState.nextSentSeq or 0 - ) + client_next_sent_seq = handshake_request.expectedSessionState.nextSentSeq or 0 if old_session and old_session.session_id == handshake_request.sessionId: # check invariants # ordering must be correct @@ -275,7 +273,6 @@ async def _establish_handshake( await old_session.close() old_session = None - if not old_session and ( client_next_sent_seq > 0 or client_next_expected_seq > 0 ):