Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import contextvars
import mimetypes
import os
import shutil
Expand Down Expand Up @@ -162,7 +163,8 @@ def _wrapped() -> tuple[bool, Any]:
del exc
return False, (exc_type, exc_args)

ok, payload = self._executor.submit(_wrapped).result()
current_context = contextvars.copy_context()
ok, payload = self._executor.submit(lambda: current_context.run(_wrapped)).result()
if ok:
return cast(_T, payload)
exc_type, exc_args = cast(tuple[type[BaseException], tuple[str, ...]], payload)
Expand Down Expand Up @@ -830,12 +832,13 @@ async def _invoke() -> Any:
# registered callbacks synchronously via FFI, so this must be a sync function.
# We run the async call on a dedicated thread to avoid conflicts with any
# event loop that may be running on the current thread.
current_context = contextvars.copy_context()
result_box: list[Any] = [None]
error_box: list[BaseException] = []

def _run() -> None:
try:
result_box[0] = asyncio.run(_invoke())
result_box[0] = current_context.run(lambda: asyncio.run(_invoke()))
except BaseException as exc:
error_box.append(exc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
import contextlib
import dataclasses
import gc
import importlib
import importlib.metadata
import importlib.util
import inspect
import json
import sys
import threading
import time
from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence
from collections.abc import Awaitable, Callable, Coroutine, Generator, Mapping, Sequence
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, cast
from unittest.mock import patch

import pytest
from agent_framework import (
Expand All @@ -32,6 +34,10 @@
ResponseStream,
tool,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from pytest import fixture

from agent_framework_hyperlight import AllowedDomain, FileMount, HyperlightCodeActProvider, HyperlightExecuteCodeTool
from agent_framework_hyperlight import _execute_code_tool as execute_code_module
Expand Down Expand Up @@ -90,6 +96,54 @@ def _skip_if_hyperlight_integration_runtime_disabled() -> None:
)


@fixture
def span_exporter(monkeypatch) -> Generator[InMemorySpanExporter]:
env_vars = [
"ENABLE_INSTRUMENTATION",
"ENABLE_SENSITIVE_DATA",
"ENABLE_CONSOLE_EXPORTERS",
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT",
"OTEL_EXPORTER_OTLP_METRICS_ENDPOINT",
"OTEL_EXPORTER_OTLP_LOGS_ENDPOINT",
"OTEL_EXPORTER_OTLP_PROTOCOL",
"OTEL_EXPORTER_OTLP_HEADERS",
"OTEL_EXPORTER_OTLP_TRACES_HEADERS",
"OTEL_EXPORTER_OTLP_METRICS_HEADERS",
"OTEL_EXPORTER_OTLP_LOGS_HEADERS",
"OTEL_SERVICE_NAME",
"OTEL_SERVICE_VERSION",
"OTEL_RESOURCE_ATTRIBUTES",
]

for key in env_vars:
monkeypatch.delenv(key, raising=False)
monkeypatch.setenv("ENABLE_INSTRUMENTATION", "True")
monkeypatch.setenv("ENABLE_SENSITIVE_DATA", "True")

import agent_framework.observability as observability
from opentelemetry import trace

importlib.reload(observability)
observability_settings = observability.ObservabilitySettings()
tracer_provider = TracerProvider(resource=observability.create_resource())
trace.set_tracer_provider(tracer_provider)
monkeypatch.setattr(observability, "OBSERVABILITY_SETTINGS", observability_settings, raising=False)

with (
patch("agent_framework.observability.OBSERVABILITY_SETTINGS", observability_settings),
patch("agent_framework.observability.configure_otel_providers"),
):
exporter = InMemorySpanExporter()
current_tracer_provider = trace.get_tracer_provider()
if not hasattr(current_tracer_provider, "add_span_processor"):
raise RuntimeError("Tracer provider does not support adding span processors.")

cast(Any, current_tracer_provider).add_span_processor(SimpleSpanProcessor(exporter))
yield exporter
exporter.clear()


@pytest.fixture(scope="module")
def shared_sandbox():
"""Long-lived sandbox with snapshot/restore for read-mostly tests.
Expand Down Expand Up @@ -379,6 +433,38 @@ def test_execute_code_tool_replaces_tools_with_the_same_name() -> None:
assert execute_code.approval_mode == "always_require"


async def test_execute_code_tool_parent_span_for_host_tools(
span_exporter: InMemorySpanExporter,
monkeypatch: pytest.MonkeyPatch,
) -> None:
_FakeSandbox.instances.clear()
monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _FakeSandbox)

execute_code = HyperlightExecuteCodeTool(tools=[compute])

span_exporter.clear()
await execute_code.invoke(
arguments={"code": 'total = call_tool("compute", a=20, b=22)\nprint(total)'},
skip_parsing=True,
)

spans = span_exporter.get_finished_spans()
execute_code_spans = [span for span in spans if span.name == "execute_tool execute_code"]
compute_spans = [span for span in spans if span.name == "execute_tool compute"]

assert len(execute_code_spans) == 1
assert len(compute_spans) == 1

execute_code_span = execute_code_spans[0]
compute_span = compute_spans[0]

assert compute_span.context is not None
assert execute_code_span.context is not None
assert compute_span.context.trace_id == execute_code_span.context.trace_id
assert compute_span.parent is not None
assert compute_span.parent.span_id == execute_code_span.context.span_id


def test_execute_code_tool_accepts_string_and_tuple_file_mounts_without_mode_flags(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
Expand Down
Loading