Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None:
return None


def _extract_field_info_from_metadata(metadata: tuple[Any, ...]) -> FieldInfo | None:
"""Returns the first FieldInfo in Annotated metadata, or None."""

for item in metadata:
if isinstance(item, FieldInfo):
return item
return None


def function_schema(
func: Callable[..., Any],
docstring_style: DocstringStyle | None = None,
Expand Down Expand Up @@ -252,13 +261,15 @@ def function_schema(
type_hints_with_extras = get_type_hints(func, include_extras=True)
type_hints: dict[str, Any] = {}
annotated_param_descs: dict[str, str] = {}
param_metadata: dict[str, tuple[Any, ...]] = {}

for name, annotation in type_hints_with_extras.items():
if name == "return":
continue

stripped_ann, metadata = _strip_annotated(annotation)
type_hints[name] = stripped_ann
param_metadata[name] = metadata

description = _extract_description_from_metadata(metadata)
if description is not None:
Expand Down Expand Up @@ -356,7 +367,20 @@ def function_schema(

else:
# Normal parameter
if default == inspect._empty:
metadata = param_metadata.get(name, ())
field_info_from_annotated = _extract_field_info_from_metadata(metadata)

if field_info_from_annotated is not None:
merged = FieldInfo.merge_field_infos(
field_info_from_annotated,
description=field_description or field_info_from_annotated.description,
)
if default != inspect._empty and not isinstance(default, FieldInfo):
merged = FieldInfo.merge_field_infos(merged, default=default)
elif isinstance(default, FieldInfo):
merged = FieldInfo.merge_field_infos(merged, default)
fields[name] = (ann, merged)
elif default == inspect._empty:
# Required field
fields[name] = (
ann,
Expand Down
179 changes: 179 additions & 0 deletions tests/test_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,3 +706,182 @@ def func_with_multiple_field_constraints(

with pytest.raises(ValidationError): # zero factor
fs.params_pydantic_model(**{"score": 50, "factor": 0.0})


# --- Annotated + Field: same behavior as Field as default ---


def test_function_with_annotated_field_required_constraints():
"""Test function with required Annotated[int, Field(...)] parameter that has constraints."""

def func_with_annotated_field_constraints(
my_number: Annotated[int, Field(..., gt=10, le=100)],
) -> int:
return my_number * 2

fs = function_schema(func_with_annotated_field_constraints, use_docstring_info=False)

# Check that the schema includes the constraints
properties = fs.params_json_schema.get("properties", {})
my_number_schema = properties.get("my_number", {})
assert my_number_schema.get("type") == "integer"
assert my_number_schema.get("exclusiveMinimum") == 10 # gt=10
assert my_number_schema.get("maximum") == 100 # le=100

# Valid input should work
valid_input = {"my_number": 50}
parsed = fs.params_pydantic_model(**valid_input)
args, kwargs_dict = fs.to_call_args(parsed)
result = func_with_annotated_field_constraints(*args, **kwargs_dict)
assert result == 100

# Invalid input: too small (should violate gt=10)
with pytest.raises(ValidationError):
fs.params_pydantic_model(**{"my_number": 5})

# Invalid input: too large (should violate le=100)
with pytest.raises(ValidationError):
fs.params_pydantic_model(**{"my_number": 150})


def test_function_with_annotated_field_optional_with_default():
"""Optional Annotated[float, Field(...)] param with default and constraints."""

def func_with_annotated_optional_field(
required_param: str,
optional_param: Annotated[float, Field(default=5.0, ge=0.0)],
) -> str:
return f"{required_param}: {optional_param}"

fs = function_schema(func_with_annotated_optional_field, use_docstring_info=False)

# Check that the schema includes the constraints and description
properties = fs.params_json_schema.get("properties", {})
optional_schema = properties.get("optional_param", {})
assert optional_schema.get("type") == "number"
assert optional_schema.get("minimum") == 0.0 # ge=0.0
assert optional_schema.get("default") == 5.0

# Valid input with default
valid_input = {"required_param": "test"}
parsed = fs.params_pydantic_model(**valid_input)
args, kwargs_dict = fs.to_call_args(parsed)
result = func_with_annotated_optional_field(*args, **kwargs_dict)
assert result == "test: 5.0"

# Valid input with explicit value
valid_input2 = {"required_param": "test", "optional_param": 10.5}
parsed2 = fs.params_pydantic_model(**valid_input2)
args2, kwargs_dict2 = fs.to_call_args(parsed2)
result2 = func_with_annotated_optional_field(*args2, **kwargs_dict2)
assert result2 == "test: 10.5"

# Invalid input: negative value (should violate ge=0.0)
with pytest.raises(ValidationError):
fs.params_pydantic_model(**{"required_param": "test", "optional_param": -1.0})


def test_function_with_annotated_field_string_constraints():
"""Annotated[str, Field(...)] parameter with string constraints (min/max length, pattern)."""

def func_with_annotated_string_field(
name: Annotated[
str,
Field(..., min_length=3, max_length=20, pattern=r"^[A-Za-z]+$"),
],
) -> str:
return f"Hello, {name}!"

fs = function_schema(func_with_annotated_string_field, use_docstring_info=False)

# Check that the schema includes string constraints
properties = fs.params_json_schema.get("properties", {})
name_schema = properties.get("name", {})
assert name_schema.get("type") == "string"
assert name_schema.get("minLength") == 3
assert name_schema.get("maxLength") == 20
assert name_schema.get("pattern") == r"^[A-Za-z]+$"

# Valid input
valid_input = {"name": "Alice"}
parsed = fs.params_pydantic_model(**valid_input)
args, kwargs_dict = fs.to_call_args(parsed)
result = func_with_annotated_string_field(*args, **kwargs_dict)
assert result == "Hello, Alice!"

# Invalid input: too short
with pytest.raises(ValidationError):
fs.params_pydantic_model(**{"name": "Al"})

# Invalid input: too long
with pytest.raises(ValidationError):
fs.params_pydantic_model(**{"name": "A" * 25})

# Invalid input: doesn't match pattern (contains numbers)
with pytest.raises(ValidationError):
fs.params_pydantic_model(**{"name": "Alice123"})


def test_function_with_annotated_field_multiple_constraints():
"""Test function with multiple Annotated params with Field having different constraint types."""

def func_with_annotated_multiple_field_constraints(
score: Annotated[
int,
Field(..., ge=0, le=100, description="Score from 0 to 100"),
],
name: Annotated[str, Field(default="Unknown", min_length=1, max_length=50)],
factor: Annotated[float, Field(default=1.0, gt=0.0, description="Positive multiplier")],
) -> str:
final_score = score * factor
return f"{name} scored {final_score}"

fs = function_schema(func_with_annotated_multiple_field_constraints, use_docstring_info=False)

# Check schema structure
properties = fs.params_json_schema.get("properties", {})

# Check score field
score_schema = properties.get("score", {})
assert score_schema.get("type") == "integer"
assert score_schema.get("minimum") == 0
assert score_schema.get("maximum") == 100
assert score_schema.get("description") == "Score from 0 to 100"

# Check name field
name_schema = properties.get("name", {})
assert name_schema.get("type") == "string"
assert name_schema.get("minLength") == 1
assert name_schema.get("maxLength") == 50
assert name_schema.get("default") == "Unknown"

# Check factor field
factor_schema = properties.get("factor", {})
assert factor_schema.get("type") == "number"
assert factor_schema.get("exclusiveMinimum") == 0.0
assert factor_schema.get("default") == 1.0
assert factor_schema.get("description") == "Positive multiplier"

# Valid input with defaults
valid_input = {"score": 85}
parsed = fs.params_pydantic_model(**valid_input)
args, kwargs_dict = fs.to_call_args(parsed)
result = func_with_annotated_multiple_field_constraints(*args, **kwargs_dict)
assert result == "Unknown scored 85.0"

# Valid input with all parameters
valid_input2 = {"score": 90, "name": "Alice", "factor": 1.5}
parsed2 = fs.params_pydantic_model(**valid_input2)
args2, kwargs_dict2 = fs.to_call_args(parsed2)
result2 = func_with_annotated_multiple_field_constraints(*args2, **kwargs_dict2)
assert result2 == "Alice scored 135.0"

# Test various validation errors
with pytest.raises(ValidationError): # score too high
fs.params_pydantic_model(**{"score": 150})

with pytest.raises(ValidationError): # empty name
fs.params_pydantic_model(**{"score": 50, "name": ""})

with pytest.raises(ValidationError): # zero factor
fs.params_pydantic_model(**{"score": 50, "factor": 0.0})