|
10 | 10 |
|
11 | 11 | import pytest |
12 | 12 | from openai.types.responses import ResponseCustomToolCall, ResponseFunctionToolCall |
| 13 | +from openai.types.responses.response_input_param import ( |
| 14 | + ComputerCallOutput, |
| 15 | + LocalShellCallOutput, |
| 16 | +) |
13 | 17 | from openai.types.responses.response_output_item import LocalShellCall |
14 | 18 | from pydantic_core import ValidationError |
15 | 19 |
|
|
26 | 30 | from agents._run_impl import ( |
27 | 31 | NextStepInterruption, |
28 | 32 | ) |
29 | | -from agents.items import MessageOutputItem, ModelResponse |
| 33 | +from agents.items import MessageOutputItem, ModelResponse, ToolCallOutputItem |
30 | 34 | from agents.run_context import RunContextWrapper |
31 | 35 | from agents.run_state import RunState as RunStateClass |
32 | 36 | from agents.usage import Usage |
@@ -865,3 +869,159 @@ async def test_preserve_persisted_item_counter_when_resuming_streamed_runs(): |
865 | 869 | # Consume events to complete the run |
866 | 870 | async for _ in result.stream_events(): |
867 | 871 | pass |
| 872 | + |
| 873 | + |
| 874 | +@pytest.mark.asyncio |
| 875 | +async def test_preserve_tool_output_types_during_serialization(): |
| 876 | + """Test that tool output types are preserved during run state serialization. |
| 877 | +
|
| 878 | + When serializing a run state, `_convert_output_item_to_protocol` unconditionally |
| 879 | + overwrites every tool output's `type` with `function_call_result`. On restore, |
| 880 | + `_deserialize_items` dispatches on this `type` to choose between |
| 881 | + `FunctionCallOutput`, `ComputerCallOutput`, or `LocalShellCallOutput`, so |
| 882 | + computer/shell/apply_patch outputs that were originally |
| 883 | + `computer_call_output`/`local_shell_call_output` are rehydrated as |
| 884 | + `function_call_output` (or fail validation), losing the tool-specific payload |
| 885 | + and breaking resumption for those tools. |
| 886 | +
|
| 887 | + This test will FAIL when the bug exists (output type will be function_call_result) |
| 888 | + and PASS when fixed (output type will be preserved as computer_call_output or |
| 889 | + local_shell_call_output). |
| 890 | + """ |
| 891 | + |
| 892 | + model = FakeModel() |
| 893 | + agent = Agent(name="TestAgent", model=model, tools=[]) |
| 894 | + |
| 895 | + # Create a RunState with a computer tool output |
| 896 | + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) |
| 897 | + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=3) |
| 898 | + |
| 899 | + # Create a computer_call_output item |
| 900 | + computer_output: ComputerCallOutput = { |
| 901 | + "type": "computer_call_output", |
| 902 | + "call_id": "call_computer_1", |
| 903 | + "output": {"type": "computer_screenshot", "image_url": "base64_screenshot_data"}, |
| 904 | + } |
| 905 | + computer_output_item = ToolCallOutputItem( |
| 906 | + agent=agent, |
| 907 | + raw_item=computer_output, |
| 908 | + output="screenshot_data", |
| 909 | + ) |
| 910 | + state._generated_items = [computer_output_item] |
| 911 | + |
| 912 | + # Serialize and deserialize the state |
| 913 | + json_data = state.to_json() |
| 914 | + |
| 915 | + # Check what was serialized - the bug converts computer_call_output to function_call_result |
| 916 | + generated_items_json = json_data.get("generatedItems", []) |
| 917 | + assert len(generated_items_json) == 1, "Computer output item should be serialized" |
| 918 | + raw_item_json = generated_items_json[0].get("rawItem", {}) |
| 919 | + serialized_type = raw_item_json.get("type") |
| 920 | + |
| 921 | + # The bug: _convert_output_item_to_protocol converts all tool outputs to function_call_result |
| 922 | + # This test will FAIL when the bug exists (type will be function_call_result) |
| 923 | + # and PASS when fixed (type will be computer_call_output) |
| 924 | + assert serialized_type == "computer_call_output", ( |
| 925 | + f"Expected computer_call_output in serialized JSON, but got {serialized_type}. " |
| 926 | + f"The bug in _convert_output_item_to_protocol converts all tool outputs to " |
| 927 | + f"function_call_result during serialization, causing them to be incorrectly " |
| 928 | + f"deserialized as FunctionCallOutput instead of ComputerCallOutput." |
| 929 | + ) |
| 930 | + |
| 931 | + deserialized_state = await RunStateClass.from_json(agent, json_data) |
| 932 | + |
| 933 | + # Verify that the computer output type is preserved after deserialization |
| 934 | + # When the bug exists, the item may be skipped due to validation errors |
| 935 | + # When fixed, it should deserialize correctly |
| 936 | + assert len(deserialized_state._generated_items) == 1, ( |
| 937 | + "Computer output item should be deserialized. When the bug exists, it may be skipped " |
| 938 | + "due to validation errors when trying to deserialize as FunctionCallOutput instead " |
| 939 | + "of ComputerCallOutput." |
| 940 | + ) |
| 941 | + deserialized_item = deserialized_state._generated_items[0] |
| 942 | + assert isinstance(deserialized_item, ToolCallOutputItem) |
| 943 | + |
| 944 | + # The raw_item should still be a ComputerCallOutput, not FunctionCallOutput |
| 945 | + raw_item = deserialized_item.raw_item |
| 946 | + if isinstance(raw_item, dict): |
| 947 | + output_type = raw_item.get("type") |
| 948 | + assert output_type == "computer_call_output", ( |
| 949 | + f"Expected computer_call_output, but got {output_type}. " |
| 950 | + f"The bug converts all tool outputs to function_call_result during serialization, " |
| 951 | + f"causing them to be incorrectly deserialized as FunctionCallOutput." |
| 952 | + ) |
| 953 | + else: |
| 954 | + # If it's a Pydantic model, check the type attribute |
| 955 | + assert hasattr(raw_item, "type") |
| 956 | + assert raw_item.type == "computer_call_output", ( |
| 957 | + f"Expected computer_call_output, but got {raw_item.type}. " |
| 958 | + f"The bug converts all tool outputs to function_call_result during serialization, " |
| 959 | + f"causing them to be incorrectly deserialized as FunctionCallOutput." |
| 960 | + ) |
| 961 | + |
| 962 | + # Test with local_shell_call_output as well |
| 963 | + # Note: The TypedDict definition requires "id" but runtime uses "call_id" |
| 964 | + # We use cast to match the actual runtime structure |
| 965 | + shell_output = cast( |
| 966 | + LocalShellCallOutput, |
| 967 | + { |
| 968 | + "type": "local_shell_call_output", |
| 969 | + "id": "shell_1", |
| 970 | + "call_id": "call_shell_1", |
| 971 | + "output": "command output", |
| 972 | + }, |
| 973 | + ) |
| 974 | + shell_output_item = ToolCallOutputItem( |
| 975 | + agent=agent, |
| 976 | + raw_item=shell_output, |
| 977 | + output="command output", |
| 978 | + ) |
| 979 | + state._generated_items = [shell_output_item] |
| 980 | + |
| 981 | + # Serialize and deserialize again |
| 982 | + json_data = state.to_json() |
| 983 | + |
| 984 | + # Check what was serialized - the bug converts local_shell_call_output to function_call_result |
| 985 | + generated_items_json = json_data.get("generatedItems", []) |
| 986 | + assert len(generated_items_json) == 1, "Shell output item should be serialized" |
| 987 | + raw_item_json = generated_items_json[0].get("rawItem", {}) |
| 988 | + serialized_type = raw_item_json.get("type") |
| 989 | + |
| 990 | + # The bug: _convert_output_item_to_protocol converts all tool outputs to function_call_result |
| 991 | + # This test will FAIL when the bug exists (type will be function_call_result) |
| 992 | + # and PASS when fixed (type will be local_shell_call_output) |
| 993 | + assert serialized_type == "local_shell_call_output", ( |
| 994 | + f"Expected local_shell_call_output in serialized JSON, but got {serialized_type}. " |
| 995 | + f"The bug in _convert_output_item_to_protocol converts all tool outputs to " |
| 996 | + f"function_call_result during serialization, causing them to be incorrectly " |
| 997 | + f"deserialized as FunctionCallOutput instead of LocalShellCallOutput." |
| 998 | + ) |
| 999 | + |
| 1000 | + deserialized_state = await RunStateClass.from_json(agent, json_data) |
| 1001 | + |
| 1002 | + # Verify that the shell output type is preserved after deserialization |
| 1003 | + # When the bug exists, the item may be skipped due to validation errors |
| 1004 | + # When fixed, it should deserialize correctly |
| 1005 | + assert len(deserialized_state._generated_items) == 1, ( |
| 1006 | + "Shell output item should be deserialized. When the bug exists, it may be skipped " |
| 1007 | + "due to validation errors when trying to deserialize as FunctionCallOutput instead " |
| 1008 | + "of LocalShellCallOutput." |
| 1009 | + ) |
| 1010 | + deserialized_item = deserialized_state._generated_items[0] |
| 1011 | + assert isinstance(deserialized_item, ToolCallOutputItem) |
| 1012 | + |
| 1013 | + raw_item = deserialized_item.raw_item |
| 1014 | + if isinstance(raw_item, dict): |
| 1015 | + output_type = raw_item.get("type") |
| 1016 | + assert output_type == "local_shell_call_output", ( |
| 1017 | + f"Expected local_shell_call_output, but got {output_type}. " |
| 1018 | + f"The bug converts all tool outputs to function_call_result during serialization, " |
| 1019 | + f"causing them to be incorrectly deserialized as FunctionCallOutput." |
| 1020 | + ) |
| 1021 | + else: |
| 1022 | + assert hasattr(raw_item, "type") |
| 1023 | + assert raw_item.type == "local_shell_call_output", ( |
| 1024 | + f"Expected local_shell_call_output, but got {raw_item.type}. " |
| 1025 | + f"The bug converts all tool outputs to function_call_result during serialization, " |
| 1026 | + f"causing them to be incorrectly deserialized as FunctionCallOutput." |
| 1027 | + ) |
0 commit comments