Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,24 @@ In addition to returning text outputs, you can return one or many images or file
- Files: [`ToolOutputFileContent`][agents.tool.ToolOutputFileContent] (or the TypedDict version, [`ToolOutputFileContentDict`][agents.tool.ToolOutputFileContentDict])
- Text: either a string or stringable objects, or [`ToolOutputText`][agents.tool.ToolOutputText] (or the TypedDict version, [`ToolOutputTextDict`][agents.tool.ToolOutputTextDict])

### Instance methods as tools

You can decorate instance methods with `@function_tool` and pass the bound method from an instance. The `self` argument is supplied automatically and excluded from the tool's JSON schema:

```python
class Calculator:
def __init__(self, base: int):
self.base = base

@function_tool
def add_to_base(self, x: int) -> int:
"""Add x to the calculator's base."""
return self.base + x

calc = Calculator(base=10)
agent = Agent(name="Math", tools=[calc.add_to_base])
```

### Custom function tools

Sometimes, you don't want to use a Python function as a tool. You can directly create a [`FunctionTool`][agents.tool.FunctionTool] if you prefer. You'll need to provide:
Expand Down
23 changes: 22 additions & 1 deletion src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class FuncSchema:
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
skipped_self: bool = False
"""Whether a leading ``self`` parameter was skipped (method-backed function tools)."""

def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
Expand All @@ -50,8 +52,14 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
keyword_args: dict[str, Any] = {}
seen_var_positional = False

# Skip a leading `self` for method-backed tools; it is supplied by binding,
# not by the model, and is not part of the schema.
param_items = list(self.signature.parameters.items())
if self.skipped_self:
param_items = param_items[1:]

# Use enumerate() so we can skip the first parameter if it's context.
for idx, (name, param) in enumerate(self.signature.parameters.items()):
for idx, (name, param) in enumerate(param_items):
# If the function takes a RunContextWrapper and this is the first parameter, skip it.
if self.takes_context and idx == 0:
continue
Expand Down Expand Up @@ -228,6 +236,7 @@ def function_schema(
description_override: str | None = None,
use_docstring_info: bool = True,
strict_json_schema: bool = True,
skip_self: bool = False,
) -> FuncSchema:
"""
Given a Python function, extracts a `FuncSchema` from it, capturing the name, description,
Expand Down Expand Up @@ -286,6 +295,17 @@ def function_schema(
# 2. Inspect function signature and get type hints
sig = inspect.signature(func)
params = list(sig.parameters.items())

# Skip a leading `self` so method-backed tools (decorated instance methods)
# validate and serialize correctly: `self` is supplied by binding, and the
# next parameter is what should be evaluated for context detection. Gated by
# `skip_self` (set by the caller for method-backed tools) so an ordinary
# function whose first argument is literally named `self` is unaffected.
skipped_self = False
if skip_self and params and params[0][0] == "self":
params = params[1:]
skipped_self = True

takes_context = False
filtered_params = []

Expand Down Expand Up @@ -421,4 +441,5 @@ def function_schema(
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
skipped_self=skipped_self,
)
48 changes: 47 additions & 1 deletion src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass, field
from enum import Enum
from types import UnionType
from types import MethodType, UnionType
from typing import (
TYPE_CHECKING,
Annotated,
Expand Down Expand Up @@ -493,6 +493,11 @@ class FunctionTool:
_emit_tool_origin: bool = field(default=True, kw_only=True, repr=False)
"""Whether runtime item generation should emit tool origin metadata for this tool."""

_bind_to_instance: Callable[[Any], FunctionTool] | None = field(
default=None, kw_only=True, repr=False, compare=False
)
"""Internal: builds an instance-bound copy of a method-backed tool (see __get__)."""

@property
def qualified_name(self) -> str:
"""Return the public qualified name used to identify this function tool."""
Expand All @@ -510,6 +515,21 @@ def __post_init__(self):
)
_validate_function_tool_timeout_config(self)

def __get__(self, instance: Any, owner: type[Any] | None = None) -> FunctionTool:
"""Descriptor hook so ``@function_tool`` works on instance methods.

When the tool is a class attribute accessed via an instance, return a copy
bound to that instance (``self`` is supplied automatically and excluded
from the JSON schema). Tools that are not method-backed return unchanged.

A fresh bound tool is built per access rather than cached, since caching on
the instance would require it to be weak-referenceable/hashable (not true
for every tool-holder class) and would otherwise retain instances.
"""
if instance is None or self._bind_to_instance is None:
return self
return self._bind_to_instance(instance)

def __copy__(self) -> FunctionTool:
copied_tool = dataclasses.replace(self)
dataclass_field_names = {tool_field.name for tool_field in dataclasses.fields(FunctionTool)}
Expand Down Expand Up @@ -658,6 +678,25 @@ def get_function_tool_origin(function_tool: FunctionTool) -> ToolOrigin | None:
return function_tool._tool_origin or ToolOrigin(type=ToolOriginType.FUNCTION)


def _looks_like_method(func: Any) -> bool:
"""Heuristic: is ``func`` an instance method decorated with ``@function_tool``?

