diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py index 81e4a27302..a75d29abc4 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py @@ -124,14 +124,28 @@ def _request_payload_from_request_event(request_event: Any) -> dict[str, Any] | def _extract_responses_from_messages(messages: list[Message]) -> dict[str, Any]: - """Extract request-info responses from incoming tool/function-result messages.""" + """Extract request-info responses from incoming messages. + + Handles both ``function_result`` content (keyed by ``call_id``) and + ``function_approval_response`` content (keyed by ``id``), so that + approval decisions sent via messages are forwarded into the workflow + responses map. + """ responses: dict[str, Any] = {} for message in messages: for content in message.contents: - if content.type != "function_result" or not content.call_id: - continue - value = _coerce_json_value(content.result) - responses[str(content.call_id)] = value + if content.type == "function_result" and content.call_id: + value = _coerce_json_value(content.result) + responses[str(content.call_id)] = value + elif content.type == "function_approval_response" and getattr(content, "id", None): + approval_value: dict[str, Any] = { + "approved": getattr(content, "approved", False), + "id": str(content.id), # type: ignore[union-attr] + } + func_call = getattr(content, "function_call", None) + if func_call is not None: + approval_value["function_call"] = make_json_safe(func_call.to_dict()) + responses[str(content.id)] = approval_value # type: ignore[union-attr] return responses diff --git a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py index 8ebd8fcaaa..26b44b03ba 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py @@ -33,6 +33,7 @@ _custom_event_value, _details_code, _details_message, + _extract_responses_from_messages, _interrupt_entry_for_request_event, _latest_assistant_contents, _latest_user_text, @@ -1172,9 +1173,253 @@ def test_details_without_error_type(self): assert _details_code(details) is None +class TestExtractResponsesFromMessages: + """Tests for _extract_responses_from_messages helper.""" + + def test_function_result_extracted(self): + """function_result content is extracted keyed by call_id.""" + result = Content.from_function_result(call_id="call-1", result="ok") + messages = [Message(role="tool", contents=[result])] + responses = _extract_responses_from_messages(messages) + assert responses == {"call-1": "ok"} + + def test_function_result_without_call_id_skipped(self): + """function_result with no call_id is ignored.""" + result = Content.from_function_result(call_id="", result="ok") + messages = [Message(role="tool", contents=[result])] + responses = _extract_responses_from_messages(messages) + assert responses == {} + + def test_function_approval_response_extracted(self): + """function_approval_response content is extracted keyed by id.""" + func_call = Content.from_function_call( + call_id="call-1", + name="do_action", + arguments={"x": 1}, + ) + approval = Content.from_function_approval_response( + approved=True, + id="approval-1", + function_call=func_call, + ) + messages = [Message(role="user", contents=[approval])] + responses = _extract_responses_from_messages(messages) + assert "approval-1" in responses + assert responses["approval-1"]["approved"] is True + assert responses["approval-1"]["id"] == "approval-1" + assert "function_call" in responses["approval-1"] + + def test_denied_approval_response_extracted(self): + """Denied function_approval_response is extracted with approved=False.""" + func_call = Content.from_function_call( + call_id="call-2", + name="delete_item", + arguments={}, + ) + approval = Content.from_function_approval_response( + approved=False, + id="approval-2", + function_call=func_call, + ) + messages = [Message(role="user", contents=[approval])] + responses = _extract_responses_from_messages(messages) + assert "approval-2" in responses + assert responses["approval-2"]["approved"] is False + + def test_mixed_result_and_approval(self): + """Both function_result and function_approval_response are extracted.""" + result = Content.from_function_result(call_id="call-1", result="done") + func_call = Content.from_function_call( + call_id="call-2", + name="submit", + arguments={}, + ) + approval = Content.from_function_approval_response( + approved=True, + id="approval-1", + function_call=func_call, + ) + messages = [ + Message(role="tool", contents=[result]), + Message(role="user", contents=[approval]), + ] + responses = _extract_responses_from_messages(messages) + assert "call-1" in responses + assert responses["call-1"] == "done" + assert "approval-1" in responses + assert responses["approval-1"]["approved"] is True + + def test_mixed_result_and_approval_same_message(self): + """Both function_result and function_approval_response in the same message are extracted.""" + result = Content.from_function_result(call_id="call-1", result="done") + func_call = Content.from_function_call( + call_id="call-2", + name="submit", + arguments={}, + ) + approval = Content.from_function_approval_response( + approved=True, + id="approval-1", + function_call=func_call, + ) + messages = [Message(role="tool", contents=[result, approval])] + responses = _extract_responses_from_messages(messages) + assert "call-1" in responses + assert responses["call-1"] == "done" + assert "approval-1" in responses + assert responses["approval-1"]["approved"] is True + + def test_text_content_skipped(self): + """Non-result, non-approval content is ignored.""" + text = Content.from_text(text="hello") + messages = [Message(role="user", contents=[text])] + responses = _extract_responses_from_messages(messages) + assert responses == {} + + def test_empty_messages(self): + """Empty message list returns empty responses.""" + assert _extract_responses_from_messages([]) == {} + + # ── Stream integration tests ── +async def test_workflow_run_approval_via_messages_approved() -> None: + """Approval response sent via messages (function_approvals) should satisfy the pending request.""" + + class ApprovalExecutor(Executor): + def __init__(self) -> None: + super().__init__(id="approval_executor") + + @handler + async def start(self, message: Any, ctx: WorkflowContext) -> None: + del message + function_call = Content.from_function_call( + call_id="refund-call", + name="submit_refund", + arguments={"order_id": "12345", "amount": "$89.99"}, + ) + approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call) + await ctx.request_info(approval_request, Content, request_id="approval-1") + + @response_handler + async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: + del original_request + status = "approved" if bool(response.approved) else "rejected" + await ctx.yield_output(f"Refund {status}.") + + workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() + first_events = [ + event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow) + ] + first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump() + interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt")) + assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1 + + # Second turn: send approval via function_approvals on a message (not resume.interrupts) + resumed_events = [ + event + async for event in run_workflow_stream( + { + "messages": [ + { + "role": "user", + "content": "", + "function_approvals": [ + { + "approved": True, + "id": "approval-1", + "call_id": "refund-call", + "name": "submit_refund", + "arguments": {"order_id": "12345", "amount": "$89.99"}, + } + ], + } + ], + }, + workflow, + ) + ] + + resumed_types = [event.type for event in resumed_events] + assert "RUN_STARTED" in resumed_types + assert "RUN_FINISHED" in resumed_types + assert "RUN_ERROR" not in resumed_types + assert "TEXT_MESSAGE_CONTENT" in resumed_types + text_deltas = [event.delta for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT"] + assert any("approved" in delta for delta in text_deltas) + resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump() + assert not resumed_finished.get("interrupt") + + +async def test_workflow_run_approval_via_messages_denied() -> None: + """Denied approval response sent via messages (function_approvals) should satisfy the pending request.""" + + class ApprovalExecutor(Executor): + def __init__(self) -> None: + super().__init__(id="approval_executor") + + @handler + async def start(self, message: Any, ctx: WorkflowContext) -> None: + del message + function_call = Content.from_function_call( + call_id="delete-call", + name="delete_record", + arguments={"record_id": "abc"}, + ) + approval_request = Content.from_function_approval_request(id="deny-1", function_call=function_call) + await ctx.request_info(approval_request, Content, request_id="deny-1") + + @response_handler + async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: + del original_request + status = "approved" if bool(response.approved) else "rejected" + await ctx.yield_output(f"Delete {status}.") + + workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() + first_events = [ + event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow) + ] + first_finished = [event for event in first_events if event.type == "RUN_FINISHED"][0].model_dump() + interrupt_payload = cast(list[dict[str, Any]], first_finished.get("interrupt")) + assert isinstance(interrupt_payload, list) and len(interrupt_payload) == 1 + + # Second turn: send denial via function_approvals on a message (not resume.interrupts) + resumed_events = [ + event + async for event in run_workflow_stream( + { + "messages": [ + { + "role": "user", + "content": "", + "function_approvals": [ + { + "approved": False, + "id": "deny-1", + "call_id": "delete-call", + "name": "delete_record", + "arguments": {"record_id": "abc"}, + } + ], + } + ], + }, + workflow, + ) + ] + + resumed_types = [event.type for event in resumed_events] + assert "RUN_STARTED" in resumed_types + assert "RUN_FINISHED" in resumed_types + assert "RUN_ERROR" not in resumed_types + assert "TEXT_MESSAGE_CONTENT" in resumed_types + text_deltas = [event.delta for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT"] + assert any("rejected" in delta for delta in text_deltas) + resumed_finished = [event for event in resumed_events if event.type == "RUN_FINISHED"][0].model_dump() + assert not resumed_finished.get("interrupt") + + async def test_workflow_run_available_interrupts_logged(): """available_interrupts in input data should be logged without errors."""