diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 182355b0f8..f2274280b0 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -137,7 +137,12 @@ def validate_json(self, json_str: str) -> Any: """Validate a JSON string against the output type. Returns the validated object, or raises a `ModelBehaviorError` if the JSON is invalid. """ - validated = _json.validate_json(json_str, self._type_adapter, partial=False) + validated = _json.validate_json( + json_str, + self._type_adapter, + partial=False, + strict=True if self._strict_json_schema else None, + ) if self._is_wrapped: if not isinstance(validated, dict): _error_tracing.attach_error_to_current_span( diff --git a/src/agents/handoffs/__init__.py b/src/agents/handoffs/__init__.py index 6d5e06dd93..106f46e6ce 100644 --- a/src/agents/handoffs/__init__.py +++ b/src/agents/handoffs/__init__.py @@ -292,6 +292,7 @@ async def _invoke_handoff( json_str=input_json, type_adapter=type_adapter, partial=False, + strict=True, ) input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff) result = input_func(ctx, validated_input) diff --git a/src/agents/realtime/handoffs.py b/src/agents/realtime/handoffs.py index c1bede6ca5..a26126df97 100644 --- a/src/agents/realtime/handoffs.py +++ b/src/agents/realtime/handoffs.py @@ -163,6 +163,7 @@ async def _invoke_handoff( json_str=input_json, type_adapter=type_adapter, partial=False, + strict=True, ) input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff) if inspect.iscoroutinefunction(input_func): diff --git a/src/agents/util/_json.py b/src/agents/util/_json.py index 1f4345bccf..67186328cd 100644 --- a/src/agents/util/_json.py +++ b/src/agents/util/_json.py @@ -13,12 +13,17 @@ T = TypeVar("T") -def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T: +def validate_json( + json_str: str, type_adapter: TypeAdapter[T], partial: bool, strict: bool | None = None +) -> T: partial_setting: bool | Literal["off", "on", "trailing-strings"] = ( "trailing-strings" if partial else False ) try: - validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting) + kwargs: dict[str, Any] = {"experimental_allow_partial": partial_setting} + if strict is not None: + kwargs["strict"] = strict + validated = type_adapter.validate_json(json_str, **kwargs) return validated except ValidationError as e: attach_error_to_current_span( diff --git a/tests/realtime/test_realtime_handoffs.py b/tests/realtime/test_realtime_handoffs.py index 41cf771631..42e18a77fb 100644 --- a/tests/realtime/test_realtime_handoffs.py +++ b/tests/realtime/test_realtime_handoffs.py @@ -7,6 +7,7 @@ from unittest.mock import Mock import pytest +from pydantic import BaseModel from agents import Agent from agents.exceptions import ModelBehaviorError, UserError @@ -245,3 +246,30 @@ def on_handoff(ctx: RunContextWrapper[Any]) -> None: assert result is rt assert called == [True] + + +class StrictInput(BaseModel): + name: str + age: int + + +@pytest.mark.asyncio +async def test_realtime_handoff_strict_json_rejects_type_coercion(): + """With strict_json_schema=True (always on for realtime handoffs), string input for an + int field must raise ModelBehaviorError instead of being silently coerced.""" + rt = RealtimeAgent(name="strict_test") + + async def _on_handoff(ctx: RunContextWrapper[Any], data: StrictInput) -> None: + pass # pragma: no cover + + handoff_obj = realtime_handoff(rt, on_handoff=_on_handoff, input_type=StrictInput) + + # age is a string "25" — strict mode should reject this + malformed_json = '{"name": "Alice", "age": "25"}' + with pytest.raises(ModelBehaviorError, match="Invalid JSON"): + await handoff_obj.on_invoke_handoff(RunContextWrapper(None), malformed_json) + + # Correctly typed input should still be accepted + valid_json = '{"name": "Alice", "age": 25}' + result = await handoff_obj.on_invoke_handoff(RunContextWrapper(None), valid_json) + assert result is rt diff --git a/tests/test_handoff_tool.py b/tests/test_handoff_tool.py index d0e8e0bb7e..9370ca9c34 100644 --- a/tests/test_handoff_tool.py +++ b/tests/test_handoff_tool.py @@ -468,3 +468,51 @@ async def test_handoff_is_enabled_filtering_integration(): agent_names = {h.agent_name for h in filtered_handoffs} assert agent_names == {"agent_1", "agent_3"} assert "agent_2" not in agent_names + + +class StrictInput(BaseModel): + name: str + age: int + + +@pytest.mark.asyncio +async def test_handoff_strict_json_rejects_type_coercion(): + """With strict_json_schema=True (default), string input for an int field must raise + ModelBehaviorError instead of being silently coerced.""" + + async def _on_handoff(ctx: RunContextWrapper[Any], input: StrictInput): + pass # pragma: no cover + + agent = Agent(name="test") + obj = handoff(agent, input_type=StrictInput, on_handoff=_on_handoff) + + # strict_json_schema defaults to True + assert obj.strict_json_schema is True + + # age is a string "25" — strict mode should reject this + malformed_json = '{"name": "Alice", "age": "25"}' + with pytest.raises(ModelBehaviorError, match="Invalid JSON"): + await obj.on_invoke_handoff(RunContextWrapper(agent), malformed_json) + + # Correctly typed input should still be accepted + valid_json = '{"name": "Alice", "age": 25}' + result = await obj.on_invoke_handoff(RunContextWrapper(agent), valid_json) + assert result == agent + + +@pytest.mark.asyncio +async def test_handoff_lenient_json_allows_type_coercion(): + """Without strict validation, Pydantic's default lenient mode silently coerces + string input for an int field — verifying backward compatibility.""" + from pydantic import TypeAdapter + + from agents.util._json import validate_json + + type_adapter = TypeAdapter(StrictInput) + + # age is a string "25" — lenient mode should coerce it to int 25 + malformed_json = '{"name": "Alice", "age": "25"}' + result = validate_json(malformed_json, type_adapter, partial=False) + assert result.name == "Alice" + assert result.age == 25 + assert isinstance(result.age, int) diff --git a/tests/test_output_tool.py b/tests/test_output_tool.py index 76daa12000..0b8e3aeaf4 100644 --- a/tests/test_output_tool.py +++ b/tests/test_output_tool.py @@ -195,3 +195,40 @@ def test_custom_output_schema(): json_str = json.dumps({"foo": "bar"}) validated = output_schema.validate_json(json_str) assert validated == ["some", "output"] + + +class StrictOutput(BaseModel): + name: str + age: int + + +def test_agent_output_schema_strict_rejects_type_coercion(): + """With strict_json_schema=True (default), string input for an int field must raise + ModelBehaviorError instead of being silently coerced.""" + schema = AgentOutputSchema(output_type=StrictOutput, strict_json_schema=True) + assert schema.is_strict_json_schema() + + # age is a string "25" — strict mode should reject this + malformed_json = '{"name": "Alice", "age": "25"}' + with pytest.raises(ModelBehaviorError, match="Invalid JSON"): + schema.validate_json(malformed_json) + + # Correctly typed input should still be accepted + valid_json = '{"name": "Alice", "age": 25}' + result = schema.validate_json(valid_json) + assert result.name == "Alice" + assert result.age == 25 + + +def test_agent_output_schema_lenient_allows_type_coercion(): + """With strict_json_schema=False, Pydantic's default lenient mode silently coerces + string input for an int field — verifying backward compatibility.""" + schema = AgentOutputSchema(output_type=StrictOutput, strict_json_schema=False) + assert not schema.is_strict_json_schema() + + # age is a string "25" — lenient mode should coerce it to int 25 + coerced_json = '{"name": "Alice", "age": "25"}' + result = schema.validate_json(coerced_json) + assert result.name == "Alice" + assert result.age == 25 + assert isinstance(result.age, int)