From 5c87f14fbd73b716150445a2eae181a12bbd6f9c Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Wed, 1 Jul 2026 13:47:14 +0900 Subject: [PATCH] refactor: tighten runtime lifecycle and remove maintenance debt --- pyproject.toml | 9 +- .../experimental/codex/codex_tool.py | 8 +- .../extensions/experimental/codex/thread.py | 2 +- .../memory/advanced_sqlite_session.py | 23 ++-- .../extensions/memory/mongodb_session.py | 6 +- src/agents/extensions/sandbox/_rclone.py | 82 +++++++++++ .../extensions/sandbox/blaxel/sandbox.py | 38 ++---- .../extensions/sandbox/daytona/sandbox.py | 52 +++---- src/agents/extensions/sandbox/e2b/mounts.py | 73 +--------- src/agents/extensions/sandbox/e2b/sandbox.py | 9 -- .../extensions/sandbox/runloop/mounts.py | 72 +--------- src/agents/extensions/tool_output_trimmer.py | 5 +- .../openai_responses_compaction_session.py | 22 ++- src/agents/models/_openai_retry.py | 127 ++---------------- src/agents/models/_retry_runtime.py | 118 +++++++++++++++- src/agents/models/openai_responses.py | 11 +- src/agents/realtime/_default_tracker.py | 6 +- src/agents/realtime/openai_realtime.py | 6 +- src/agents/realtime/session.py | 71 +++++----- src/agents/retry.py | 4 +- src/agents/run_internal/model_retry.py | 121 ++--------------- src/agents/sandbox/sandboxes/docker.py | 56 ++++---- src/agents/sandbox/sandboxes/unix_local.py | 63 ++++----- src/agents/sandbox/session/pty_output.py | 50 +++++++ tests/extensions/sandbox/test_daytona.py | 37 +++++ tests/extensions/sandbox/test_e2b.py | 6 +- .../extensions/sandbox/test_runloop_mounts.py | 8 +- tests/models/test_openai_retry_helpers.py | 36 +++-- tests/realtime/test_openai_realtime.py | 6 +- tests/realtime/test_playback_tracker.py | 12 +- tests/realtime/test_session.py | 63 ++++++++- tests/realtime/test_session_exceptions.py | 19 +-- tests/sandbox/test_docker.py | 3 +- tests/sandbox/test_unix_local.py | 38 ++++++ 34 files changed, 642 insertions(+), 620 deletions(-) create mode 100644 src/agents/extensions/sandbox/_rclone.py create mode 100644 src/agents/sandbox/session/pty_output.py diff --git a/pyproject.toml b/pyproject.toml index 799391c73d..e5f03a57e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,11 @@ select = [ "I", # isort "B", # flake8-bugbear "C4", # flake8-comprehensions + "DTZ005", # datetime.now() without a timezone + "G004", # logging statement uses f-string + "RUF006", # unowned asyncio tasks + "RUF012", # mutable class attributes without ClassVar + "RUF100", # unused noqa directives "UP", # pyupgrade ] isort = { combine-as-imports = true, known-first-party = ["agents"] } @@ -127,7 +132,9 @@ isort = { combine-as-imports = true, known-first-party = ["agents"] } convention = "google" [tool.ruff.lint.per-file-ignores] -"examples/**/*.py" = ["E501"] +"examples/**/*.py" = ["DTZ005", "E501", "G004", "RUF006", "RUF012", "RUF100"] +"examples/**/*.ipynb" = ["RUF100"] +"tests/**/*.py" = ["RUF006", "RUF012", "RUF100"] [tool.mypy] strict = true diff --git a/src/agents/extensions/experimental/codex/codex_tool.py b/src/agents/extensions/experimental/codex/codex_tool.py index cf0eee08cd..20616e3894 100644 --- a/src/agents/extensions/experimental/codex/codex_tool.py +++ b/src/agents/extensions/experimental/codex/codex_tool.py @@ -532,7 +532,7 @@ def _validate_default_run_context_thread_id_suffix(value: str) -> str: def _parse_tool_input(parameters_model: type[BaseModel], input_json: str) -> BaseModel: try: json_data = json.loads(input_json) if input_json else {} - except Exception as exc: # noqa: BLE001 + except Exception as exc: if _debug.DONT_LOG_TOOL_DATA: logger.debug("Invalid JSON input for codex tool") else: @@ -933,7 +933,7 @@ def _store_thread_id_in_run_context( try: setattr(context, key, thread_id) - except Exception as exc: # noqa: BLE001 + except Exception as exc: raise UserError( f'Unable to store Codex thread_id in run context field "{key}". ' "Use a mutable dict context or set a writable attribute." @@ -965,7 +965,7 @@ def _set_pydantic_context_value(context: BaseModel, key: str, value: str) -> boo if key in model_fields: try: setattr(context, key, value) - except Exception: # noqa: BLE001 + except Exception: return False return True @@ -974,7 +974,7 @@ def _set_pydantic_context_value(context: BaseModel, key: str, value: str) -> boo return True except ValueError: pass - except Exception: # noqa: BLE001 + except Exception: return False state = getattr(context, "__dict__", None) diff --git a/src/agents/extensions/experimental/codex/thread.py b/src/agents/extensions/experimental/codex/thread.py index 2ba687dce0..d6f8d69b63 100644 --- a/src/agents/extensions/experimental/codex/thread.py +++ b/src/agents/extensions/experimental/codex/thread.py @@ -150,7 +150,7 @@ async def _run_streamed_internal( ) from exc try: parsed = _parse_event(item) - except Exception as exc: # noqa: BLE001 + except Exception as exc: raise RuntimeError(f"Failed to parse event: {item}") from exc if isinstance(parsed, ThreadStartedEvent): # Capture the thread id so callers can resume later. diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 704c85058e..d4d6b085c6 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -268,7 +268,7 @@ async def store_run_usage(self, result: RunResult) -> None: # Only update turn-level usage - session usage is aggregated on demand await self._update_turn_usage_internal(current_turn, result.context_wrapper.usage) except Exception as e: - self._logger.error(f"Failed to store usage for session {self.session_id}: {e}") + self._logger.error("Failed to store usage for session %s: %s", self.session_id, e) def _get_next_turn_number(self, branch_id: str) -> int: """Get the next turn number for a specific branch. @@ -491,7 +491,7 @@ def _cleanup_orphaned_messages_sync(self, conn: sqlite3.Connection) -> int: deleted_count = cursor.rowcount if deleted_count: - self._logger.info(f"Cleaned up {deleted_count} orphaned messages") + self._logger.info("Cleaned up %s orphaned messages", deleted_count) return deleted_count def _classify_message_type(self, item: TResponseInputItem) -> str: @@ -639,7 +639,11 @@ def _validate_turn(): self._current_branch_id = branch_name self._logger.debug( - f"Created branch '{branch_name}' from turn {turn_number} ('{turn_content}') in '{old_branch}'" # noqa: E501 + "Created branch '%s' from turn %s ('%s') in '%s'", + branch_name, + turn_number, + turn_content, + old_branch, ) return branch_name @@ -697,7 +701,7 @@ def _validate_branch(): old_branch = self._current_branch_id self._current_branch_id = branch_id - self._logger.info(f"Switched from branch '{old_branch}' to '{branch_id}'") + self._logger.info("Switched from branch '%s' to '%s'", old_branch, branch_id) async def delete_branch(self, branch_id: str, force: bool = False) -> None: """Delete a branch and all its associated data. @@ -778,8 +782,11 @@ def _delete_sync(): ) self._logger.info( - f"Deleted branch '{branch_id}': {structure_deleted} message entries, " - f"{usage_deleted} usage entries, {orphaned_messages_deleted} orphaned messages" + "Deleted branch '%s': %s message entries, %s usage entries, %s orphaned messages", + branch_id, + structure_deleted, + usage_deleted, + orphaned_messages_deleted, ) async def list_branches(self) -> list[dict[str, Any]]: @@ -1305,7 +1312,7 @@ def _update_sync(): try: input_details_json = json.dumps(usage_data.input_tokens_details.__dict__) except (TypeError, ValueError) as e: - self._logger.warning(f"Failed to serialize input tokens details: {e}") + self._logger.warning("Failed to serialize input tokens details: %s", e) input_details_json = None if ( @@ -1315,7 +1322,7 @@ def _update_sync(): try: output_details_json = json.dumps(usage_data.output_tokens_details.__dict__) except (TypeError, ValueError) as e: - self._logger.warning(f"Failed to serialize output tokens details: {e}") + self._logger.warning("Failed to serialize output tokens details: %s", e) output_details_json = None with closing(conn.cursor()) as cursor: diff --git a/src/agents/extensions/memory/mongodb_session.py b/src/agents/extensions/memory/mongodb_session.py index 113acdc6af..07354577d6 100644 --- a/src/agents/extensions/memory/mongodb_session.py +++ b/src/agents/extensions/memory/mongodb_session.py @@ -35,7 +35,7 @@ import threading import weakref from datetime import datetime, timezone -from typing import Any +from typing import Any, ClassVar from ._optional_imports import raise_optional_dependency_error @@ -97,8 +97,8 @@ class MongoDBSession(SessionABC): # one across loops raises RuntimeError. create_index is idempotent, so # we only need the threading lock to guard the boolean done flag — no # async coordination is required. - _init_state: dict[int, dict[tuple[str, str, str], bool]] = {} - _init_guard: threading.Lock = threading.Lock() + _init_state: ClassVar[dict[int, dict[tuple[str, str, str], bool]]] = {} + _init_guard: ClassVar[threading.Lock] = threading.Lock() session_settings: SessionSettings | None = None diff --git a/src/agents/extensions/sandbox/_rclone.py b/src/agents/extensions/sandbox/_rclone.py new file mode 100644 index 0000000000..e3d652db97 --- /dev/null +++ b/src/agents/extensions/sandbox/_rclone.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from ...sandbox.entries.mounts.patterns import RcloneMountPattern +from ...sandbox.errors import MountConfigError +from ...sandbox.session.base_sandbox_session import BaseSandboxSession + +_APT = "DEBIAN_FRONTEND=noninteractive DEBCONF_NOWARNINGS=yes apt-get -o Dpkg::Use-Pty=0" +_RCLONE_CHECK = "command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone" +_INSTALL_RCLONE_COMMANDS = ( + f"{_APT} update -qq", + f"{_APT} install -y -qq curl unzip ca-certificates", + "curl -fsSL https://rclone.org/install.sh | bash", +) + + +async def ensure_rclone(session: BaseSandboxSession) -> None: + rclone = await session.exec("sh", "-lc", _RCLONE_CHECK, shell=False) + if rclone.ok(): + return + + apt = await session.exec("sh", "-lc", "command -v apt-get >/dev/null 2>&1", shell=False) + if not apt.ok(): + raise MountConfigError( + message="rclone is not installed and apt-get is unavailable; preinstall rclone", + context={"package": "rclone"}, + ) + + for command in _INSTALL_RCLONE_COMMANDS: + install = await session.exec( + "sh", + "-lc", + command, + shell=False, + timeout=300, + user="root", + ) + if not install.ok(): + raise MountConfigError( + message="failed to install rclone", + context={"package": "rclone", "exit_code": install.exit_code}, + ) + + rclone = await session.exec("sh", "-lc", _RCLONE_CHECK, shell=False) + if not rclone.ok(): + raise MountConfigError( + message="rclone was installed but is still not available on PATH", + context={"package": "rclone"}, + ) + + +async def _default_user_ids(session: BaseSandboxSession) -> tuple[str, str] | None: + result = await session.exec("sh", "-lc", "id -u; id -g", shell=False, timeout=30) + if not result.ok(): + return None + + lines = result.stdout.decode("utf-8", errors="replace").splitlines() + if len(lines) < 2 or not lines[0].isdigit() or not lines[1].isdigit(): + return None + return lines[0], lines[1] + + +def _append_option(args: list[str], option: str, *values: str) -> None: + if option not in args: + args.extend([option, *values]) + + +async def rclone_pattern_for_session( + session: BaseSandboxSession, + pattern: RcloneMountPattern, +) -> RcloneMountPattern: + if pattern.mode != "fuse": + return pattern + + extra_args = list(pattern.extra_args) + _append_option(extra_args, "--allow-other") + user_ids = await _default_user_ids(session) + if user_ids is not None: + uid, gid = user_ids + _append_option(extra_args, "--uid", uid) + _append_option(extra_args, "--gid", gid) + + return pattern.model_copy(update={"extra_args": extra_args}) diff --git a/src/agents/extensions/sandbox/blaxel/sandbox.py b/src/agents/extensions/sandbox/blaxel/sandbox.py index 8d5bf8197a..02c8e87b38 100644 --- a/src/agents/extensions/sandbox/blaxel/sandbox.py +++ b/src/agents/extensions/sandbox/blaxel/sandbox.py @@ -44,6 +44,7 @@ from ....sandbox.session.base_sandbox_session import BaseSandboxSession from ....sandbox.session.dependencies import Dependencies from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.pty_output import collect_pty_output from ....sandbox.session.pty_types import ( PTY_PROCESSES_MAX, PTY_PROCESSES_WARNING, @@ -52,7 +53,6 @@ clamp_pty_yield_time_ms, process_id_to_prune_from_meta, resolve_pty_write_yield_time_ms, - truncate_text_by_tokens, ) from ....sandbox.session.runtime_helpers import RESOLVE_WORKSPACE_PATH_HELPER, RuntimeHelperScript from ....sandbox.session.sandbox_client import BaseSandboxClient @@ -959,34 +959,14 @@ async def _collect_pty_output( yield_time_ms: int, max_output_tokens: int | None, ) -> tuple[bytes, int | None]: - deadline = time.monotonic() + (yield_time_ms / 1000) - output = bytearray() - - while True: - async with entry.output_lock: - while entry.output_chunks: - output.extend(entry.output_chunks.popleft()) - - if time.monotonic() >= deadline: - break - if entry.done: - async with entry.output_lock: - while entry.output_chunks: - output.extend(entry.output_chunks.popleft()) - break - - remaining_s = deadline - time.monotonic() - if remaining_s <= 0: - break - try: - await asyncio.wait_for(entry.output_notify.wait(), timeout=remaining_s) - except asyncio.TimeoutError: - break - entry.output_notify.clear() - - text = output.decode("utf-8", errors="replace") - truncated, original_token_count = truncate_text_by_tokens(text, max_output_tokens) - return truncated.encode("utf-8", errors="replace"), original_token_count + return await collect_pty_output( + output_chunks=entry.output_chunks, + output_lock=entry.output_lock, + output_notify=entry.output_notify, + is_done=lambda: entry.done, + yield_time_ms=yield_time_ms, + max_output_tokens=max_output_tokens, + ) async def _finalize_pty_update( self, diff --git a/src/agents/extensions/sandbox/daytona/sandbox.py b/src/agents/extensions/sandbox/daytona/sandbox.py index 36bd195031..1083926ca5 100644 --- a/src/agents/extensions/sandbox/daytona/sandbox.py +++ b/src/agents/extensions/sandbox/daytona/sandbox.py @@ -43,6 +43,7 @@ from ....sandbox.session.base_sandbox_session import BaseSandboxSession from ....sandbox.session.dependencies import Dependencies from ....sandbox.session.manager import Instrumentation +from ....sandbox.session.pty_output import collect_pty_output from ....sandbox.session.pty_types import ( PTY_PROCESSES_MAX, PTY_PROCESSES_WARNING, @@ -51,7 +52,6 @@ clamp_pty_yield_time_ms, process_id_to_prune_from_meta, resolve_pty_write_yield_time_ms, - truncate_text_by_tokens, ) from ....sandbox.session.runtime_helpers import RESOLVE_WORKSPACE_PATH_HELPER, RuntimeHelperScript from ....sandbox.session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions @@ -388,6 +388,7 @@ class _DaytonaPtySessionEntry: last_used: float = field(default_factory=time.monotonic) done: bool = False exit_code: int | None = None + worker_task: asyncio.Task[None] | None = None class DaytonaSandboxSession(BaseSandboxSession): @@ -672,7 +673,7 @@ async def _on_data(chunk: bytes | str) -> None: timeout=exec_timeout, ) entry.pty_handle = pty_handle - asyncio.create_task(self._run_pty_waiter(entry)) + entry.worker_task = asyncio.create_task(self._run_pty_waiter(entry)) await asyncio.wait_for(pty_handle.wait_for_connection(), timeout=exec_timeout) await asyncio.wait_for( pty_handle.send_input(cmd_str + "\n"), @@ -699,7 +700,7 @@ async def _on_data(chunk: bytes | str) -> None: timeout=exec_timeout, ) entry.cmd_id = resp.cmd_id - asyncio.create_task( + entry.worker_task = asyncio.create_task( self._run_session_reader( entry, daytona_session_id, @@ -885,36 +886,14 @@ async def _collect_pty_output( yield_time_ms: int, max_output_tokens: int | None, ) -> tuple[bytes, int | None]: - deadline = time.monotonic() + (yield_time_ms / 1000) - output = bytearray() - - while True: - async with entry.output_lock: - while entry.output_chunks: - output.extend(entry.output_chunks.popleft()) - - if time.monotonic() >= deadline: - break - - if entry.done: - async with entry.output_lock: - while entry.output_chunks: - output.extend(entry.output_chunks.popleft()) - break - - remaining_s = deadline - time.monotonic() - if remaining_s <= 0: - break - - try: - await asyncio.wait_for(entry.output_notify.wait(), timeout=remaining_s) - except asyncio.TimeoutError: - break - entry.output_notify.clear() - - text = output.decode("utf-8", errors="replace") - truncated, original_token_count = truncate_text_by_tokens(text, max_output_tokens) - return truncated.encode("utf-8", errors="replace"), original_token_count + return await collect_pty_output( + output_chunks=entry.output_chunks, + output_lock=entry.output_lock, + output_notify=entry.output_notify, + is_done=lambda: entry.done, + yield_time_ms=yield_time_ms, + max_output_tokens=max_output_tokens, + ) def _prune_pty_sessions_if_needed(self) -> _DaytonaPtySessionEntry | None: if len(self._pty_sessions) < PTY_PROCESSES_MAX: @@ -937,6 +916,13 @@ async def _terminate_pty_entry(self, entry: _DaytonaPtySessionEntry) -> None: except Exception: pass + worker_task = entry.worker_task + entry.worker_task = None + if worker_task is not None and worker_task is not asyncio.current_task(): + if not worker_task.done(): + worker_task.cancel() + await asyncio.gather(worker_task, return_exceptions=True) + async def read(self, path: Path | str, *, user: str | User | None = None) -> io.IOBase: error_path = posix_path_as_path(coerce_posix_path(path)) if user is not None: diff --git a/src/agents/extensions/sandbox/e2b/mounts.py b/src/agents/extensions/sandbox/e2b/mounts.py index 3e37eda803..94b0a3bbb4 100644 --- a/src/agents/extensions/sandbox/e2b/mounts.py +++ b/src/agents/extensions/sandbox/e2b/mounts.py @@ -10,14 +10,11 @@ from ....sandbox.errors import MountConfigError from ....sandbox.materialization import MaterializedFile from ....sandbox.session.base_sandbox_session import BaseSandboxSession - -_APT = "DEBIAN_FRONTEND=noninteractive DEBCONF_NOWARNINGS=yes apt-get -o Dpkg::Use-Pty=0" -_RCLONE_CHECK = "command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone" -_INSTALL_RCLONE_COMMANDS = ( - f"{_APT} update -qq", - f"{_APT} install -y -qq curl unzip ca-certificates", - "curl -fsSL https://rclone.org/install.sh | bash", +from .._rclone import ( + ensure_rclone as _ensure_rclone, + rclone_pattern_for_session as _rclone_pattern_for_session, ) + _FUSE_ALLOW_OTHER = ( "chmod a+rw /dev/fuse && " "touch /etc/fuse.conf && " @@ -55,68 +52,6 @@ async def _ensure_fuse_support(session: BaseSandboxSession) -> None: ) -async def _ensure_rclone(session: BaseSandboxSession) -> None: - rclone = await session.exec("sh", "-lc", _RCLONE_CHECK, shell=False) - if rclone.ok(): - return - - apt = await session.exec("sh", "-lc", "command -v apt-get >/dev/null 2>&1", shell=False) - if not apt.ok(): - raise MountConfigError( - message="rclone is not installed and apt-get is unavailable; preinstall rclone", - context={"package": "rclone"}, - ) - - for command in _INSTALL_RCLONE_COMMANDS: - install = await session.exec("sh", "-lc", command, shell=False, timeout=300, user="root") - if not install.ok(): - raise MountConfigError( - message="failed to install rclone", - context={"package": "rclone", "exit_code": install.exit_code}, - ) - - rclone = await session.exec("sh", "-lc", _RCLONE_CHECK, shell=False) - if not rclone.ok(): - raise MountConfigError( - message="rclone was installed but is still not available on PATH", - context={"package": "rclone"}, - ) - - -async def _default_user_ids(session: BaseSandboxSession) -> tuple[str, str] | None: - result = await session.exec("sh", "-lc", "id -u; id -g", shell=False, timeout=30) - if not result.ok(): - return None - - lines = result.stdout.decode("utf-8", errors="replace").splitlines() - if len(lines) < 2 or not lines[0].isdigit() or not lines[1].isdigit(): - return None - return lines[0], lines[1] - - -def _append_option(args: list[str], option: str, *values: str) -> None: - if option not in args: - args.extend([option, *values]) - - -async def _rclone_pattern_for_session( - session: BaseSandboxSession, - pattern: RcloneMountPattern, -) -> RcloneMountPattern: - if pattern.mode != "fuse": - return pattern - - extra_args = list(pattern.extra_args) - _append_option(extra_args, "--allow-other") - user_ids = await _default_user_ids(session) - if user_ids is not None: - uid, gid = user_ids - _append_option(extra_args, "--uid", uid) - _append_option(extra_args, "--gid", gid) - - return pattern.model_copy(update={"extra_args": extra_args}) - - def _assert_e2b_session(session: BaseSandboxSession) -> None: if type(session).__name__ != "E2BSandboxSession": raise MountConfigError( diff --git a/src/agents/extensions/sandbox/e2b/sandbox.py b/src/agents/extensions/sandbox/e2b/sandbox.py index 64cae2caa3..24ddf53ffd 100644 --- a/src/agents/extensions/sandbox/e2b/sandbox.py +++ b/src/agents/extensions/sandbox/e2b/sandbox.py @@ -380,15 +380,6 @@ async def _sandbox_write_file( ) -async def _sandbox_remove_file( - sandbox: object, - path: str, - *, - request_timeout: float | None = None, -) -> object: - return await _as_sandbox_api(sandbox).files.remove(path, request_timeout=request_timeout) - - async def _sandbox_make_dir( sandbox: object, path: str, diff --git a/src/agents/extensions/sandbox/runloop/mounts.py b/src/agents/extensions/sandbox/runloop/mounts.py index 4c1daec892..66116794c8 100644 --- a/src/agents/extensions/sandbox/runloop/mounts.py +++ b/src/agents/extensions/sandbox/runloop/mounts.py @@ -10,14 +10,12 @@ from ....sandbox.errors import MountConfigError from ....sandbox.materialization import MaterializedFile from ....sandbox.session.base_sandbox_session import BaseSandboxSession +from .._rclone import ( + ensure_rclone as _ensure_rclone, + rclone_pattern_for_session as _rclone_pattern_for_session, +) _APT = "DEBIAN_FRONTEND=noninteractive DEBCONF_NOWARNINGS=yes apt-get -o Dpkg::Use-Pty=0" -_RCLONE_CHECK = "command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone" -_INSTALL_RCLONE_COMMANDS = ( - f"{_APT} update -qq", - f"{_APT} install -y -qq curl unzip ca-certificates", - "curl -fsSL https://rclone.org/install.sh | bash", -) _INSTALL_FUSE_COMMANDS = ( f"{_APT} update -qq", f"{_APT} install -y -qq fuse3", @@ -100,68 +98,6 @@ async def _ensure_fuse_support(session: BaseSandboxSession) -> None: ) -async def _ensure_rclone(session: BaseSandboxSession) -> None: - rclone = await session.exec("sh", "-lc", _RCLONE_CHECK, shell=False) - if rclone.ok(): - return - - apt = await session.exec("sh", "-lc", "command -v apt-get >/dev/null 2>&1", shell=False) - if not apt.ok(): - raise MountConfigError( - message="rclone is not installed and apt-get is unavailable; preinstall rclone", - context={"package": "rclone"}, - ) - - for command in _INSTALL_RCLONE_COMMANDS: - install = await session.exec("sh", "-lc", command, shell=False, timeout=300, user="root") - if not install.ok(): - raise MountConfigError( - message="failed to install rclone", - context={"package": "rclone", "exit_code": install.exit_code}, - ) - - rclone = await session.exec("sh", "-lc", _RCLONE_CHECK, shell=False) - if not rclone.ok(): - raise MountConfigError( - message="rclone was installed but is still not available on PATH", - context={"package": "rclone"}, - ) - - -async def _default_user_ids(session: BaseSandboxSession) -> tuple[str, str] | None: - result = await session.exec("sh", "-lc", "id -u; id -g", shell=False, timeout=30) - if not result.ok(): - return None - - lines = result.stdout.decode("utf-8", errors="replace").splitlines() - if len(lines) < 2 or not lines[0].isdigit() or not lines[1].isdigit(): - return None - return lines[0], lines[1] - - -def _append_option(args: list[str], option: str, *values: str) -> None: - if option not in args: - args.extend([option, *values]) - - -async def _rclone_pattern_for_session( - session: BaseSandboxSession, - pattern: RcloneMountPattern, -) -> RcloneMountPattern: - if pattern.mode != "fuse": - return pattern - - extra_args = list(pattern.extra_args) - _append_option(extra_args, "--allow-other") - user_ids = await _default_user_ids(session) - if user_ids is not None: - uid, gid = user_ids - _append_option(extra_args, "--uid", uid) - _append_option(extra_args, "--gid", gid) - - return pattern.model_copy(update={"extra_args": extra_args}) - - def _assert_runloop_session(session: BaseSandboxSession) -> None: if type(session).__name__ != "RunloopSandboxSession": raise MountConfigError( diff --git a/src/agents/extensions/tool_output_trimmer.py b/src/agents/extensions/tool_output_trimmer.py index 26b307f14f..d4e9e9ec42 100644 --- a/src/agents/extensions/tool_output_trimmer.py +++ b/src/agents/extensions/tool_output_trimmer.py @@ -148,8 +148,9 @@ def __call__(self, data: CallModelData[Any]) -> ModelInputData: if trimmed_count > 0: logger.debug( - f"ToolOutputTrimmer: trimmed {trimmed_count} tool output(s), " - f"saved ~{chars_saved} chars" + "ToolOutputTrimmer: trimmed %s tool output(s), saved ~%s chars", + trimmed_count, + chars_saved, ) return _ModelInputData(input=new_items, instructions=model_data.instructions) diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index c112b706a1..2ec40663db 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -196,14 +196,18 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None if not should_compact: logger.debug( - f"skip: decision hook declined compaction for {self._response_id} " - f"(mode={resolved_mode})" + "skip: decision hook declined compaction for %s (mode=%s)", + self._response_id, + resolved_mode, ) return self._deferred_response_id = None logger.debug( - f"compact: start for {self._response_id} using {self.model} (mode={resolved_mode})" + "compact: start for %s using %s (mode=%s)", + self._response_id, + self.model, + resolved_mode, ) compact_kwargs: dict[str, Any] = {"model": self.model} @@ -228,9 +232,11 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None self._session_items = output_items logger.debug( - f"compact: done for {self._response_id} " - f"(mode={resolved_mode}, output={len(output_items)}, " - f"candidates={len(self._compaction_candidate_items)})" + "compact: done for %s (mode=%s, output=%s, candidates=%s)", + self._response_id, + resolved_mode, + len(output_items), + len(self._compaction_candidate_items), ) async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: @@ -367,7 +373,9 @@ async def _ensure_compaction_candidates( self._session_items = history logger.debug( - f"candidates: initialized (history={len(history)}, candidates={len(candidates)})" + "candidates: initialized (history=%s, candidates=%s)", + len(history), + len(candidates), ) return (candidates[:], history[:]) diff --git a/src/agents/models/_openai_retry.py b/src/agents/models/_openai_retry.py index 3efb577f66..f7a941be39 100644 --- a/src/agents/models/_openai_retry.py +++ b/src/agents/models/_openai_retry.py @@ -1,121 +1,16 @@ from __future__ import annotations -import time -from collections.abc import Iterator, Mapping -from email.utils import parsedate_to_datetime -from typing import Any - -import httpx -from openai import APIConnectionError, APIStatusError, APITimeoutError +from openai import APIConnectionError, APITimeoutError from ..retry import ModelRetryAdvice, ModelRetryAdviceRequest, ModelRetryNormalizedError - - -def _iter_error_chain(error: Exception) -> Iterator[Exception]: - current: Exception | None = error - seen: set[int] = set() - while current is not None and id(current) not in seen: - seen.add(id(current)) - yield current - next_error = current.__cause__ or current.__context__ - current = next_error if isinstance(next_error, Exception) else None - - -def _header_lookup(headers: Any, key: str) -> str | None: - normalized_key = key.lower() - if isinstance(headers, httpx.Headers): - value = headers.get(key) - return value if isinstance(value, str) else None - if isinstance(headers, Mapping): - for header_name, header_value in headers.items(): - if str(header_name).lower() == normalized_key and isinstance(header_value, str): - return header_value - return None - - -def _get_header_value(error: Exception, key: str) -> str | None: - for candidate in _iter_error_chain(error): - response = getattr(candidate, "response", None) - if isinstance(response, httpx.Response): - header_value = _header_lookup(response.headers, key) - if header_value is not None: - return header_value - - for attr_name in ("headers", "response_headers"): - header_value = _header_lookup(getattr(candidate, attr_name, None), key) - if header_value is not None: - return header_value - - return None - - -def _parse_retry_after_ms(value: str | None) -> float | None: - if value is None: - return None - try: - parsed = float(value) / 1000.0 - except ValueError: - return None - return parsed if parsed >= 0 else None - - -def _parse_retry_after(value: str | None) -> float | None: - if value is None: - return None - - try: - parsed = float(value) - except ValueError: - parsed = None - if parsed is not None: - return parsed if parsed >= 0 else None - - try: - retry_datetime = parsedate_to_datetime(value) - except (TypeError, ValueError, IndexError): - return None - - return max(retry_datetime.timestamp() - time.time(), 0.0) - - -def _get_status_code(error: Exception) -> int | None: - for candidate in _iter_error_chain(error): - if isinstance(candidate, APIStatusError): - return candidate.status_code - status_code = getattr(candidate, "status_code", None) - if isinstance(status_code, int): - return status_code - status = getattr(candidate, "status", None) - if isinstance(status, int): - return status - return None - - -def _get_request_id(error: Exception) -> str | None: - for candidate in _iter_error_chain(error): - request_id = getattr(candidate, "request_id", None) - if isinstance(request_id, str): - return request_id - return None - - -def _get_error_code(error: Exception) -> str | None: - for candidate in _iter_error_chain(error): - error_code = getattr(candidate, "code", None) - if isinstance(error_code, str): - return error_code - - body = getattr(candidate, "body", None) - if isinstance(body, Mapping): - nested_error = body.get("error") - if isinstance(nested_error, Mapping): - nested_code = nested_error.get("code") - if isinstance(nested_code, str): - return nested_code - body_code = body.get("code") - if isinstance(body_code, str): - return body_code - return None +from ._retry_runtime import ( + get_error_code as _get_error_code, + get_error_header as _get_header_value, + get_request_id as _get_request_id, + get_retry_after, + get_status_code as _get_status_code, + iter_error_chain as _iter_error_chain, +) def _is_stateful_request(request: ModelRetryAdviceRequest) -> bool: @@ -163,9 +58,7 @@ def get_openai_retry_advice(request: ModelRetryAdviceRequest) -> ModelRetryAdvic reason=str(error), ) - retry_after = _parse_retry_after_ms(_get_header_value(error, "retry-after-ms")) - if retry_after is None: - retry_after = _parse_retry_after(_get_header_value(error, "retry-after")) + retry_after = get_retry_after(error) normalized = _build_normalized_error(error, retry_after=retry_after) stateful_request = _is_stateful_request(request) diff --git a/src/agents/models/_retry_runtime.py b/src/agents/models/_retry_runtime.py index 795b5cc45e..268987129a 100644 --- a/src/agents/models/_retry_runtime.py +++ b/src/agents/models/_retry_runtime.py @@ -1,8 +1,14 @@ from __future__ import annotations -from collections.abc import Iterator +import time +from collections.abc import Iterator, Mapping from contextlib import contextmanager from contextvars import ContextVar +from email.utils import parsedate_to_datetime +from typing import Any + +import httpx +from openai import APIStatusError _DISABLE_PROVIDER_MANAGED_RETRIES: ContextVar[bool] = ContextVar( "disable_provider_managed_retries", @@ -38,3 +44,113 @@ def websocket_pre_event_retries_disabled(disabled: bool) -> Iterator[None]: def should_disable_websocket_pre_event_retries() -> bool: return _DISABLE_WEBSOCKET_PRE_EVENT_RETRIES.get() + + +def iter_error_chain(error: Exception) -> Iterator[Exception]: + current: Exception | None = error + seen: set[int] = set() + while current is not None and id(current) not in seen: + seen.add(id(current)) + yield current + next_error = current.__cause__ or current.__context__ + current = next_error if isinstance(next_error, Exception) else None + + +def header_lookup(headers: Any, key: str) -> str | None: + normalized_key = key.lower() + if isinstance(headers, httpx.Headers): + value = headers.get(key) + return value if isinstance(value, str) else None + if isinstance(headers, Mapping): + for header_name, header_value in headers.items(): + if str(header_name).lower() == normalized_key and isinstance(header_value, str): + return header_value + return None + + +def get_error_header(error: Exception, key: str) -> str | None: + for candidate in iter_error_chain(error): + response = getattr(candidate, "response", None) + if isinstance(response, httpx.Response): + header_value = header_lookup(response.headers, key) + if header_value is not None: + return header_value + + for attr_name in ("headers", "response_headers"): + header_value = header_lookup(getattr(candidate, attr_name, None), key) + if header_value is not None: + return header_value + return None + + +def parse_retry_after_ms(value: str | None) -> float | None: + if value is None: + return None + try: + parsed = float(value) / 1000.0 + except ValueError: + return None + return parsed if parsed >= 0 else None + + +def parse_retry_after_value(value: str | None) -> float | None: + if value is None: + return None + + try: + parsed = float(value) + except ValueError: + parsed = None + if parsed is not None: + return parsed if parsed >= 0 else None + + try: + retry_datetime = parsedate_to_datetime(value) + except (TypeError, ValueError, IndexError): + return None + return max(retry_datetime.timestamp() - time.time(), 0.0) + + +def get_retry_after(error: Exception) -> float | None: + retry_after = parse_retry_after_ms(get_error_header(error, "retry-after-ms")) + if retry_after is not None: + return retry_after + return parse_retry_after_value(get_error_header(error, "retry-after")) + + +def get_status_code(error: Exception) -> int | None: + for candidate in iter_error_chain(error): + if isinstance(candidate, APIStatusError): + return candidate.status_code + for attr_name in ("status_code", "status"): + value = getattr(candidate, attr_name, None) + if isinstance(value, int): + return value + return None + + +def get_request_id(error: Exception) -> str | None: + for candidate in iter_error_chain(error): + request_id = getattr(candidate, "request_id", None) + if isinstance(request_id, str): + return request_id + return None + + +def get_error_code(error: Exception) -> str | None: + for candidate in iter_error_chain(error): + error_code = getattr(candidate, "code", None) + if isinstance(error_code, str): + return error_code + + body = getattr(candidate, "body", None) + if isinstance(body, Mapping): + nested_error = body.get("error") + if isinstance(nested_error, Mapping): + nested_code = nested_error.get("code") + if isinstance(nested_code, str): + return nested_code + body_code = body.get("code") + if isinstance(body_code, str): + return body_code + return None diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index f66c6afb7b..4c5c77fc55 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -9,7 +9,7 @@ from contextvars import ContextVar from dataclasses import asdict, dataclass, is_dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeGuard, cast, get_args, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast, overload import httpx from openai import AsyncOpenAI, NotGiven, Omit, omit @@ -88,9 +88,6 @@ _HEADERS_OVERRIDE: ContextVar[dict[str, str] | None] = ContextVar( "openai_responses_headers_override", default=None ) -_RESPONSE_INCLUDABLE_VALUES = frozenset( - value for value in get_args(ResponseIncludable) if isinstance(value, str) -) class _NamespaceToolParam(TypedDict): @@ -132,10 +129,6 @@ def _require_responses_tool_param(value: object) -> ResponsesToolParam: return cast(ResponsesToolParam, value) -def _is_response_includable(value: object) -> TypeGuard[ResponseIncludable]: - return isinstance(value, str) and value in _RESPONSE_INCLUDABLE_VALUES - - def _coerce_response_includables(values: Sequence[str]) -> list[ResponseIncludable]: includables: list[ResponseIncludable] = [] for value in values: @@ -221,7 +214,7 @@ class OpenAIResponsesWebSocketOptions(TypedDict): class _ResponseStreamWithRequestId: """Wrap an SDK event stream and retain the originating request ID.""" - _TERMINAL_EVENT_TYPES = { + _TERMINAL_EVENT_TYPES: ClassVar[set[str]] = { "response.completed", "response.failed", "response.incomplete", diff --git a/src/agents/realtime/_default_tracker.py b/src/agents/realtime/_default_tracker.py index 8003c268da..dfc28e771f 100644 --- a/src/agents/realtime/_default_tracker.py +++ b/src/agents/realtime/_default_tracker.py @@ -1,7 +1,7 @@ from __future__ import annotations +import time from dataclasses import dataclass -from datetime import datetime from ._util import calculate_audio_length_ms from .config import RealtimeAudioFormat @@ -9,7 +9,7 @@ @dataclass class ModelAudioState: - initial_received_time: datetime + initial_received_time: float audio_length_ms: float @@ -35,7 +35,7 @@ def on_audio_delta(self, item_id: str, item_content_index: int, audio_bytes: byt self._last_audio_item = new_key if new_key not in self._states: - self._states[new_key] = ModelAudioState(datetime.now(), ms) + self._states[new_key] = ModelAudioState(time.monotonic(), ms) else: self._states[new_key].audio_length_ms += ms diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 3759f8150a..fe59933920 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -6,9 +6,9 @@ import json import math import os +import time from collections.abc import Callable, Mapping from dataclasses import dataclass -from datetime import datetime from typing import Annotated, Any, Literal, TypeAlias, cast import pydantic @@ -890,9 +890,7 @@ def _get_playback_state(self) -> RealtimePlaybackState: item_id, item_content_index = last_audio_item_id audio_state = self._audio_state_tracker.get_state(item_id, item_content_index) if audio_state: - elapsed_ms = ( - datetime.now() - audio_state.initial_received_time - ).total_seconds() * 1000 + elapsed_ms = (time.monotonic() - audio_state.initial_received_time) * 1000 return { "current_item_id": item_id, "current_item_content_index": item_content_index, diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 3b186e5502..dd015a7eac 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -499,6 +499,10 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None: """Put an event into the queue.""" await self._event_queue.put(event) + def _put_event_nowait(self, event: RealtimeSessionEvent) -> None: + """Put an event into the unbounded queue from a synchronous callback.""" + self._event_queue.put_nowait(event) + async def _function_needs_approval( self, function_tool: FunctionTool, tool_call: RealtimeModelToolCallEvent ) -> bool: @@ -1282,20 +1286,23 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None: exception = task.exception() if exception: # Create an exception event instead of raising - asyncio.create_task( - self._put_event( - RealtimeError( - info=self._event_info, - error={"message": f"Guardrail task failed: {str(exception)}"}, - ) + self._put_event_nowait( + RealtimeError( + info=self._event_info, + error={"message": f"Guardrail task failed: {str(exception)}"}, ) ) - def _cleanup_guardrail_tasks(self) -> None: - for task in self._guardrail_tasks: - if not task.done(): - task.cancel() - self._guardrail_tasks.clear() + @staticmethod + async def _cancel_and_wait_for_tasks(tasks: set[asyncio.Task[Any]]) -> None: + observed: set[asyncio.Task[Any]] = set() + while new_tasks := tuple(task for task in tasks if task not in observed): + observed.update(new_tasks) + for task in new_tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*new_tasks, return_exceptions=True) + tasks.difference_update(observed) def _enqueue_tool_call_task( self, @@ -1335,17 +1342,14 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: exception.call_id, exc_info=exception, ) - asyncio.create_task( - self._put_event( - RealtimeError( - info=self._event_info, - error={ - "message": ( - "Tool output send failed; cached output will be retried: " - f"{exception}" - ) - }, - ) + self._put_event_nowait( + RealtimeError( + info=self._event_info, + error={ + "message": ( + f"Tool output send failed; cached output will be retried: {exception}" + ) + }, ) ) return @@ -1355,21 +1359,13 @@ def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None: if self._stored_exception is None: self._stored_exception = exception - asyncio.create_task( - self._put_event( - RealtimeError( - info=self._event_info, - error={"message": f"Tool call task failed: {exception}"}, - ) + self._put_event_nowait( + RealtimeError( + info=self._event_info, + error={"message": f"Tool call task failed: {exception}"}, ) ) - def _cleanup_tool_call_tasks(self) -> None: - for task in self._tool_call_tasks: - if not task.done(): - task.cancel() - self._tool_call_tasks.clear() - def _wake_event_iterators(self) -> None: for _ in range(self._event_iterator_waiters): self._event_queue.put_nowait(_REALTIME_SESSION_CLOSED_SENTINEL) @@ -1380,9 +1376,10 @@ async def _cleanup(self) -> None: self._wake_event_iterators() return - # Cancel and cleanup guardrail tasks - self._cleanup_guardrail_tasks() - self._cleanup_tool_call_tasks() + # Cancel each observed task once, await its finalizer, then rescan for tasks + # created while the observed batch was unwinding. + await self._cancel_and_wait_for_tasks(self._guardrail_tasks) + await self._cancel_and_wait_for_tasks(self._tool_call_tasks) # Remove ourselves as a listener self._model.remove_listener(self) diff --git a/src/agents/retry.py b/src/agents/retry.py index 4b122cbfa7..a3e9ba7b5b 100644 --- a/src/agents/retry.py +++ b/src/agents/retry.py @@ -145,8 +145,8 @@ def _mark_retry_capabilities( retries_safe_transport_errors: bool, retries_all_transient_errors: bool, ) -> RetryPolicy: - setattr(policy, _RETRIES_SAFE_TRANSPORT_ERRORS_ATTR, retries_safe_transport_errors) # noqa: B010 - setattr(policy, _RETRIES_ALL_TRANSIENT_ERRORS_ATTR, retries_all_transient_errors) # noqa: B010 + setattr(policy, _RETRIES_SAFE_TRANSPORT_ERRORS_ATTR, retries_safe_transport_errors) + setattr(policy, _RETRIES_ALL_TRANSIENT_ERRORS_ATTR, retries_all_transient_errors) return policy diff --git a/src/agents/run_internal/model_retry.py b/src/agents/run_internal/model_retry.py index 289daca0b4..f7bda380bf 100644 --- a/src/agents/run_internal/model_retry.py +++ b/src/agents/run_internal/model_retry.py @@ -2,18 +2,21 @@ import asyncio import random -import time -from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Mapping -from email.utils import parsedate_to_datetime +from collections.abc import AsyncIterator, Awaitable, Callable, Mapping from inspect import isawaitable from typing import Any import httpx -from openai import APIConnectionError, APIStatusError, APITimeoutError, BadRequestError +from openai import APIConnectionError, APITimeoutError, BadRequestError from ..items import ModelResponse, TResponseStreamEvent from ..logger import logger from ..models._retry_runtime import ( + get_error_code as _get_error_code, + get_request_id as _get_request_id, + get_retry_after as _get_retry_after, + get_status_code as _get_status_code, + iter_error_chain as _iter_error_chain, provider_managed_retries_disabled, websocket_pre_event_retries_disabled, ) @@ -44,120 +47,12 @@ _RETRY_SAFE_STREAM_EVENT_TYPES = frozenset({"response.created", "response.in_progress"}) -def _iter_error_chain(error: Exception) -> Iterator[Exception]: - current: Exception | None = error - seen: set[int] = set() - while current is not None and id(current) not in seen: - seen.add(id(current)) - yield current - next_error = current.__cause__ or current.__context__ - current = next_error if isinstance(next_error, Exception) else None - - def _is_conversation_locked_error(error: Exception) -> bool: return ( isinstance(error, BadRequestError) and getattr(error, "code", "") == "conversation_locked" ) -def _get_header_value(headers: Any, key: str) -> str | None: - normalized_key = key.lower() - if isinstance(headers, httpx.Headers): - value = headers.get(key) - return value if isinstance(value, str) else None - if isinstance(headers, Mapping): - for header_name, header_value in headers.items(): - if str(header_name).lower() == normalized_key and isinstance(header_value, str): - return header_value - return None - - -def _extract_headers(error: Exception) -> httpx.Headers | Mapping[str, str] | None: - for candidate in _iter_error_chain(error): - response = getattr(candidate, "response", None) - if isinstance(response, httpx.Response): - return response.headers - - for attr_name in ("headers", "response_headers"): - headers = getattr(candidate, attr_name, None) - if isinstance(headers, httpx.Headers | Mapping): - return headers - - return None - - -def _parse_retry_after(headers: httpx.Headers | Mapping[str, str] | None) -> float | None: - if headers is None: - return None - - retry_after_ms = _get_header_value(headers, "retry-after-ms") - if retry_after_ms is not None: - try: - parsed_ms = float(retry_after_ms) / 1000.0 - except ValueError: - parsed_ms = None - if parsed_ms is not None and parsed_ms >= 0: - return parsed_ms - - retry_after = _get_header_value(headers, "retry-after") - if retry_after is None: - return None - - try: - parsed_seconds = float(retry_after) - except ValueError: - parsed_seconds = None - if parsed_seconds is not None: - return parsed_seconds if parsed_seconds >= 0 else None - - try: - retry_datetime = parsedate_to_datetime(retry_after) - except (TypeError, ValueError, IndexError): - return None - - return max(retry_datetime.timestamp() - time.time(), 0.0) - - -def _get_status_code(error: Exception) -> int | None: - for candidate in _iter_error_chain(error): - if isinstance(candidate, APIStatusError): - return candidate.status_code - - for attr_name in ("status_code", "status"): - value = getattr(candidate, attr_name, None) - if isinstance(value, int): - return value - - return None - - -def _get_error_code(error: Exception) -> str | None: - for candidate in _iter_error_chain(error): - error_code = getattr(candidate, "code", None) - if isinstance(error_code, str): - return error_code - - body = getattr(candidate, "body", None) - if isinstance(body, Mapping): - nested_error = body.get("error") - if isinstance(nested_error, Mapping): - nested_code = nested_error.get("code") - if isinstance(nested_code, str): - return nested_code - body_code = body.get("code") - if isinstance(body_code, str): - return body_code - return None - - -def _get_request_id(error: Exception) -> str | None: - for candidate in _iter_error_chain(error): - request_id = getattr(candidate, "request_id", None) - if isinstance(request_id, str): - return request_id - return None - - def _is_abort_like_error(error: Exception) -> bool: if isinstance(error, asyncio.CancelledError): return True @@ -211,7 +106,7 @@ def _normalize_retry_error( error_code=_get_error_code(error), message=str(error), request_id=_get_request_id(error), - retry_after=_parse_retry_after(_extract_headers(error)), + retry_after=_get_retry_after(error), is_abort=_is_abort_like_error(error), is_network_error=_is_network_like_error(error), is_timeout=any( diff --git a/src/agents/sandbox/sandboxes/docker.py b/src/agents/sandbox/sandboxes/docker.py index ae160fc978..c54c95b843 100644 --- a/src/agents/sandbox/sandboxes/docker.py +++ b/src/agents/sandbox/sandboxes/docker.py @@ -49,6 +49,7 @@ from ..session.base_sandbox_session import BaseSandboxSession from ..session.dependencies import Dependencies from ..session.manager import Instrumentation +from ..session.pty_output import collect_pty_output from ..session.pty_types import ( PTY_PROCESSES_MAX, PTY_PROCESSES_WARNING, @@ -57,7 +58,6 @@ clamp_pty_yield_time_ms, process_id_to_prune_from_meta, resolve_pty_write_yield_time_ms, - truncate_text_by_tokens, ) from ..session.runtime_helpers import RESOLVE_WORKSPACE_PATH_HELPER, RuntimeHelperScript from ..session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions @@ -164,6 +164,7 @@ class DockerSandboxSession(BaseSandboxSession): _pty_lock: asyncio.Lock _pty_processes: dict[int, _DockerPtyProcessEntry] _reserved_pty_process_ids: set[int] + _cleanup_tasks: set[asyncio.Task[None]] state: DockerSandboxSessionState _ARCHIVE_STAGING_DIR: Path = posix_path_as_path( @@ -185,6 +186,7 @@ def __init__( self._pty_lock = asyncio.Lock() self._pty_processes = {} self._reserved_pty_process_ids = set() + self._cleanup_tasks = set() @classmethod def from_state( @@ -764,6 +766,12 @@ async def _shutdown_backend(self) -> None: # If the container is already gone/stopped, ignore. pass + async def _after_stop(self) -> None: + await self._wait_for_cleanup_tasks() + + async def _after_shutdown(self) -> None: + await self._wait_for_cleanup_tasks() + @staticmethod def _start_exec_socket(*, api: Any, exec_id: str, tty: bool = False) -> _DockerExecSocket: if not all( @@ -1068,36 +1076,14 @@ async def _collect_pty_output( yield_time_ms: int, max_output_tokens: int | None, ) -> tuple[bytes, int | None]: - deadline = time.monotonic() + (yield_time_ms / 1000) - output = bytearray() - - while True: - async with entry.output_lock: - while entry.output_chunks: - output.extend(entry.output_chunks.popleft()) - - if time.monotonic() >= deadline: - break - - if entry.output_closed.is_set(): - async with entry.output_lock: - while entry.output_chunks: - output.extend(entry.output_chunks.popleft()) - break - - remaining_s = deadline - time.monotonic() - if remaining_s <= 0: - break - - try: - await asyncio.wait_for(entry.output_notify.wait(), timeout=remaining_s) - except asyncio.TimeoutError: - break - entry.output_notify.clear() - - text = output.decode("utf-8", errors="replace") - truncated_text, original_token_count = truncate_text_by_tokens(text, max_output_tokens) - return truncated_text.encode("utf-8", errors="replace"), original_token_count + return await collect_pty_output( + output_chunks=entry.output_chunks, + output_lock=entry.output_lock, + output_notify=entry.output_notify, + is_done=entry.output_closed.is_set, + yield_time_ms=yield_time_ms, + max_output_tokens=max_output_tokens, + ) async def _finalize_pty_update( self, @@ -1282,7 +1268,13 @@ async def hydrate_workspace(self, data: io.IOBase) -> None: def _schedule_rm_best_effort(self, path: Path) -> None: loop = asyncio.get_running_loop() - loop.create_task(self._rm_best_effort(path)) + task = loop.create_task(self._rm_best_effort(path)) + self._cleanup_tasks.add(task) + task.add_done_callback(self._cleanup_tasks.discard) + + async def _wait_for_cleanup_tasks(self) -> None: + while cleanup_tasks := tuple(self._cleanup_tasks): + await asyncio.gather(*cleanup_tasks, return_exceptions=True) def _workspace_archive_stream( self, diff --git a/src/agents/sandbox/sandboxes/unix_local.py b/src/agents/sandbox/sandboxes/unix_local.py index 6bedcb5fa0..240d041ff9 100644 --- a/src/agents/sandbox/sandboxes/unix_local.py +++ b/src/agents/sandbox/sandboxes/unix_local.py @@ -45,6 +45,7 @@ from ..session.base_sandbox_session import BaseSandboxSession from ..session.dependencies import Dependencies from ..session.manager import Instrumentation +from ..session.pty_output import collect_pty_output from ..session.pty_types import ( PTY_PROCESSES_MAX, PTY_PROCESSES_WARNING, @@ -53,7 +54,6 @@ clamp_pty_yield_time_ms, process_id_to_prune_from_meta, resolve_pty_write_yield_time_ms, - truncate_text_by_tokens, ) from ..session.sandbox_client import BaseSandboxClient, BaseSandboxClientOptions from ..session.workspace_payloads import coerce_write_payload @@ -70,6 +70,7 @@ _DEFAULT_MANIFEST_ROOT = cast(str, Manifest.model_fields["root"].default) _PTY_READ_CHUNK_BYTES = 16_384 _PTY_CHILD_SIGNAL_DEFAULTS = (signal.SIGINT, signal.SIGQUIT) +_PTY_FD_CLOSE_GRACE_SECONDS = 0.1 logger = logging.getLogger(__name__) @@ -130,6 +131,7 @@ class UnixLocalSandboxSession(BaseSandboxSession): _pty_lock: asyncio.Lock _pty_processes: dict[int, _UnixPtyProcessEntry] _reserved_pty_process_ids: set[int] + _fd_close_tasks: set[asyncio.Task[None]] def __init__(self, *, state: UnixLocalSandboxSessionState) -> None: self.state = state @@ -137,6 +139,7 @@ def __init__(self, *, state: UnixLocalSandboxSessionState) -> None: self._pty_lock = asyncio.Lock() self._pty_processes = {} self._reserved_pty_process_ids = set() + self._fd_close_tasks = set() @classmethod def from_state(cls, state: UnixLocalSandboxSessionState) -> "UnixLocalSandboxSession": @@ -192,10 +195,14 @@ async def provision_manifest_accounts(self) -> None: ) async def _after_shutdown(self) -> None: + await self._wait_for_fd_close_tasks() # Best-effort: mark session not running. We intentionally do not delete the workspace # directory here; cleanup is handled by the Client.delete(). self._running = False + async def _after_stop(self) -> None: + await self._wait_for_fd_close_tasks() + async def _resolve_exposed_port(self, port: int) -> ExposedPortEndpoint: return ExposedPortEndpoint(host="127.0.0.1", port=port, tls=False) @@ -476,36 +483,14 @@ async def _collect_pty_output( yield_time_ms: int, max_output_tokens: int | None, ) -> tuple[bytes, int | None]: - deadline = time.monotonic() + (yield_time_ms / 1000) - output = bytearray() - - while True: - async with entry.output_lock: - while entry.output_chunks: - output.extend(entry.output_chunks.popleft()) - - if time.monotonic() >= deadline: - break - - if entry.output_closed.is_set(): - async with entry.output_lock: - while entry.output_chunks: - output.extend(entry.output_chunks.popleft()) - break - - remaining_s = deadline - time.monotonic() - if remaining_s <= 0: - break - - try: - await asyncio.wait_for(entry.output_notify.wait(), timeout=remaining_s) - except asyncio.TimeoutError: - break - entry.output_notify.clear() - - text = output.decode("utf-8", errors="replace") - truncated_text, original_token_count = truncate_text_by_tokens(text, max_output_tokens) - return truncated_text.encode("utf-8", errors="replace"), original_token_count + return await collect_pty_output( + output_chunks=entry.output_chunks, + output_lock=entry.output_lock, + output_notify=entry.output_notify, + is_done=entry.output_closed.is_set, + yield_time_ms=yield_time_ms, + max_output_tokens=max_output_tokens, + ) async def _finalize_pty_update( self, @@ -564,9 +549,9 @@ async def _terminate_pty_entry(self, entry: _UnixPtyProcessEntry) -> None: if entry.tty: if primary_fd is not None: # On macOS we have observed os.close() on the PTY master fd block while a - # background reader thread is still inside os.read(). Close it off-thread so - # session teardown remains best-effort and non-blocking. - asyncio.create_task(asyncio.to_thread(_close_fd_quietly, primary_fd)) + # background reader thread is still inside os.read(). Keep the close task owned + # by the session without making PTY termination wait indefinitely for it. + self._schedule_fd_close(primary_fd) entry.output_closed.set() entry.output_notify.set() return @@ -577,6 +562,16 @@ async def _terminate_pty_entry(self, entry: _UnixPtyProcessEntry) -> None: if entry.wait_task is not None: await asyncio.gather(entry.wait_task, return_exceptions=True) + def _schedule_fd_close(self, fd: int) -> None: + task = asyncio.create_task(asyncio.to_thread(_close_fd_quietly, fd)) + self._fd_close_tasks.add(task) + task.add_done_callback(self._fd_close_tasks.discard) + + async def _wait_for_fd_close_tasks(self) -> None: + tasks = tuple(self._fd_close_tasks) + if tasks: + await asyncio.wait(tasks, timeout=_PTY_FD_CLOSE_GRACE_SECONDS) + def _confined_exec_command( self, *, diff --git a/src/agents/sandbox/session/pty_output.py b/src/agents/sandbox/session/pty_output.py new file mode 100644 index 0000000000..25cbe774e7 --- /dev/null +++ b/src/agents/sandbox/session/pty_output.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import asyncio +import time +from collections import deque +from collections.abc import Callable + +from .pty_types import truncate_text_by_tokens + + +async def collect_pty_output( + *, + output_chunks: deque[bytes], + output_lock: asyncio.Lock, + output_notify: asyncio.Event, + is_done: Callable[[], bool], + yield_time_ms: int, + max_output_tokens: int | None, +) -> tuple[bytes, int | None]: + """Collect and truncate PTY output until the deadline or provider completion.""" + deadline = time.monotonic() + (yield_time_ms / 1000) + output = bytearray() + + while True: + async with output_lock: + while output_chunks: + output.extend(output_chunks.popleft()) + + if time.monotonic() >= deadline: + break + + if is_done(): + async with output_lock: + while output_chunks: + output.extend(output_chunks.popleft()) + break + + remaining_s = deadline - time.monotonic() + if remaining_s <= 0: + break + + try: + await asyncio.wait_for(output_notify.wait(), timeout=remaining_s) + except asyncio.TimeoutError: + break + output_notify.clear() + + text = output.decode("utf-8", errors="replace") + truncated, original_token_count = truncate_text_by_tokens(text, max_output_tokens) + return truncated.encode("utf-8", errors="replace"), original_token_count diff --git a/tests/extensions/sandbox/test_daytona.py b/tests/extensions/sandbox/test_daytona.py index c8276cf545..407b8a9a62 100644 --- a/tests/extensions/sandbox/test_daytona.py +++ b/tests/extensions/sandbox/test_daytona.py @@ -1526,6 +1526,43 @@ async def test_session_reader_keeps_entry_live_when_logs_fail_without_exit_code( assert entry.done is False assert entry.exit_code is None + @pytest.mark.asyncio + async def test_terminate_pty_entry_awaits_worker_finalizer( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + daytona_module = _load_daytona_module(monkeypatch) + sandbox = _FakeDaytonaSandbox() + state = daytona_module.DaytonaSandboxSessionState( + manifest=Manifest(root=daytona_module.DEFAULT_DAYTONA_WORKSPACE_ROOT), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.id, + ) + session = daytona_module.DaytonaSandboxSession.from_state(state, sandbox=sandbox) + entry = daytona_module._DaytonaPtySessionEntry( # noqa: SLF001 + daytona_session_id="session-123", + pty_handle=object(), + tty=False, + cmd_id="cmd-123", + ) + finalizer_finished = asyncio.Event() + + async def worker() -> None: + try: + await asyncio.Event().wait() + finally: + await asyncio.sleep(0) + finalizer_finished.set() + + entry.worker_task = asyncio.create_task(worker()) + await asyncio.sleep(0) + + await session._terminate_pty_entry(entry) # noqa: SLF001 + + assert finalizer_finished.is_set() + assert entry.worker_task is None + assert sandbox.process.delete_session_calls == ["session-123"] + # --------------------------------------------------------------------------- # DaytonaCloudBucketMountStrategy tests diff --git a/tests/extensions/sandbox/test_e2b.py b/tests/extensions/sandbox/test_e2b.py index dca74be216..d96985afdb 100644 --- a/tests/extensions/sandbox/test_e2b.py +++ b/tests/extensions/sandbox/test_e2b.py @@ -16,12 +16,14 @@ from pydantic import Field, PrivateAttr import agents.extensions.sandbox.e2b.sandbox as e2b_module +from agents.extensions.sandbox._rclone import ( + ensure_rclone as _ensure_rclone, + rclone_pattern_for_session as _rclone_pattern_for_session, +) from agents.extensions.sandbox.e2b.mounts import ( E2BCloudBucketMountStrategy, _assert_e2b_session, _ensure_fuse_support, - _ensure_rclone, - _rclone_pattern_for_session, ) from agents.extensions.sandbox.e2b.sandbox import ( E2BSandboxClient, diff --git a/tests/extensions/sandbox/test_runloop_mounts.py b/tests/extensions/sandbox/test_runloop_mounts.py index e3eb55351a..3a071515dd 100644 --- a/tests/extensions/sandbox/test_runloop_mounts.py +++ b/tests/extensions/sandbox/test_runloop_mounts.py @@ -135,7 +135,7 @@ def test_runloop_session_guard_accepts_correct_type() -> None: @pytest.mark.asyncio async def test_runloop_ensure_rclone_installs_with_root_apt() -> None: - from agents.extensions.sandbox.runloop.mounts import _ensure_rclone + from agents.extensions.sandbox._rclone import ensure_rclone session = _FakeRunloopMountSession( [ @@ -147,7 +147,7 @@ async def test_runloop_ensure_rclone_installs_with_root_apt() -> None: ] ) - await _ensure_rclone(session) + await ensure_rclone(session) assert session.exec_calls[:2] == [ "sh -lc command -v rclone >/dev/null 2>&1 || test -x /usr/local/bin/rclone", @@ -215,10 +215,10 @@ async def test_runloop_ensure_fuse_installs_missing_fusermount() -> None: @pytest.mark.asyncio async def test_runloop_rclone_pattern_adds_fuse_access_args() -> None: - from agents.extensions.sandbox.runloop.mounts import _rclone_pattern_for_session + from agents.extensions.sandbox._rclone import rclone_pattern_for_session session = _FakeRunloopMountSession([_exec_ok(stdout=b"1000\n1000\n")]) - pattern = await _rclone_pattern_for_session(session, RcloneMountPattern(mode="fuse")) + pattern = await rclone_pattern_for_session(session, RcloneMountPattern(mode="fuse")) assert pattern.extra_args == ["--allow-other", "--uid", "1000", "--gid", "1000"] diff --git a/tests/models/test_openai_retry_helpers.py b/tests/models/test_openai_retry_helpers.py index 5c2252c275..b29605146c 100644 --- a/tests/models/test_openai_retry_helpers.py +++ b/tests/models/test_openai_retry_helpers.py @@ -12,16 +12,17 @@ import httpx -from agents.models._openai_retry import ( - _get_error_code, - _get_header_value, - _get_status_code, - _header_lookup, - _parse_retry_after, - _parse_retry_after_ms, - get_openai_retry_advice, +from agents.models._openai_retry import get_openai_retry_advice +from agents.models._retry_runtime import ( + get_error_code as _get_error_code, + get_error_header as _get_header_value, + get_status_code as _get_status_code, + header_lookup as _header_lookup, + parse_retry_after_ms as _parse_retry_after_ms, + parse_retry_after_value as _parse_retry_after, ) from agents.retry import ModelRetryAdviceRequest +from agents.run_internal.model_retry import _normalize_retry_error class _HeaderError(Exception): @@ -99,6 +100,25 @@ class _TopLevelBody(Exception): assert _get_error_code(Exception("none")) is None +def test_provider_and_runner_retry_normalization_share_metadata() -> None: + class _RetryableError(Exception): + status_code = 429 + request_id = "req_test" + body = {"error": {"code": "rate_limit_exceeded"}} + headers = {"retry-after-ms": "1500"} + + error = _RetryableError("slow down") + advice = get_openai_retry_advice(_make_request(error)) + runner_normalized = _normalize_retry_error(error, None) + + assert advice is not None + assert advice.normalized is not None + assert advice.normalized.status_code == runner_normalized.status_code + assert advice.normalized.error_code == runner_normalized.error_code + assert advice.normalized.request_id == runner_normalized.request_id + assert advice.normalized.retry_after == runner_normalized.retry_after + + def test_advice_unsafe_to_replay() -> None: error = Exception("cannot replay") error.unsafe_to_replay = True # type: ignore[attr-defined] diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index be1bcee7b7..169d8644ee 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -1,6 +1,6 @@ import asyncio import json -from datetime import datetime, timedelta +import time from types import SimpleNamespace from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch @@ -815,7 +815,7 @@ async def test_speech_started_skips_truncate_when_audio_complete(self, model, mo model._audio_state_tracker.on_audio_delta("i1", 0, b"a" * 48_000) state = model._audio_state_tracker.get_state("i1", 0) assert state is not None - state.initial_received_time = datetime.now() - timedelta(seconds=5) + state.initial_received_time = time.monotonic() - 5 monkeypatch.setattr( model, @@ -846,7 +846,7 @@ async def test_speech_started_truncates_when_response_ongoing(self, model, monke model._audio_state_tracker.on_audio_delta("i1", 0, b"a" * 48_000) state = model._audio_state_tracker.get_state("i1", 0) assert state is not None - state.initial_received_time = datetime.now() - timedelta(seconds=5) + state.initial_received_time = time.monotonic() - 5 model._ongoing_response = True monkeypatch.setattr( diff --git a/tests/realtime/test_playback_tracker.py b/tests/realtime/test_playback_tracker.py index a0a284b17a..783e1460e4 100644 --- a/tests/realtime/test_playback_tracker.py +++ b/tests/realtime/test_playback_tracker.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest @@ -135,6 +135,16 @@ def test_audio_state_accumulation_across_deltas(self): expected_length = (8 / (24_000 * 2)) * 1000 assert state.audio_length_ms == pytest.approx(expected_length, rel=0, abs=1e-6) + def test_audio_state_uses_monotonic_timestamp(self): + tracker = ModelAudioTracker() + + with patch("agents.realtime._default_tracker.time.monotonic", return_value=42.5): + tracker.on_audio_delta("item_1", 0, b"test") + + state = tracker.get_state("item_1", 0) + assert state is not None + assert state.initial_received_time == 42.5 + def test_state_cleanup_on_interruption(self): """Test both trackers properly reset state on interruption.""" diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 018f63b344..86778ba56a 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -322,14 +322,71 @@ async def failing_task(): session._on_guardrail_task_done(task) - # Allow event task to enqueue - await asyncio.sleep(0.01) - # Should have a RealtimeError queued err = await session._event_queue.get() assert isinstance(err, RealtimeError) +@pytest.mark.asyncio +async def test_close_awaits_task_finalizers_without_recancelling(): + session = RealtimeSession(_DummyModel(), RealtimeAgent(name="agent"), None) + finalizer_started = asyncio.Event() + release_finalizer = asyncio.Event() + finalizer_finished = asyncio.Event() + + async def background_task() -> None: + try: + await asyncio.Event().wait() + finally: + finalizer_started.set() + await release_finalizer.wait() + finalizer_finished.set() + + task = asyncio.create_task(background_task()) + session._guardrail_tasks.add(task) + await asyncio.sleep(0) + + close_task = asyncio.create_task(session.close()) + await finalizer_started.wait() + await asyncio.sleep(0) + assert not close_task.done() + + release_finalizer.set() + await close_task + + assert finalizer_finished.is_set() + assert session._guardrail_tasks == set() + + +@pytest.mark.asyncio +async def test_close_rescans_for_tasks_created_during_unwind(): + session = RealtimeSession(_DummyModel(), RealtimeAgent(name="agent"), None) + spawned_task_finished = asyncio.Event() + + async def spawned_task() -> None: + try: + await asyncio.Event().wait() + finally: + spawned_task_finished.set() + + async def original_task() -> None: + try: + await asyncio.Event().wait() + finally: + task = asyncio.create_task(spawned_task()) + session._tool_call_tasks.add(task) + await asyncio.sleep(0) + + task = asyncio.create_task(original_task()) + session._tool_call_tasks.add(task) + await asyncio.sleep(0) + + await session.close() + + assert spawned_task_finished.is_set() + assert session._tool_call_tasks == set() + + @pytest.mark.asyncio async def test_get_handoffs_async_is_enabled(monkeypatch): # Agent includes both a direct Handoff and a RealtimeAgent (auto-converted) diff --git a/tests/realtime/test_session_exceptions.py b/tests/realtime/test_session_exceptions.py index da93902368..84602e4181 100644 --- a/tests/realtime/test_session_exceptions.py +++ b/tests/realtime/test_session_exceptions.py @@ -249,16 +249,11 @@ async def test_exception_during_guardrail_processing( session = RealtimeSession(fake_model, fake_agent, None) - # Add some fake guardrail tasks - fake_task1 = Mock() - fake_task1.done.return_value = False - fake_task1.cancel = Mock() - - fake_task2 = Mock() - fake_task2.done.return_value = True - fake_task2.cancel = Mock() - - session._guardrail_tasks = {fake_task1, fake_task2} + # Add one pending and one completed guardrail task. + pending_task = asyncio.create_task(asyncio.Event().wait()) + completed_task = asyncio.create_task(asyncio.sleep(0)) + await completed_task + session._guardrail_tasks = {pending_task, completed_task} fake_model.set_next_events([exception_event]) @@ -268,8 +263,8 @@ async def test_exception_during_guardrail_processing( pass # Verify guardrail tasks were properly cleaned up - fake_task1.cancel.assert_called_once() - fake_task2.cancel.assert_not_called() # Already done + assert pending_task.cancelled() + assert not completed_task.cancelled() assert len(session._guardrail_tasks) == 0 @pytest.mark.asyncio diff --git a/tests/sandbox/test_docker.py b/tests/sandbox/test_docker.py index 52701274eb..3e7f890e21 100644 --- a/tests/sandbox/test_docker.py +++ b/tests/sandbox/test_docker.py @@ -677,9 +677,10 @@ async def test_docker_persist_workspace_defers_stage_cleanup_until_archive_close assert session.stage_cleanup_calls == [] _ = archive.read() - await asyncio.sleep(0) + await session._wait_for_cleanup_tasks() assert session.stage_cleanup_calls == [session.last_staging_parent] + assert session._cleanup_tasks == set() def test_docker_start_exec_socket_closes_underlying_http_response() -> None: diff --git a/tests/sandbox/test_unix_local.py b/tests/sandbox/test_unix_local.py index 6f34273f02..bb1680538a 100644 --- a/tests/sandbox/test_unix_local.py +++ b/tests/sandbox/test_unix_local.py @@ -1,7 +1,10 @@ from __future__ import annotations +import asyncio import signal from pathlib import Path +from types import SimpleNamespace +from typing import cast import pytest @@ -11,6 +14,7 @@ UnixLocalSandboxClient, UnixLocalSandboxSession, UnixLocalSandboxSessionState, + _UnixPtyProcessEntry, ) from agents.sandbox.snapshot import NoopSnapshot from agents.sandbox.types import ExecResult, User @@ -37,6 +41,40 @@ async def _exec_internal( class TestUnixLocalPty: + @pytest.mark.asyncio + async def test_tty_fd_close_is_owned_without_blocking_termination( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + session = _RecordingUnixLocalSession(tmp_path) + close_started = asyncio.Event() + release_close = asyncio.Event() + + async def blocked_to_thread(*args: object, **kwargs: object) -> None: + _ = (args, kwargs) + close_started.set() + await release_close.wait() + + monkeypatch.setattr(asyncio, "to_thread", blocked_to_thread) + process = cast( + asyncio.subprocess.Process, + SimpleNamespace(returncode=0, pid=None), + ) + entry = _UnixPtyProcessEntry(process=process, tty=True, primary_fd=123) + + await asyncio.wait_for(session._terminate_pty_entry(entry), timeout=0.5) + await close_started.wait() + + assert len(session._fd_close_tasks) == 1 + await asyncio.wait_for(session._wait_for_fd_close_tasks(), timeout=0.5) + assert len(session._fd_close_tasks) == 1 + + release_close.set() + await asyncio.gather(*session._fd_close_tasks) + await asyncio.sleep(0) + assert session._fd_close_tasks == set() + @pytest.mark.asyncio async def test_pty_exec_write_poll_and_unknown_session_errors(self, tmp_path: Path) -> None: client = UnixLocalSandboxClient()