diff --git a/src/agents/result.py b/src/agents/result.py index 8ae407003a..c555eea309 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -53,6 +53,8 @@ T = TypeVar("T") +_STREAMING_CANCEL_TASK_DRAIN_SECONDS = 0.25 + @dataclass(frozen=True) class AgentToolInvocation: @@ -677,15 +679,14 @@ def cancel(self, mode: Literal["immediate", "after_turn"] = "immediate") -> None if mode == "immediate": # Existing behavior - immediate shutdown self._cleanup_tasks() # Cancel all running tasks - self.is_complete = True # Mark the run as complete to stop event streaming while not self._input_guardrail_queue.empty(): self._input_guardrail_queue.get_nowait() # Unblock any streamers waiting on the event queue. - self._event_queue.put_nowait(QueueCompleteSentinel()) if not self._waiting_on_event_queue: self._drain_event_queue() + self._event_queue.put_nowait(QueueCompleteSentinel()) elif mode == "after_turn": # Soft cancel - just set the flag @@ -735,7 +736,8 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: if isinstance(item, QueueCompleteSentinel): # Await input guardrails if they are still running, so late # exceptions are captured. - await self._await_task_safely(self._input_guardrails_task) + if self._cancel_mode != "immediate": + await self._await_task_safely(self._input_guardrails_task) self._event_queue.task_done() @@ -752,6 +754,11 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: # Cancellation should return promptly, so avoid waiting on long-running tasks. # Tasks have already been cancelled above. self._cleanup_tasks() + self.is_complete = True + elif self._cancel_mode == "immediate": + await self._drain_cancelled_tasks() + self._check_errors() + self.is_complete = True else: # Ensure main execution completes before cleanup to avoid race conditions # with session operations. @@ -764,7 +771,7 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: # Safely terminate all background tasks after main execution has finished. self._cleanup_tasks() - if not cancelled: + if not cancelled and self._cancel_mode != "immediate": await self._run_sandbox_cleanup() finally: # Allow any pending callbacks (e.g., cancellation handlers) to enqueue their @@ -846,6 +853,45 @@ def _cleanup_tasks(self): if self._output_guardrails_task and not self._output_guardrails_task.done(): self._output_guardrails_task.cancel() + def _owned_background_tasks(self) -> list[asyncio.Task[Any]]: + return [ + task + for task in ( + self.run_loop_task, + self._input_guardrails_task, + self._output_guardrails_task, + ) + if task is not None + ] + + async def _drain_cancelled_tasks(self) -> None: + tasks = self._owned_background_tasks() + if not tasks: + return + + for task in tasks: + if not task.done(): + task.cancel() + + done, pending = await asyncio.wait( + tasks, + timeout=_STREAMING_CANCEL_TASK_DRAIN_SECONDS, + ) + if done: + await asyncio.gather(*done, return_exceptions=True) + + for task in pending: + task.add_done_callback(self._consume_background_task_result) + + @staticmethod + def _consume_background_task_result(task: asyncio.Task[Any]) -> None: + try: + task.result() + except asyncio.CancelledError: + pass + except Exception as exc: + logger.debug(f"Background streaming task failed after cancellation: {exc}") + def __str__(self) -> str: return pretty_print_run_result_streaming(self) diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py index 87c094947f..528b71e671 100644 --- a/tests/test_cancel_streaming.py +++ b/tests/test_cancel_streaming.py @@ -5,7 +5,8 @@ import pytest from openai.types.responses import ResponseCompletedEvent -from agents import Agent, Runner +from agents import Agent, Runner, RunResultStreaming +from agents.run_context import RunContextWrapper from agents.stream_events import RawResponsesStreamEvent from .fake_model import FakeModel @@ -126,7 +127,9 @@ async def test_cancel_cleans_up_resources(): async for _ in result.stream_events(): result.cancel() break + remaining_events = [event async for event in result.stream_events()] # After cancel, queues should be empty and is_complete True + assert remaining_events == [] assert result.is_complete, "Result should be marked complete after cancel." assert result._event_queue.empty(), "Event queue should be empty after cancel." assert result._input_guardrail_queue.empty(), ( @@ -134,6 +137,88 @@ async def test_cancel_cleans_up_resources(): ) +@pytest.mark.asyncio +async def test_cancel_immediate_drains_owned_tasks_before_marking_complete(): + result = RunResultStreaming( + input="hi", + new_items=[], + raw_responses=[], + final_output=None, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + context_wrapper=RunContextWrapper(context=None), + current_agent=Agent(name="A", model=FakeModel()), + current_turn=0, + max_turns=1, + _current_agent_output_schema=None, + trace=None, + ) + cleanup_finished = [asyncio.Event() for _ in range(3)] + + async def wait_until_cancelled(cleanup_event: asyncio.Event) -> None: + try: + await asyncio.Event().wait() + finally: + await asyncio.sleep(0) + cleanup_event.set() + + tasks = [asyncio.create_task(wait_until_cancelled(event)) for event in cleanup_finished] + ( + result.run_loop_task, + result._input_guardrails_task, + result._output_guardrails_task, + ) = tasks + + await asyncio.sleep(0) + result.cancel(mode="immediate") + + assert result.is_complete is False + + events = [event async for event in result.stream_events()] + + assert events == [] + assert result.is_complete is True + assert all(task.done() for task in tasks) + assert all(event.is_set() for event in cleanup_finished) + + +@pytest.mark.asyncio +async def test_stream_events_timeout_marks_result_complete_without_sentinel(): + result = RunResultStreaming( + input="hi", + new_items=[], + raw_responses=[], + final_output=None, + input_guardrail_results=[], + output_guardrail_results=[], + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + context_wrapper=RunContextWrapper(context=None), + current_agent=Agent(name="A", model=FakeModel()), + current_turn=0, + max_turns=1, + _current_agent_output_schema=None, + trace=None, + ) + + async def wait_forever() -> None: + await asyncio.Event().wait() + + result.run_loop_task = asyncio.create_task(wait_forever()) + event_iter = result.stream_events().__aiter__() + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(event_iter.__anext__(), timeout=0.01) + + assert result.is_complete is True + + remaining_events = [event async for event in result.stream_events()] + + assert remaining_events == [] + + @pytest.mark.asyncio async def test_cancel_immediate_mode_explicit(): """Test explicit immediate mode behaves same as default.""" @@ -145,7 +230,9 @@ async def test_cancel_immediate_mode_explicit(): async for _ in result.stream_events(): result.cancel(mode="immediate") break + remaining_events = [event async for event in result.stream_events()] + assert remaining_events == [] assert result.is_complete assert result._event_queue.empty() assert result._cancel_mode == "immediate"