diff --git a/docs/tools.md b/docs/tools.md index 3ece84b178..ba16a405bb 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -387,6 +387,24 @@ In addition to returning text outputs, you can return one or many images or file - Files: [`ToolOutputFileContent`][agents.tool.ToolOutputFileContent] (or the TypedDict version, [`ToolOutputFileContentDict`][agents.tool.ToolOutputFileContentDict]) - Text: either a string or stringable objects, or [`ToolOutputText`][agents.tool.ToolOutputText] (or the TypedDict version, [`ToolOutputTextDict`][agents.tool.ToolOutputTextDict]) +### Instance methods as tools + +You can decorate instance methods with `@function_tool` and pass the bound method from an instance. The `self` argument is supplied automatically and excluded from the tool's JSON schema: + +```python +class Calculator: + def __init__(self, base: int): + self.base = base + + @function_tool + def add_to_base(self, x: int) -> int: + """Add x to the calculator's base.""" + return self.base + x + +calc = Calculator(base=10) +agent = Agent(name="Math", tools=[calc.add_to_base]) +``` + ### Custom function tools Sometimes, you don't want to use a Python function as a tool. You can directly create a [`FunctionTool`][agents.tool.FunctionTool] if you prefer. You'll need to provide: diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index 8fe52df320..a30edb9efd 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -40,6 +40,8 @@ class FuncSchema: strict_json_schema: bool = True """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" + skipped_self: bool = False + """Whether a leading ``self`` parameter was skipped (method-backed function tools).""" def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: """ @@ -50,8 +52,14 @@ def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: keyword_args: dict[str, Any] = {} seen_var_positional = False + # Skip a leading `self` for method-backed tools; it is supplied by binding, + # not by the model, and is not part of the schema. + param_items = list(self.signature.parameters.items()) + if self.skipped_self: + param_items = param_items[1:] + # Use enumerate() so we can skip the first parameter if it's context. - for idx, (name, param) in enumerate(self.signature.parameters.items()): + for idx, (name, param) in enumerate(param_items): # If the function takes a RunContextWrapper and this is the first parameter, skip it. if self.takes_context and idx == 0: continue @@ -228,6 +236,7 @@ def function_schema( description_override: str | None = None, use_docstring_info: bool = True, strict_json_schema: bool = True, + skip_self: bool = False, ) -> FuncSchema: """ Given a Python function, extracts a `FuncSchema` from it, capturing the name, description, @@ -286,6 +295,17 @@ def function_schema( # 2. Inspect function signature and get type hints sig = inspect.signature(func) params = list(sig.parameters.items()) + + # Skip a leading `self` so method-backed tools (decorated instance methods) + # validate and serialize correctly: `self` is supplied by binding, and the + # next parameter is what should be evaluated for context detection. Gated by + # `skip_self` (set by the caller for method-backed tools) so an ordinary + # function whose first argument is literally named `self` is unaffected. + skipped_self = False + if skip_self and params and params[0][0] == "self": + params = params[1:] + skipped_self = True + takes_context = False filtered_params = [] @@ -421,4 +441,5 @@ def function_schema( signature=sig, takes_context=takes_context, strict_json_schema=strict_json_schema, + skipped_self=skipped_self, ) diff --git a/src/agents/tool.py b/src/agents/tool.py index c8563e2e1b..8638f835f1 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -11,7 +11,7 @@ from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass, field from enum import Enum -from types import UnionType +from types import MethodType, UnionType from typing import ( TYPE_CHECKING, Annotated, @@ -493,6 +493,11 @@ class FunctionTool: _emit_tool_origin: bool = field(default=True, kw_only=True, repr=False) """Whether runtime item generation should emit tool origin metadata for this tool.""" + _bind_to_instance: Callable[[Any], FunctionTool] | None = field( + default=None, kw_only=True, repr=False, compare=False + ) + """Internal: builds an instance-bound copy of a method-backed tool (see __get__).""" + @property def qualified_name(self) -> str: """Return the public qualified name used to identify this function tool.""" @@ -510,6 +515,21 @@ def __post_init__(self): ) _validate_function_tool_timeout_config(self) + def __get__(self, instance: Any, owner: type[Any] | None = None) -> FunctionTool: + """Descriptor hook so ``@function_tool`` works on instance methods. + + When the tool is a class attribute accessed via an instance, return a copy + bound to that instance (``self`` is supplied automatically and excluded + from the JSON schema). Tools that are not method-backed return unchanged. + + A fresh bound tool is built per access rather than cached, since caching on + the instance would require it to be weak-referenceable/hashable (not true + for every tool-holder class) and would otherwise retain instances. + """ + if instance is None or self._bind_to_instance is None: + return self + return self._bind_to_instance(instance) + def __copy__(self) -> FunctionTool: copied_tool = dataclasses.replace(self) dataclass_field_names = {tool_field.name for tool_field in dataclasses.fields(FunctionTool)} @@ -658,6 +678,25 @@ def get_function_tool_origin(function_tool: FunctionTool) -> ToolOrigin | None: return function_tool._tool_origin or ToolOrigin(type=ToolOriginType.FUNCTION) +def _looks_like_method(func: Any) -> bool: + """Heuristic: is ``func`` an instance method decorated with ``@function_tool``? + + True only when the first parameter is ``self`` *and* the qualified name shows the + function is defined in a class body (e.g. ``Class.method``). This deliberately + excludes a plain module-level function whose first argument happens to be named + ``self`` (qualname has no class component), so its behavior is unchanged. + """ + try: + params = list(inspect.signature(func).parameters) + except (TypeError, ValueError): + return False + if not params or params[0] != "self": + return False + qualname = getattr(func, "__qualname__", "") + parts = qualname.split(".") + return len(parts) >= 2 and parts[-2] != "" + + @dataclass class FileSearchTool: """A hosted tool that lets the LLM search through a vector store. Currently only supported with @@ -1966,6 +2005,7 @@ def function_tool( def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: is_sync_function_tool = not inspect.iscoroutinefunction(the_func) + is_method = _looks_like_method(the_func) schema = function_schema( func=the_func, name_override=name_override, @@ -1973,6 +2013,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: docstring_style=docstring_style, use_docstring_info=use_docstring_info, strict_json_schema=strict_mode, + skip_self=is_method, ) async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: @@ -2035,6 +2076,11 @@ async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any: custom_data_extractor=custom_data_extractor, sync_invoker=is_sync_function_tool, ) + if is_method: + # Bind `self` when the tool is accessed via an instance (see __get__). + function_tool._bind_to_instance = lambda instance: _create_function_tool( + MethodType(the_func, instance) + ) return function_tool # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/tests/test_function_tool_methods.py b/tests/test_function_tool_methods.py new file mode 100644 index 0000000000..0d12ba7af5 --- /dev/null +++ b/tests/test_function_tool_methods.py @@ -0,0 +1,111 @@ +"""@function_tool support on instance methods (#94).""" + +from __future__ import annotations + +import json + +from agents import Agent, FunctionTool, RunContextWrapper, Runner, function_tool +from agents.tool_context import ToolContext +from tests.fake_model import FakeModel +from tests.test_responses import get_function_tool_call, get_text_message + + +class Calculator: + def __init__(self, base: int) -> None: + self.base = base + + @function_tool + def add_to_base(self, x: int) -> int: + """Add x to the calculator's base.""" + return self.base + x + + @function_tool + def add_with_context(self, ctx: RunContextWrapper[int], x: int) -> int: + """Add x to the base and the run context value.""" + return self.base + x + ctx.context + + +def _ctx(tool: FunctionTool) -> ToolContext: + return ToolContext(context=None, tool_name=tool.name, tool_call_id="1", tool_arguments="") + + +def test_instance_access_binds_self_and_drops_it_from_schema() -> None: + calc = Calculator(10) + tool = calc.add_to_base # descriptor __get__ -> instance-bound tool + + assert isinstance(tool, FunctionTool) + properties = tool.params_json_schema.get("properties", {}) + assert "self" not in properties + assert "x" in properties + + +async def test_instance_method_tool_invokes_with_self() -> None: + calc = Calculator(10) + tool = calc.add_to_base + result = await tool.on_invoke_tool(_ctx(tool), json.dumps({"x": 5})) + assert result == 15 + + +async def test_distinct_instances_bind_independently() -> None: + ten, twenty = Calculator(10), Calculator(20) + assert await ten.add_to_base.on_invoke_tool(_ctx(ten.add_to_base), json.dumps({"x": 1})) == 11 + assert ( + await twenty.add_to_base.on_invoke_tool(_ctx(twenty.add_to_base), json.dumps({"x": 1})) + == 21 + ) + + +async def test_context_taking_method_binds_self_and_context() -> None: + # A method that takes RunContextWrapper after self must not raise at decoration + # and must receive both self and the run context when invoked. + calc = Calculator(10) + tool = calc.add_with_context + assert "self" not in tool.params_json_schema.get("properties", {}) + assert "ctx" not in tool.params_json_schema.get("properties", {}) + assert "x" in tool.params_json_schema.get("properties", {}) + + ctx: ToolContext[int] = ToolContext( + context=5, tool_name=tool.name, tool_call_id="1", tool_arguments="" + ) + result = await tool.on_invoke_tool(ctx, json.dumps({"x": 2})) + assert result == 17 # base 10 + x 2 + context 5 + + +def test_module_level_self_named_function_is_not_treated_as_method() -> None: + # A plain function whose first arg happens to be named `self` is unaffected: + # `self` stays in the schema and is supplied by the model. + @function_tool + def weird(self: int, x: int) -> int: + """A free function with an unfortunate first argument name.""" + return self + x + + assert "self" in weird.params_json_schema.get("properties", {}) + + +def test_class_access_returns_unbound_tool() -> None: + # Accessing via the class (no instance) returns the original tool unchanged. + assert isinstance(Calculator.add_to_base, FunctionTool) + + +def test_module_level_function_tool_unaffected() -> None: + @function_tool + def free(x: int) -> int: + """A free function.""" + return x + + assert isinstance(free, FunctionTool) + assert "x" in free.params_json_schema.get("properties", {}) + + +async def test_instance_method_tool_runs_in_agent() -> None: + calc = Calculator(100) + model = FakeModel() + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("add_to_base", json.dumps({"x": 5}))], + [get_text_message("done")], + ] + ) + agent = Agent(name="A", instructions="x", model=model, tools=[calc.add_to_base]) + result = await Runner.run(agent, "add 5") + assert result.final_output == "done"