Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chatlas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ._provider_portkey import ChatPortkey
from ._provider_snowflake import ChatSnowflake
from ._tokens import token_usage
from ._tools import Tool, ToolRejectError
from ._tools import Tool, ToolBuiltIn, ToolRejectError
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn

try:
Expand Down Expand Up @@ -84,6 +84,7 @@
"Provider",
"token_usage",
"Tool",
"ToolBuiltIn",
"ToolRejectError",
"Turn",
"UserTurn",
Expand Down
58 changes: 42 additions & 16 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from ._mcp_manager import MCPSessionManager
from ._provider import ModelInfo, Provider, StandardModelParams, SubmitInputArgsT
from ._tokens import compute_cost, get_token_pricing, tokens_log
from ._tools import Tool, ToolRejectError
from ._tools import Tool, ToolBuiltIn, ToolRejectError
from ._turn import AssistantTurn, SystemTurn, Turn, UserTurn, user_turn
from ._typing_extensions import TypedDict, TypeGuard
from ._utils import MISSING, MISSING_TYPE, html_escape, wrap_async
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
self.system_prompt = system_prompt
self.kwargs_chat: SubmitInputArgsT = kwargs_chat or {}

self._tools: dict[str, Tool] = {}
self._tools: dict[str, Tool | ToolBuiltIn] = {}
self._on_tool_request_callbacks = CallbackManager()
self._on_tool_result_callbacks = CallbackManager()
self._current_display: Optional[MarkdownDisplay] = None
Expand Down Expand Up @@ -1880,7 +1880,7 @@ async def cleanup_mcp_tools(self, names: Optional[Sequence[str]] = None):

def register_tool(
self,
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool,
func: Callable[..., Any] | Callable[..., Awaitable[Any]] | Tool | ToolBuiltIn,
*,
force: bool = False,
name: Optional[str] = None,
Expand Down Expand Up @@ -1982,23 +1982,30 @@ def add(a: int, b: int) -> int:
func.func, name=name, model=model, annotations=annotations
)
func = func.func
tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
else:
if isinstance(func, ToolBuiltIn):
tool = func
else:
tool = Tool.from_func(
func, name=name, model=model, annotations=annotations
Comment on lines +1985 to +1991
Copy link

Copilot AI Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for creating a Tool from a function is duplicated on lines 1985 and 1990-1992. Consider restructuring to eliminate this duplication by handling the ToolBuiltIn case first, then having a single Tool.from_func call for all other cases.

Copilot uses AI. Check for mistakes.
)

tool = Tool.from_func(func, name=name, model=model, annotations=annotations)
if tool.name in self._tools and not force:
raise ValueError(
f"Tool with name '{tool.name}' is already registered. "
"Set `force=True` to overwrite it."
)
self._tools[tool.name] = tool

def get_tools(self) -> list[Tool]:
def get_tools(self) -> list[Tool | ToolBuiltIn]:
"""
Get the list of registered tools.

Returns
-------
list[Tool]
A list of `Tool` instances that are currently registered with the chat.
list[Tool | ToolBuiltIn]
A list of `Tool` or `ToolBuiltIn` instances that are currently registered with the chat.
"""
return list(self._tools.values())

