From edbec4e30b2a9bd3eb0b466be0fab69a0854bfaa Mon Sep 17 00:00:00 2001 From: habema Date: Thu, 5 Feb 2026 15:23:10 +0300 Subject: [PATCH 1/3] feat: Add tool origin tracking to ToolCallItem and ToolCallOutputItem - Add ToolOriginType enum and ToolOrigin dataclass - Add _tool_origin field to FunctionTool - Set tool_origin for MCP tools and agent-as-tool - Extract and set tool_origin in ToolCallItem and ToolCallOutputItem creation - Add comprehensive tests for tool origin tracking --- src/agents/agent.py | 7 + src/agents/items.py | 7 + src/agents/mcp/util.py | 9 +- src/agents/run_internal/run_loop.py | 9 +- src/agents/run_internal/tool_execution.py | 3 + src/agents/run_internal/turn_resolution.py | 9 +- src/agents/tool.py | 58 ++++ tests/test_tool_origin.py | 333 +++++++++++++++++++++ 8 files changed, 431 insertions(+), 4 deletions(-) create mode 100644 tests/test_tool_origin.py diff --git a/src/agents/agent.py b/src/agents/agent.py index b0368e8698..1afc33757b 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -45,6 +45,8 @@ FunctionToolResult, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, _extract_tool_argument_json_error, default_tool_error_function, ) @@ -802,6 +804,11 @@ async def _run_agent_tool(context: ToolContext, input_json: str) -> Any: ) run_agent_tool._is_agent_tool = True run_agent_tool._agent_instance = self + # Set origin tracking on run_agent (the FunctionTool returned by @function_tool) + run_agent_tool._tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool=self, + ) return run_agent_tool diff --git a/src/agents/items.py b/src/agents/items.py index 94ab5daa35..64565b6037 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -49,6 +49,7 @@ from .exceptions import AgentsException, ModelBehaviorError from .logger import logger from .tool import ( + ToolOrigin, ToolOutputFileContent, ToolOutputImage, ToolOutputText, @@ -248,6 +249,9 @@ class ToolCallItem(RunItemBase[Any]): description: str | None = None """Optional tool description if known at item creation time.""" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -271,6 +275,9 @@ class ToolCallOutputItem(RunItemBase[Any]): type: Literal["tool_call_output_item"] = "tool_call_output_item" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 9c9a59f683..a72c55ccaf 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -20,6 +20,8 @@ FunctionTool, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, ToolOutputImageDict, ToolOutputTextDict, default_tool_error_function, @@ -301,7 +303,7 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput: bool | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] ) = server._get_needs_approval_for_tool(tool, agent) - return FunctionTool( + function_tool = FunctionTool( name=tool.name, description=tool.description or "", params_json_schema=schema, @@ -309,6 +311,11 @@ async def invoke_func(ctx: ToolContext[Any], input_json: str) -> ToolOutput: strict_json_schema=is_strict, needs_approval=needs_approval, ) + function_tool._tool_origin = ToolOrigin( + type=ToolOriginType.MCP, + mcp_server=server, + ) + return function_tool @staticmethod def _merge_mcp_meta( diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index e807c0cb11..4404868ed8 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -49,7 +49,7 @@ RawResponsesStreamEvent, RunItemStreamEvent, ) -from ..tool import Tool, dispose_resolved_computers +from ..tool import FunctionTool, Tool, _get_tool_origin_info, dispose_resolved_computers from ..tracing import Span, SpanError, agent_span, get_current_trace from ..tracing.model_tracing import get_model_tracing_impl from ..tracing.span_data import AgentSpanData @@ -1216,13 +1216,18 @@ async def run_single_turn_streamed( # execution behavior in process_model_response). tool_name = getattr(output_item, "name", None) tool_description: str | None = None + tool_origin = None if isinstance(tool_name, str) and tool_name in tool_map: - tool_description = getattr(tool_map[tool_name], "description", None) + tool = tool_map[tool_name] + tool_description = getattr(tool, "description", None) + if isinstance(tool, FunctionTool): + tool_origin = _get_tool_origin_info(tool) tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), agent=agent, description=tool_description, + tool_origin=tool_origin, ) streamed_result._event_queue.put_nowait( RunItemStreamEvent(item=tool_item, name="tool_called") diff --git a/src/agents/run_internal/tool_execution.py b/src/agents/run_internal/tool_execution.py index bc370ea611..a22f9a5cdc 100644 --- a/src/agents/run_internal/tool_execution.py +++ b/src/agents/run_internal/tool_execution.py @@ -52,6 +52,7 @@ ShellCallOutcome, ShellCommandOutput, Tool, + _get_tool_origin_info, resolve_computer, ) from ..tool_context import ToolContext @@ -973,10 +974,12 @@ async def run_single_tool(func_tool: FunctionTool, tool_call: ResponseFunctionTo run_item: RunItem | None = None if not nested_interruptions: + tool_origin = _get_tool_origin_info(tool_run.function_tool) run_item = ToolCallOutputItem( output=result, raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), agent=agent, + tool_origin=tool_origin, ) else: # Skip tool output until nested interruptions are resolved. diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index fed661ea9a..86872f4d27 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -62,6 +62,7 @@ LocalShellTool, ShellTool, Tool, + _get_tool_origin_info, ) from ..tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from ..tracing import SpanError, handoff_span @@ -1473,8 +1474,14 @@ def process_model_response( raise ModelBehaviorError(error) func_tool = function_map[output.name] + tool_origin = _get_tool_origin_info(func_tool) items.append( - ToolCallItem(raw_item=output, agent=agent, description=func_tool.description) + ToolCallItem( + raw_item=output, + agent=agent, + description=func_tool.description, + tool_origin=tool_origin, + ) ) functions.append( ToolRunFunction( diff --git a/src/agents/tool.py b/src/agents/tool.py index 4f70adc0f8..06cc25a734 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import enum import inspect import json import weakref @@ -48,6 +49,7 @@ if TYPE_CHECKING: from .agent import Agent, AgentBase from .items import RunItem, ToolApprovalItem + from .mcp.server import MCPServer ToolParams = ParamSpec("ToolParams") @@ -182,6 +184,59 @@ class ComputerProvider(Generic[ComputerT]): ] +class ToolOriginType(str, enum.Enum): + """The type of tool origin.""" + + FUNCTION = "function" + """Regular Python function tool created via @function_tool decorator.""" + + MCP = "mcp" + """MCP server tool converted via MCPUtil.to_function_tool().""" + + AGENT_AS_TOOL = "agent_as_tool" + """Agent converted to tool via agent.as_tool().""" + + +@dataclass +class ToolOrigin: + """Information about the origin/source of a function tool.""" + + type: ToolOriginType + """The type of tool origin.""" + + mcp_server: MCPServer | None = None + """The MCP server object. Only set when type is MCP.""" + + agent_as_tool: Agent[Any] | None = None + """The agent object. Only set when type is AGENT_AS_TOOL.""" + + def __repr__(self) -> str: + """Custom repr that only includes relevant fields.""" + parts = [f"type={self.type.value!r}"] + if self.mcp_server is not None: + parts.append(f"mcp_server_name={self.mcp_server.name!r}") + if self.agent_as_tool is not None: + parts.append(f"agent_as_tool_name={self.agent_as_tool.name!r}") + return f"ToolOrigin({', '.join(parts)})" + + +def _get_tool_origin_info(function_tool: FunctionTool) -> ToolOrigin | None: + """Extract origin information from a FunctionTool. + + Args: + function_tool: The function tool to extract origin info from. + + Returns: + ToolOrigin object if origin is set, otherwise None (defaults to FUNCTION type). + """ + origin = function_tool._tool_origin + if origin is None: + # Default to FUNCTION if not explicitly set + return ToolOrigin(type=ToolOriginType.FUNCTION) + + return origin + + @dataclass class FunctionToolResult: tool: FunctionTool @@ -264,6 +319,9 @@ class FunctionTool: _agent_instance: Any = field(default=None, init=False, repr=False) """Internal reference to the agent instance if this is an agent-as-tool.""" + _tool_origin: ToolOrigin | None = field(default=None, init=False, repr=False) + """Internal field tracking the origin of this tool (FUNCTION, MCP, or AGENT_AS_TOOL).""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py new file mode 100644 index 0000000000..5800b87099 --- /dev/null +++ b/tests/test_tool_origin.py @@ -0,0 +1,333 @@ +"""Tests for tool origin tracking feature.""" + +from __future__ import annotations + +import sys +from typing import cast + +import pytest + +from agents import Agent, FunctionTool, RunContextWrapper, Runner, function_tool +from agents.items import ToolCallItem, ToolCallItemTypes, ToolCallOutputItem +from agents.tool import ToolOrigin, ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_function_tool_origin(): + """Test that regular function tools have FUNCTION origin.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_call_items[0].tool_origin.mcp_server is None + assert tool_call_items[0].tool_origin.agent_as_tool is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_items[0].tool_origin.mcp_server is None + assert tool_output_items[0].tool_origin.agent_as_tool is None + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_mcp_tool_origin(): + """Test that MCP tools have MCP origin with server name.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server is not None + assert tool_call_items[0].tool_origin.mcp_server.name == "test_mcp_server" + assert tool_call_items[0].tool_origin.agent_as_tool is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server is not None + assert tool_output_items[0].tool_origin.mcp_server.name == "test_mcp_server" + assert tool_output_items[0].tool_origin.agent_as_tool is None + + +@pytest.mark.asyncio +async def test_agent_as_tool_origin(): + """Test that agent-as-tool has AGENT_AS_TOOL origin with agent name.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_items[0].tool_origin.mcp_server is None + assert tool_call_items[0].tool_origin.agent_as_tool is not None + assert tool_call_items[0].tool_origin.agent_as_tool.name == "nested_agent" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_items[0].tool_origin.mcp_server is None + assert tool_output_items[0].tool_origin.agent_as_tool is not None + assert tool_output_items[0].tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_multiple_tool_origins(): + """Test that multiple tools from different origins work together.""" + model = FakeModel() + nested_model = FakeModel() + + @function_tool + def func_tool(x: int) -> str: + """Function tool.""" + return f"function: {x}" + + mcp_server = FakeMCPServer(server_name="mcp_server") + mcp_server.add_tool("mcp_tool", {}) + + nested_agent = Agent(name="nested", model=nested_model, instructions="Nested agent") + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + agent_tool = nested_agent.as_tool(tool_name="agent_tool", tool_description="Agent tool") + + agent = Agent( + name="test", + model=model, + tools=[func_tool, agent_tool], + mcp_servers=[mcp_server], + ) + + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call("func_tool", '{"x": 1}'), + get_function_tool_call("mcp_tool", ""), + get_function_tool_call("agent_tool", '{"input": "test"}'), + ], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 3 + assert len(tool_output_items) == 3 + + # Find items by tool name + function_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "func_tool" + ) + mcp_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "mcp_tool" + ) + agent_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "agent_tool" + ) + + assert function_item.tool_origin is not None + assert function_item.tool_origin.type == ToolOriginType.FUNCTION + assert mcp_item.tool_origin is not None + assert mcp_item.tool_origin.type == ToolOriginType.MCP + assert mcp_item.tool_origin.mcp_server is not None + assert mcp_item.tool_origin.mcp_server.name == "mcp_server" + assert agent_item.tool_origin is not None + assert agent_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert agent_item.tool_origin.agent_as_tool is not None + assert agent_item.tool_origin.agent_as_tool.name == "nested" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_tool_origin_streaming(): + """Test that tool origin is populated correctly in streaming scenarios.""" + model = FakeModel() + server = FakeMCPServer(server_name="streaming_server") + server.add_tool("streaming_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("streaming_tool", "")], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="test") + tool_call_items = [] + tool_output_items = [] + + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + if isinstance(event.item, ToolCallItem): + tool_call_items.append(event.item) + elif isinstance(event.item, ToolCallOutputItem): + tool_output_items.append(event.item) + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server is not None + assert tool_call_items[0].tool_origin.mcp_server.name == "streaming_server" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server is not None + assert tool_output_items[0].tool_origin.mcp_server.name == "streaming_server" + + +@pytest.mark.asyncio +async def test_tool_origin_repr(): + """Test that ToolOrigin repr only shows relevant fields.""" + # FUNCTION origin + function_origin = ToolOrigin(type=ToolOriginType.FUNCTION) + assert "mcp_server_name" not in repr(function_origin) + assert "agent_as_tool_name" not in repr(function_origin) + + # MCP origin + if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + test_server = FakeMCPServer(server_name="test_server") + mcp_origin = ToolOrigin(type=ToolOriginType.MCP, mcp_server=test_server) + assert "mcp_server_name='test_server'" in repr(mcp_origin) + assert "agent_as_tool_name" not in repr(mcp_origin) + + # AGENT_AS_TOOL origin + model = FakeModel() + test_agent = Agent(name="test_agent", model=model, instructions="Test agent") + agent_origin = ToolOrigin(type=ToolOriginType.AGENT_AS_TOOL, agent_as_tool=test_agent) + assert "agent_as_tool_name='test_agent'" in repr(agent_origin) + assert "mcp_server_name" not in repr(agent_origin) + + +@pytest.mark.asyncio +async def test_tool_origin_defaults_to_function(): + """Test that tools without explicit origin default to FUNCTION.""" + model = FakeModel() + + # Create a FunctionTool directly without using @function_tool decorator + async def test_func(ctx: RunContextWrapper, args: str) -> str: + return "result" + + tool = FunctionTool( + name="direct_tool", + description="Direct tool", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=test_func, + ) + + agent = Agent(name="test", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("direct_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + + assert len(tool_call_items) == 1 + # Even though _tool_origin is None, _get_tool_origin_info defaults to FUNCTION + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_non_function_tool_items_have_no_origin(): + """Test that non-FunctionTool items (computer, shell, etc.) don't have tool_origin.""" + model = FakeModel() + + @function_tool + def func_tool() -> str: + """Function tool.""" + return "result" + + agent = Agent(name="test", model=model, tools=[func_tool]) + + # Create a ToolCallItem for a non-function tool (simulating computer/shell tool) + computer_call = { + "type": "computer_use_preview", + "call_id": "call_123", + "actions": [], + } + + # This simulates what happens for non-FunctionTool items + # They should not have tool_origin set + item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, computer_call), + agent=agent, + ) + + assert item.tool_origin is None From e1b635702c2f3f2f6b3e8ec98fe187c00beab71f Mon Sep 17 00:00:00 2001 From: habema Date: Thu, 5 Feb 2026 15:30:28 +0300 Subject: [PATCH 2/3] fix memory leak in code review and add test --- src/agents/items.py | 12 +++++++++ src/agents/tool.py | 41 ++++++++++++++++++++++++++++-- tests/test_tool_origin.py | 53 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 2 deletions(-) diff --git a/src/agents/items.py b/src/agents/items.py index 64565b6037..7139e07f99 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -252,6 +252,12 @@ class ToolCallItem(RunItemBase[Any]): tool_origin: ToolOrigin | None = field(default=None, repr=False) """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -278,6 +284,12 @@ class ToolCallOutputItem(RunItemBase[Any]): tool_origin: ToolOrigin | None = field(default=None, repr=False) """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def release_agent(self) -> None: + """Release agent references including tool_origin.agent_as_tool.""" + super().release_agent() + if self.tool_origin is not None: + self.tool_origin.release_agent() + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. diff --git a/src/agents/tool.py b/src/agents/tool.py index 06cc25a734..2e0e043581 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -210,13 +210,50 @@ class ToolOrigin: agent_as_tool: Agent[Any] | None = None """The agent object. Only set when type is AGENT_AS_TOOL.""" + _agent_as_tool_ref: weakref.ReferenceType[Agent[Any]] | None = field( + default=None, init=False, repr=False + ) + """Weak reference to agent_as_tool for memory management.""" + + def __post_init__(self) -> None: + """Initialize weak reference for agent_as_tool.""" + if self.agent_as_tool is not None: + self._agent_as_tool_ref = weakref.ref(self.agent_as_tool) + + def __getattribute__(self, name: str) -> Any: + """Lazily resolve agent_as_tool via weakref when strong ref is cleared.""" + if name == "agent_as_tool": + # Check if strong reference still exists + value = object.__getattribute__(self, "__dict__").get("agent_as_tool") + if value is not None: + return value + # Try to resolve via weakref + ref = object.__getattribute__(self, "_agent_as_tool_ref") + if ref is not None: + agent = ref() + if agent is not None: + return agent + return None + return super().__getattribute__(name) + + def release_agent(self) -> None: + """Release the strong reference to agent_as_tool while keeping a weak reference.""" + if "agent_as_tool" not in self.__dict__: + return + agent = self.__dict__.get("agent_as_tool") + if agent is not None: + self._agent_as_tool_ref = weakref.ref(agent) + # Set to None instead of deleting so dataclass repr/asdict keep working. + self.__dict__["agent_as_tool"] = None + def __repr__(self) -> str: """Custom repr that only includes relevant fields.""" parts = [f"type={self.type.value!r}"] if self.mcp_server is not None: parts.append(f"mcp_server_name={self.mcp_server.name!r}") - if self.agent_as_tool is not None: - parts.append(f"agent_as_tool_name={self.agent_as_tool.name!r}") + agent = self.agent_as_tool + if agent is not None: + parts.append(f"agent_as_tool_name={agent.name!r}") return f"ToolOrigin({', '.join(parts)})" diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py index 5800b87099..245f982491 100644 --- a/tests/test_tool_origin.py +++ b/tests/test_tool_origin.py @@ -2,7 +2,9 @@ from __future__ import annotations +import gc import sys +import weakref from typing import cast import pytest @@ -331,3 +333,54 @@ def func_tool() -> str: ) assert item.tool_origin is None + + +def test_tool_origin_release_agent_clears_strong_reference(): + """Test that release_agent() clears strong reference to agent_as_tool.""" + # Create a ToolOrigin with an agent_as_tool + nested_agent = Agent( + name="nested_agent", + model=FakeModel(), + instructions="You are a nested agent.", + ) + + tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool=nested_agent, + ) + + # Create a ToolCallItem with this tool_origin + tool_call_item = ToolCallItem( + raw_item=cast( + ToolCallItemTypes, + { + "type": "function_call", + "name": "test_tool", + "call_id": "call_123", + "arguments": "{}", + }, + ), + agent=nested_agent, + tool_origin=tool_origin, + ) + + # Verify agent_as_tool is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.agent_as_tool is nested_agent + + # Create weak reference to verify GC behavior + nested_agent_ref = weakref.ref(nested_agent) + + # Release agent - this should clear strong reference in tool_origin + tool_call_item.release_agent() + + # After release, agent_as_tool should still be accessible via weakref + assert tool_call_item.tool_origin.agent_as_tool is nested_agent + + # Delete the agent and force GC + del nested_agent + gc.collect() + + # After GC, agent_as_tool should be None since strong refs were cleared + assert nested_agent_ref() is None + assert tool_call_item.tool_origin.agent_as_tool is None From 5b2835b82a4a1732ab17fd363f25a119bd4724bc Mon Sep 17 00:00:00 2001 From: habema Date: Sat, 7 Feb 2026 21:46:39 +0300 Subject: [PATCH 3/3] address code review and add test --- src/agents/run_state.py | 86 ++++++++- tests/test_tool_origin_serialization.py | 228 ++++++++++++++++++++++++ 2 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 tests/test_tool_origin_serialization.py diff --git a/src/agents/run_state.py b/src/agents/run_state.py index d02d298140..08b23e2505 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -63,6 +63,8 @@ HostedMCPTool, LocalShellTool, ShellTool, + ToolOrigin, + ToolOriginType, ) from .tool_guardrails import ( AllowBehavior, @@ -635,6 +637,13 @@ def _serialize_item(self, item: RunItem) -> dict[str, Any]: result["tool_name"] = item.tool_name if hasattr(item, "description") and item.description is not None: result["description"] = item.description + if hasattr(item, "tool_origin") and item.tool_origin is not None: + tool_origin_data: dict[str, Any] = {"type": item.tool_origin.type.value} + if item.tool_origin.agent_as_tool is not None: + tool_origin_data["agent_as_tool"] = {"name": item.tool_origin.agent_as_tool.name} + if item.tool_origin.mcp_server is not None: + tool_origin_data["mcp_server"] = {"name": item.tool_origin.mcp_server.name} + result["tool_origin"] = tool_origin_data return result @@ -1918,6 +1927,67 @@ def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: return agent_map +def _deserialize_tool_origin( + tool_origin_data: dict[str, Any] | None, agent_map: dict[str, Agent[Any]], agent: Agent[Any] +) -> ToolOrigin | None: + """Deserialize ToolOrigin from JSON data. + + Args: + tool_origin_data: Serialized tool origin dictionary. + agent_map: Map of agent names to agent instances. + agent: The agent associated with this item (used for MCP server lookup). + + Returns: + ToolOrigin instance or None if data is missing/invalid. + """ + if not tool_origin_data: + return None + + origin_type_str = tool_origin_data.get("type") + if not origin_type_str: + return None + + try: + origin_type = ToolOriginType(origin_type_str) + except ValueError: + logger.warning(f"Unknown tool origin type: {origin_type_str}") + return None + + agent_as_tool: Agent[Any] | None = None + mcp_server: Any | None = None + + if origin_type == ToolOriginType.AGENT_AS_TOOL: + agent_data = tool_origin_data.get("agent_as_tool") + if agent_data and isinstance(agent_data, Mapping): + agent_name = agent_data.get("name") + if agent_name: + agent_as_tool = agent_map.get(agent_name) + if not agent_as_tool: + logger.warning(f"Agent {agent_name} not found in agent map for tool_origin") + + elif origin_type == ToolOriginType.MCP: + mcp_data = tool_origin_data.get("mcp_server") + if mcp_data and isinstance(mcp_data, Mapping): + server_name = mcp_data.get("name") + if server_name: + # Try to find the MCP server from the agent's mcp_servers list + mcp_servers = getattr(agent, "mcp_servers", []) + for server in mcp_servers: + if hasattr(server, "name") and server.name == server_name: + mcp_server = server + break + if not mcp_server: + logger.debug( + f"MCP server {server_name} not found in agent's mcp_servers for tool_origin" + ) + + return ToolOrigin( + type=origin_type, + agent_as_tool=agent_as_tool, + mcp_server=mcp_server, + ) + + def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[ModelResponse]: """Deserialize model responses from JSON data. @@ -2019,8 +2089,17 @@ def _resolve_agent_info( raw_item_tool = _deserialize_tool_call_raw_item(normalized_raw_item) # Preserve description if it was stored with the item description = item_data.get("description") + # Preserve tool_origin if it was stored with the item + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent + ) result.append( - ToolCallItem(agent=agent, raw_item=raw_item_tool, description=description) + ToolCallItem( + agent=agent, + raw_item=raw_item_tool, + description=description, + tool_origin=tool_origin, + ) ) elif item_type == "tool_call_output_item": @@ -2029,11 +2108,16 @@ def _resolve_agent_info( raw_item_output = _deserialize_tool_call_output_raw_item(normalized_raw_item) if raw_item_output is None: continue + # Preserve tool_origin if it was stored with the item + tool_origin = _deserialize_tool_origin( + item_data.get("tool_origin"), agent_map, agent + ) result.append( ToolCallOutputItem( agent=agent, raw_item=raw_item_output, output=item_data.get("output", ""), + tool_origin=tool_origin, ) ) diff --git a/tests/test_tool_origin_serialization.py b/tests/test_tool_origin_serialization.py new file mode 100644 index 0000000000..87bca9fcc4 --- /dev/null +++ b/tests/test_tool_origin_serialization.py @@ -0,0 +1,228 @@ +"""Tests for tool_origin serialization in RunState.""" + +from __future__ import annotations + +import sys + +import pytest + +from agents import Agent, Runner, function_tool +from agents.items import ToolCallItem, ToolCallOutputItem +from agents.run_state import RunState +from agents.tool import ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_serialize_tool_origin_function(): + """Test that FUNCTION tool_origin is serialized and deserialized.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.FUNCTION + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(agent, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.FUNCTION + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_serialize_tool_origin_agent_as_tool(): + """Test that AGENT_AS_TOOL tool_origin is serialized and deserialized.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_item.tool_origin.agent_as_tool is not None + assert tool_call_item.tool_origin.agent_as_tool.name == "nested_agent" + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_item.tool_origin.agent_as_tool is not None + assert tool_output_item.tool_origin.agent_as_tool.name == "nested_agent" + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=orchestrator, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(orchestrator, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert deserialized_tool_call.tool_origin.agent_as_tool is not None + assert deserialized_tool_call.tool_origin.agent_as_tool.name == "nested_agent" + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert deserialized_tool_output.tool_origin.agent_as_tool is not None + assert deserialized_tool_output.tool_origin.agent_as_tool.name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_serialize_tool_origin_mcp(): + """Test that MCP tool_origin is serialized and deserialized.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert len(tool_output_items) == 1 + + tool_call_item = tool_call_items[0] + tool_output_item = tool_output_items[0] + + # Verify tool_origin is set + assert tool_call_item.tool_origin is not None + assert tool_call_item.tool_origin.type == ToolOriginType.MCP + assert tool_call_item.tool_origin.mcp_server is not None + assert tool_call_item.tool_origin.mcp_server.name == "test_mcp_server" + assert tool_output_item.tool_origin is not None + assert tool_output_item.tool_origin.type == ToolOriginType.MCP + assert tool_output_item.tool_origin.mcp_server is not None + assert tool_output_item.tool_origin.mcp_server.name == "test_mcp_server" + + # Serialize and deserialize + context = result.context_wrapper + state = RunState( + context=context, + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._generated_items = [tool_call_item, tool_output_item] + + json_data = state.to_json() + deserialized_state = await RunState.from_json(agent, json_data) + + # Verify tool_origin was preserved + deserialized_tool_call = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallItem) + ) + deserialized_tool_output = next( + item for item in deserialized_state._generated_items if isinstance(item, ToolCallOutputItem) + ) + + assert deserialized_tool_call.tool_origin is not None + assert deserialized_tool_call.tool_origin.type == ToolOriginType.MCP + # MCP server should be reconstructed from agent's mcp_servers + assert deserialized_tool_call.tool_origin.mcp_server is not None + assert deserialized_tool_call.tool_origin.mcp_server.name == "test_mcp_server" + assert deserialized_tool_output.tool_origin is not None + assert deserialized_tool_output.tool_origin.type == ToolOriginType.MCP + assert deserialized_tool_output.tool_origin.mcp_server is not None + assert deserialized_tool_output.tool_origin.mcp_server.name == "test_mcp_server"