Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion python/packages/claude/agent_framework_claude/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from __future__ import annotations

import contextlib
import inspect
import logging
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload

Expand Down Expand Up @@ -73,6 +74,54 @@
TOOLS_MCP_SERVER_NAME = "_agent_framework_tools"


FunctionApprovalCallback = Callable[[Content], "bool | Awaitable[bool]"]
"""Callback invoked by the agent before executing a FunctionTool that requires approval.

The callback receives a ``FunctionCallContent`` describing the pending call
(``name``, ``arguments``, and a synthetic ``call_id``) and must return ``True``
to allow execution or ``False`` to deny it. Both synchronous and ``await``-able
return values are supported.

The Claude Agent SDK manages its own tool-calling loop, so the framework cannot
round-trip a ``FunctionApprovalRequestContent`` / ``FunctionApprovalResponseContent``
pair the way the standard chat-client pipeline does. This callback is the
agent-level enforcement point for tools declared with
``approval_mode="always_require"``: when no callback is configured the agent
denies these calls by default.
"""


async def _resolve_function_approval(
callback: FunctionApprovalCallback | None,
func_tool: FunctionTool,
arguments: Mapping[str, Any] | None,
) -> bool:
"""Run the agent-level approval callback for a pending tool call.

Returns ``True`` only when ``callback`` is configured and explicitly returns
a truthy value. A missing callback or any callback failure is treated as a
denial so the secure-by-default policy holds even if the user code raises.
"""
if callback is None:
return False
request = Content.from_function_call(
call_id=f"af-claude-approval::{func_tool.name}",
name=func_tool.name,
arguments=None if arguments is None else dict(arguments),
)
try:
outcome = callback(request)
if inspect.isawaitable(outcome):
outcome = await outcome
except Exception:
logger.exception(
"on_function_approval callback raised for tool '%s'; denying execution.",
func_tool.name,
)
return False
return bool(outcome)


class ClaudeAgentSettings(TypedDict, total=False):
"""Claude Agent settings.

Expand Down Expand Up @@ -175,6 +224,13 @@ class ClaudeAgentOptions(TypedDict, total=False):
effort: Literal["low", "medium", "high", "max"]
"""Effort level for thinking depth."""

on_function_approval: FunctionApprovalCallback
"""Approval callback for ``FunctionTool`` instances declared with
``approval_mode="always_require"``. The callback is awaited (sync or async)
inside the SDK tool-handler before the tool is executed; a falsy return
value denies the call. If omitted, calls to such tools are denied with an
explanatory message returned to the model."""


OptionsT = TypeVar(
"OptionsT",
Expand Down Expand Up @@ -275,6 +331,7 @@ def __init__(
max_turns = opts.pop("max_turns", None)
max_budget_usd = opts.pop("max_budget_usd", None)
self._mcp_servers: dict[str, Any] = opts.pop("mcp_servers", None) or {}
self._function_approval_handler: FunctionApprovalCallback | None = opts.pop("on_function_approval", None)
Comment thread
eavanvalkenburg marked this conversation as resolved.

# Load settings from environment and options
self._settings = load_settings(
Expand Down Expand Up @@ -487,10 +544,29 @@ def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool) -> SdkMcpTool[
Returns:
An SdkMcpTool instance.
"""
approval_handler = self._function_approval_handler
requires_approval = func_tool.approval_mode == "always_require"

async def handler(args: dict[str, Any]) -> dict[str, Any]:
"""Handler that invokes the FunctionTool."""
try:
if requires_approval and not await _resolve_function_approval(approval_handler, func_tool, args):
deny_text = (
f"Tool '{func_tool.name}' requires human approval "
"(approval_mode='always_require') and the request was denied."
if approval_handler is not None
else (
f"Tool '{func_tool.name}' requires human approval "
"(approval_mode='always_require') but no on_function_approval "
"callback is configured on the agent; the request was denied."
)
)
logger.info(
"Denying execution of tool '%s' (approval_mode='always_require', %s)",
func_tool.name,
"callback denied" if approval_handler is not None else "no callback configured",
)
return {"content": [{"type": "text", "text": deny_text}]}
if func_tool.input_model:
args_instance = func_tool.input_model(**args)
result = await func_tool.invoke(arguments=args_instance)
Expand Down Expand Up @@ -538,6 +614,13 @@ async def _apply_runtime_options(self, options: dict[str, Any] | None) -> None:
if not options or not self._client:
return

if "on_function_approval" in options:
raise ValueError(
"on_function_approval is a security-sensitive option and must be set "
"via default_options at agent construction time. It cannot be overridden "
"per run."
)

if "model" in options:
await self._client.set_model(options["model"])

Expand Down
149 changes: 149 additions & 0 deletions python/packages/claude/tests/test_claude_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,141 @@ def failing_tool() -> str:
assert "Something went wrong" in result["content"][0]["text"]


# region Test ClaudeAgent Function Approval Enforcement


class TestClaudeAgentFunctionApproval:
"""Tests that ``approval_mode='always_require'`` is enforced at the agent boundary."""

