Skip to content

Commit 2acd2cb

Browse files
committed
Add basic image generation support; introduce new ToolBuiltIn class
1 parent 1b9ac9b commit 2acd2cb

File tree

8 files changed

+194
-49
lines changed

8 files changed

+194
-49
lines changed

chatlas/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ._provider_portkey import ChatPortkey
3535
from ._provider_snowflake import ChatSnowflake
3636
from ._tokens import token_usage
37-
from ._tools import Tool, ToolRejectError
37+
from ._tools import Tool, ToolBuiltIn, ToolRejectError
3838
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn
3939

4040
try:
@@ -84,6 +84,7 @@
8484
"Provider",
8585
"token_usage",
8686
"Tool",
87+
"ToolBuiltIn",
8788
"ToolRejectError",
8889
"Turn",
8990
"UserTurn",

chatlas/_chat.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from ._mcp_manager import MCPSessionManager
5050
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
5151
from ._tokens import compute_cost, get_token_pricing, tokens_log
52-
from ._tools import Tool, ToolRejectError
52+
from ._tools import Tool, ToolBuiltIn, ToolRejectError
5353
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn
5454
from ._typing_extensions import TypedDict, TypeGuard
5555
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
@@ -132,7 +132,7 @@ def __init__(
132132
self.system_prompt = system_prompt
133133
self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {}
134134

135-
self._tools: dict[str, Tool] = {}
135+
self._tools: dict[str, Tool | ToolBuiltIn] = {}
136136
self._on_tool_request_callbacks = CallbackManager()
137137
self._on_tool_result_callbacks = CallbackManager()
138138
self._current_display: Optional[MarkdownDisplay] = None
@@ -1880,7 +1880,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):
18801880

