Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
AgentFrameworkHost,
HostedRunResult,
)
from opentelemetry import context as otel_context
from opentelemetry import trace
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.testclient import TestClient

from agent_framework_hosting_responses import ResponsesChannel
Expand Down Expand Up @@ -516,6 +519,41 @@ def test_stateful_call_and_result_content_coalesce_across_messages(self) -> None


class TestResponsesChannelStreaming:
def test_sse_streaming_uses_request_parent_span_context(self) -> None:
observed: dict[str, int] = {}
parent_ctx = trace.SpanContext(
trace_id=0xABCDEF00112233445566778899AABBCC,
span_id=0x1122334455667788,
is_remote=False,
trace_flags=trace.TraceFlags(0x01),
trace_state=trace.TraceState(),
)
parent_span = trace.NonRecordingSpan(parent_ctx)

class _SpanAwareAgent(_FakeAgent):
def run(self, messages: Any = None, *, stream: bool = False, **kwargs: Any) -> Any:
self.calls.append({"messages": messages, "stream": stream, "kwargs": kwargs})
if stream:
observed["run_span_id"] = trace.get_current_span().get_span_context().span_id
return _FakeStream(["chunk"])
return super().run(messages=messages, stream=stream, **kwargs)

async def _middleware_dispatch(request: Any, call_next: Any) -> Any:
token = otel_context.attach(trace.set_span_in_context(parent_span))
try:
return await call_next(request)
finally:
otel_context.detach(token)

host = AgentFrameworkHost(target=_SpanAwareAgent(), channels=[ResponsesChannel()])
host.app.add_middleware(BaseHTTPMiddleware, dispatch=_middleware_dispatch)

with TestClient(host.app) as client:
r = client.post("/responses", json={"input": "hi", "stream": True})

assert r.status_code == 200
assert observed["run_span_id"] == parent_ctx.span_id

def test_sse_emits_created_delta_completed(self) -> None:
agent = _FakeAgent(reply="hello world", chunks=["hello", " ", "world"])
host = AgentFrameworkHost(target=agent, channels=[ResponsesChannel()])
Expand Down
57 changes: 53 additions & 4 deletions python/packages/hosting/agent_framework_hosting/_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import os
import uuid
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable, Mapping, Sequence
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager
from contextlib import AbstractContextManager, ExitStack, asynccontextmanager, contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