async def test_handler_denies_when_no_callback_configured(self) -> None:
"""Approval-required tool must be denied without executing when no callback is set."""
invocations: list[Any] = []

@tool(approval_mode="always_require")
def dangerous(path: str) -> str:
"""A tool that requires human approval."""
invocations.append(path)
return f"deleted {path}"

agent = ClaudeAgent()
sdk_tool = agent._function_tool_to_sdk_mcp_tool(dangerous) # type: ignore[reportPrivateUsage]

result = await sdk_tool.handler({"path": "/critical"})

assert invocations == []
text = result["content"][0]["text"]
assert "requires human approval" in text
assert "no on_function_approval callback is configured" in text

async def test_handler_denies_when_callback_returns_false(self) -> None:
"""Falsy callback return value must deny the call and skip execution."""
invocations: list[Any] = []
seen: list[Content] = []

def deny(call: Content) -> bool:
seen.append(call)
return False

@tool(approval_mode="always_require")
def dangerous(path: str) -> str:
"""A tool that requires human approval."""
invocations.append(path)
return f"deleted {path}"

agent = ClaudeAgent(default_options={"on_function_approval": deny})
sdk_tool = agent._function_tool_to_sdk_mcp_tool(dangerous) # type: ignore[reportPrivateUsage]

result = await sdk_tool.handler({"path": "/critical"})

assert invocations == []
assert len(seen) == 1
assert seen[0].type == "function_call"
assert seen[0].name == "dangerous" # type: ignore[attr-defined]
assert seen[0].arguments == {"path": "/critical"} # type: ignore[attr-defined]
assert "denied" in result["content"][0]["text"].lower()

async def test_handler_executes_when_callback_returns_true(self) -> None:
"""Truthy callback return value must allow the tool to execute normally."""

def approve(call: Content) -> bool:
return True

@tool(approval_mode="always_require")
def guarded(x: int) -> str:
"""A tool that requires human approval."""
return f"result={x}"

agent = ClaudeAgent(default_options={"on_function_approval": approve})
sdk_tool = agent._function_tool_to_sdk_mcp_tool(guarded) # type: ignore[reportPrivateUsage]

result = await sdk_tool.handler({"x": 42})

assert result["content"][0]["text"] == "result=42"

async def test_handler_supports_async_callback(self) -> None:
"""Async callback must be awaited and respected."""

async def approve(call: Content) -> bool:
return True

@tool(approval_mode="always_require")
def guarded(x: int) -> str:
"""A tool that requires human approval."""
return f"async={x}"

agent = ClaudeAgent(default_options={"on_function_approval": approve})
sdk_tool = agent._function_tool_to_sdk_mcp_tool(guarded) # type: ignore[reportPrivateUsage]

result = await sdk_tool.handler({"x": 7})

assert result["content"][0]["text"] == "async=7"

async def test_callback_failure_denies_safely(self) -> None:
"""A callback that raises must result in denial, not in tool execution."""
invocations: list[Any] = []

def boom(call: Content) -> bool:
raise RuntimeError("nope")

@tool(approval_mode="always_require")
def dangerous(x: int) -> str:
"""A tool that requires human approval."""
invocations.append(x)
return f"x={x}"

agent = ClaudeAgent(default_options={"on_function_approval": boom})
sdk_tool = agent._function_tool_to_sdk_mcp_tool(dangerous) # type: ignore[reportPrivateUsage]

result = await sdk_tool.handler({"x": 1})

assert invocations == []
assert "denied" in result["content"][0]["text"].lower()

async def test_handler_does_not_invoke_callback_for_never_require(self) -> None:
"""Tools without approval_mode='always_require' must not trigger the callback."""
callback_calls: list[Any] = []

def approve(call: Content) -> bool:
callback_calls.append(call)
return True

@tool
def safe(x: int) -> str:
"""A tool that does not require approval."""
return f"safe={x}"

agent = ClaudeAgent(default_options={"on_function_approval": approve})
sdk_tool = agent._function_tool_to_sdk_mcp_tool(safe) # type: ignore[reportPrivateUsage]

result = await sdk_tool.handler({"x": 5})

assert callback_calls == []
assert result["content"][0]["text"] == "safe=5"


# endregion


# region Test ClaudeAgent Permissions


Expand Down Expand Up @@ -786,6 +921,20 @@ async def test_apply_runtime_options_none(self) -> None:
mock_client.set_model.assert_not_called()
mock_client.set_permission_mode.assert_not_called()

async def test_apply_runtime_on_function_approval_rejected(self) -> None:
"""on_function_approval cannot be overridden per run."""
mock_client = MagicMock()
mock_client.set_model = AsyncMock()
mock_client.set_permission_mode = AsyncMock()

agent = ClaudeAgent()
agent._client = mock_client # type: ignore[reportPrivateUsage]

with pytest.raises(ValueError, match="on_function_approval"):
await agent._apply_runtime_options({"on_function_approval": lambda _c: True}) # type: ignore[reportPrivateUsage]
mock_client.set_model.assert_not_called()
mock_client.set_permission_mode.assert_not_called()


# region Test ClaudeAgent Structured Output

Expand Down
Loading
Loading