Skip to content

Commit c2d2fce

Browse files
committed
test: add tests to preserve tool output types during run state serialization
1 parent d6c2623 commit c2d2fce

File tree

1 file changed

+161
-1
lines changed

1 file changed

+161
-1
lines changed

tests/test_hitl_error_scenarios.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
import pytest
1212
from openai.types.responses import ResponseCustomToolCall, ResponseFunctionToolCall
13+
from openai.types.responses.response_input_param import (
14+
ComputerCallOutput,
15+
LocalShellCallOutput,
16+
)
1317
from openai.types.responses.response_output_item import LocalShellCall
1418
from pydantic_core import ValidationError
1519

@@ -26,7 +30,7 @@
2630
from agents._run_impl import (
2731
NextStepInterruption,
2832
)
29-
from agents.items import MessageOutputItem, ModelResponse
33+
from agents.items import MessageOutputItem, ModelResponse, ToolCallOutputItem
3034
from agents.run_context import RunContextWrapper
3135
from agents.run_state import RunState as RunStateClass
3236
from agents.usage import Usage
@@ -865,3 +869,159 @@ async def test_preserve_persisted_item_counter_when_resuming_streamed_runs():
865869
# Consume events to complete the run
866870
async for _ in result.stream_events():
867871
pass
872+
873+
874+
@pytest.mark.asyncio
875+
async def test_preserve_tool_output_types_during_serialization():
876+
"""Test that tool output types are preserved during run state serialization.
877+
878+
When serializing a run state, `_convert_output_item_to_protocol` unconditionally
879+
overwrites every tool output's `type` with `function_call_result`. On restore,
880+
`_deserialize_items` dispatches on this `type` to choose between
881+
`FunctionCallOutput`, `ComputerCallOutput`, or `LocalShellCallOutput`, so
882+
computer/shell/apply_patch outputs that were originally
883+
`computer_call_output`/`local_shell_call_output` are rehydrated as
884+
`function_call_output` (or fail validation), losing the tool-specific payload
885+
and breaking resumption for those tools.
886+
887+
This test will FAIL when the bug exists (output type will be function_call_result)
888+
and PASS when fixed (output type will be preserved as computer_call_output or
889+
local_shell_call_output).
890+
"""
891+
892+
model = FakeModel()
893+
agent = Agent(name="TestAgent", model=model, tools=[])
894+
895+
# Create a RunState with a computer tool output
896+
context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={})
897+
state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=3)
898+
899+
# Create a computer_call_output item
900+
computer_output: ComputerCallOutput = {
901+
"type": "computer_call_output",
902+
"call_id": "call_computer_1",
903+
"output": {"type": "computer_screenshot", "image_url": "base64_screenshot_data"},
904+
}
905+
computer_output_item = ToolCallOutputItem(
906+
agent=agent,
907+
raw_item=computer_output,
908+
output="screenshot_data",
909+
)
910+
state._generated_items = [computer_output_item]
911+
912+
# Serialize and deserialize the state
913+
json_data = state.to_json()
914+
915+
# Check what was serialized - the bug converts computer_call_output to function_call_result
916+
generated_items_json = json_data.get("generatedItems", [])
917+
assert len(generated_items_json) == 1, "Computer output item should be serialized"
918+
raw_item_json = generated_items_json[0].get("rawItem", {})
919+
serialized_type = raw_item_json.get("type")
920+
921+
# The bug: _convert_output_item_to_protocol converts all tool outputs to function_call_result
922+
# This test will FAIL when the bug exists (type will be function_call_result)
923+
# and PASS when fixed (type will be computer_call_output)
924+
assert serialized_type == "computer_call_output", (
925+
f"Expected computer_call_output in serialized JSON, but got {serialized_type}. "
926+
f"The bug in _convert_output_item_to_protocol converts all tool outputs to "
927+
f"function_call_result during serialization, causing them to be incorrectly "
928+
f"deserialized as FunctionCallOutput instead of ComputerCallOutput."
929+
)
930+
931+
deserialized_state = await RunStateClass.from_json(agent, json_data)
932+
933+
# Verify that the computer output type is preserved after deserialization
934+
# When the bug exists, the item may be skipped due to validation errors
935+
# When fixed, it should deserialize correctly
936+
assert len(deserialized_state._generated_items) == 1, (
937+
"Computer output item should be deserialized. When the bug exists, it may be skipped "
938+
"due to validation errors when trying to deserialize as FunctionCallOutput instead "
939+
"of ComputerCallOutput."
940+
)
941+
deserialized_item = deserialized_state._generated_items[0]
942+
assert isinstance(deserialized_item, ToolCallOutputItem)
943+
944+
# The raw_item should still be a ComputerCallOutput, not FunctionCallOutput
945+
raw_item = deserialized_item.raw_item
946+
if isinstance(raw_item, dict):
947+
output_type = raw_item.get("type")
948+
assert output_type == "computer_call_output", (
949+
f"Expected computer_call_output, but got {output_type}. "
950+
f"The bug converts all tool outputs to function_call_result during serialization, "
951+
f"causing them to be incorrectly deserialized as FunctionCallOutput."
952+
)
953+
else:
954+
# If it's a Pydantic model, check the type attribute
955+
assert hasattr(raw_item, "type")
956+
assert raw_item.type == "computer_call_output", (
957+
f"Expected computer_call_output, but got {raw_item.type}. "
958+
f"The bug converts all tool outputs to function_call_result during serialization, "
959+
f"causing them to be incorrectly deserialized as FunctionCallOutput."
960+
)
961+
962+
# Test with local_shell_call_output as well
963+
# Note: The TypedDict definition requires "id" but runtime uses "call_id"
964+
# We use cast to match the actual runtime structure
965+
shell_output = cast(
966+
LocalShellCallOutput,
967+
{
968+
"type": "local_shell_call_output",
969+
"id": "shell_1",
970+
"call_id": "call_shell_1",
971+
"output": "command output",
972+
},
973+
)
974+
shell_output_item = ToolCallOutputItem(
975+
agent=agent,
976+
raw_item=shell_output,
977+
output="command output",
978+
)
979+
state._generated_items = [shell_output_item]
980+
981+
# Serialize and deserialize again
982+
json_data = state.to_json()
983+
984+
# Check what was serialized - the bug converts local_shell_call_output to function_call_result
985+
generated_items_json = json_data.get("generatedItems", [])
986+
assert len(generated_items_json) == 1, "Shell output item should be serialized"
987+
raw_item_json = generated_items_json[0].get("rawItem", {})
988+
serialized_type = raw_item_json.get("type")
989+
990+
# The bug: _convert_output_item_to_protocol converts all tool outputs to function_call_result
991+
# This test will FAIL when the bug exists (type will be function_call_result)
992+
# and PASS when fixed (type will be local_shell_call_output)
993+
assert serialized_type == "local_shell_call_output", (
994+
f"Expected local_shell_call_output in serialized JSON, but got {serialized_type}. "
995+
f"The bug in _convert_output_item_to_protocol converts all tool outputs to "
996+
f"function_call_result during serialization, causing them to be incorrectly "
997+
f"deserialized as FunctionCallOutput instead of LocalShellCallOutput."
998+
)
999+
1000+
deserialized_state = await RunStateClass.from_json(agent, json_data)
1001+
1002+
# Verify that the shell output type is preserved after deserialization
1003+
# When the bug exists, the item may be skipped due to validation errors
1004+
# When fixed, it should deserialize correctly
1005+
assert len(deserialized_state._generated_items) == 1, (
1006+
"Shell output item should be deserialized. When the bug exists, it may be skipped "
1007+
"due to validation errors when trying to deserialize as FunctionCallOutput instead "
1008+
"of LocalShellCallOutput."
1009+
)
1010+
deserialized_item = deserialized_state._generated_items[0]
1011+
assert isinstance(deserialized_item, ToolCallOutputItem)
1012+
1013+
raw_item = deserialized_item.raw_item
1014+
if isinstance(raw_item, dict):
1015+
output_type = raw_item.get("type")
1016+
assert output_type == "local_shell_call_output", (
1017+
f"Expected local_shell_call_output, but got {output_type}. "
1018+
f"The bug converts all tool outputs to function_call_result during serialization, "
1019+
f"causing them to be incorrectly deserialized as FunctionCallOutput."
1020+
)
1021+
else:
1022+
assert hasattr(raw_item, "type")
1023+
assert raw_item.type == "local_shell_call_output", (
1024+
f"Expected local_shell_call_output, but got {raw_item.type}. "
1025+
f"The bug converts all tool outputs to function_call_result during serialization, "
1026+
f"causing them to be incorrectly deserialized as FunctionCallOutput."
1027+
)

0 commit comments

Comments
 (0)