diff --git a/src/replit_river/v2/session.py b/src/replit_river/v2/session.py index 3fe948e8..99b30e82 100644 --- a/src/replit_river/v2/session.py +++ b/src/replit_river/v2/session.py @@ -169,7 +169,6 @@ class Session[HandshakeMetadata]: # Terminating _terminating_task: asyncio.Task[None] | None - _closing_waiter: asyncio.Event | None def __init__( self, @@ -229,7 +228,6 @@ def __init__( # Terminating self._terminating_task = None - self._closing_waiter = None self._start_recv_from_ws() self._start_buffered_message_sender() @@ -393,11 +391,11 @@ async def close( reason: Exception | None = None, ) -> None: """Close the session and all associated streams.""" - if self._closing_waiter: + if self._terminating_task: try: logger.debug("Session already closing, waiting...") async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC): - await self._closing_waiter.wait() + await self._terminating_task except asyncio.TimeoutError: logger.warning( f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} " @@ -436,7 +434,6 @@ async def do_close() -> None: f"ws: {self._ws}" ) self._state = SessionState.CLOSING - self._closing_waiter = asyncio.Event() # We're closing, so we need to wake up... # ... tasks waiting for connection to be established @@ -502,14 +499,11 @@ async def do_close() -> None: # This will get us GC'd, so this should be the last thing. self._close_session_callback(self) - # Release waiters, then release the event - self._closing_waiter.set() - self._closing_waiter = None - if self._terminating_task: return self._terminating_task - return asyncio.create_task(do_close()) + self._terminating_task = asyncio.create_task(do_close()) + return self._terminating_task def _start_buffered_message_sender( self, diff --git a/tests/v2/test_v2_session_lifecycle.py b/tests/v2/test_v2_session_lifecycle.py index fb5982cc..2fa8d0e2 100644 --- a/tests/v2/test_v2_session_lifecycle.py +++ b/tests/v2/test_v2_session_lifecycle.py @@ -6,6 +6,7 @@ from websockets.asyncio.server import ServerConnection, serve from websockets.typing import Data +from replit_river.common_session import SessionState from replit_river.messages import parse_transport_msg from replit_river.rate_limiter import RateLimiter from replit_river.rpc import TransportMessage @@ -114,14 +115,20 @@ async def test_connect(ws_server: WsServerFixture) -> None: await connecting -async def test_reconnect(ws_server: WsServerFixture) -> None: +async def test_close_race(ws_server: WsServerFixture) -> None: (urimeta, recv, conn) = ws_server + callcount = 0 + + def close_session_callback(_session: Session) -> None: + nonlocal callcount + callcount += 1 + session = Session( server_id="SERVER", session_id="SESSION1", transport_options=TransportOptions(), - close_session_callback=lambda _: None, + close_session_callback=close_session_callback, client_id="CLIENT1", rate_limiter=_PermissiveRateLimiter(), uri_and_metadata_factory=urimeta, @@ -132,4 +139,9 @@ async def test_reconnect(ws_server: WsServerFixture) -> None: assert isinstance(msg, TransportMessage) assert msg.payload["type"] == "HANDSHAKE_REQ" await session.close() + await session.close() + await session.close() + await session.close() await connecting + assert session._state == SessionState.CLOSED + assert callcount == 1