diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 3b186e5502..36ab799fa8 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -82,6 +82,7 @@ class _RealtimeSessionClosedSentinel: _REALTIME_SESSION_CLOSED_SENTINEL = _RealtimeSessionClosedSentinel() +_BACKGROUND_TASK_CLEANUP_TIMEOUT = 1.0 def _serialize_tool_output(output: Any) -> str: @@ -192,6 +193,8 @@ def __init__( asyncio.Queue() ) self._event_iterator_waiters = 0 + self._cleanup_future: asyncio.Future[None] | None = None + self._closing = False self._closed = False self._stored_exception: BaseException | None = None self._pending_tool_calls: dict[str, _PendingToolCall] = {} @@ -1265,6 +1268,8 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool: def _enqueue_guardrail_task(self, text: str, response_id: str) -> None: # Runs the guardrails in a separate task to avoid blocking the main loop + if self._closing or self._closed: + return task = asyncio.create_task(self._run_output_guardrails(text, response_id)) self._guardrail_tasks.add(task) @@ -1277,6 +1282,11 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: # Remove from tracking set self._guardrail_tasks.discard(task) + if self._closing or self._closed: + if not task.cancelled(): + task.exception() + return + # Check for exceptions and propagate as events if not task.cancelled(): exception = task.exception() @@ -1291,11 +1301,12 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: ) ) - def _cleanup_guardrail_tasks(self) -> None: - for task in self._guardrail_tasks: - if not task.done(): - task.cancel() - self._guardrail_tasks.clear() + async def _cleanup_guardrail_tasks( + self, caller_task: asyncio.Task[Any] | None + ) -> set[asyncio.Task[Any]]: + return await self._cancel_and_wait_for_tasks( + self._guardrail_tasks, "guardrail", caller_task + ) def _enqueue_tool_call_task( self, @@ -1307,6 +1318,9 @@ def _enqueue_tool_call_task( call_id_reserved: bool = False, ) -> None: """Run tool calls in the background to avoid blocking realtime transport.""" + if self._closing or self._closed: + return + handle_kwargs: dict[str, Any] = {"agent_snapshot": agent_snapshot} if dispatch_snapshot is not None: handle_kwargs["dispatch_snapshot"] = dispatch_snapshot @@ -1322,6 +1336,11 @@ def _enqueue_tool_call_task( def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: self._tool_call_tasks.discard(task) + if self._closing or self._closed: + if not task.cancelled(): + task.exception() + return + if task.cancelled(): return @@ -1364,11 +1383,45 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: ) ) - def _cleanup_tool_call_tasks(self) -> None: - for task in self._tool_call_tasks: + async def _cleanup_tool_call_tasks( + self, caller_task: asyncio.Task[Any] | None + ) -> set[asyncio.Task[Any]]: + return await self._cancel_and_wait_for_tasks( + self._tool_call_tasks, "tool call", caller_task + ) + + async def _cancel_and_wait_for_tasks( + self, + tasks: set[asyncio.Task[Any]], + label: str, + caller_task: asyncio.Task[Any] | None, + ) -> set[asyncio.Task[Any]]: + tasks_to_wait: list[asyncio.Task[Any]] = [] + + for task in list(tasks): + if task is caller_task: + tasks.discard(task) + continue if not task.done(): task.cancel() - self._tool_call_tasks.clear() + tasks_to_wait.append(task) + + if tasks_to_wait: + _done, pending = await asyncio.wait( + tasks_to_wait, + timeout=_BACKGROUND_TASK_CLEANUP_TIMEOUT, + ) + if pending: + logger.warning( + "Timed out waiting for %d realtime %s background task(s) to stop.", + len(pending), + label, + ) + + tasks.difference_update(_done) + return pending + + return set() def _wake_event_iterators(self) -> None: for _ in range(self._event_iterator_waiters): @@ -1380,23 +1433,51 @@ async def _cleanup(self) -> None: self._wake_event_iterators() return - # Cancel and cleanup guardrail tasks - self._cleanup_guardrail_tasks() - self._cleanup_tool_call_tasks() + cleanup_future = self._cleanup_future + if cleanup_future is not None: + await asyncio.shield(cleanup_future) + return - # Remove ourselves as a listener - self._model.remove_listener(self) + cleanup_future = asyncio.get_running_loop().create_future() + self._cleanup_future = cleanup_future + self._closing = True + caller_task = asyncio.current_task() - # Close the model connection - await self._model.close() + try: + # Cancel and cleanup guardrail tasks + await asyncio.gather( + self._cleanup_guardrail_tasks(caller_task), + self._cleanup_tool_call_tasks(caller_task), + ) - # Clear pending approval tracking - self._pending_tool_calls.clear() - self._pending_tool_outputs.clear() + # Remove ourselves as a listener + self._model.remove_listener(self) - # Mark as closed - self._closed = True - self._wake_event_iterators() + # Close the model connection + await self._model.close() + + # Clear pending approval tracking + self._pending_tool_calls.clear() + self._pending_tool_outputs.clear() + self._guardrail_tasks.clear() + self._tool_call_tasks.clear() + + # Mark as closed + self._closed = True + except BaseException as exc: + self._closing = False + if not cleanup_future.done(): + cleanup_future.set_exception(exc) + cleanup_future.exception() + raise + else: + self._closing = False + if not cleanup_future.done(): + cleanup_future.set_result(None) + self._wake_event_iterators() + finally: + if self._cleanup_future is cleanup_future: + self._cleanup_future = None def _dispatch_snapshot_from_settings( self, diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 018f63b344..4ebf1a9d15 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -11,6 +11,7 @@ from agents.exceptions import ToolTimeoutError, UserError from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail from agents.handoffs import Handoff +from agents.realtime import session as session_module from agents.realtime.agent import RealtimeAgent from agents.realtime.config import RealtimeRunConfig, RealtimeSessionModelSettings from agents.realtime.events import ( @@ -208,6 +209,242 @@ async def test_aiter_exits_waiting_iterators_when_session_closes(): task.result() +@pytest.mark.asyncio +async def test_cleanup_awaits_cancelled_task_finalizers_before_model_close(): + close_order: list[str] = [] + + class _CloseRecordingModel(_DummyModel): + async def close(self): + close_order.append("model_close") + + model = _CloseRecordingModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + guardrail_started = asyncio.Event() + tool_started = asyncio.Event() + + async def tracked_task(label: str, started: asyncio.Event) -> None: + started.set() + try: + await asyncio.Event().wait() + finally: + await asyncio.sleep(0) + close_order.append(label) + + guardrail = asyncio.create_task(tracked_task("guardrail", guardrail_started)) + tool_call = asyncio.create_task(tracked_task("tool", tool_started)) + session._guardrail_tasks.add(guardrail) + session._tool_call_tasks.add(tool_call) + + await guardrail_started.wait() + await tool_started.wait() + + await session._cleanup() + + try: + assert close_order[-1] == "model_close" + assert set(close_order[:2]) == {"guardrail", "tool"} + assert len(session._guardrail_tasks) == 0 + assert len(session._tool_call_tasks) == 0 + finally: + await asyncio.gather(guardrail, tool_call, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_cleanup_bounds_wait_for_cancellation_resistant_tasks(monkeypatch): + monkeypatch.setattr(session_module, "_BACKGROUND_TASK_CLEANUP_TIMEOUT", 0.01, raising=False) + + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + started = asyncio.Event() + cancel_seen = asyncio.Event() + release = asyncio.Event() + + async def cancellation_resistant_task() -> None: + started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancel_seen.set() + await release.wait() + + task = asyncio.create_task(cancellation_resistant_task()) + session._guardrail_tasks.add(task) + await started.wait() + + try: + await asyncio.wait_for(session._cleanup(), timeout=1) + assert cancel_seen.is_set() + assert session._closed is True + assert not task.done() + assert len(session._guardrail_tasks) == 0 + finally: + release.set() + task.cancel() + await asyncio.gather(task, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_cleanup_retains_timed_out_tasks_when_model_close_fails(monkeypatch): + monkeypatch.setattr(session_module, "_BACKGROUND_TASK_CLEANUP_TIMEOUT", 0.01, raising=False) + + class _FailOnceCloseModel(_DummyModel): + def __init__(self) -> None: + super().__init__() + self.close_calls = 0 + + async def close(self): + self.close_calls += 1 + if self.close_calls == 1: + raise RuntimeError("model close failed") + + model = _FailOnceCloseModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + started = asyncio.Event() + cancel_seen = asyncio.Event() + release = asyncio.Event() + + async def cancellation_resistant_task() -> None: + started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancel_seen.set() + await release.wait() + + task = asyncio.create_task(cancellation_resistant_task()) + session._guardrail_tasks.add(task) + await started.wait() + + try: + with pytest.raises(RuntimeError, match="model close failed"): + await asyncio.wait_for(session._cleanup(), timeout=1) + + assert cancel_seen.is_set() + assert session._closed is False + assert task in session._guardrail_tasks + + await asyncio.wait_for(session._cleanup(), timeout=1) + + assert session._closed is True + assert task.done() + assert task not in session._guardrail_tasks + assert model.close_calls == 2 + finally: + release.set() + task.cancel() + await asyncio.gather(task, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_tracked_task_can_close_session_without_awaiting_itself(): + class _CloseCountingModel(_DummyModel): + def __init__(self) -> None: + super().__init__() + self.close_count = 0 + + async def close(self): + self.close_count += 1 + + model = _CloseCountingModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + started = asyncio.Event() + finished = asyncio.Event() + + async def close_from_tracked_task() -> None: + started.set() + await session.close() + finished.set() + + task = asyncio.create_task(close_from_tracked_task()) + session._tool_call_tasks.add(task) + await started.wait() + + await asyncio.wait_for(task, timeout=1) + + assert finished.is_set() + assert session._closed is True + assert model.close_count == 1 + assert task not in session._tool_call_tasks + + +@pytest.mark.asyncio +async def test_concurrent_close_waits_for_in_flight_cleanup_failure(): + class _BlockingFailCloseModel(_DummyModel): + def __init__(self) -> None: + super().__init__() + self.close_entered = asyncio.Event() + self.release_close = asyncio.Event() + + async def close(self): + self.close_entered.set() + await self.release_close.wait() + raise RuntimeError("model close failed") + + model = _BlockingFailCloseModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + first_close = asyncio.create_task(session.close()) + await model.close_entered.wait() + + second_close = asyncio.create_task(session.close()) + await asyncio.sleep(0) + + assert not second_close.done() + + model.release_close.set() + results = await asyncio.gather(first_close, second_close, return_exceptions=True) + + assert [type(result) for result in results] == [RuntimeError, RuntimeError] + assert [str(result) for result in results] == ["model close failed", "model close failed"] + assert session._closed is False + + +@pytest.mark.asyncio +async def test_late_background_task_failures_after_cleanup_do_not_mutate_closed_session( + monkeypatch, +): + monkeypatch.setattr(session_module, "_BACKGROUND_TASK_CLEANUP_TIMEOUT", 0.01, raising=False) + + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + guardrail_started = asyncio.Event() + tool_started = asyncio.Event() + release = asyncio.Event() + + async def fail_after_cleanup_timeout(started: asyncio.Event) -> None: + started.set() + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + await release.wait() + raise RuntimeError("late background failure") from None + + guardrail_task = asyncio.create_task(fail_after_cleanup_timeout(guardrail_started)) + tool_task = asyncio.create_task(fail_after_cleanup_timeout(tool_started)) + session._guardrail_tasks.add(guardrail_task) + session._tool_call_tasks.add(tool_task) + guardrail_task.add_done_callback(session._on_guardrail_task_done) + tool_task.add_done_callback(session._on_tool_call_task_done) + await guardrail_started.wait() + await tool_started.wait() + + await asyncio.wait_for(session._cleanup(), timeout=1) + release.set() + await asyncio.gather(guardrail_task, tool_task, return_exceptions=True) + await asyncio.sleep(0) + + assert session._stored_exception is None + assert session._event_queue.empty() + assert guardrail_task not in session._guardrail_tasks + assert tool_task not in session._tool_call_tasks + + @pytest.mark.asyncio async def test_transcription_completed_adds_new_user_item(): model = _DummyModel()