From 409c7adfe36aa6b880c9a49addd24d09f6f8c796 Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Tue, 23 Jun 2026 16:03:12 -0700 Subject: [PATCH] fix: raise a clear error when on_invoke_tool gets a non-context --- src/agents/tool.py | 11 +++++++++++ tests/test_function_tool.py | 21 +++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/agents/tool.py b/src/agents/tool.py index c8563e2e1b..dbe7e4d6c7 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -552,6 +552,17 @@ def __agents_bind_function_tool__( return bound_invoker async def __call__(self, ctx: ToolContext[Any], input: str) -> Any: + # Validate the context up front. A non-context value (most commonly None) would + # otherwise blow up deep in the invocation with a cryptic AttributeError, so fail + # fast here with an actionable message before any tool logic runs. The base + # RunContextWrapper is accepted because agent-as-tool invokers run with one; the + # message points to ToolContext since that is what function tools need. + if not isinstance(ctx, RunContextWrapper): + raise TypeError( + f"on_invoke_tool requires a ToolContext, got {type(ctx).__name__}. " + "Construct one with ToolContext(context=..., tool_name=..., " + "tool_call_id=..., tool_arguments=...) or invoke the tool through Runner." + ) try: return await self._invoke_tool_impl(ctx, input) except Exception as e: diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 60ae2558cc..116e2bafc7 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -155,6 +155,27 @@ async def test_simple_function(): ) +@pytest.mark.asyncio +async def test_on_invoke_tool_rejects_non_tool_context(): + tool = function_tool(simple_function) + + # A non-ToolContext (most commonly None) should fail fast with a clear TypeError + # instead of a cryptic AttributeError raised deep inside the invocation path. + with pytest.raises(TypeError, match="on_invoke_tool requires a ToolContext, got NoneType"): + await tool.on_invoke_tool(cast(Any, None), '{"a": 1, "b": 2}') + + # The error message names the offending type to help developers spot the mistake. + with pytest.raises(TypeError, match="got int"): + await tool.on_invoke_tool(cast(Any, 123), '{"a": 1, "b": 2}') + + # A valid ToolContext is unaffected by the guard. + result = await tool.on_invoke_tool( + ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'), + '{"a": 1, "b": 2}', + ) + assert result == 3 + + @pytest.mark.asyncio async def test_sync_function_runs_via_to_thread(monkeypatch: pytest.MonkeyPatch) -> None: calls = {"to_thread": 0, "func": 0}