Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion python/packages/ag-ui/agent_framework_ag_ui/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""AgentFrameworkAgent wrapper for AG-UI protocol."""

from collections import OrderedDict
from collections.abc import AsyncGenerator
from typing import Any, cast

Expand Down Expand Up @@ -101,6 +102,14 @@ def __init__(
require_confirmation=require_confirmation,
)

# Server-side registry of pending approval requests.
# Keys are "{thread_id}:{request_id}", values are the function name.
# Populated when approval requests are emitted; consumed when responses arrive.
# Prevents bypass, function name spoofing, and replay attacks.
# Bounded to prevent unbounded growth from abandoned approval requests.
self._pending_approvals: OrderedDict[str, str] = OrderedDict()
self._pending_approvals_max_size: int = 10_000

async def run(
self,
input_data: dict[str, Any],
Expand All @@ -113,5 +122,7 @@ async def run(
Yields:
AG-UI events
"""
async for event in run_agent_stream(input_data, self.agent, self.config):
async for event in run_agent_stream(
input_data, self.agent, self.config, pending_approvals=self._pending_approvals
):
yield event
96 changes: 95 additions & 1 deletion python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,28 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]:
return events


def _evict_oldest_approvals(registry: dict[str, str], max_size: int = 10_000) -> None:
"""Evict the oldest entries from the pending-approvals registry (LRU).

Only effective when *registry* is an ``OrderedDict``; plain dicts are
left untouched because insertion-order eviction is unreliable for them.
"""
if len(registry) <= max_size:
return
try:
while len(registry) > max_size:
registry.popitem(last=False) # type: ignore[call-arg]
except (TypeError, KeyError):
pass


async def _resolve_approval_responses(
messages: list[Any],
tools: list[Any],
agent: SupportsAgentRun,
run_kwargs: dict[str, Any],
pending_approvals: dict[str, str] | None = None,
thread_id: str = "",
) -> None:
"""Execute approved function calls and replace approval content with results.

Expand All @@ -385,13 +402,71 @@ async def _resolve_approval_responses(
tools: List of available tools
agent: The agent instance (to get client and config)
run_kwargs: Kwargs for tool execution
pending_approvals: Server-side registry of pending approval requests.
Keys are ``{thread_id}:{request_id}``, values are function names.
When provided, every approval response is validated against this
registry to prevent bypass, function name spoofing, and replay.
thread_id: The conversation thread ID used to scope registry keys.
"""
fcc_todo = _collect_approval_responses(messages)
if not fcc_todo:
return

approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
rejected_responses = [resp for resp in fcc_todo.values() if not resp.approved]

# Validate every approval response (approved AND rejected) against the
# pending approvals registry. Invalid responses are stripped from messages
# entirely — not converted to rejection results, which would inject
# attacker-controlled content into the LLM conversation.
if pending_approvals is not None and (approved_responses or rejected_responses):
validated: list[Any] = []
validated_rejected: list[Any] = []
invalid_ids: set[str] = set()
for resp in approved_responses + rejected_responses:
resp_id = resp.id or ""
resp_name = resp.function_call.name if resp.function_call else None
registry_key = f"{thread_id}:{resp_id}"

if registry_key not in pending_approvals:
logger.warning(
"Rejected approval response id=%s: no matching pending approval request",
resp_id,
)
invalid_ids.add(resp_id)
continue

pending_name = pending_approvals[registry_key]
if resp_name != pending_name:
logger.warning(
"Rejected approval response id=%s: function name mismatch (response=%s, pending=%s)",
resp_id,
resp_name,
pending_name,
)
invalid_ids.add(resp_id)
continue

# Valid — consume entry to prevent replay
del pending_approvals[registry_key]
if resp.approved:
validated.append(resp)
else:
validated_rejected.append(resp)

# Strip invalid approval responses from messages and fcc_todo so
# _replace_approval_contents_with_results never sees them.
if invalid_ids:
for inv_id in invalid_ids:
fcc_todo.pop(inv_id, None)
for msg in messages:
msg.contents = [
c for c in msg.contents if not (c.type == "function_approval_response" and c.id in invalid_ids)
]

approved_responses = validated
rejected_responses = validated_rejected

approved_function_results: list[Any] = []

# Execute approved tool calls
Expand Down Expand Up @@ -597,6 +672,7 @@ async def run_agent_stream(
input_data: dict[str, Any],
agent: SupportsAgentRun,
config: AgentConfig,
pending_approvals: dict[str, str] | None = None,
) -> AsyncGenerator[BaseEvent]:
"""Run agent and yield AG-UI events.

Expand All @@ -607,6 +683,10 @@ async def run_agent_stream(
input_data: AG-UI request data with messages, state, tools, etc.
agent: The Agent Framework agent to run
config: Agent configuration
pending_approvals: Optional server-side registry of pending approval
requests. Keys are ``{thread_id}:{request_id}``, values are
function names. When provided, approval responses are validated
against this registry to prevent bypass, spoofing, and replay.

Yields:
AG-UI events
Expand Down Expand Up @@ -707,7 +787,7 @@ async def run_agent_stream(
# Resolve approval responses (execute approved tools, replace approvals with results)
# This must happen before running the agent so it sees the tool results
tools_for_execution = tools if tools is not None else server_tools
await _resolve_approval_responses(messages, tools_for_execution, agent, run_kwargs)
await _resolve_approval_responses(messages, tools_for_execution, agent, run_kwargs, pending_approvals, thread_id)

# Defense-in-depth: replace approval payloads in snapshot with actual tool results
# so CopilotKit does not re-send stale approval content on subsequent turns.
Expand Down Expand Up @@ -782,6 +862,20 @@ async def run_agent_stream(
for content in update.contents:
content_type = getattr(content, "type", None)
logger.debug(f"Processing content type={content_type}, message_id={flow.message_id}")

# Register pending approval requests so we can validate responses later
if content_type == "function_approval_request" and pending_approvals is not None:
if content.id and content.function_call and content.function_call.name:
pending_approvals[f"{thread_id}:{content.id}"] = content.function_call.name
# Evict oldest entries if the registry exceeds a safe bound (LRU)
_evict_oldest_approvals(pending_approvals, max_size=10_000)
else:
logger.warning(
"Approval request not registered: missing id=%s, function_call=%s, or function name",
getattr(content, "id", None),
getattr(content, "function_call", None),
)

for event in _emit_content(
content,
flow,
Expand Down
Loading