diff --git a/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py b/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py index b4e5f8c9d3..0a7963dcf7 100644 --- a/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py +++ b/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import contextvars import mimetypes import os import shutil @@ -162,7 +163,8 @@ def _wrapped() -> tuple[bool, Any]: del exc return False, (exc_type, exc_args) - ok, payload = self._executor.submit(_wrapped).result() + current_context = contextvars.copy_context() + ok, payload = self._executor.submit(lambda: current_context.run(_wrapped)).result() if ok: return cast(_T, payload) exc_type, exc_args = cast(tuple[type[BaseException], tuple[str, ...]], payload) @@ -830,12 +832,13 @@ async def _invoke() -> Any: # registered callbacks synchronously via FFI, so this must be a sync function. # We run the async call on a dedicated thread to avoid conflicts with any # event loop that may be running on the current thread. + current_context = contextvars.copy_context() result_box: list[Any] = [None] error_box: list[BaseException] = [] def _run() -> None: try: - result_box[0] = asyncio.run(_invoke()) + result_box[0] = current_context.run(lambda: asyncio.run(_invoke())) except BaseException as exc: error_box.append(exc) diff --git a/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py b/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py index 4fa79348f3..c72b870847 100644 --- a/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py +++ b/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py @@ -6,6 +6,7 @@ import contextlib import dataclasses import gc +import importlib import importlib.metadata import importlib.util import inspect @@ -13,11 +14,12 @@ import sys import threading import time -from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence +from collections.abc import Awaitable, Callable, Coroutine, Generator, Mapping, Sequence from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory from typing import Any, cast +from unittest.mock import patch import pytest from agent_framework import ( @@ -32,6 +34,10 @@ ResponseStream, tool, ) +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from pytest import fixture from agent_framework_hyperlight import AllowedDomain, FileMount, HyperlightCodeActProvider, HyperlightExecuteCodeTool from agent_framework_hyperlight import _execute_code_tool as execute_code_module @@ -90,6 +96,54 @@ def _skip_if_hyperlight_integration_runtime_disabled() -> None: ) +@fixture +def span_exporter(monkeypatch) -> Generator[InMemorySpanExporter]: + env_vars = [ + "ENABLE_INSTRUMENTATION", + "ENABLE_SENSITIVE_DATA", + "ENABLE_CONSOLE_EXPORTERS", + "OTEL_EXPORTER_OTLP_ENDPOINT", + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", + "OTEL_EXPORTER_OTLP_LOGS_ENDPOINT", + "OTEL_EXPORTER_OTLP_PROTOCOL", + "OTEL_EXPORTER_OTLP_HEADERS", + "OTEL_EXPORTER_OTLP_TRACES_HEADERS", + "OTEL_EXPORTER_OTLP_METRICS_HEADERS", + "OTEL_EXPORTER_OTLP_LOGS_HEADERS", + "OTEL_SERVICE_NAME", + "OTEL_SERVICE_VERSION", + "OTEL_RESOURCE_ATTRIBUTES", + ] + + for key in env_vars: + monkeypatch.delenv(key, raising=False) + monkeypatch.setenv("ENABLE_INSTRUMENTATION", "True") + monkeypatch.setenv("ENABLE_SENSITIVE_DATA", "True") + + import agent_framework.observability as observability + from opentelemetry import trace + + importlib.reload(observability) + observability_settings = observability.ObservabilitySettings() + tracer_provider = TracerProvider(resource=observability.create_resource()) + trace.set_tracer_provider(tracer_provider) + monkeypatch.setattr(observability, "OBSERVABILITY_SETTINGS", observability_settings, raising=False) + + with ( + patch("agent_framework.observability.OBSERVABILITY_SETTINGS", observability_settings), + patch("agent_framework.observability.configure_otel_providers"), + ): + exporter = InMemorySpanExporter() + current_tracer_provider = trace.get_tracer_provider() + if not hasattr(current_tracer_provider, "add_span_processor"): + raise RuntimeError("Tracer provider does not support adding span processors.") + + cast(Any, current_tracer_provider).add_span_processor(SimpleSpanProcessor(exporter)) + yield exporter + exporter.clear() + + @pytest.fixture(scope="module") def shared_sandbox(): """Long-lived sandbox with snapshot/restore for read-mostly tests. @@ -379,6 +433,38 @@ def test_execute_code_tool_replaces_tools_with_the_same_name() -> None: assert execute_code.approval_mode == "always_require" +async def test_execute_code_tool_parent_span_for_host_tools( + span_exporter: InMemorySpanExporter, + monkeypatch: pytest.MonkeyPatch, +) -> None: + _FakeSandbox.instances.clear() + monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _FakeSandbox) + + execute_code = HyperlightExecuteCodeTool(tools=[compute]) + + span_exporter.clear() + await execute_code.invoke( + arguments={"code": 'total = call_tool("compute", a=20, b=22)\nprint(total)'}, + skip_parsing=True, + ) + + spans = span_exporter.get_finished_spans() + execute_code_spans = [span for span in spans if span.name == "execute_tool execute_code"] + compute_spans = [span for span in spans if span.name == "execute_tool compute"] + + assert len(execute_code_spans) == 1 + assert len(compute_spans) == 1 + + execute_code_span = execute_code_spans[0] + compute_span = compute_spans[0] + + assert compute_span.context is not None + assert execute_code_span.context is not None + assert compute_span.context.trace_id == execute_code_span.context.trace_id + assert compute_span.parent is not None + assert compute_span.parent.span_id == execute_code_span.context.span_id + + def test_execute_code_tool_accepts_string_and_tuple_file_mounts_without_mode_flags( tmp_path: Path, monkeypatch: pytest.MonkeyPatch,