Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 102 additions & 21 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class _RealtimeSessionClosedSentinel:


_REALTIME_SESSION_CLOSED_SENTINEL = _RealtimeSessionClosedSentinel()
_BACKGROUND_TASK_CLEANUP_TIMEOUT = 1.0


def _serialize_tool_output(output: Any) -> str:
Expand Down Expand Up @@ -192,6 +193,8 @@ def __init__(
asyncio.Queue()
)
self._event_iterator_waiters = 0
self._cleanup_future: asyncio.Future[None] | None = None
self._closing = False
self._closed = False
self._stored_exception: BaseException | None = None
self._pending_tool_calls: dict[str, _PendingToolCall] = {}
Expand Down Expand Up @@ -1265,6 +1268,8 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool:

def _enqueue_guardrail_task(self, text: str, response_id: str) -> None:
# Runs the guardrails in a separate task to avoid blocking the main loop
if self._closing or self._closed:
return

task = asyncio.create_task(self._run_output_guardrails(text, response_id))
self._guardrail_tasks.add(task)
Expand All @@ -1277,6 +1282,11 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
# Remove from tracking set
self._guardrail_tasks.discard(task)

if self._closing or self._closed:
if not task.cancelled():
task.exception()
return

# Check for exceptions and propagate as events
if not task.cancelled():
exception = task.exception()
Expand All @@ -1291,11 +1301,12 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_guardrail_tasks(self) -> None:
for task in self._guardrail_tasks:
if not task.done():
task.cancel()
self._guardrail_tasks.clear()
async def _cleanup_guardrail_tasks(
self, caller_task: asyncio.Task[Any] | None
) -> set[asyncio.Task[Any]]:
return await self._cancel_and_wait_for_tasks(
self._guardrail_tasks, "guardrail", caller_task
)

def _enqueue_tool_call_task(
self,
Expand All @@ -1307,6 +1318,9 @@ def _enqueue_tool_call_task(
call_id_reserved: bool = False,
) -> None:
"""Run tool calls in the background to avoid blocking realtime transport."""
if self._closing or self._closed:
return

handle_kwargs: dict[str, Any] = {"agent_snapshot": agent_snapshot}
if dispatch_snapshot is not None:
handle_kwargs["dispatch_snapshot"] = dispatch_snapshot
Expand All @@ -1322,6 +1336,11 @@ def _enqueue_tool_call_task(
def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
self._tool_call_tasks.discard(task)

if self._closing or self._closed:
if not task.cancelled():
task.exception()
return

if task.cancelled():
return

Expand Down Expand Up @@ -1364,11 +1383,45 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_tool_call_tasks(self) -> None:
for task in self._tool_call_tasks:
async def _cleanup_tool_call_tasks(
self, caller_task: asyncio.Task[Any] | None
) -> set[asyncio.Task[Any]]:
return await self._cancel_and_wait_for_tasks(
self._tool_call_tasks, "tool call", caller_task
)

async def _cancel_and_wait_for_tasks(
self,
tasks: set[asyncio.Task[Any]],
label: str,
caller_task: asyncio.Task[Any] | None,
) -> set[asyncio.Task[Any]]:
tasks_to_wait: list[asyncio.Task[Any]] = []

for task in list(tasks):
if task is caller_task:
tasks.discard(task)
continue
Comment on lines +1402 to +1404

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Prevent self-closing tasks from resuming work

When the caller is one of the tracked background tasks, such as a function tool or output guardrail that captures the session and calls await session.close(), this branch skips cancelling that current task so cleanup can finish. After close() returns to the same coroutine, the rest of _handle_tool_call or _run_output_guardrails can continue and enqueue events or send tool outputs/interrupts after _model.close() has already run, and any resulting exception is then swallowed by the _closed callback path. The self-await guard should also abort or no-op the caller's remaining session work after close completes.

Useful? React with 👍 / 👎.

if not task.done():
task.cancel()
self._tool_call_tasks.clear()
tasks_to_wait.append(task)

if tasks_to_wait:
_done, pending = await asyncio.wait(
tasks_to_wait,
timeout=_BACKGROUND_TASK_CLEANUP_TIMEOUT,
)
if pending:
logger.warning(
"Timed out waiting for %d realtime %s background task(s) to stop.",
len(pending),
label,
)

tasks.difference_update(_done)
return pending

return set()

def _wake_event_iterators(self) -> None:
for _ in range(self._event_iterator_waiters):
Expand All @@ -1380,23 +1433,51 @@ async def _cleanup(self) -> None:
self._wake_event_iterators()
return

# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
self._cleanup_tool_call_tasks()
cleanup_future = self._cleanup_future
if cleanup_future is not None:
await asyncio.shield(cleanup_future)
return

# Remove ourselves as a listener
self._model.remove_listener(self)
cleanup_future = asyncio.get_running_loop().create_future()
self._cleanup_future = cleanup_future
self._closing = True
caller_task = asyncio.current_task()

# Close the model connection
await self._model.close()
try:
# Cancel and cleanup guardrail tasks
await asyncio.gather(
self._cleanup_guardrail_tasks(caller_task),
self._cleanup_tool_call_tasks(caller_task),
)

# Clear pending approval tracking
self._pending_tool_calls.clear()
self._pending_tool_outputs.clear()
# Remove ourselves as a listener
self._model.remove_listener(self)

# Mark as closed
self._closed = True
self._wake_event_iterators()
# Close the model connection
await self._model.close()

# Clear pending approval tracking
self._pending_tool_calls.clear()
self._pending_tool_outputs.clear()
self._guardrail_tasks.clear()
self._tool_call_tasks.clear()

# Mark as closed
self._closed = True
except BaseException as exc:
self._closing = False
if not cleanup_future.done():
cleanup_future.set_exception(exc)
cleanup_future.exception()
raise
else:
self._closing = False
if not cleanup_future.done():
cleanup_future.set_result(None)
self._wake_event_iterators()
finally:
if self._cleanup_future is cleanup_future:
self._cleanup_future = None

def _dispatch_snapshot_from_settings(
self,
Expand Down
Loading