Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/packages/core/agent_framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
)
from .exceptions import (
MiddlewareException,
UserInputRequiredException,
WorkflowCheckpointException,
WorkflowConvergenceException,
WorkflowException,
Expand Down Expand Up @@ -291,6 +292,7 @@
"TypeCompatibilityError",
"UpdateT",
"UsageDetails",
"UserInputRequiredException",
"ValidationTypeEnum",
"Workflow",
"WorkflowAgent",
Expand Down
25 changes: 15 additions & 10 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
map_chat_to_agent_update,
normalize_messages,
)
from .exceptions import AgentInvalidResponseException
from .exceptions import AgentInvalidResponseException, UserInputRequiredException
from .observability import AgentTelemetryLayer

if sys.version_info >= (3, 13):
Expand Down Expand Up @@ -538,14 +538,16 @@ async def agent_wrapper(**kwargs: Any) -> str:

if stream_callback is None:
# Use non-streaming mode
return (
await self.run(
input_text,
stream=False,
session=parent_session,
**forwarded_kwargs,
)
).text
response = await self.run(
input_text,
stream=False,
session=parent_session,
**forwarded_kwargs,
)

if response.user_input_requests:
raise UserInputRequiredException(contents=response.user_input_requests)
return response.text

# Use streaming mode - accumulate updates and create final response
response_updates: list[AgentResponseUpdate] = []
Expand All @@ -557,7 +559,10 @@ async def agent_wrapper(**kwargs: Any) -> str:
stream_callback(update)

# Create final text from accumulated updates
return AgentResponse.from_updates(response_updates).text
final_response = AgentResponse.from_updates(response_updates)
if final_response.user_input_requests:
raise UserInputRequiredException(contents=final_response.user_input_requests)
return final_response.text

