Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/agents/agent_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
validated = _json.validate_json(
json_str,
self._type_adapter,
partial=False,
strict=True if self._strict_json_schema else None,
)
if self._is_wrapped:
if not isinstance(validated, dict):
_error_tracing.attach_error_to_current_span(
Expand Down
1 change: 1 addition & 0 deletions src/agents/handoffs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ async def _invoke_handoff(
json_str=input_json,
type_adapter=type_adapter,
partial=False,
strict=True,
Comment thread
Om-Borse26 marked this conversation as resolved.
)
input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff)
result = input_func(ctx, validated_input)
Expand Down
1 change: 1 addition & 0 deletions src/agents/realtime/handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ async def _invoke_handoff(
json_str=input_json,
type_adapter=type_adapter,
partial=False,
strict=True,
)
input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff)
if inspect.iscoroutinefunction(input_func):
Expand Down
9 changes: 7 additions & 2 deletions src/agents/util/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@
T = TypeVar("T")


def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T:
def validate_json(
json_str: str, type_adapter: TypeAdapter[T], partial: bool, strict: bool | None = None
) -> T:
partial_setting: bool | Literal["off", "on", "trailing-strings"] = (
"trailing-strings" if partial else False
)
try:
validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting)
kwargs: dict[str, Any] = {"experimental_allow_partial": partial_setting}
if strict is not None:
kwargs["strict"] = strict
validated = type_adapter.validate_json(json_str, **kwargs)
return validated
except ValidationError as e:
attach_error_to_current_span(
Expand Down
28 changes: 28 additions & 0 deletions tests/realtime/test_realtime_handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest.mock import Mock

import pytest
from pydantic import BaseModel

from agents import Agent
from agents.exceptions import ModelBehaviorError, UserError
Expand Down Expand Up @@ -245,3 +246,30 @@ def on_handoff(ctx: RunContextWrapper[Any]) -> None:

assert result is rt
assert called == [True]


class StrictInput(BaseModel):
name: str
age: int


@pytest.mark.asyncio
async def test_realtime_handoff_strict_json_rejects_type_coercion():
"""With strict_json_schema=True (always on for realtime handoffs), string input for an
int field must raise ModelBehaviorError instead of being silently coerced."""
rt = RealtimeAgent(name="strict_test")

async def _on_handoff(ctx: RunContextWrapper[Any], data: StrictInput) -> None:
pass # pragma: no cover

handoff_obj = realtime_handoff(rt, on_handoff=_on_handoff, input_type=StrictInput)

# age is a string "25" — strict mode should reject this
malformed_json = '{"name": "Alice", "age": "25"}'
with pytest.raises(ModelBehaviorError, match="Invalid JSON"):
await handoff_obj.on_invoke_handoff(RunContextWrapper(None), malformed_json)

# Correctly typed input should still be accepted
valid_json = '{"name": "Alice", "age": 25}'
result = await handoff_obj.on_invoke_handoff(RunContextWrapper(None), valid_json)
assert result is rt
48 changes: 48 additions & 0 deletions tests/test_handoff_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,51 @@ async def test_handoff_is_enabled_filtering_integration():
agent_names = {h.agent_name for h in filtered_handoffs}
assert agent_names == {"agent_1", "agent_3"}
assert "agent_2" not in agent_names


class StrictInput(BaseModel):
name: str
age: int


@pytest.mark.asyncio
async def test_handoff_strict_json_rejects_type_coercion():
"""With strict_json_schema=True (default), string input for an int field must raise
ModelBehaviorError instead of being silently coerced."""

async def _on_handoff(ctx: RunContextWrapper[Any], input: StrictInput):
pass # pragma: no cover

agent = Agent(name="test")
obj = handoff(agent, input_type=StrictInput, on_handoff=_on_handoff)

# strict_json_schema defaults to True
assert obj.strict_json_schema is True

# age is a string "25" — strict mode should reject this
malformed_json = '{"name": "Alice", "age": "25"}'
with pytest.raises(ModelBehaviorError, match="Invalid JSON"):
await obj.on_invoke_handoff(RunContextWrapper(agent), malformed_json)

# Correctly typed input should still be accepted
valid_json = '{"name": "Alice", "age": 25}'
result = await obj.on_invoke_handoff(RunContextWrapper(agent), valid_json)
assert result == agent


@pytest.mark.asyncio
async def test_handoff_lenient_json_allows_type_coercion():
"""Without strict validation, Pydantic's default lenient mode silently coerces
string input for an int field — verifying backward compatibility."""
from pydantic import TypeAdapter

from agents.util._json import validate_json

type_adapter = TypeAdapter(StrictInput)

# age is a string "25" — lenient mode should coerce it to int 25
malformed_json = '{"name": "Alice", "age": "25"}'
result = validate_json(malformed_json, type_adapter, partial=False)
assert result.name == "Alice"
assert result.age == 25
assert isinstance(result.age, int)
37 changes: 37 additions & 0 deletions tests/test_output_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,40 @@ def test_custom_output_schema():
json_str = json.dumps({"foo": "bar"})
validated = output_schema.validate_json(json_str)
assert validated == ["some", "output"]


class StrictOutput(BaseModel):
name: str
age: int


def test_agent_output_schema_strict_rejects_type_coercion():
"""With strict_json_schema=True (default), string input for an int field must raise
ModelBehaviorError instead of being silently coerced."""
schema = AgentOutputSchema(output_type=StrictOutput, strict_json_schema=True)
assert schema.is_strict_json_schema()

# age is a string "25" — strict mode should reject this
malformed_json = '{"name": "Alice", "age": "25"}'
with pytest.raises(ModelBehaviorError, match="Invalid JSON"):
schema.validate_json(malformed_json)

# Correctly typed input should still be accepted
valid_json = '{"name": "Alice", "age": 25}'
result = schema.validate_json(valid_json)
assert result.name == "Alice"
assert result.age == 25


def test_agent_output_schema_lenient_allows_type_coercion():
"""With strict_json_schema=False, Pydantic's default lenient mode silently coerces
string input for an int field — verifying backward compatibility."""
schema = AgentOutputSchema(output_type=StrictOutput, strict_json_schema=False)
assert not schema.is_strict_json_schema()

# age is a string "25" — lenient mode should coerce it to int 25
coerced_json = '{"name": "Alice", "age": "25"}'
result = schema.validate_json(coerced_json)
assert result.name == "Alice"
assert result.age == 25
assert isinstance(result.age, int)