Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ def _prepare_tools_for_anthropic(self, options: Mapping[str, Any]) -> dict[str,
tool_mode = validate_tool_mode(options.get("tool_choice"))
if tool_mode is None:
return result or None
if "allowed_tools" in tool_mode:
logger.warning("allowed_tools is not supported by Anthropic; the setting will be ignored")
allow_multiple = options.get("allow_multiple_tool_calls")
match tool_mode.get("mode"):
case "auto":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def _prepare_options(

tool_config = self._prepare_tools(options.get("tools"))
if tool_mode := validate_tool_mode(options.get("tool_choice")):
if "allowed_tools" in tool_mode:
logger.warning("allowed_tools is not supported by Bedrock; the setting will be ignored")
match tool_mode.get("mode"):
case "none":
# Bedrock doesn't support toolChoice "none".
Expand Down
15 changes: 14 additions & 1 deletion python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3246,10 +3246,12 @@ class ToolMode(TypedDict, total=False):
Fields:
mode: One of "auto", "required", or "none".
required_function_name: Optional function name when `mode == "required"`.
allowed_tools: Optional list of tool names when `mode` is `"auto"` or `"required"`.
Comment thread
giles17 marked this conversation as resolved.
"""

mode: Literal["auto", "required", "none"]
required_function_name: str
allowed_tools: list[str]


# region TypedDict-based Chat Options
Expand Down Expand Up @@ -3482,7 +3484,7 @@ def validate_tool_mode(

Returns:
A ToolMode dict (contains keys: "mode", and optionally
"required_function_name"), or ``None`` when not provided.
"required_function_name" or "allowed_tools"), or ``None`` when not provided.

Raises:
ContentError: If the tool_choice string is invalid.
Expand All @@ -3499,6 +3501,17 @@ def validate_tool_mode(
raise ContentError(f"Invalid tool choice: {tool_choice['mode']}")
if tool_choice["mode"] != "required" and "required_function_name" in tool_choice:
raise ContentError("tool_choice with mode other than 'required' cannot have 'required_function_name'")
if tool_choice["mode"] not in ("auto", "required") and "allowed_tools" in tool_choice:
raise ContentError("tool_choice 'allowed_tools' is only valid when mode is 'auto' or 'required'")
if "allowed_tools" in tool_choice:
allowed_tools = tool_choice["allowed_tools"]
if isinstance(allowed_tools, str) or not isinstance(allowed_tools, Sequence):
raise ContentError("tool_choice 'allowed_tools' must be a non-string sequence of strings")
if not all(isinstance(tool_name, str) for tool_name in allowed_tools):
raise ContentError("tool_choice 'allowed_tools' must contain only strings")
normalized_tool_choice = dict(tool_choice)
normalized_tool_choice["allowed_tools"] = list(allowed_tools)
return cast(ToolMode, normalized_tool_choice)
return tool_choice


Expand Down
43 changes: 43 additions & 0 deletions python/packages/core/tests/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,16 +1087,20 @@ def test_chat_tool_mode():
required_any: ToolMode = {"mode": "required"}
required_mode: ToolMode = {"mode": "required", "required_function_name": "example_function"}
none_mode: ToolMode = {"mode": "none"}
allowed_mode: ToolMode = {"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]}

# Check the type and content
assert auto_mode["mode"] == "auto"
assert "required_function_name" not in auto_mode
assert "allowed_tools" not in auto_mode
assert required_any["mode"] == "required"
assert "required_function_name" not in required_any
assert required_mode["mode"] == "required"
assert required_mode["required_function_name"] == "example_function"
assert none_mode["mode"] == "none"
assert "required_function_name" not in none_mode
assert allowed_mode["mode"] == "auto"
assert allowed_mode["allowed_tools"] == ["get_weather", "search_docs"]

# equality of dicts
assert {"mode": "required", "required_function_name": "example_function"} == {
Expand Down Expand Up @@ -1154,6 +1158,45 @@ def test_chat_options_tool_choice_validation():
with raises(ContentError):
validate_tool_mode({"mode": "auto", "required_function_name": "should_not_be_here"})

# Valid allowed_tools
assert validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather"]}) == {
"mode": "auto",
"allowed_tools": ["get_weather"],
}
assert validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]}) == {
"mode": "auto",
"allowed_tools": ["get_weather", "search_docs"],
}

# allowed_tools valid with required mode
assert validate_tool_mode({"mode": "required", "allowed_tools": ["get_weather"]}) == {
"mode": "required",
"allowed_tools": ["get_weather"],
}

# allowed_tools invalid with none mode
with raises(ContentError):
validate_tool_mode({"mode": "none", "allowed_tools": ["get_weather"]})

# allowed_tools must be a non-string sequence of strings
with raises(ContentError):
validate_tool_mode({"mode": "auto", "allowed_tools": "get_weather"})
with raises(ContentError):
validate_tool_mode({"mode": "auto", "allowed_tools": 123})
with raises(ContentError):
validate_tool_mode({"mode": "auto", "allowed_tools": ["get_weather", 123]})

# Empty list is valid (caller explicitly allows no tools)
assert validate_tool_mode({"mode": "auto", "allowed_tools": []}) == {
"mode": "auto",
"allowed_tools": [],
}

# Tuple is normalized to list
result = validate_tool_mode({"mode": "auto", "allowed_tools": ("get_weather",)})
assert result is not None
assert result["allowed_tools"] == ["get_weather"]


def test_chat_options_merge(tool_tool, ai_tool) -> None:
"""Test merge_chat_options utility function."""
Expand Down
15 changes: 12 additions & 3 deletions python/packages/gemini/agent_framework_gemini/_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,19 +823,28 @@ def _prepare_tool_config(self, tool_choice: Any) -> types.ToolConfig | None:

match tool_mode.get("mode"):
case "auto":
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.AUTO, None
if "allowed_tools" in tool_mode:
function_calling_mode = types.FunctionCallingConfigMode.VALIDATED
allowed_names = list(tool_mode["allowed_tools"])
else:
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.AUTO, None
case "none":
function_calling_mode, allowed_names = types.FunctionCallingConfigMode.NONE, None
case "required":
function_calling_mode = types.FunctionCallingConfigMode.ANY
name = tool_mode.get("required_function_name")
allowed_names = [name] if name else None
if name:
allowed_names = [name]
elif "allowed_tools" in tool_mode:
allowed_names = list(tool_mode["allowed_tools"])
else:
allowed_names = None
case unknown_mode:
logger.warning("Unsupported tool_choice mode for Gemini: %s", unknown_mode)
return None

function_calling_kwargs: dict[str, Any] = {"mode": function_calling_mode}
if allowed_names:
if allowed_names is not None:
function_calling_kwargs["allowed_function_names"] = allowed_names

return types.ToolConfig(function_calling_config=types.FunctionCallingConfig(**function_calling_kwargs))
Expand Down
80 changes: 80 additions & 0 deletions python/packages/gemini/tests/test_gemini_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,86 @@ async def test_unknown_tool_choice_mode_is_ignored() -> None:
assert not hasattr(config, "tool_config") or config.tool_config is None


async def test_tool_choice_auto_with_allowed_tools_uses_VALIDATED() -> None:
"""Maps auto + allowed_tools to FunctionCallingConfigMode.VALIDATED with allowed_function_names."""
tool = _make_dummy_tool()
client, mock = _make_gemini_client()
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))

await client.get_response(
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
options={
"tools": [tool],
"tool_choice": {"mode": "auto", "allowed_tools": ["dummy", "other"]},
},
)

config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
function_calling_config = config.tool_config.function_calling_config
assert function_calling_config.mode == "VALIDATED"
assert function_calling_config.allowed_function_names == ["dummy", "other"]


async def test_tool_choice_auto_with_empty_allowed_tools_uses_VALIDATED() -> None:
"""Maps auto + empty allowed_tools to VALIDATED with empty allowed_function_names."""
tool = _make_dummy_tool()
client, mock = _make_gemini_client()
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))

await client.get_response(
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
options={
"tools": [tool],
"tool_choice": {"mode": "auto", "allowed_tools": []},
},
)

config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
function_calling_config = config.tool_config.function_calling_config
assert function_calling_config.mode == "VALIDATED"
assert function_calling_config.allowed_function_names == []


async def test_tool_choice_required_with_allowed_tools_uses_ANY() -> None:
"""Maps required + allowed_tools to ANY with allowed_function_names."""
tool = _make_dummy_tool()
client, mock = _make_gemini_client()
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))

await client.get_response(
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
options={
"tools": [tool],
"tool_choice": {"mode": "required", "allowed_tools": ["dummy"]},
},
)

config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
function_calling_config = config.tool_config.function_calling_config
assert function_calling_config.mode == "ANY"
assert function_calling_config.allowed_function_names == ["dummy"]


async def test_tool_choice_required_function_name_takes_precedence_over_allowed_tools() -> None:
"""When both required_function_name and allowed_tools are present, required_function_name wins."""
tool = _make_dummy_tool()
client, mock = _make_gemini_client()
mock.aio.models.generate_content = AsyncMock(return_value=_make_response([_make_part(text="Hi")]))

await client.get_response(
messages=[Message(role="user", contents=[Content.from_text("Hi")])],
options={
"tools": [tool],
"tool_choice": {"mode": "required", "required_function_name": "dummy", "allowed_tools": ["other"]},
},
)

config: types.GenerateContentConfig = mock.aio.models.generate_content.call_args.kwargs["config"]
function_calling_config = config.tool_config.function_calling_config
assert function_calling_config.mode == "ANY"
assert function_calling_config.allowed_function_names == ["dummy"]


# built-in tool factories


Expand Down
6 changes: 6 additions & 0 deletions python/packages/openai/agent_framework_openai/_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,12 @@ async def _prepare_options(
"type": "function",
"name": func_name,
}
elif mode == "auto" and (allowed := tool_mode.get("allowed_tools")) is not None:
run_options["tool_choice"] = {
"type": "allowed_tools",
"mode": "auto",
"tools": [{"type": "function", "name": name} for name in allowed],
}
else:
run_options["tool_choice"] = mode
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,12 @@ def _prepare_options(self, messages: Sequence[Message], options: Mapping[str, An
"type": "function",
"function": {"name": func_name},
}
elif mode in ("auto", "required") and tool_mode.get("allowed_tools") is not None:
logger.warning(
"allowed_tools is not supported by the Chat Completions API; "
"the setting will be ignored. Use OpenAIChatClient (Responses API) instead."
)
run_options["tool_choice"] = mode
else:
run_options["tool_choice"] = mode

Expand Down
90 changes: 90 additions & 0 deletions python/packages/openai/tests/openai/test_openai_chat_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4259,6 +4259,12 @@ async def get_api_key() -> str:
True,
id="tool_choice_required",
),
param(
"tool_choice",
{"mode": "auto", "allowed_tools": ["get_weather"]},
True,
id="tool_choice_allowed_tools",
),
param("response_format", OutputStruct, True, id="response_format_pydantic"),
param(
"response_format",
Expand Down Expand Up @@ -4813,6 +4819,90 @@ async def test_prepare_options_excludes_continuation_token() -> None:
assert run_options["background"] is True


async def test_prepare_options_allowed_tools() -> None:
"""Test that _prepare_options converts allowed_tools to OpenAI API format."""
client = OpenAIChatClient(model="test-model", api_key="test-key")

@tool
def get_weather(city: str) -> str:
"""Get the weather for a city."""
return f"Sunny in {city}"

@tool
def search_docs(query: str) -> str:
"""Search documentation."""
return f"Results for {query}"

messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
options: dict[str, Any] = {
"model": "test-model",
"tools": [get_weather, search_docs],
"tool_choice": {"mode": "auto", "allowed_tools": ["get_weather"]},
}

run_options = await client._prepare_options(messages, options)

assert run_options["tool_choice"] == {
"type": "allowed_tools",
"mode": "auto",
"tools": [{"type": "function", "name": "get_weather"}],
}


async def test_prepare_options_allowed_tools_multiple() -> None:
"""Test that _prepare_options converts multiple allowed_tools correctly."""
client = OpenAIChatClient(model="test-model", api_key="test-key")

@tool
def get_weather(city: str) -> str:
"""Get the weather for a city."""
return f"Sunny in {city}"

@tool
def search_docs(query: str) -> str:
"""Search documentation."""
return f"Results for {query}"

messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
options: dict[str, Any] = {
"model": "test-model",
"tools": [get_weather, search_docs],
"tool_choice": {"mode": "auto", "allowed_tools": ["get_weather", "search_docs"]},
}

run_options = await client._prepare_options(messages, options)

assert run_options["tool_choice"] == {
"type": "allowed_tools",
"mode": "auto",
"tools": [
{"type": "function", "name": "get_weather"},
{"type": "function", "name": "search_docs"},
],
}


async def test_prepare_options_auto_without_allowed_tools() -> None:
"""Test that auto mode without allowed_tools still returns plain 'auto' string."""
client = OpenAIChatClient(model="test-model", api_key="test-key")

@tool
def get_weather(city: str) -> str:
"""Get the weather for a city."""
return f"Sunny in {city}"

messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
options: dict[str, Any] = {
"model": "test-model",
"tools": [get_weather],
"tool_choice": {"mode": "auto"},
}

run_options = await client._prepare_options(messages, options)

assert run_options["tool_choice"] == "auto"


# endregion


Expand Down
Loading
Loading