Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 88 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from __future__ import annotations

import asyncio
import contextlib
import copy
import logging
from collections.abc import AsyncGenerator, Sequence
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from inspect import isawaitable
from typing import Any

Expand All @@ -29,6 +31,78 @@

logger = logging.getLogger(__name__)

# Default idle interval after which a transport-level SSE keepalive comment is emitted.
DEFAULT_KEEPALIVE_INTERVAL_SECONDS = 15.0

# An SSE comment line (a line beginning with ``:``) is a protocol no-op: clients and
# parsers ignore it, but it forces a write on the wire so idle-timeout proxies in front
# of the endpoint keep the connection open during long silent gaps between events.
_SSE_KEEPALIVE_COMMENT = ": keepalive\n\n"


class _StreamEnd:
"""Sentinel marking normal completion of the upstream event stream."""


async def _with_sse_keepalive(
events: AsyncIterator[str],
interval_seconds: float,
) -> AsyncGenerator[str]:
"""Yield upstream SSE strings, inserting keepalive comments during idle gaps.

Real events flush immediately and in order; a keepalive comment is emitted only when no
upstream event arrives within ``interval_seconds``, so the heartbeat is limited to idle
periods and never delays or reorders genuine events.

The upstream generator is drained by a single dedicated task that owns its entire lifecycle
(creation, iteration, and cleanup). This keeps every ``__anext__`` and the terminal cleanup
hooks in one ``contextvars`` context, which the agent run/telemetry pipeline requires:
those hooks reset ``ContextVar`` tokens that must be reset in the same context that created
them. Racing each individual pull with ``asyncio.wait_for`` would instead scatter the pulls
across contexts and break that cleanup.
"""
# Bound the queue so it acts as backpressure rather than an unbounded buffer: once the
# consumer (the StreamingResponse) falls behind, ``put`` blocks and the upstream generator
# is throttled to the drain rate instead of buffering arbitrarily many chunks in memory for
# a fast agent. A maxsize of 1 is enough to decouple the single producer pull from the single
# consumer emit while still pausing the producer the moment the consumer stops draining; a
# full queue simply means there is no idle gap, so no keepalive is needed during that window.
queue: asyncio.Queue[str | type[_StreamEnd] | Exception] = asyncio.Queue(maxsize=1)

async def _drain() -> None:
try:
async for event in events:
await queue.put(event)
except asyncio.CancelledError:
# Consumer went away; let cancellation propagate so the upstream generator closes.
raise
except Exception as exc: # noqa: BLE001 - surfaced to the consumer via the queue
await queue.put(exc)
else:
await queue.put(_StreamEnd)

producer = asyncio.ensure_future(_drain())
try:
while True:
try:
item = await asyncio.wait_for(queue.get(), interval_seconds)
except asyncio.TimeoutError:
# Upstream produced nothing within the interval; emit a transport-level heartbeat.
yield _SSE_KEEPALIVE_COMMENT
continue
if item is _StreamEnd:
return
if isinstance(item, BaseException):
raise item
yield item # type: ignore[misc]
finally:
# Stop draining if the consumer goes away (e.g. client disconnect) and let the
# producer task observe cancellation so the upstream generator is closed cleanly.
if not producer.done():
producer.cancel()
with contextlib.suppress(asyncio.CancelledError):
await producer