agent_tool: FunctionTool = FunctionTool(
name=tool_name,
Expand Down
49 changes: 40 additions & 9 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pydantic import BaseModel, Field, ValidationError, create_model

from ._serialization import SerializationMixin
from .exceptions import ToolException
from .exceptions import ToolException, UserInputRequiredException
from .observability import (
OPERATION_DURATION_BUCKET_BOUNDARIES,
OtelAttr,
Expand Down Expand Up @@ -1228,6 +1228,8 @@ async def _auto_invoke_function(
result=function_result,
additional_properties=function_call_content.additional_properties,
)
except UserInputRequiredException:
raise
except Exception as exc:
message = "Error: Function failed."
if config.get("include_detailed_errors", False):
Expand Down Expand Up @@ -1274,6 +1276,8 @@ async def final_function_handler(context_obj: Any) -> Any:
additional_properties=function_call_content.additional_properties,
)
raise
except UserInputRequiredException:
raise
except Exception as exc:
message = "Error: Function failed."
if config.get("include_detailed_errors", False):
Expand Down Expand Up @@ -1384,8 +1388,8 @@ async def _try_execute_function_calls(
async def invoke_with_termination_handling(
function_call: Content,
seq_idx: int,
) -> tuple[Content, bool]:
"""Invoke function and catch MiddlewareTermination, returning (result, should_terminate)."""
) -> tuple[list[Content], bool]:
"""Invoke function and catch MiddlewareTermination, returning (results, should_terminate)."""
try:
result = await _auto_invoke_function(
function_call_content=function_call, # type: ignore[arg-type]
Expand All @@ -1396,24 +1400,48 @@ async def invoke_with_termination_handling(
middleware_pipeline=middleware_pipeline,
config=config,
)
return (result, False)
return ([result], False)
except MiddlewareTermination as exc:
# Middleware requested termination - return result as Content
# exc.result may already be a Content (set by _auto_invoke_function) or raw value
if isinstance(exc.result, Content):
return (exc.result, True)
return ([exc.result], True)
result_content = Content.from_function_result(
call_id=function_call.call_id, # type: ignore[arg-type]
result=exc.result,
)
return (result_content, True)
return ([result_content], True)
except UserInputRequiredException as exc:
# Sub-agent requires user input — propagate the Content items so
# _handle_function_call_results can surface them to the parent response.
if exc.contents:
propagated: list[Content] = []
for idx, item in enumerate(exc.contents):
item.call_id = function_call.call_id # type: ignore[attr-defined]
if not item.id: # type: ignore[attr-defined]
item.id = f"{function_call.call_id}:{idx}" # type: ignore[attr-defined]
propagated.append(item)
if propagated:
return (propagated, False)
return (
[
Content.from_function_result(
call_id=function_call.call_id, # type: ignore[arg-type]
result="Tool requires user input but no request details were provided.",
exception="UserInputRequiredException",
)
],
False,
)

execution_results = await asyncio.gather(*[
invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls)
])

# Unpack results - each is (Content, terminate_flag)
contents: list[Content] = [result[0] for result in execution_results]
# Flatten results in original function_calls order — each task returns (list[Content], terminate_flag)
contents: list[Content] = []
for result_contents, _ in execution_results:
contents.extend(result_contents)
# If any function requested termination, terminate the loop
should_terminate = any(result[1] for result in execution_results)
return (contents, should_terminate)
Expand Down Expand Up @@ -1645,7 +1673,10 @@ def _handle_function_call_results(
) -> FunctionRequestResult:
from ._types import Message

if any(fccr.type in {"function_approval_request", "function_call"} for fccr in function_call_results):
if any(
fccr.type in {"function_approval_request", "function_call"} or fccr.user_input_request
for fccr in function_call_results
):
# Only add items that aren't already in the message (e.g. function_approval_request wrappers).
# Declaration-only function_call items are already present from the LLM response.
new_items = [fccr for fccr in function_call_results if fccr.type != "function_call"]
Expand Down
35 changes: 34 additions & 1 deletion python/packages/core/agent_framework/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
and guidance on choosing the correct exception class.
"""

from __future__ import annotations

import logging
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
from ._types import Content

logger = logging.getLogger("agent_framework")

Expand Down Expand Up @@ -180,6 +185,34 @@ class ToolExecutionException(ToolException):
pass


class UserInputRequiredException(ToolException):
"""Raised when a tool wrapping a sub-agent requires user input to proceed.

This exception carries the ``user_input_request`` Content items emitted by
the sub-agent (e.g., ``oauth_consent_request``, ``function_approval_request``)
so the tool invocation layer can propagate them to the parent agent's response
instead of swallowing them as a generic tool error.

Args:
contents: The user-input-request Content items from the sub-agent response.
message: Human-readable description of why user input is needed.
"""

def __init__(
self,
contents: list[Content],
message: str = "Tool requires user input to proceed.",
) -> None:
"""Create a UserInputRequiredException.

Args:
contents: The user-input-request Content items from the sub-agent response.
message: Human-readable description of why user input is needed.
"""
super().__init__(message, log_level=None)
self.contents: list[Content] = contents


# endregion

# region Middleware Exceptions
Expand Down
68 changes: 68 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,4 +1622,72 @@ async def test_stores_by_default_with_store_false_in_default_options_injects_inm
assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)


# region as_tool user_input_request propagation


async def test_as_tool_raises_on_user_input_request_non_streaming(client: SupportsChatGetResponse) -> None:
"""Test that as_tool raises UserInputRequiredException when the sub-agent response has user_input_requests."""
from agent_framework.exceptions import UserInputRequiredException

# Configure mock client to return a response with oauth_consent_request content
consent_content = Content.from_oauth_consent_request(
consent_link="https://login.microsoftonline.com/consent",
)
client.responses = [ # type: ignore[attr-defined]
ChatResponse(messages=Message(role="assistant", contents=[consent_content])),
]

agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent")
agent_tool = agent.as_tool()

with raises(UserInputRequiredException) as exc_info:
await agent_tool.invoke(arguments=agent_tool.input_model(task="Do something"))

assert len(exc_info.value.contents) == 1
assert exc_info.value.contents[0].type == "oauth_consent_request"
assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent"


async def test_as_tool_raises_on_user_input_request_streaming(client: SupportsChatGetResponse) -> None:
"""Test that as_tool raises UserInputRequiredException in streaming mode."""
from agent_framework.exceptions import UserInputRequiredException

consent_content = Content.from_oauth_consent_request(
consent_link="https://login.microsoftonline.com/consent",
)
client.streaming_responses = [ # type: ignore[attr-defined]
[ChatResponseUpdate(contents=[consent_content], role="assistant")],
]

collected_updates: list[AgentResponseUpdate] = []

def stream_callback(update: AgentResponseUpdate) -> None:
collected_updates.append(update)

agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent")
agent_tool = agent.as_tool(stream_callback=stream_callback)

with raises(UserInputRequiredException) as exc_info:
await agent_tool.invoke(arguments=agent_tool.input_model(task="Do something"))

assert len(exc_info.value.contents) == 1
assert exc_info.value.contents[0].type == "oauth_consent_request"
# Stream callback should still have received the update before the exception
assert len(collected_updates) > 0


async def test_as_tool_returns_text_when_no_user_input_request(client: SupportsChatGetResponse) -> None:
"""Test that as_tool returns text normally when there are no user_input_requests."""
client.responses = [ # type: ignore[attr-defined]
ChatResponse(messages=Message(role="assistant", text="Here is the result")),
]

agent = Agent(client=client, name="NormalAgent", description="Normal agent")
agent_tool = agent.as_tool()

result = await agent_tool.invoke(arguments=agent_tool.input_model(task="Do something"))

assert result == "Here is the result"


# endregion
57 changes: 57 additions & 0 deletions python/packages/core/tests/core/test_function_invocation_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3511,4 +3511,61 @@ def test_dict_overwrites_existing_conversation_id(self):
assert kwargs["chat_options"]["conversation_id"] == "new_id"


# region UserInputRequiredException propagation through tool invocation


async def test_user_input_request_propagates_through_as_tool(chat_client_base: SupportsChatGetResponse):
"""Test that user_input_request content from a sub-agent wrapped as a tool propagates to the parent response.

This is an end-to-end test: sub-agent returns oauth_consent_request →
as_tool raises UserInputRequiredException → invoke_with_termination_handling catches it →
_handle_function_call_results returns "action": "return" → Content ends up in parent response.
"""
from agent_framework.exceptions import UserInputRequiredException

# Create a mock tool that simulates what as_tool does when the sub-agent
# returns user_input_request content
@tool(name="delegate_agent", approval_mode="never_require")
def delegate_tool(task: str) -> str:
raise UserInputRequiredException(
contents=[
Content.from_oauth_consent_request(
consent_link="https://login.microsoftonline.com/consent",
)
]
)

# Parent agent calls the tool, which raises UserInputRequiredException
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="1", name="delegate_agent", arguments='{"task": "do it"}'),
],
)
),
]

response = await chat_client_base.get_response(
[Message(role="user", text="delegate this")],
options={"tool_choice": "auto", "tools": [delegate_tool]},
)

# The oauth_consent_request Content should be in the parent response's assistant message
user_requests = [
content
for msg in response.messages
for content in msg.contents
if isinstance(content, Content) and content.user_input_request
]
assert len(user_requests) == 1
assert user_requests[0].type == "oauth_consent_request"
assert user_requests[0].consent_link == "https://login.microsoftonline.com/consent"
assert user_requests[0].user_input_request is True


# endregion


# endregion