Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions agentlightning/llm_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,16 @@ async def _serve_context(self) -> AsyncGenerator[None, None]:

logger.info("LLMProxy server is cleaning up.")

# Close any cached aiohttp sessions held by the store client to
# prevent resource leaks (file descriptors, TCP connections) that
# can accumulate across training restarts and cause MCP server
# timeouts. See https://github.com/microsoft/agent-lightning/issues/471
if self.store is not None and hasattr(self.store, "close"):
try:
await self.store.close()
except Exception:
logger.warning("Error closing store sessions during LLMProxy cleanup.", exc_info=True)

# Remove worker config to avoid stale references.
if self._config_file and os.path.exists(self._config_file):
os.unlink(self._config_file)
Expand Down
86 changes: 71 additions & 15 deletions agentlightning/store/client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,6 +1359,7 @@ def __init__(
self.server_address_root = server_address.rstrip("/")
self.server_address = self.server_address_root + API_V1_AGL_PREFIX
self._sessions: Dict[int, aiohttp.ClientSession] = {} # id(loop) -> ClientSession
self._session_loops: Dict[int, asyncio.AbstractEventLoop] = {} # id(loop) -> loop ref
self._lock = threading.Lock()

# retry config
Expand Down Expand Up @@ -1415,6 +1416,7 @@ def __setstate__(self, state: Dict[str, Any]):
self.server_address = state["server_address"]
self.server_address_root = state["server_address_root"]
self._sessions = {}
self._session_loops = {}
self._lock = threading.Lock()
self._retry_delays = state["_retry_delays"]
self._health_retry_delays = state["_health_retry_delays"]
Expand All @@ -1423,6 +1425,24 @@ def __setstate__(self, state: Dict[str, Any]):
self._dequeue_was_successful = False
self._dequeue_first_unsuccessful = True

def _close_session_sync(self, sess: aiohttp.ClientSession, label: str = "") -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better have tests for this.

"""Best-effort synchronous teardown of a ClientSession.

When a session's owning event loop is already closed we cannot
``await sess.close()``. Instead we close the underlying connector
directly (which releases sockets/FDs) and mark the session object
as closed so aiohttp does not emit *Unclosed client session*
warnings during garbage collection.
"""
try:
connector = sess.connector
if connector is not None:
connector.close()
# Prevent the destructor warning for an unclosed session.
sess._closed = True # type: ignore[attr-defined]
except Exception:
client_logger.debug("Error during synchronous session cleanup%s", label, exc_info=True)

async def _get_session(self) -> aiohttp.ClientSession:
# In the proxy process, FastAPI middleware calls
# client_store.get_next_span_sequence_id(...). With
Expand All @@ -1443,6 +1463,16 @@ async def _get_session(self) -> aiohttp.ClientSession:
loop = asyncio.get_running_loop()
key = id(loop)
with self._lock:
# Evict sessions whose event loops have been closed. This
# prevents resource leaks when loops are torn down and new
# ones are created (e.g. during RAG training restarts).
stale_keys = [k for k, cached_loop in self._session_loops.items() if cached_loop.is_closed()]
for k in stale_keys:
stale_sess = self._sessions.pop(k, None)
self._session_loops.pop(k, None)
if stale_sess is not None and not stale_sess.closed:
self._close_session_sync(stale_sess, label=f" for stale loop {k}")
Comment on lines +1469 to +1474

sess = self._sessions.get(key)
if sess is None or sess.closed:
timeout = aiohttp.ClientTimeout(
Expand All @@ -1453,6 +1483,7 @@ async def _get_session(self) -> aiohttp.ClientSession:
)
sess = aiohttp.ClientSession(timeout=timeout)
self._sessions[key] = sess
self._session_loops[key] = loop
return sess

async def _wait_until_healthy(self, session: aiohttp.ClientSession) -> bool:
Expand Down Expand Up @@ -1547,25 +1578,50 @@ async def _request_json(
raise last_exc

async def close(self):
"""Close the HTTP session."""
"""Close all cached HTTP sessions.

Sessions bound to the current event loop are closed with a proper
``await``. Sessions created on foreign (but still running) loops are
closed via ``run_coroutine_threadsafe`` so the connector teardown
happens on the correct loop. Sessions whose loops are already closed
have their connectors shut down synchronously to avoid leaking sockets.
"""
with self._lock:
sessions = list(self._sessions.values())
sessions = dict(self._sessions) # key -> session
loops = dict(self._session_loops) # key -> loop
self._sessions.clear()
self._session_loops.clear()

# close them on their own loops to avoid warnings
async def _close(sess: aiohttp.ClientSession):
if not sess.closed:
await sess.close()
current_loop: asyncio.AbstractEventLoop | None = None
try:
current_loop = asyncio.get_running_loop()
except RuntimeError:
pass

# If called from one loop, best-effort close here.
for s in sessions:
try:
await _close(s)
except RuntimeError:
# If created on a different loop/thread, schedule a thread-safe close
# Fallback: close without awaiting (library tolerates it in practice),
# or keep a per-loop shutdown hook where they were created.
pass
for key, sess in sessions.items():
if sess.closed:
continue
sess_loop = loops.get(key)

# Case 1: session belongs to the current loop -- await close
if sess_loop is not None and current_loop is not None and sess_loop is current_loop:
try:
await sess.close()
except Exception:
client_logger.debug("Error closing aiohttp session on current loop", exc_info=True)
continue

# Case 2: session's loop is still running -- schedule close there
if sess_loop is not None and not sess_loop.is_closed():
try:
asyncio.run_coroutine_threadsafe(sess.close(), sess_loop)
except RuntimeError:
# Loop was closed between our check and the call
self._close_session_sync(sess, label=f" for key {key}")
continue
Comment on lines +1614 to +1621

# Case 3: session's loop is already closed -- synchronous teardown
self._close_session_sync(sess, label=f" for key {key}")

async def start_rollout(
self,
Expand Down
103 changes: 103 additions & 0 deletions tests/store/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,3 +1693,106 @@ def counting_health_get(self: aiohttp.ClientSession, url: Any, *args: Any, **kwa
finally:
await client.close()
await server.stop()


async def test_get_session_evicts_stale_loop_entries() -> None:
"""Sessions whose event loops have been closed should be evicted on the
next _get_session() call so they don't pile up across training restarts.

This is the core leak the fix targets (see issue #471)."""
client = LightningStoreClient(
"http://127.0.0.1:9", # nothing should actually call out
retry_delays=(),
health_retry_delays=(),
)
try:
# Create a session bound to a "stale" loop that's already closed.
stale_loop = asyncio.new_event_loop()
try:
async with aiohttp.ClientSession() as stale_sess:
# Inject a stale entry by hand to bypass the need for a
# running loop. Use the client's private dicts directly.
client._sessions[id(stale_loop)] = stale_sess
client._session_loops[id(stale_loop)] = stale_loop
finally:
stale_loop.close()

assert id(stale_loop) in client._sessions
assert id(stale_loop) in client._session_loops

# Call _get_session from the current (still running) loop. The
# eviction path should drop the stale entry and create a new one
# for the running loop.
sess = await client._get_session()
assert sess is not None
assert id(stale_loop) not in client._sessions, "stale entry should have been evicted"
assert id(stale_loop) not in client._session_loops

# The new session should be bound to the current running loop.
current_loop = asyncio.get_running_loop()
assert id(current_loop) in client._sessions
assert id(current_loop) in client._session_loops
assert not sess.closed
finally:
await client.close()


async def test_close_handles_sessions_on_closed_loops() -> None:
"""close() should not raise when it encounters a session whose owning
loop has been closed. The connector should be torn down synchronously."""
client = LightningStoreClient(
"http://127.0.0.1:9",
retry_delays=(),
health_retry_delays=(),
)
try:
# Create a session on a separate, soon-to-be-closed loop.
foreign_loop = asyncio.new_event_loop()
try:
async with aiohttp.ClientSession() as foreign_sess:
client._sessions[id(foreign_loop)] = foreign_sess
client._session_loops[id(foreign_loop)] = foreign_loop
# Detach the session from its loop so aiohttp's __aexit__
# doesn't double-close it after we've cleaned up.
foreign_sess._connector = None # type: ignore[attr-defined]
finally:
foreign_loop.close()

# close() should run cleanly even though the foreign loop is closed.
await client.close()
assert client._sessions == {}
assert client._session_loops == {}
finally:
# Defensive: a second close() should also be a no-op.
await client.close()


def test_close_session_sync_releases_connector() -> None:
"""_close_session_sync should release the underlying connector sockets
synchronously, even when called outside the session's owning loop."""
client = LightningStoreClient(
"http://127.0.0.1:9",
retry_delays=(),
health_retry_delays=(),
)

# Build a fake session that records the close() call. We don't use
# aiohttp.ClientSession directly here because its __init__ requires a
# running event loop; the test exercises the cleanup logic in
# isolation, which only depends on the duck-typed interface.
closed = {"called": False}

class _FakeConnector:
def close(self) -> None:
closed["called"] = True

class _FakeSession:
def __init__(self) -> None:
self.connector = _FakeConnector()
self._closed = False

sess = _FakeSession() # type: ignore[assignment]
client._close_session_sync(sess, label=" unit-test") # type: ignore[arg-type]
assert closed["called"], "connector.close() should have been called"
# The session is marked as closed to suppress aiohttp's destructor warning.
assert sess._closed is True
Comment on lines +1794 to +1798