Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/agents/tracing/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import threading
from typing import TYPE_CHECKING

from ..logger import logger

if TYPE_CHECKING:
from .provider import TraceProvider

Expand All @@ -19,16 +21,31 @@ def _shutdown_global_trace_provider() -> None:


def set_trace_provider(provider: TraceProvider) -> None:
"""Set the global trace provider used by tracing utilities."""
"""Set the global trace provider used by tracing utilities.

If a provider is already set and is being replaced, the previous provider
is shut down so any background threads, network clients, or other resources
held by its processors are released.
"""
global GLOBAL_TRACE_PROVIDER
global _SHUTDOWN_HANDLER_REGISTERED

with _GLOBAL_TRACE_PROVIDER_LOCK:
previous = GLOBAL_TRACE_PROVIDER
GLOBAL_TRACE_PROVIDER = provider
if not _SHUTDOWN_HANDLER_REGISTERED:
atexit.register(_shutdown_global_trace_provider)
_SHUTDOWN_HANDLER_REGISTERED = True

# Shut down inside the lock so a concurrent `set_trace_provider(previous)`
# cannot reinstall `previous` between releasing the lock and the shutdown
# call, which would close the processors of the now-active provider.
if previous is not None and previous is not provider:
try:
previous.shutdown()
except Exception as exc:
logger.error(f"Error shutting down previous trace provider: {exc}")


def get_trace_provider() -> TraceProvider:
"""Get the global trace provider used by tracing utilities.
Expand Down
44 changes: 44 additions & 0 deletions tests/tracing/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,50 @@ def fake_register(callback: Any) -> Any:

assert cast(Any, tracing_setup.GLOBAL_TRACE_PROVIDER) is second
assert registrations == [tracing_setup._shutdown_global_trace_provider]
# The replaced provider should be shut down so it does not leak its
# background threads, network clients, or other processor resources.
assert first.shutdown_calls == 1
assert second.shutdown_calls == 0


def test_set_trace_provider_swallows_previous_shutdown_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _ExplodingProvider:
def __init__(self) -> None:
self.shutdown_calls = 0

def shutdown(self) -> None:
self.shutdown_calls += 1
raise RuntimeError("boom")

monkeypatch.setattr(atexit, "register", lambda callback: callback)
monkeypatch.setattr(tracing_setup, "GLOBAL_TRACE_PROVIDER", None)
monkeypatch.setattr(tracing_setup, "_SHUTDOWN_HANDLER_REGISTERED", False)

first = _ExplodingProvider()
second = _DummyProvider()
tracing_setup.set_trace_provider(cast(Any, first))
# Failure inside ``previous.shutdown()`` must not propagate or leave the
# global state pointing at the old provider.
tracing_setup.set_trace_provider(cast(Any, second))

assert first.shutdown_calls == 1
assert cast(Any, tracing_setup.GLOBAL_TRACE_PROVIDER) is second


def test_set_trace_provider_skips_shutdown_when_same_instance(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(atexit, "register", lambda callback: callback)
monkeypatch.setattr(tracing_setup, "GLOBAL_TRACE_PROVIDER", None)
monkeypatch.setattr(tracing_setup, "_SHUTDOWN_HANDLER_REGISTERED", False)

provider = _DummyProvider()
tracing_setup.set_trace_provider(cast(Any, provider))
tracing_setup.set_trace_provider(cast(Any, provider))

assert provider.shutdown_calls == 0


def test_get_trace_provider_returns_existing_provider(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down