Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 153 additions & 44 deletions src/agents/sandbox/sandboxes/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,69 @@

logger = logging.getLogger(__name__)


# Non-seekable payloads are spooled to measure their length; keep small ones in
# RAM and spill larger ones to a temp file so a big upload can't OOM the process.
_STREAM_SPOOL_MAX_SIZE = 16 * 1024 * 1024


def _measure_stream(stream: io.IOBase) -> tuple[int, io.IOBase, io.IOBase | None]:
"""Return ``(length, readable_stream, spool_to_close)`` for a length-framed write.

Seekable streams are measured in place (and rewound); ``spool_to_close`` is
``None``. Non-seekable streams (e.g. an HTTP response body or pipe) are copied
into a ``SpooledTemporaryFile`` — kept in memory up to
``_STREAM_SPOOL_MAX_SIZE``, spilled to disk beyond it — so the byte length can
be determined without buffering the whole payload in RAM; the spool is returned
so the caller can close it.

Callers run this on the executor thread, never the event loop.
"""
try:
start = stream.tell()
stream.seek(0, io.SEEK_END)
end = stream.tell()
stream.seek(start)
# Clamp to 0: a stream positioned past its end has no readable bytes, and
# a negative count would become `head -c -N` ("all but the last N bytes"),
# which reads to EOF and re-hangs over a TLS stdin.
return max(0, end - start), stream, None
except (AttributeError, OSError, ValueError):
spool: Any = tempfile.SpooledTemporaryFile(max_size=_STREAM_SPOOL_MAX_SIZE)
length = 0
while True:
chunk = stream.read(1024 * 1024)
if not chunk:
break
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
length += len(chunk)
spool.write(chunk)
spool.seek(0)
return length, spool, spool


# POSIX sh that pipes exactly ``<n>`` bytes into the real command (``"$@"``).
# ``head -c`` bounds the read so completion never depends on a stdin half-close
# (unreliable over a TLS DOCKER_HOST). A plain ``head -c "$n" | "$@"`` pipeline
# reports only the *consumer's* exit status, so a missing/failed ``head`` — e.g. a
# custom image without coreutils/busybox ``head`` — would be masked: the consumer
# (``cat``/``tar``) sees an empty pipe, exits 0, and the write "succeeds" while
# creating an empty file. Capture the producer's status via a temp file and fail
# the exec (exit 98) if it is non-zero, so such writes surface as errors instead
# of silent data loss. ``pipefail`` is avoided because it is not POSIX (dash).
_LENGTH_FRAMED_STDIN_SCRIPT = (
"n=$1; shift; "
'status_file="${TMPDIR:-/tmp}/.agents-sandbox-framed.$$"; '
'{ head -c "$n"; echo "$?" >"$status_file"; } | "$@"; '
"consumer_status=$?; "
'producer_status="$(cat "$status_file" 2>/dev/null)"; '
'rm -f "$status_file"; '
'[ "$producer_status" = 0 ] || exit 98; '
'exit "$consumer_status"'
)


_PREPARE_USER_PTY_PID_SCRIPT = (
'pid_path="$1"\n'
'pid_user="$2"\n'
Expand Down Expand Up @@ -562,57 +625,103 @@ async def _stream_into_exec(
error_path: Path,
user: str | User | None = None,
) -> None:
# Frame the payload by length so the in-container reader terminates on a
# byte count rather than a stdin half-close. Docker's exec-attach stream
# does not carry a reliable stdin EOF over a TLS DOCKER_HOST: the
# ``shutdown(SHUT_WR)`` below is silently swallowed, so ``tar -x`` / ``cat``
# would block forever waiting for input that never ends (observed against
# Docker-in-Docker sidecars and remote daemons reached over TLS). Piping
# the real command through ``head -c <n>`` makes it stop after exactly
# ``<n>`` bytes, independent of transport, and keeps the deliberate
# avoidance of ``put_archive()`` (see ``write``) intact.
def _write() -> int | None:
container_client = self._container.client
assert container_client is not None
api = container_client.api
resp = api.exec_create(
self._container.id,
cmd,
stdin=True,
stdout=True,
stderr=True,
workdir=None,
user=self._coerce_exec_user(user) or "",
)
exec_socket = self._start_exec_socket(api=api, exec_id=cast(str, resp["Id"]))
sock = exec_socket.sock
raw_sock = exec_socket.raw_sock
try:
while True:
chunk = stream.read(1024 * 1024)
if not chunk:
break
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
elif not isinstance(chunk, bytes):
chunk = bytes(chunk)
if hasattr(raw_sock, "sendall"):
raw_sock.sendall(chunk)
else:
cast(Any, sock).write(chunk)

try:
if hasattr(raw_sock, "shutdown"):
raw_sock.shutdown(socket.SHUT_WR)
else:
cast(Any, sock).flush()
except Exception:
pass

# Measure/spool on this executor thread (never the event loop). A
# non-seekable stream is spooled to a SpooledTemporaryFile (bounded
# memory, then disk) rather than read whole into RAM.
payload_length, read_stream, spool = _measure_stream(stream)
try:
framed_cmd = [
"sh",
"-c",
_LENGTH_FRAMED_STDIN_SCRIPT,
"sh",
str(payload_length),
*cmd,
]
resp = api.exec_create(
self._container.id,
framed_cmd,
stdin=True,
stdout=True,
stderr=True,
workdir=None,
user=self._coerce_exec_user(user) or "",
)
exec_socket = self._start_exec_socket(api=api, exec_id=cast(str, resp["Id"]))
sock = exec_socket.sock
raw_sock = exec_socket.raw_sock
try:
if hasattr(raw_sock, "recv"):
while raw_sock.recv(1024 * 1024):
pass
else:
while cast(Any, sock).read(1024 * 1024):
pass
except Exception:
pass
# Send exactly ``payload_length`` bytes — the count the exec
# was framed with (``head -c "$n"``). Reading to EOF instead
# would desync if the stream changed after _measure_stream:
# extra bytes would pile up behind a ``head`` that already
# stopped, and a short read would leave ``head`` blocked on a
# TLS stdin that never EOFs (the original hang). If the stream
# ends early we fail loudly rather than truncate silently.
remaining = payload_length
while remaining > 0:
chunk = read_stream.read(min(1024 * 1024, remaining))
if not chunk:
raise WorkspaceArchiveWriteError(
path=error_path,
context={
"reason": "stream_shorter_than_measured",
"expected": str(payload_length),
"sent": str(payload_length - remaining),
},
)
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
elif not isinstance(chunk, bytes):
chunk = bytes(chunk)
if len(chunk) > remaining:
# Only reachable for multibyte text streams (never the
# byte streams these writes use); cap to the framed count.
chunk = chunk[:remaining]
if hasattr(raw_sock, "sendall"):
raw_sock.sendall(chunk)
else:
cast(Any, sock).write(chunk)
remaining -= len(chunk)

try:
if hasattr(raw_sock, "shutdown"):
raw_sock.shutdown(socket.SHUT_WR)
else:
cast(Any, sock).flush()
except Exception:
pass

try:
if hasattr(raw_sock, "recv"):
while raw_sock.recv(1024 * 1024):
pass
else:
while cast(Any, sock).read(1024 * 1024):
pass
except Exception:
pass
finally:
exec_socket.close()

return cast(int | None, api.exec_inspect(resp["Id"]).get("ExitCode"))
finally:
exec_socket.close()

return cast(int | None, api.exec_inspect(resp["Id"]).get("ExitCode"))
if spool is not None:
spool.close()

loop = asyncio.get_running_loop()
try:
Expand Down
Loading
Loading