18811881
def register_tool(
18821882
self,
1883-
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
1883+
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | "ToolBuiltIn",
18841884
*,
18851885
force: bool = False,
18861886
name: Optional[str] = None,
@@ -1974,31 +1974,39 @@ def add(a: int, b: int) -> int:
19741974
ValueError
19751975
If a tool with the same name already exists and `force` is `False`.
19761976
"""
1977-
if isinstance(func, Tool):
1977+
if isinstance(func, ToolBuiltIn):
1978+
# ToolBuiltIn objects are stored directly without conversion
1979+
tool = func
1980+
tool_name = tool.name
1981+
elif isinstance(func, Tool):
19781982
name = name or func.name
19791983
annotations = annotations or func.annotations
19801984
if model is not None:
19811985
func = Tool.from_func(
19821986
func.func, name=name, model=model, annotations=annotations
19831987
)
19841988
func = func.func
1989+
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1990+
tool_name = tool.name
1991+
else:
1992+
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1993+
tool_name = tool.name
19851994

1986-
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
1987-
if tool.name in self._tools and not force:
1995+
if tool_name in self._tools and not force:
19881996
raise ValueError(
1989-
f"Tool with name '{tool.name}' is already registered. "
1997+
f"Tool with name '{tool_name}' is already registered. "
19901998
"Set `force=True` to overwrite it."
19911999
)
1992-
self._tools[tool.name] = tool
2000+
self._tools[tool_name] = tool
19932001

1994-
def get_tools(self) -> list[Tool]:
2002+
def get_tools(self) -> list[Tool | ToolBuiltIn]:
19952003
"""
19962004
Get the list of registered tools.
19972005
19982006
Returns
19992007
-------
2000-
list[Tool]
2001-
A list of `Tool` instances that are currently registered with the chat.
2008+
list[Tool | ToolBuiltIn]
2009+
A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat.
20022010
"""
20032011
return list(self._tools.values())
20042012

@@ -2522,7 +2530,7 @@ def _submit_turns(
25222530
data_model: type[BaseModel] | None = None,
25232531
kwargs: Optional[SubmitInputArgsT] = None,
25242532
) -> Generator[str, None, None]:
2525-
if any(x._is_async for x in self._tools.values()):
2533+
if any(hasattr(x, "_is_async") and x._is_async for x in self._tools.values()):
25262534
raise ValueError("Cannot use async tools in a synchronous chat")
25272535

25282536
def emit(text: str | Content):
@@ -2683,15 +2691,27 @@ def _collect_all_kwargs(
26832691

26842692
def _invoke_tool(self, request: ContentToolRequest):
26852693
tool = self._tools.get(request.name)
2686-
func = tool.func if tool is not None else None
26872694

2688-
if func is None:
2695+
if tool is None:
26892696
yield self._handle_tool_error_result(
26902697
request,
26912698
error=RuntimeError("Unknown tool."),
26922699
)
26932700
return
26942701

2702+
if isinstance(tool, ToolBuiltIn):
2703+
# Built-in tools are handled by the provider, not invoked directly
2704+
yield self._handle_tool_error_result(
2705+
request,
2706+
error=RuntimeError(
2707+
f"Built-in tool '{request.name}' cannot be invoked directly. "
2708+
"It should be handled by the provider."
2709+
),
2710+
)
2711+
return
2712+
2713+
func = tool.func
2714+
26952715
# First, invoke the request callbacks. If a ToolRejectError is raised,
26962716
# treat it like a tool failure (i.e., gracefully handle it).
26972717
result: ContentToolResult | None = None
@@ -2739,6 +2759,17 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
27392759
)
27402760
return
27412761

2762+
if isinstance(tool, ToolBuiltIn):
2763+
# Built-in tools are handled by the provider, not invoked directly
2764+
yield self._handle_tool_error_result(
2765+
request,
2766+
error=RuntimeError(
2767+
f"Built-in tool '{request.name}' cannot be invoked directly. "
2768+
"It should be handled by the provider."
2769+
),
2770+
)
2771+
return
2772+
27422773
if tool._is_async:
27432774
func = tool.func
27442775
else:

chatlas/_content.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ._typing_extensions import TypedDict
1212

1313
if TYPE_CHECKING:
14-
from ._tools import Tool
14+
from ._tools import Tool, ToolBuiltIn
1515

1616

1717
class ToolAnnotations(TypedDict, total=False):
@@ -104,15 +104,28 @@ class ToolInfo(BaseModel):
104104
annotations: Optional[ToolAnnotations] = None
105105

106106
@classmethod
107-
def from_tool(cls, tool: "Tool") -> "ToolInfo":
108-
"""Create a ToolInfo from a Tool instance."""
109-
func_schema = tool.schema["function"]
110-
return cls(
111-
name=tool.name,
112-
description=func_schema.get("description", ""),
113-
parameters=func_schema.get("parameters", {}),
114-
annotations=tool.annotations,
115-
)
107+
def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo":
108+
"""Create a ToolInfo from a Tool or ToolBuiltIn instance."""
109+
from ._tools import ToolBuiltIn
110+
111+
if isinstance(tool, ToolBuiltIn):
112+
# For built-in tools, extract info from the definition
113+
defn = tool.definition
114+
return cls(
115+
name=tool.name,
116+
description=defn.get("description", ""),
117+
parameters=defn.get("parameters", {}),
118+
annotations=None,
119+
)
120+
else:
121+
# For regular tools, extract from schema
122+
func_schema = tool.schema["function"]
123+
return cls(
124+
name=tool.name,
125+
description=func_schema.get("description", ""),
126+
parameters=func_schema.get("parameters", {}),
127+
annotations=tool.annotations,
128+
)
116129

117130

118131
ContentTypeEnum = Literal[
@@ -247,6 +260,22 @@ def __str__(self):
247260
def _repr_markdown_(self):
248261
return self.__str__()
249262

263+
def _repr_png_(self):
264+
"""Display PNG images directly in Jupyter notebooks."""
265+
if self.image_content_type == "image/png" and self.data:
266+
import base64
267+
268+
return base64.b64decode(self.data)
269+
return None
270+
271+
def _repr_jpeg_(self):
272+
"""Display JPEG images directly in Jupyter notebooks."""
273+
if self.image_content_type == "image/jpeg" and self.data:
274+
import base64
275+
276+
return base64.b64decode(self.data)
277+
return None
278+
250279
def __repr__(self, indent: int = 0):
251280
n_bytes = len(self.data) if self.data else 0
252281
return (

chatlas/_provider.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic import BaseModel
1717

1818
from ._content import Content
19-
from ._tools import Tool
19+
from ._tools import Tool, ToolBuiltIn
2020
from ._turn import AssistantTurn, Turn
2121
from ._typing_extensions import NotRequired, TypedDict
2222

@@ -162,7 +162,7 @@ def chat_perform(
162162
*,
163163
stream: Literal[False],
164164
turns: list[Turn],
165-
tools: dict[str, Tool],
165+
tools: dict[str, Tool | ToolBuiltIn],
166166
data_model: Optional[type[BaseModel]],
167167
kwargs: SubmitInputArgsT,
168168
) -> ChatCompletionT: ...
@@ -174,7 +174,7 @@ def chat_perform(
174174
*,
175175
stream: Literal[True],
176176
turns: list[Turn],
177-
tools: dict[str, Tool],
177+
tools: dict[str, Tool | ToolBuiltIn],
178178
data_model: Optional[type[BaseModel]],
179179
kwargs: SubmitInputArgsT,
180180
) -> Iterable[ChatCompletionChunkT]: ...
@@ -185,7 +185,7 @@ def chat_perform(
185185
*,
186186
stream: bool,
187187
turns: list[Turn],
188-
tools: dict[str, Tool],
188+
tools: dict[str, Tool | ToolBuiltIn],
189189
data_model: Optional[type[BaseModel]],
190190
kwargs: SubmitInputArgsT,
191191
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
@@ -197,7 +197,7 @@ async def chat_perform_async(
197197
*,
198198
stream: Literal[False],
199199
turns: list[Turn],
200-
tools: dict[str, Tool],
200+
tools: dict[str, Tool | ToolBuiltIn],
201201
data_model: Optional[type[BaseModel]],
202202
kwargs: SubmitInputArgsT,
203203
) -> ChatCompletionT: ...
@@ -209,7 +209,7 @@ async def chat_perform_async(
209209
*,
210210
stream: Literal[True],
211211
turns: list[Turn],
212-
tools: dict[str, Tool],
212+
tools: dict[str, Tool | ToolBuiltIn],
213213
data_model: Optional[type[BaseModel]],
214214
kwargs: SubmitInputArgsT,
215215
) -> AsyncIterable[ChatCompletionChunkT]: ...
@@ -220,7 +220,7 @@ async def chat_perform_async(
220220
*,
221221
stream: bool,
222222
turns: list[Turn],
223-
tools: dict[str, Tool],
223+
tools: dict[str, Tool | ToolBuiltIn],
224224
data_model: Optional[type[BaseModel]],
225225
kwargs: SubmitInputArgsT,
226226
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
@@ -259,15 +259,15 @@ def value_tokens(
259259
def token_count(
260260
self,
261261
*args: Content | str,
262-
tools: dict[str, Tool],
262+
tools: dict[str, Tool | ToolBuiltIn],
263263
data_model: Optional[type[BaseModel]],
264264
) -> int: ...
265265

266266
@abstractmethod
267267
async def token_count_async(
268268
self,
269269
*args: Content | str,
270-
tools: dict[str, Tool],
270+
tools: dict[str, Tool | ToolBuiltIn],
271271
data_model: Optional[type[BaseModel]],
272272
) -> int: ...
273273

chatlas/_provider_google.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,17 +307,25 @@ def _chat_perform_args(
307307
config.response_mime_type = "application/json"
308308

309309
if tools:
310-
config.tools = [
311-
GoogleTool(
312-
function_declarations=[
310+
from ._tools import ToolBuiltIn
311+
312+
function_declarations = []
313+
for tool in tools.values():
314+
if isinstance(tool, ToolBuiltIn):
315+
# For built-in tools, pass the raw definition through
316+
# This allows provider-specific tools like image generation
317+
# Note: Google's API expects these in a specific format
318+
continue # Built-in tools are not yet fully supported for Google
319+
else:
320+
function_declarations.append(
313321
FunctionDeclaration.from_callable(
314322
client=self._client._api_client,
315323
callable=tool.func,
316324
)
317-
for tool in tools.values()
318-
]
319-
)
320-
]
325+
)
326+
327+
if function_declarations:
328+
config.tools = [GoogleTool(function_declarations=function_declarations)]
321329

322330
kwargs_full["config"] = config
323331

@@ -545,6 +553,20 @@ def _as_turn(
545553
),
546554
)
547555
)
556+
inline_data = part.get("inlineData") or part.get("inline_data")
557+
if inline_data:
558+
# Handle image generation responses
559+
mime_type = inline_data.get("mimeType") or inline_data.get("mime_type")
560+
data = inline_data.get("data")
561+
if mime_type and data:
562+
# Ensure data is a string (should be base64 encoded)
563+
data_str = data if isinstance(data, str) else str(data)
564+
contents.append(
565+
ContentImageInline(
566+
image_content_type=mime_type, # type: ignore
567+
data=data_str,
568+
)
569+
)
548570

549571
if isinstance(finish_reason, FinishReason):
550572
finish_reason = finish_reason.name

0 commit comments

Comments
 (0)