True only when the first parameter is ``self`` *and* the qualified name shows the
function is defined in a class body (e.g. ``Class.method``). This deliberately
excludes a plain module-level function whose first argument happens to be named
``self`` (qualname has no class component), so its behavior is unchanged.
"""
try:
params = list(inspect.signature(func).parameters)
except (TypeError, ValueError):
return False
if not params or params[0] != "self":
return False
qualname = getattr(func, "__qualname__", "")
parts = qualname.split(".")
return len(parts) >= 2 and parts[-2] != "<locals>"


@dataclass
class FileSearchTool:
"""A hosted tool that lets the LLM search through a vector store. Currently only supported with
Expand Down Expand Up @@ -1966,13 +2005,15 @@ def function_tool(

def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
is_sync_function_tool = not inspect.iscoroutinefunction(the_func)
is_method = _looks_like_method(the_func)
schema = function_schema(
func=the_func,
name_override=name_override,
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
skip_self=is_method,
)

async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
Expand Down Expand Up @@ -2035,6 +2076,11 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
custom_data_extractor=custom_data_extractor,
sync_invoker=is_sync_function_tool,
)
if is_method:
# Bind `self` when the tool is accessed via an instance (see __get__).
function_tool._bind_to_instance = lambda instance: _create_function_tool(
MethodType(the_func, instance)
)
return function_tool

# If func is actually a callable, we were used as @function_tool with no parentheses
Expand Down
111 changes: 111 additions & 0 deletions tests/test_function_tool_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""@function_tool support on instance methods (#94)."""

from __future__ import annotations

import json

from agents import Agent, FunctionTool, RunContextWrapper, Runner, function_tool
from agents.tool_context import ToolContext
from tests.fake_model import FakeModel
from tests.test_responses import get_function_tool_call, get_text_message


class Calculator:
def __init__(self, base: int) -> None:
self.base = base

@function_tool
def add_to_base(self, x: int) -> int:
"""Add x to the calculator's base."""
return self.base + x

@function_tool
def add_with_context(self, ctx: RunContextWrapper[int], x: int) -> int:
"""Add x to the base and the run context value."""
return self.base + x + ctx.context


def _ctx(tool: FunctionTool) -> ToolContext:
return ToolContext(context=None, tool_name=tool.name, tool_call_id="1", tool_arguments="")


def test_instance_access_binds_self_and_drops_it_from_schema() -> None:
calc = Calculator(10)
tool = calc.add_to_base # descriptor __get__ -> instance-bound tool

assert isinstance(tool, FunctionTool)
properties = tool.params_json_schema.get("properties", {})
assert "self" not in properties
assert "x" in properties


async def test_instance_method_tool_invokes_with_self() -> None:
calc = Calculator(10)
tool = calc.add_to_base
result = await tool.on_invoke_tool(_ctx(tool), json.dumps({"x": 5}))
assert result == 15


async def test_distinct_instances_bind_independently() -> None:
ten, twenty = Calculator(10), Calculator(20)
assert await ten.add_to_base.on_invoke_tool(_ctx(ten.add_to_base), json.dumps({"x": 1})) == 11
assert (
await twenty.add_to_base.on_invoke_tool(_ctx(twenty.add_to_base), json.dumps({"x": 1}))
== 21
)


async def test_context_taking_method_binds_self_and_context() -> None:
# A method that takes RunContextWrapper after self must not raise at decoration
# and must receive both self and the run context when invoked.
calc = Calculator(10)
tool = calc.add_with_context
assert "self" not in tool.params_json_schema.get("properties", {})
assert "ctx" not in tool.params_json_schema.get("properties", {})
assert "x" in tool.params_json_schema.get("properties", {})

ctx: ToolContext[int] = ToolContext(
context=5, tool_name=tool.name, tool_call_id="1", tool_arguments=""
)
result = await tool.on_invoke_tool(ctx, json.dumps({"x": 2}))
assert result == 17 # base 10 + x 2 + context 5


def test_module_level_self_named_function_is_not_treated_as_method() -> None:
# A plain function whose first arg happens to be named `self` is unaffected:
# `self` stays in the schema and is supplied by the model.
@function_tool
def weird(self: int, x: int) -> int:
"""A free function with an unfortunate first argument name."""
return self + x

assert "self" in weird.params_json_schema.get("properties", {})


def test_class_access_returns_unbound_tool() -> None:
# Accessing via the class (no instance) returns the original tool unchanged.
assert isinstance(Calculator.add_to_base, FunctionTool)


def test_module_level_function_tool_unaffected() -> None:
@function_tool
def free(x: int) -> int:
"""A free function."""
return x

assert isinstance(free, FunctionTool)
assert "x" in free.params_json_schema.get("properties", {})


async def test_instance_method_tool_runs_in_agent() -> None:
calc = Calculator(100)
model = FakeModel()
model.add_multiple_turn_outputs(
[
[get_function_tool_call("add_to_base", json.dumps({"x": 5}))],
[get_text_message("done")],
]
)
agent = Agent(name="A", instructions="x", model=model, tools=[calc.add_to_base])
result = await Runner.run(agent, "add 5")
assert result.final_output == "done"