Expand Down Expand Up @@ -2522,7 +2529,7 @@ def _submit_turns(
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
if any(x._is_async for x in self._tools.values()):
if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")

def emit(text: str | Content):
Expand Down Expand Up @@ -2683,15 +2690,24 @@ def _collect_all_kwargs(

def _invoke_tool(self, request: ContentToolRequest):
tool = self._tools.get(request.name)
func = tool.func if tool is not None else None

if func is None:
if tool is None:
yield self._handle_tool_error_result(
request,
error=RuntimeError("Unknown tool."),
)
return

if isinstance(tool, ToolBuiltIn):
yield self._handle_tool_error_result(
request,
error=RuntimeError(
f"Built-in tool '{request.name}' cannot be invoked directly. "
"It should be handled by the provider."
),
)
return

# First, invoke the request callbacks. If a ToolRejectError is raised,
# treat it like a tool failure (i.e., gracefully handle it).
result: ContentToolResult | None = None
Expand All @@ -2703,9 +2719,9 @@ def _invoke_tool(self, request: ContentToolRequest):

try:
if isinstance(request.arguments, dict):
res = func(**request.arguments)
res = tool.func(**request.arguments)
else:
res = func(request.arguments)
res = tool.func(request.arguments)

# Normalize res as a generator of results.
if not inspect.isgenerator(res):
Expand Down Expand Up @@ -2739,10 +2755,15 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
)
return

if tool._is_async:
func = tool.func
else:
func = wrap_async(tool.func)
if isinstance(tool, ToolBuiltIn):
yield self._handle_tool_error_result(
request,
error=RuntimeError(
f"Built-in tool '{request.name}' cannot be invoked directly. "
"It should be handled by the provider."
),
)
return

# First, invoke the request callbacks. If a ToolRejectError is raised,
# treat it like a tool failure (i.e., gracefully handle it).
Expand All @@ -2753,6 +2774,11 @@ async def _invoke_tool_async(self, request: ContentToolRequest):
yield self._handle_tool_error_result(request, e)
return

if tool._is_async:
func = tool.func
else:
func = wrap_async(tool.func)

# Invoke the tool (if it hasn't been rejected).
try:
if isinstance(request.arguments, dict):
Expand Down
42 changes: 32 additions & 10 deletions chatlas/_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ._typing_extensions import TypedDict

if TYPE_CHECKING:
from ._tools import Tool
from ._tools import Tool, ToolBuiltIn


class ToolAnnotations(TypedDict, total=False):
Expand Down Expand Up @@ -104,15 +104,21 @@ class ToolInfo(BaseModel):
annotations: Optional[ToolAnnotations] = None

@classmethod
def from_tool(cls, tool: "Tool") -> "ToolInfo":
"""Create a ToolInfo from a Tool instance."""
func_schema = tool.schema["function"]
return cls(
name=tool.name,
description=func_schema.get("description", ""),
parameters=func_schema.get("parameters", {}),
annotations=tool.annotations,
)
def from_tool(cls, tool: "Tool | ToolBuiltIn") -> "ToolInfo":
"""Create a ToolInfo from a Tool or ToolBuiltIn instance."""
from ._tools import ToolBuiltIn

if isinstance(tool, ToolBuiltIn):
return cls(name=tool.name, description=tool.name, parameters={})
else:
# For regular tools, extract from schema
func_schema = tool.schema["function"]
return cls(
name=tool.name,
description=func_schema.get("description", ""),
parameters=func_schema.get("parameters", {}),
annotations=tool.annotations,
)


ContentTypeEnum = Literal[
Expand Down Expand Up @@ -247,6 +253,22 @@ def __str__(self):
def _repr_markdown_(self):
return self.__str__()

def _repr_png_(self):
"""Display PNG images directly in Jupyter notebooks."""
if self.image_content_type == "image/png" and self.data:
import base64

return base64.b64decode(self.data)
return None

def _repr_jpeg_(self):
"""Display JPEG images directly in Jupyter notebooks."""
if self.image_content_type == "image/jpeg" and self.data:
import base64

return base64.b64decode(self.data)
return None

Comment on lines +256 to +271
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to think through non-notebook contexts as well -- maybe for another PR though

def __repr__(self, indent: int = 0):
n_bytes = len(self.data) if self.data else 0
return (
Expand Down
6 changes: 3 additions & 3 deletions chatlas/_mcp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional, Sequence

from ._tools import Tool
from ._tools import Tool, ToolBuiltIn

if TYPE_CHECKING:
from mcp import ClientSession
Expand All @@ -23,7 +23,7 @@ class SessionInfo(ABC):

# Primary derived attributes
session: ClientSession | None = None
tools: dict[str, Tool] = field(default_factory=dict)
tools: dict[str, Tool | ToolBuiltIn] = field(default_factory=dict)

# Background task management
ready_event: asyncio.Event = field(default_factory=asyncio.Event)
Expand Down Expand Up @@ -74,7 +74,7 @@ async def request_tools(self) -> None:
tool_names = tool_names.difference(exclude)

# Apply namespace and convert to chatlas.Tool instances
self_tools: dict[str, Tool] = {}
self_tools: dict[str, Tool | ToolBuiltIn] = {}
for tool in response.tools:
if tool.name not in tool_names:
continue
Expand Down
18 changes: 9 additions & 9 deletions chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pydantic import BaseModel

from ._content import Content
from ._tools import Tool
from ._tools import Tool, ToolBuiltIn
from ._turn import AssistantTurn, Turn
from ._typing_extensions import NotRequired, TypedDict

Expand Down Expand Up @@ -162,7 +162,7 @@ def chat_perform(
*,
stream: Literal[False],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> ChatCompletionT: ...
Expand All @@ -174,7 +174,7 @@ def chat_perform(
*,
stream: Literal[True],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> Iterable[ChatCompletionChunkT]: ...
Expand All @@ -185,7 +185,7 @@ def chat_perform(
*,
stream: bool,
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> Iterable[ChatCompletionChunkT] | ChatCompletionT: ...
Expand All @@ -197,7 +197,7 @@ async def chat_perform_async(
*,
stream: Literal[False],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> ChatCompletionT: ...
Expand All @@ -209,7 +209,7 @@ async def chat_perform_async(
*,
stream: Literal[True],
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> AsyncIterable[ChatCompletionChunkT]: ...
Expand All @@ -220,7 +220,7 @@ async def chat_perform_async(
*,
stream: bool,
turns: list[Turn],
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
kwargs: SubmitInputArgsT,
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...
Expand Down Expand Up @@ -259,15 +259,15 @@ def value_tokens(
def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
) -> int: ...

@abstractmethod
async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
tools: dict[str, Tool | ToolBuiltIn],
data_model: Optional[type[BaseModel]],
) -> int: ...

Expand Down
Loading
Loading