def _get_snapshot_store(
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow,
Expand Down Expand Up @@ -82,6 +156,7 @@ def add_agent_framework_fastapi_endpoint(
dependencies: Sequence[Depends] | None = None,
snapshot_store: AGUIThreadSnapshotStore | None = None,
snapshot_scope_resolver: SnapshotScopeResolver | None = None,
keepalive_interval_seconds: float | None = DEFAULT_KEEPALIVE_INTERVAL_SECONDS,
) -> None:
"""Add an AG-UI endpoint to a FastAPI app.

Expand All @@ -103,7 +178,14 @@ def add_agent_framework_fastapi_endpoint(
explicit Snapshot Scope resolver.
snapshot_scope_resolver: Optional resolver for the application-defined Snapshot Scope. Required whenever
a snapshot store is configured because an AG-UI Thread id is not an authorization boundary.
keepalive_interval_seconds: Idle interval (in seconds) after which a transport-level SSE keepalive
comment (``: keepalive``) is written to the stream when no agent event has been produced.
This prevents idle-timeout proxies (e.g. Azure ingress, nginx, serverless front doors) from
dropping a healthy but silent event stream during long-running tools. Real events still flush
immediately. Defaults to ``15.0``; pass ``None`` to disable keepalives entirely.
"""
if keepalive_interval_seconds is not None and keepalive_interval_seconds <= 0:
raise ValueError("keepalive_interval_seconds must be a positive number or None.")
protocol_runner: AgentFrameworkAgent | AgentFrameworkWorkflow
if isinstance(agent, AgentFrameworkWorkflow):
protocol_runner = agent
Expand Down Expand Up @@ -208,8 +290,12 @@ async def event_generator() -> AsyncGenerator[str]:
except Exception:
logger.exception("[%s] Failed to encode RUN_ERROR event", path)

stream: AsyncGenerator[str] = event_generator()
if keepalive_interval_seconds is not None:
stream = _with_sse_keepalive(stream, keepalive_interval_seconds)

return StreamingResponse(
event_generator(),
stream,
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
Expand Down
118 changes: 118 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

"""Tests for FastAPI endpoint creation (_endpoint.py)."""

import asyncio
import json
from collections.abc import AsyncIterator
from typing import Any, cast

import pytest
Expand All @@ -22,6 +24,7 @@

from agent_framework_ag_ui import InMemoryAGUIThreadSnapshotStore, add_agent_framework_fastapi_endpoint
from agent_framework_ag_ui._agent import AgentFrameworkAgent
from agent_framework_ag_ui._endpoint import _SSE_KEEPALIVE_COMMENT, _with_sse_keepalive
from agent_framework_ag_ui._workflow import AgentFrameworkWorkflow


Expand Down Expand Up @@ -1844,3 +1847,118 @@ def factory(thread_id: str) -> Any:

runner.clear_thread_workflow("thread-1")
assert runner._resolve_workflow("thread-1", "tenant-b") is not workflow_b


async def test_sse_keepalive_emitted_during_idle_gap_and_real_events_pass_through():
"""A silent gap between upstream events yields keepalive comments without dropping real events."""
first_released = asyncio.Event()

async def upstream() -> AsyncIterator[str]:
# Emit immediately, then go silent until the test explicitly releases the next event,
# simulating a long-running tool that produces no AG-UI events for a while.
yield "data: A\n\n"
await first_released.wait()
yield "data: B\n\n"

# Keep the interval small so the idle gap trips a keepalive quickly, but large enough that a
# loaded CI runner reliably enqueues the first real event before the initial timeout fires.
wrapped = _with_sse_keepalive(upstream(), 0.05)

chunks: list[str] = []
chunks.append(await wrapped.__anext__()) # real event A flushes immediately

# While upstream is idle, the wrapper must surface keepalive comments.
keepalive = await wrapped.__anext__()
assert keepalive == _SSE_KEEPALIVE_COMMENT
chunks.append(keepalive)

# Release the second real event and drain the rest of the stream.
first_released.set()
async for chunk in wrapped:
chunks.append(chunk)

data_lines = [chunk for chunk in chunks if chunk.startswith("data: ")]
assert data_lines == ["data: A\n\n", "data: B\n\n"]
assert _SSE_KEEPALIVE_COMMENT in chunks


async def test_sse_keepalive_not_emitted_when_events_flow_without_gaps():
"""Back-to-back events must pass through untouched with no keepalive comments inserted."""

async def upstream() -> AsyncIterator[str]:
yield "data: X\n\n"
yield "data: Y\n\n"
yield "data: Z\n\n"

chunks = [chunk async for chunk in _with_sse_keepalive(upstream(), 0.05)]

assert chunks == ["data: X\n\n", "data: Y\n\n", "data: Z\n\n"]
assert _SSE_KEEPALIVE_COMMENT not in chunks


async def test_sse_keepalive_wrapper_handles_empty_stream():
"""An upstream that yields nothing terminates cleanly with no keepalives."""

async def upstream() -> AsyncIterator[str]:
return
yield # pragma: no cover - present only to make this an async generator

chunks = [chunk async for chunk in _with_sse_keepalive(upstream(), 0.01)]

assert chunks == []


async def test_sse_keepalive_wrapper_propagates_upstream_errors():
"""Errors raised by the upstream generator surface to the consumer rather than hanging."""

async def upstream() -> AsyncIterator[str]:
yield "data: A\n\n"
raise RuntimeError("boom")

wrapped = _with_sse_keepalive(upstream(), 0.05)
assert await wrapped.__anext__() == "data: A\n\n"
with pytest.raises(RuntimeError, match="boom"):
await wrapped.__anext__()


async def test_endpoint_accepts_keepalive_interval_and_streams_events(build_chat_client):
"""Endpoint accepts keepalive_interval_seconds and still streams the full event sequence."""
app = FastAPI()
agent = Agent(name="test", instructions="Test agent", client=build_chat_client("Keepalive response"))

add_agent_framework_fastapi_endpoint(app, agent, path="/keepalive", keepalive_interval_seconds=0.05)

client = TestClient(app)
response = client.post("/keepalive", json={"messages": [{"role": "user", "content": "Hello"}]})

assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"

event_types = [event.get("type") for event in _decode_sse_events(response)]
assert "RUN_STARTED" in event_types
assert "TEXT_MESSAGE_CONTENT" in event_types
assert "RUN_FINISHED" in event_types


async def test_endpoint_keepalive_can_be_disabled(build_chat_client):
"""Passing keepalive_interval_seconds=None keeps the plain stream behavior."""
app = FastAPI()
agent = Agent(name="test", instructions="Test agent", client=build_chat_client())

add_agent_framework_fastapi_endpoint(app, agent, path="/no-keepalive", keepalive_interval_seconds=None)

client = TestClient(app)
response = client.post("/no-keepalive", json={"messages": [{"role": "user", "content": "Hello"}]})

assert response.status_code == 200
# No keepalive comment lines should appear when the feature is disabled.
assert ": keepalive" not in response.content.decode("utf-8")


async def test_endpoint_rejects_non_positive_keepalive_interval(build_chat_client):
"""A non-positive keepalive interval is rejected at registration time."""
app = FastAPI()
agent = Agent(name="test", instructions="Test agent", client=build_chat_client())

with pytest.raises(ValueError, match="keepalive_interval_seconds must be a positive number or None"):
add_agent_framework_fastapi_endpoint(app, agent, path="/bad-keepalive", keepalive_interval_seconds=0)