|
49 | 49 | from ._mcp_manager import MCPSessionManager |
50 | 50 | from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT |
51 | 51 | from ._tokens import compute_cost, get_token_pricing, tokens_log |
52 | | -from ._tools import Tool, ToolRejectError |
| 52 | +from ._tools import Tool, ToolBuiltIn, ToolRejectError |
53 | 53 | from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn |
54 | 54 | from ._typing_extensions import TypedDict, TypeGuard |
55 | 55 | from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async |
@@ -132,7 +132,7 @@ def __init__( |
132 | 132 | self.system_prompt = system_prompt |
133 | 133 | self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {} |
134 | 134 |
|
135 | | - self._tools: dict[str, Tool] = {} |
| 135 | + self._tools: dict[str, Tool | ToolBuiltIn] = {} |
136 | 136 | self._on_tool_request_callbacks = CallbackManager() |
137 | 137 | self._on_tool_result_callbacks = CallbackManager() |
138 | 138 | self._current_display: Optional[MarkdownDisplay] = None |
@@ -1880,7 +1880,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None): |
1880 | 1880 |
|
1881 | 1881 | def register_tool( |
1882 | 1882 | self, |
1883 | | - func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool, |
| 1883 | + func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn", |
1884 | 1884 | *, |
1885 | 1885 | force: bool = False, |
1886 | 1886 | name: Optional[str] = None, |
@@ -1974,31 +1974,39 @@ def add(a: int, b: int) -> int: |
1974 | 1974 | ValueError |
1975 | 1975 | If a tool with the same name already exists and `force` is `False`. |
1976 | 1976 | """ |
1977 | | - if isinstance(func, Tool): |
| 1977 | + if isinstance(func, ToolBuiltIn): |
| 1978 | + # ToolBuiltIn objects are stored directly without conversion |
| 1979 | + tool = func |
| 1980 | + tool_name = tool.name |
| 1981 | + elif isinstance(func, Tool): |
1978 | 1982 | name = name or func.name |
1979 | 1983 | annotations = annotations or func.annotations |
1980 | 1984 | if model is not None: |
1981 | 1985 | func = Tool.from_func( |
1982 | 1986 | func.func, name=name, model=model, annotations=annotations |
1983 | 1987 | ) |
1984 | 1988 | func = func.func |
| 1989 | + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
| 1990 | + tool_name = tool.name |
| 1991 | + else: |
| 1992 | + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
| 1993 | + tool_name = tool.name |
1985 | 1994 |
|
1986 | | - tool = Tool.from_func(func, name=name, model=model, annotations=annotations) |
1987 | | - if tool.name in self._tools and not force: |
| 1995 | + if tool_name in self._tools and not force: |
1988 | 1996 | raise ValueError( |
1989 | | - f"Tool with name '{tool.name}' is already registered. " |
| 1997 | + f"Tool with name '{tool_name}' is already registered. " |
1990 | 1998 | "Set `force=True` to overwrite it." |
1991 | 1999 | ) |
1992 | | - self._tools[tool.name] = tool |
| 2000 | + self._tools[tool_name] = tool |
1993 | 2001 |
|
1994 | | - def get_tools(self) -> list[Tool]: |
| 2002 | + def get_tools(self) -> list[Tool | ToolBuiltIn]: |
1995 | 2003 | """ |
1996 | 2004 | Get the list of registered tools. |
1997 | 2005 |
|
1998 | 2006 | Returns |
1999 | 2007 | ------- |
2000 | | - list[Tool] |
2001 | | - A list of `Tool` instances that are currently registered with the chat. |
| 2008 | + list[Tool | ToolBuiltIn] |
| 2009 | + A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat. |
2002 | 2010 | """ |
2003 | 2011 | return list(self._tools.values()) |
2004 | 2012 |
|
@@ -2522,7 +2530,7 @@ def _submit_turns( |
2522 | 2530 | data_model: type[BaseModel] | None = None, |
2523 | 2531 | kwargs: Optional[SubmitInputArgsT] = None, |
2524 | 2532 | ) -> Generator[str, None, None]: |
2525 | | - if any(x._is_async for x in self._tools.values()): |
| 2533 | + if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()): |
2526 | 2534 | raise ValueError("Cannot use async tools in a synchronous chat") |
2527 | 2535 |
|
2528 | 2536 | def emit(text: str | Content): |
@@ -2683,15 +2691,27 @@ def _collect_all_kwargs( |
2683 | 2691 |
|
2684 | 2692 | def _invoke_tool(self, request: ContentToolRequest): |
2685 | 2693 | tool = self._tools.get(request.name) |
2686 | | - func = tool.func if tool is not None else None |
2687 | 2694 |
|
2688 | | - if func is None: |
| 2695 | + if tool is None: |
2689 | 2696 | yield self._handle_tool_error_result( |
2690 | 2697 | request, |
2691 | 2698 | error=RuntimeError("Unknown tool."), |
2692 | 2699 | ) |
2693 | 2700 | return |
2694 | 2701 |
|
| 2702 | + if isinstance(tool, ToolBuiltIn): |
| 2703 | + # Built-in tools are handled by the provider, not invoked directly |
| 2704 | + yield self._handle_tool_error_result( |
| 2705 | + request, |
| 2706 | + error=RuntimeError( |
| 2707 | + f"Built-in tool '{request.name}' cannot be invoked directly. " |
| 2708 | + "It should be handled by the provider." |
| 2709 | + ), |
| 2710 | + ) |
| 2711 | + return |
| 2712 | + |
| 2713 | + func = tool.func |
| 2714 | + |
2695 | 2715 | # First, invoke the request callbacks. If a ToolRejectError is raised, |
2696 | 2716 | # treat it like a tool failure (i.e., gracefully handle it). |
2697 | 2717 | result: ContentToolResult | None = None |
@@ -2739,6 +2759,17 @@ async def _invoke_tool_async(self, request: ContentToolRequest): |
2739 | 2759 | ) |
2740 | 2760 | return |
2741 | 2761 |
|
| 2762 | + if isinstance(tool, ToolBuiltIn): |
| 2763 | + # Built-in tools are handled by the provider, not invoked directly |
| 2764 | + yield self._handle_tool_error_result( |
| 2765 | + request, |
| 2766 | + error=RuntimeError( |
| 2767 | + f"Built-in tool '{request.name}' cannot be invoked directly. " |
| 2768 | + "It should be handled by the provider." |
| 2769 | + ), |
| 2770 | + ) |
| 2771 | + return |
| 2772 | + |
2742 | 2773 | if tool._is_async: |
2743 | 2774 | func = tool.func |
2744 | 2775 | else: |
|
0 commit comments