Expand All @@ -42,6 +42,8 @@
Workflow,
WorkflowEvent,
)
from opentelemetry import context as otel_context
from opentelemetry import trace
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
Expand Down Expand Up @@ -187,6 +189,20 @@ async def _apply_response_hook(
return out


def _capture_current_otel_context() -> object | None:
"""Capture the current OTel context when a valid span is active.

Streaming channels can defer target iteration until after the route handler
has returned (for example, `StreamingResponse`). Capturing the current OTel
context at stream-construction time lets the host restore strict parent-child
span linkage during deferred pulls and finalization.
"""
current_span_context = trace.get_current_span().get_span_context()
if not current_span_context.is_valid:
return None
return otel_context.get_current()


def _workflow_event_to_update(event: WorkflowEvent[Any]) -> AgentResponseUpdate | None:
"""Map a :class:`WorkflowEvent` to a channel-friendly :class:`AgentResponseUpdate`.

Expand Down Expand Up @@ -288,9 +304,16 @@ class _BoundResponseStream:
already closed the stack.
"""

def __init__(self, inner: Any, stack: ExitStack) -> None:
def __init__(
self,
inner: Any,
stack: ExitStack,
*,
otel_context_snapshot: object | None = None,
) -> None:
self._inner = inner
self._stack = stack
self._otel_context_snapshot = otel_context_snapshot
self._closed = False

def _close(self) -> None:
Expand All @@ -299,6 +322,18 @@ def _close(self) -> None:
self._closed = True
self._stack.close()

@contextmanager
def _activate_otel_context(self) -> Any:
"""Re-activate the captured OTel parent context for deferred work."""
if self._otel_context_snapshot is None:
yield
return
token = otel_context.attach(cast("Any", self._otel_context_snapshot))
try:
yield
finally:
otel_context.detach(token)

async def aclose(self) -> None:
"""Idempotently release the bound request context.

Expand All @@ -321,15 +356,23 @@ def __aiter__(self) -> AsyncIterator[Any]:
return self._wrap()

async def _wrap(self) -> AsyncIterator[Any]:
with self._activate_otel_context():
iterator = self._inner.__aiter__()
try:
async for item in self._inner:
while True:
try:
with self._activate_otel_context():
item = await iterator.__anext__()
except StopAsyncIteration:
break
yield item
finally:
self._close()

async def get_final_response(self) -> Any:
try:
return await self._inner.get_final_response()
with self._activate_otel_context():
return await self._inner.get_final_response()
finally:
self._close()

Expand Down Expand Up @@ -1126,9 +1169,15 @@ def _invoke_stream(self, request: ChannelRequest) -> ResponseStream[AgentRespons
# stream in an adapter that holds the binding open across the
# iteration lifecycle.
binder = self._bind_request_context(request)
# Capture the request-parent OTel context BEFORE ``target.run``.
# Python evaluates positional args before keyword args, so doing
# this inline in the ``_BoundResponseStream(...)`` call would run
# ``target.run(...)`` first and may capture a shifted context.
otel_context_snapshot = _capture_current_otel_context()
return _BoundResponseStream( # type: ignore[return-value]
self.target.run(self._wrap_input(request), stream=True, **run_kwargs),
binder,
otel_context_snapshot=otel_context_snapshot,
)
Comment thread
eavanvalkenburg marked this conversation as resolved.

def _resolve_checkpoint_storage(self, request: ChannelRequest) -> CheckpointStorage | None:
Expand Down
113 changes: 113 additions & 0 deletions python/packages/hosting/tests/hosting/test_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import pytest
from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Content, Message, ResponseStream
from agent_framework._workflows._events import WorkflowEvent
from opentelemetry import context as otel_context
from opentelemetry import trace
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute, Route
Expand Down Expand Up @@ -1186,6 +1188,117 @@ async def test_await_path_routes_through_get_final_response(self) -> None:
assert names.count("enter") == 1
assert names.count("exit") == 1

async def test_deferred_streaming_keeps_captured_otel_parent_context(self) -> None:
"""`run_stream()` captures the current OTel context and reuses it for deferred pulls.

Reproduces channel behavior where stream consumption starts later than stream
construction (for example via StreamingResponse body iteration).
"""

class _SpanRecordingAgent:
id = "span-recorder"
name: str | None = "SpanRecorder"
description: str | None = "Records active span ids during stream pulls/finalization."

def __init__(self) -> None:
self.seen_span_ids: list[int] = []

def run(self, messages: Any = None, *, stream: bool = False, **kwargs: Any) -> Any:
if not stream:
raise AssertionError("non-streaming path not exercised here")

async def _gen() -> AsyncIterator[AgentResponseUpdate]:
self.seen_span_ids.append(trace.get_current_span().get_span_context().span_id)
yield AgentResponseUpdate(contents=[Content.from_text("chunk")], role="assistant")

async def _finalize(items: Sequence[AgentResponseUpdate]) -> AgentResponse: # noqa: RUF029
self.seen_span_ids.append(trace.get_current_span().get_span_context().span_id)
return AgentResponse.from_updates(items)

return ResponseStream(_gen(), finalizer=_finalize)

agent = _SpanRecordingAgent()
ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=cast(Any, agent), channels=[ch])
_ = host.app
assert ch.context is not None

req = ChannelRequest(
channel="responses",
operation="op",
input="hi",
stream=True,
attributes={"response_id": "resp_otel"},
)

parent_ctx = trace.SpanContext(
trace_id=0x123456789ABCDEF0123456789ABCDEF0,
span_id=0x123456789ABCDEF0,
is_remote=False,
trace_flags=trace.TraceFlags(0x01),
trace_state=trace.TraceState(),
)
parent_span = trace.NonRecordingSpan(parent_ctx)
token = otel_context.attach(trace.set_span_in_context(parent_span))
try:
stream = await ch.context.run_stream(req)
finally:
otel_context.detach(token)

# Consumption happens after the caller context has ended.
chunks = [u.text async for u in stream]
final = await stream.get_final_response()

assert chunks == ["chunk"]
assert final.text == "chunk"
assert agent.seen_span_ids == [parent_ctx.span_id, parent_ctx.span_id]

async def test_run_stream_captures_otel_context_before_target_run(self, monkeypatch: Any) -> None:
"""Guard the evaluation-order pitfall called out in review.

``_invoke_stream`` must capture OTel context before calling
``target.run(...)``. If that order flips, deferred streaming can bind to
the wrong parent context.
"""

from agent_framework_hosting import _host as host_module

order: list[str] = []

def _capture() -> None:
order.append("capture")
return

monkeypatch.setattr(host_module, "_capture_current_otel_context", _capture)

class _OrderAgent:
id = "order-agent"
name: str | None = "OrderAgent"
description: str | None = "Records call order."

def run(self, messages: Any = None, *, stream: bool = False, **kwargs: Any) -> Any:
order.append("run")

async def _gen() -> AsyncIterator[AgentResponseUpdate]:
yield AgentResponseUpdate(contents=[Content.from_text("chunk")], role="assistant")

async def _finalize(items: Sequence[AgentResponseUpdate]) -> AgentResponse: # noqa: RUF029
return AgentResponse.from_updates(items)

return ResponseStream(_gen(), finalizer=_finalize)

ch = _RecordingChannel(name="responses")
host = AgentFrameworkHost(target=cast(Any, _OrderAgent()), channels=[ch])
_ = host.app
assert ch.context is not None

stream = await ch.context.run_stream(
ChannelRequest(channel="responses", operation="op", input="hi", stream=True),
)
await cast(Any, stream).aclose()

assert order[:2] == ["capture", "run"]


# --------------------------------------------------------------------------- #
# `_wrap_input` — list[Message] LAST-message metadata stamping #
Expand Down
Loading