diff --git a/chatlas/__init__.py b/chatlas/__init__.py index db3229a4..76cdea52 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -34,7 +34,7 @@ from ._provider_portkey import ChatPortkey from ._provider_snowflake import ChatSnowflake from ._tokens import token_usage -from ._tools import Tool, ToolRejectError +from ._tools import Tool, ToolBuiltIn, ToolRejectError from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn try: @@ -84,6 +84,7 @@ "Provider", "token_usage", "Tool", + "ToolBuiltIn", "ToolRejectError", "Turn", "UserTurn", diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 6886ce8f..97043a38 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -49,7 +49,7 @@ from ._mcp_manager import MCPSessionManager from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT from ._tokens import compute_cost, get_token_pricing, tokens_log -from ._tools import Tool, ToolRejectError +from ._tools import Tool, ToolBuiltIn, ToolRejectError from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn from ._typing_extensions import TypedDict, TypeGuard from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async @@ -132,7 +132,7 @@ def __init__( self.system_prompt = system_prompt self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {} - self._tools: dict[str, Tool] = {} + self._tools: dict[str, Tool | ToolBuiltIn] = {} self._on_tool_request_callbacks = CallbackManager() self._on_tool_result_callbacks = CallbackManager() self._current_display: Optional[MarkdownDisplay] = None @@ -1880,7 +1880,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None): def register_tool( self, - func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool, + func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | ToolBuiltIn, *, force: bool = False, name: Optional[str] = None, @@ -1982,8 +1982,15 @@ def add(a: int, b: int) -> int: func.func, name=name, model=model, annotations=annotations ) func = func.func + tool = Tool.from_func(func, name=name, model=model, annotations=annotations) + else: + if isinstance(func, ToolBuiltIn): + tool = func + else: + tool = Tool.from_func( + func, name=name, model=model, annotations=annotations + ) - tool = Tool.from_func(func, name=name, model=model, annotations=annotations) if tool.name in self._tools and not force: raise ValueError( f"Tool with name '{tool.name}' is already registered. " @@ -1991,14 +1998,14 @@ def add(a: int, b: int) -> int: ) self._tools[tool.name] = tool - def get_tools(self) -> list[Tool]: + def get_tools(self) -> list[Tool | ToolBuiltIn]: """ Get the list of registered tools. Returns ------- - list[Tool] - A list of `Tool` instances that are currently registered with the chat. + list[Tool | ToolBuiltIn] + A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat. """ return list(self._tools.values()) @@ -2522,7 +2529,7 @@ def _submit_turns( data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, ) -> Generator[str, None, None]: - if any(x._is_async for x in self._tools.values()): + if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()): raise ValueError("Cannot use async tools in a synchronous chat") def emit(text: str | Content): @@ -2683,15 +2690,24 @@ def _collect_all_kwargs( def _invoke_tool(self, request: ContentToolRequest): tool = self._tools.get(request.name) - func = tool.func if tool is not None else None - if func is None: + if tool is None: yield self._handle_tool_error_result( request, error=RuntimeError("Unknown tool."), ) return + if isinstance(tool, ToolBuiltIn): + yield self._handle_tool_error_result( + request, + error=RuntimeError( + f"Built-in tool '{request.name}' cannot be invoked directly. " + "It should be handled by the provider." + ), + ) + return + # First, invoke the request callbacks. If a ToolRejectError is raised, # treat it like a tool failure (i.e., gracefully handle it). result: ContentToolResult | None = None @@ -2703,9 +2719,9 @@ def _invoke_tool(self, request: ContentToolRequest): try: if isinstance(request.arguments, dict): - res = func(**request.arguments) + res = tool.func(**request.arguments) else: - res = func(request.arguments) + res = tool.func(request.arguments) # Normalize res as a generator of results. if not inspect.isgenerator(res): @@ -2739,10 +2755,15 @@ async def _invoke_tool_async(self, request: ContentToolRequest): ) return - if tool._is_async: - func = tool.func - else: - func = wrap_async(tool.func) + if isinstance(tool, ToolBuiltIn): + yield self._handle_tool_error_result( + request, + error=RuntimeError( + f"Built-in tool '{request.name}' cannot be invoked directly. " + "It should be handled by the provider." + ), + ) + return # First, invoke the request callbacks. If a ToolRejectError is raised, # treat it like a tool failure (i.e., gracefully handle it). @@ -2753,6 +2774,11 @@ async def _invoke_tool_async(self, request: ContentToolRequest): yield self._handle_tool_error_result(request, e) return + if tool._is_async: + func = tool.func + else: + func = wrap_async(tool.func) + # Invoke the tool (if it hasn't been rejected). try: if isinstance(request.arguments, dict): diff --git a/chatlas/_content.py b/chatlas/_content.py index 7cd8de69..70c23e31 100644 --- a/chatlas/_content.py +++ b/chatlas/_content.py @@ -11,7 +11,7 @@ from ._typing_extensions import TypedDict if TYPE_CHECKING: - from ._tools import Tool + from ._tools import Tool, ToolBuiltIn class ToolAnnotations(TypedDict, total=False): @@ -104,15 +104,21 @@ class ToolInfo(BaseModel): annotations: Optional[ToolAnnotations] = None @classmethod - def from_tool(cls, tool: "Tool") -> "ToolInfo": - """Create a ToolInfo from a Tool instance.""" - func_schema = tool.schema["function"] - return cls( - name=tool.name, - description=func_schema.get("description", ""), - parameters=func_schema.get("parameters", {}), - annotations=tool.annotations, - ) + def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo": + """Create a ToolInfo from a Tool or ToolBuiltIn instance.""" + from ._tools import ToolBuiltIn + + if isinstance(tool, ToolBuiltIn): + return cls(name=tool.name, description=tool.name, parameters={}) + else: + # For regular tools, extract from schema + func_schema = tool.schema["function"] + return cls( + name=tool.name, + description=func_schema.get("description", ""), + parameters=func_schema.get("parameters", {}), + annotations=tool.annotations, + ) ContentTypeEnum = Literal[ @@ -247,6 +253,22 @@ def __str__(self): def _repr_markdown_(self): return self.__str__() + def _repr_png_(self): + """Display PNG images directly in Jupyter notebooks.""" + if self.image_content_type == "image/png" and self.data: + import base64 + + return base64.b64decode(self.data) + return None + + def _repr_jpeg_(self): + """Display JPEG images directly in Jupyter notebooks.""" + if self.image_content_type == "image/jpeg" and self.data: + import base64 + + return base64.b64decode(self.data) + return None + def __repr__(self, indent: int = 0): n_bytes = len(self.data) if self.data else 0 return ( diff --git a/chatlas/_mcp_manager.py b/chatlas/_mcp_manager.py index c4e7b2c9..15d34fdc 100644 --- a/chatlas/_mcp_manager.py +++ b/chatlas/_mcp_manager.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional, Sequence -from ._tools import Tool +from ._tools import Tool, ToolBuiltIn if TYPE_CHECKING: from mcp import ClientSession @@ -23,7 +23,7 @@ class SessionInfo(ABC): # Primary derived attributes session: ClientSession | None = None - tools: dict[str, Tool] = field(default_factory=dict) + tools: dict[str, Tool | ToolBuiltIn] = field(default_factory=dict) # Background task management ready_event: asyncio.Event = field(default_factory=asyncio.Event) @@ -74,7 +74,7 @@ async def request_tools(self) -> None: tool_names = tool_names.difference(exclude) # Apply namespace and convert to chatlas.Tool instances - self_tools: dict[str, Tool] = {} + self_tools: dict[str, Tool | ToolBuiltIn] = {} for tool in response.tools: if tool.name not in tool_names: continue diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 56b2a528..3dd5497a 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -16,7 +16,7 @@ from pydantic import BaseModel from ._content import Content -from ._tools import Tool +from ._tools import Tool, ToolBuiltIn from ._turn import AssistantTurn, Turn from ._typing_extensions import NotRequired, TypedDict @@ -162,7 +162,7 @@ def chat_perform( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> ChatCompletionT: ... @@ -174,7 +174,7 @@ def chat_perform( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> Iterable[ChatCompletionChunkT]: ... @@ -185,7 +185,7 @@ def chat_perform( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ... @@ -197,7 +197,7 @@ async def chat_perform_async( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> ChatCompletionT: ... @@ -209,7 +209,7 @@ async def chat_perform_async( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> AsyncIterable[ChatCompletionChunkT]: ... @@ -220,7 +220,7 @@ async def chat_perform_async( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], kwargs: SubmitInputArgsT, ) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ... @@ -259,7 +259,7 @@ def value_tokens( def token_count( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: ... @@ -267,7 +267,7 @@ def token_count( async def token_count_async( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: ... diff --git a/chatlas/_provider_anthropic.py b/chatlas/_provider_anthropic.py index 3796a3fb..2cf1a99e 100644 --- a/chatlas/_provider_anthropic.py +++ b/chatlas/_provider_anthropic.py @@ -15,7 +15,6 @@ ) import orjson -from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel from ._chat import Chat @@ -38,7 +37,7 @@ StandardModelParams, ) from ._tokens import get_token_pricing -from ._tools import Tool, basemodel_to_param_schema +from ._tools import Tool, ToolBuiltIn, basemodel_to_param_schema from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn from ._utils import split_http_client_kwargs @@ -48,7 +47,7 @@ MessageParam, RawMessageStreamEvent, TextBlock, - ToolParam, + ToolUnionParam, ToolUseBlock, ) from anthropic.types.cache_control_ephemeral_param import CacheControlEphemeralParam @@ -304,7 +303,7 @@ def chat_perform( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): ... @@ -315,7 +314,7 @@ def chat_perform( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): ... @@ -325,7 +324,7 @@ def chat_perform( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): @@ -338,7 +337,7 @@ async def chat_perform_async( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): ... @@ -349,7 +348,7 @@ async def chat_perform_async( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): ... @@ -359,7 +358,7 @@ async def chat_perform_async( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): @@ -370,12 +369,12 @@ def _chat_perform_args( self, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": tool_schemas = [ - self._anthropic_tool_schema(tool.schema) for tool in tools.values() + self._anthropic_tool_schema(tool) for tool in tools.values() ] # If data extraction is requested, add a "mock" tool with parameters inferred from the data model @@ -395,7 +394,7 @@ def _structured_tool_call(**kwargs: Any): }, } - tool_schemas.append(self._anthropic_tool_schema(data_model_tool.schema)) + tool_schemas.append(self._anthropic_tool_schema(data_model_tool)) if stream: stream = False @@ -497,7 +496,7 @@ def value_tokens(self, completion): def token_count( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: kwargs = self._token_count_args( @@ -511,7 +510,7 @@ def token_count( async def token_count_async( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: kwargs = self._token_count_args( @@ -525,7 +524,7 @@ async def token_count_async( def _token_count_args( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> dict[str, Any]: turn = user_turn(*args) @@ -655,11 +654,14 @@ def _as_content_block(content: Content) -> "ContentBlockParam": raise ValueError(f"Unknown content type: {type(content)}") @staticmethod - def _anthropic_tool_schema(schema: "ChatCompletionToolParam") -> "ToolParam": - fn = schema["function"] + def _anthropic_tool_schema(tool: "Tool | ToolBuiltIn") -> "ToolUnionParam": + if isinstance(tool, ToolBuiltIn): + return tool.definition # type: ignore + + fn = tool.schema["function"] name = fn["name"] - res: "ToolParam" = { + res: "ToolUnionParam" = { "name": name, "input_schema": { "type": "object", diff --git a/chatlas/_provider_google.py b/chatlas/_provider_google.py index 6c00e952..5a26acbc 100644 --- a/chatlas/_provider_google.py +++ b/chatlas/_provider_google.py @@ -21,7 +21,7 @@ from ._merge import merge_dicts from ._provider import ModelInfo, Provider, StandardModelParamNames, StandardModelParams from ._tokens import get_token_pricing -from ._tools import Tool +from ._tools import Tool, ToolBuiltIn from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn if TYPE_CHECKING: @@ -208,7 +208,7 @@ def chat_perform( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): ... @@ -219,7 +219,7 @@ def chat_perform( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): ... @@ -229,7 +229,7 @@ def chat_perform( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): @@ -245,7 +245,7 @@ async def chat_perform_async( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): ... @@ -256,7 +256,7 @@ async def chat_perform_async( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): ... @@ -266,7 +266,7 @@ async def chat_perform_async( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): @@ -279,7 +279,7 @@ async def chat_perform_async( def _chat_perform_args( self, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": @@ -315,6 +315,9 @@ def _chat_perform_args( callable=tool.func, ) for tool in tools.values() + # TODO: to support built-in tools, we may need a way to make + # tool names (e.g., google_search to google.genai.types.GoogleSearch()) + if isinstance(tool, Tool) ] ) ] @@ -372,7 +375,7 @@ def value_tokens(self, completion): def token_count( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ): kwargs = self._token_count_args( @@ -387,7 +390,7 @@ def token_count( async def token_count_async( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ): kwargs = self._token_count_args( @@ -402,7 +405,7 @@ async def token_count_async( def _token_count_args( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> dict[str, Any]: turn = user_turn(*args) @@ -545,6 +548,17 @@ def _as_turn( ), ) ) + inline_data = part.get("inline_data") + if inline_data: + mime_type = inline_data.get("mime_type") + data = inline_data.get("data") + if mime_type and data: + contents.append( + ContentImageInline( + data=data.decode("utf-8"), + image_content_type=mime_type, # type: ignore + ) + ) if isinstance(finish_reason, FinishReason): finish_reason = finish_reason.name diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index a7a58c55..a9046a76 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -24,7 +24,7 @@ from ._provider import StandardModelParamNames, StandardModelParams from ._provider_openai_completions import load_tool_request_args from ._provider_openai_generic import BatchResult, OpenAIAbstractProvider -from ._tools import Tool, basemodel_to_param_schema +from ._tools import Tool, ToolBuiltIn, basemodel_to_param_schema from ._turn import AssistantTurn, Turn if TYPE_CHECKING: @@ -169,7 +169,7 @@ def chat_perform( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): @@ -181,7 +181,7 @@ async def chat_perform_async( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): @@ -192,7 +192,7 @@ def _chat_perform_args( self, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": @@ -204,13 +204,15 @@ def _chat_perform_args( **(kwargs or {}), } - tool_schemas = [tool.schema for tool in tools.values()] - if tool_schemas: - # Convert completion tool format to responses format - responses_tools: list["ToolParam"] = [] - for schema in tool_schemas: + # Handle tools - both regular and built-in + tool_params: list["ToolParam"] = [] + for tool in tools.values(): + if isinstance(tool, ToolBuiltIn): + tool_params.append(cast(ToolParam, tool.definition)) + else: + schema = tool.schema func = schema["function"] - responses_tools.append( + tool_params.append( { "type": "function", "name": func["name"], @@ -219,8 +221,9 @@ def _chat_perform_args( "strict": func.get("strict", True), } ) - if responses_tools: - kwargs_full["tools"] = responses_tools + + if tool_params: + kwargs_full["tools"] = tool_params # Add structured data extraction if present if data_model is not None: @@ -323,6 +326,25 @@ def _response_as_turn(completion: Response, has_data_model: bool) -> AssistantTu extra=output.model_dump(), ) ) + + elif output.type == "image_generation_call": + result = output.result + if result: + mime_type = "image/png" + if "image/jpeg" in result: + mime_type = "image/jpeg" + elif "image/webp" in result: + mime_type = "image/webp" + elif "image/gif" in result: + mime_type = "image/gif" + + contents.append( + ContentImageInline( + data=result, + image_content_type=mime_type, + ) + ) + else: raise ValueError(f"Unknown output type: {output.type}") diff --git a/chatlas/_provider_openai_completions.py b/chatlas/_provider_openai_completions.py index 79bda038..1b7c66dc 100644 --- a/chatlas/_provider_openai_completions.py +++ b/chatlas/_provider_openai_completions.py @@ -31,7 +31,7 @@ from ._merge import merge_dicts from ._provider import StandardModelParamNames, StandardModelParams from ._provider_openai_generic import BatchResult, OpenAIAbstractProvider -from ._tools import Tool, basemodel_to_param_schema +from ._tools import Tool, ToolBuiltIn, basemodel_to_param_schema from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn from ._utils import MISSING, MISSING_TYPE, is_testing @@ -121,7 +121,7 @@ def chat_perform( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): @@ -133,7 +133,7 @@ async def chat_perform_async( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): @@ -144,11 +144,17 @@ def _chat_perform_args( self, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ) -> "SubmitInputArgs": - tool_schemas = [tool.schema for tool in tools.values()] + + tool_schemas = [] + for tool in tools.values(): + if isinstance(tool, ToolBuiltIn): + tool_schemas.append(tool.definition) + else: + tool_schemas.append(tool.schema) kwargs_full: "SubmitInputArgs" = { "stream": stream, diff --git a/chatlas/_provider_openai_generic.py b/chatlas/_provider_openai_generic.py index 62738650..29f8c5f5 100644 --- a/chatlas/_provider_openai_generic.py +++ b/chatlas/_provider_openai_generic.py @@ -24,7 +24,7 @@ SubmitInputArgsT, ) from ._tokens import get_token_pricing -from ._tools import Tool +from ._tools import Tool, ToolBuiltIn from ._turn import AssistantTurn, Turn, UserTurn, user_turn from ._utils import split_http_client_kwargs @@ -122,7 +122,7 @@ def list_models(self): def token_count( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: try: @@ -154,7 +154,7 @@ def token_count( async def token_count_async( self, *args: Content | str, - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: return self.token_count(*args, tools=tools, data_model=data_model) @@ -265,7 +265,7 @@ def _chat_perform_args( self, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> SubmitInputArgsT: ... diff --git a/chatlas/_provider_snowflake.py b/chatlas/_provider_snowflake.py index 33944b4d..67f1603b 100644 --- a/chatlas/_provider_snowflake.py +++ b/chatlas/_provider_snowflake.py @@ -21,7 +21,7 @@ ) from ._logging import log_model_default from ._provider import Provider, StandardModelParamNames, StandardModelParams -from ._tools import Tool, basemodel_to_param_schema +from ._tools import Tool, ToolBuiltIn, basemodel_to_param_schema from ._turn import AssistantTurn, Turn from ._utils import drop_none @@ -205,7 +205,7 @@ def chat_perform( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["CompleteRequest"] = None, ): ... @@ -216,7 +216,7 @@ def chat_perform( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["CompleteRequest"] = None, ): ... @@ -226,7 +226,7 @@ def chat_perform( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["CompleteRequest"] = None, ): @@ -254,7 +254,7 @@ async def chat_perform_async( *, stream: Literal[False], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["CompleteRequest"] = None, ): ... @@ -265,7 +265,7 @@ async def chat_perform_async( *, stream: Literal[True], turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["CompleteRequest"] = None, ): ... @@ -275,7 +275,7 @@ async def chat_perform_async( *, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["CompleteRequest"] = None, ): @@ -303,7 +303,7 @@ def _complete_request( self, stream: bool, turns: list[Turn], - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]] = None, kwargs: Optional["CompleteRequest"] = None, ): @@ -455,7 +455,7 @@ def value_tokens(self, completion): def token_count( self, *args: "Content | str", - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: raise NotImplementedError( @@ -465,7 +465,7 @@ def token_count( async def token_count_async( self, *args: "Content | str", - tools: dict[str, Tool], + tools: dict[str, Tool | ToolBuiltIn], data_model: Optional[type[BaseModel]], ) -> int: raise NotImplementedError( @@ -565,9 +565,15 @@ def _as_turn(self, completion: "Completion", has_data_model: bool) -> AssistantT # N.B. this is currently the best documentation I can find for how tool calling works # https://quickstarts.snowflake.com/guide/getting-started-with-tool-use-on-cortex-and-anthropic-claude/index.html#5 - def _as_snowflake_tool(self, tool: Tool): + def _as_snowflake_tool(self, tool: Tool | ToolBuiltIn): import snowflake.core.cortex.inference_service._generated.models as models + if isinstance(tool, ToolBuiltIn): + raise NotImplementedError( + "Built-in tools are not yet supported for Snowflake. " + "Please use custom tools via Tool instances." + ) + func = tool.schema["function"] params = func.get("parameters", {}) diff --git a/chatlas/_tools.py b/chatlas/_tools.py index 1e616819..666d4b30 100644 --- a/chatlas/_tools.py +++ b/chatlas/_tools.py @@ -25,6 +25,7 @@ __all__ = ( "Tool", + "ToolBuiltIn", "ToolRejectError", ) @@ -232,6 +233,27 @@ async def _call(**args: Any) -> AsyncGenerator[ContentToolResult, None]: ) +class ToolBuiltIn: + """ + Define a built-in provider-specific tool + + This class represents tools that are built into specific providers (like image + generation). Unlike regular Tool objects, ToolBuiltIn instances pass raw + provider-specific JSON directly through to the API. + + Parameters + ---------- + name + The name of the tool. + definition + The raw provider-specific tool definition as a dictionary. + """ + + def __init__(self, *, name: str, definition: dict[str, Any]): + self.name = name + self.definition = definition + + class ToolRejectError(Exception): """ Error to represent a tool call being rejected.