diff --git a/astrbot/core/agent/run_context.py b/astrbot/core/agent/run_context.py index 687ad22e5..c206f7e83 100644 --- a/astrbot/core/agent/run_context.py +++ b/astrbot/core/agent/run_context.py @@ -17,6 +17,8 @@ class ContextWrapper(Generic[TContext]): messages: list[Message] = Field(default_factory=list) """This field stores the llm message context for the agent run, agent runners will maintain this field automatically.""" tool_call_timeout: int = 60 # Default tool call timeout in seconds + tool_call_approval: dict[str, Any] = Field(default_factory=dict) + """Tool call approval runtime configuration.""" NoContext = ContextWrapper[None] diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 10cf2e96c..29f383bf0 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -37,6 +37,10 @@ from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment from ..response import AgentResponseData, AgentStats from ..run_context import ContextWrapper, TContext +from ..tool_call_approval import ( + ToolCallApprovalContext, + request_tool_call_approval, +) from ..tool_executor import BaseFunctionToolExecutor from .base import AgentResponse, AgentState, BaseAgentRunner @@ -659,6 +663,41 @@ async def _handle_function_tools( # 如果没有 handler(如 MCP 工具),使用所有参数 valid_params = func_tool_args + approval_cfg = self.run_context.tool_call_approval + if approval_cfg.get("enable", False): + event = getattr(self.run_context.context, "event", None) + if event is None: + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=( + f"error: tool call approval is enabled, but event context is unavailable for `{func_tool_name}`." + ), + ), + ) + continue + approval_result = await request_tool_call_approval( + config=approval_cfg, + ctx=ToolCallApprovalContext( + event=event, + tool_name=func_tool_name, + tool_args=valid_params, + tool_call_id=func_tool_id, + ), + ) + if not approval_result.approved: + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=approval_result.to_tool_result_text( + func_tool_name + ), + ), + ) + continue + try: await self.agent_hooks.on_tool_start( self.run_context, diff --git a/astrbot/core/agent/tool_call_approval.py b/astrbot/core/agent/tool_call_approval.py new file mode 100644 index 000000000..cc8ab502f --- /dev/null +++ b/astrbot/core/agent/tool_call_approval.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import secrets +import string +import typing as T +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from astrbot import logger +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.utils.session_waiter import ( + FILTERS, + DefaultSessionFilter, + SessionController, + SessionWaiter, +) + +ApprovalReason = T.Literal[ + "approved", + "rejected", + "timeout", + "unsupported_strategy", + "error", +] + + +@dataclass(slots=True) +class ToolCallApprovalContext: + event: AstrMessageEvent + tool_name: str + tool_args: dict[str, T.Any] + tool_call_id: str + + +@dataclass(slots=True) +class ToolCallApprovalResult: + approved: bool + reason: ApprovalReason + detail: str = "" + + def to_tool_result_text(self, tool_name: str) -> str: + if self.approved: + return f"tool call approval passed: {tool_name}" + if self.reason == "timeout": + return ( + f"error: tool call approval timed out for `{tool_name}`. " + "The tool call was cancelled." + ) + if self.reason == "unsupported_strategy": + return ( + f"error: tool call approval strategy is unsupported for `{tool_name}`. " + "The tool call was cancelled." + ) + if self.reason == "error": + return ( + f"error: tool call approval failed for `{tool_name}` ({self.detail}). " + "The tool call was cancelled." + ) + return ( + f"error: user rejected tool call approval for `{tool_name}`. " + "The tool call was cancelled." + ) + + +class BaseToolCallApprovalStrategy(ABC): + @property + @abstractmethod + def name(self) -> str: ... + + @abstractmethod + async def request( + self, + ctx: ToolCallApprovalContext, + config: dict[str, T.Any], + ) -> ToolCallApprovalResult: ... + + +class DynamicCodeApprovalStrategy(BaseToolCallApprovalStrategy): + @property + def name(self) -> str: + return "dynamic_code" + + async def request( + self, + ctx: ToolCallApprovalContext, + config: dict[str, T.Any], + ) -> ToolCallApprovalResult: + timeout_seconds = _safe_int(config.get("timeout", 60), default=60, minimum=1) + dynamic_cfg = config.get("dynamic_code", {}) + if not isinstance(dynamic_cfg, dict): + dynamic_cfg = {} + code_length = _safe_int(dynamic_cfg.get("code_length", 6), default=6, minimum=4) + case_sensitive = bool(dynamic_cfg.get("case_sensitive", False)) + + code = "".join(secrets.choice(string.digits) for _ in range(code_length)) + + await ctx.event.send( + MessageChain().message( + "Tool call needs your approval before execution.\n" + f"Tool: `{ctx.tool_name}`\n" + f"Approval code: `{code}`\n" + "Please send this code to continue. " + "Any other message will cancel this tool call." + ) + ) + + try: + result = await _wait_for_code_input( + event=ctx.event, + expected_code=code, + timeout=timeout_seconds, + case_sensitive=case_sensitive, + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Tool call approval failed unexpectedly for %s: %s", + ctx.tool_name, + exc, + exc_info=True, + ) + return ToolCallApprovalResult( + approved=False, + reason="error", + detail=str(exc), + ) + + if not result.approved: + if result.reason == "timeout": + await ctx.event.send( + MessageChain().message( + f"Tool call `{ctx.tool_name}` approval timed out. This call was cancelled." + ) + ) + else: + await ctx.event.send( + MessageChain().message( + f"Tool call `{ctx.tool_name}` was cancelled." + ) + ) + return result + + +_STRATEGY_REGISTRY: dict[str, BaseToolCallApprovalStrategy] = {} + + +def register_tool_call_approval_strategy( + strategy: BaseToolCallApprovalStrategy, +) -> None: + _STRATEGY_REGISTRY[strategy.name] = strategy + + +def _register_builtin_strategies() -> None: + register_tool_call_approval_strategy(DynamicCodeApprovalStrategy()) + + +_register_builtin_strategies() + + +async def request_tool_call_approval( + *, + config: dict[str, T.Any] | None, + ctx: ToolCallApprovalContext, +) -> ToolCallApprovalResult: + if not config or not bool(config.get("enable", False)): + return ToolCallApprovalResult(approved=True, reason="approved") + + strategy_name = ( + str(config.get("strategy", "dynamic_code")).strip() or "dynamic_code" + ) + strategy = _STRATEGY_REGISTRY.get(strategy_name) + if not strategy: + logger.warning("Unsupported tool call approval strategy: %s", strategy_name) + return ToolCallApprovalResult( + approved=False, + reason="unsupported_strategy", + detail=strategy_name, + ) + return await strategy.request(ctx, config) + + +async def _wait_for_code_input( + *, + event: AstrMessageEvent, + expected_code: str, + timeout: int, + case_sensitive: bool, +) -> ToolCallApprovalResult: + session_filter = DefaultSessionFilter() + FILTERS.append(session_filter) + waiter = SessionWaiter( + session_filter=session_filter, + session_id=event.unified_msg_origin, + record_history_chains=False, + ) + + async def _handler( + controller: SessionController, incoming: AstrMessageEvent + ) -> None: + raw_input = (incoming.message_str or "").strip() + if _is_code_match( + expected=expected_code, + actual=raw_input, + case_sensitive=case_sensitive, + ): + if not controller.future.done(): + controller.future.set_result( + ToolCallApprovalResult(approved=True, reason="approved"), + ) + else: + if not controller.future.done(): + controller.future.set_result( + ToolCallApprovalResult( + approved=False, + reason="rejected", + detail=raw_input, + ) + ) + controller.stop() + + try: + result = await waiter.register_wait(handler=_handler, timeout=timeout) + except TimeoutError: + return ToolCallApprovalResult(approved=False, reason="timeout") + + if isinstance(result, ToolCallApprovalResult): + return result + return ToolCallApprovalResult( + approved=False, + reason="error", + detail=f"Invalid approval result type: {type(result).__name__}", + ) + + +def _is_code_match(*, expected: str, actual: str, case_sensitive: bool) -> bool: + if case_sensitive: + return actual == expected + return actual.casefold() == expected.casefold() + + +def _safe_int(value: T.Any, *, default: int, minimum: int) -> int: + try: + parsed = int(value) + if parsed < minimum: + return minimum + return parsed + except Exception: # noqa: BLE001 + return default diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 7883dca8f..0ae4e4594 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -121,6 +121,8 @@ class MainAgentBuildConfig: timezone: str | None = None max_quoted_fallback_images: int = 20 """Maximum number of images injected from quoted-message fallback extraction.""" + tool_call_approval: dict = field(default_factory=dict) + """Tool call approval configuration.""" @dataclass(slots=True) @@ -1118,6 +1120,7 @@ async def build_main_agent( run_context=AgentContextWrapper( context=astr_agent_ctx, tool_call_timeout=config.tool_call_timeout, + tool_call_approval=config.tool_call_approval, ), tool_executor=FunctionToolExecutor(), agent_hooks=MAIN_AGENT_HOOKS, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index a151e7cfd..c0c2a9287 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -117,6 +117,15 @@ "max_agent_step": 30, "tool_call_timeout": 60, "tool_schema_mode": "full", + "tool_call_approval": { + "enable": False, + "strategy": "dynamic_code", + "timeout": 60, + "dynamic_code": { + "code_length": 6, + "case_sensitive": False, + }, + }, "llm_safety_mode": True, "safety_mode_strategy": "system_prompt", # TODO: llm judge "file_extract": { @@ -2330,6 +2339,31 @@ class ChatProviderTemplate(TypedDict): "tool_schema_mode": { "type": "string", }, + "tool_call_approval": { + "type": "object", + "items": { + "enable": { + "type": "bool", + }, + "strategy": { + "type": "string", + }, + "timeout": { + "type": "int", + }, + "dynamic_code": { + "type": "object", + "items": { + "code_length": { + "type": "int", + }, + "case_sensitive": { + "type": "bool", + }, + }, + }, + }, + }, "file_extract": { "type": "object", "items": { @@ -3066,6 +3100,50 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.tool_call_approval.enable": { + "description": "启用工具调用确认", + "type": "bool", + "hint": "开启后,工具调用需要用户确认后才会执行。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.tool_call_approval.strategy": { + "description": "工具调用确认策略", + "type": "string", + "options": ["dynamic_code"], + "labels": ["Dynamic Code(动态码)"], + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.tool_call_approval.enable": True, + }, + }, + "provider_settings.tool_call_approval.timeout": { + "description": "工具调用确认超时(秒)", + "type": "int", + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.tool_call_approval.enable": True, + }, + }, + "provider_settings.tool_call_approval.dynamic_code.code_length": { + "description": "动态确认码长度", + "type": "int", + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.tool_call_approval.enable": True, + "provider_settings.tool_call_approval.strategy": "dynamic_code", + }, + }, + "provider_settings.tool_call_approval.dynamic_code.case_sensitive": { + "description": "动态确认码区分大小写", + "type": "bool", + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.tool_call_approval.enable": True, + "provider_settings.tool_call_approval.strategy": "dynamic_code", + }, + }, "provider_settings.wake_prefix": { "description": "LLM 聊天额外唤醒前缀 ", "type": "string", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 33908fa98..cb53d8d14 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -44,6 +44,7 @@ async def initialize(self, ctx: PipelineContext) -> None: ] self.max_step: int = settings.get("max_agent_step", 30) self.tool_call_timeout: int = settings.get("tool_call_timeout", 60) + self.tool_call_approval: dict = settings.get("tool_call_approval", {}) self.tool_schema_mode: str = settings.get("tool_schema_mode", "full") if self.tool_schema_mode not in ("skills_like", "full"): logger.warning( @@ -124,6 +125,7 @@ async def initialize(self, ctx: PipelineContext) -> None: subagent_orchestrator=conf.get("subagent_orchestrator", {}), timezone=self.ctx.plugin_manager.context.get_config().get("timezone"), max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20), + tool_call_approval=self.tool_call_approval, ) async def process( diff --git a/tests/test_tool_call_approval.py b/tests/test_tool_call_approval.py new file mode 100644 index 000000000..ea603b433 --- /dev/null +++ b/tests/test_tool_call_approval.py @@ -0,0 +1,110 @@ +import os +import sys + +import pytest + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from astrbot.core.agent.tool_call_approval import ( + ToolCallApprovalContext, + ToolCallApprovalResult, + request_tool_call_approval, +) + + +class DummyEvent: + def __init__(self) -> None: + self.unified_msg_origin = "test:friend:test_user" + self.sent_messages = [] + self.message_str = "" + + async def send(self, message) -> None: + self.sent_messages.append(message) + + +@pytest.mark.asyncio +async def test_request_tool_call_approval_disabled_returns_approved(): + event = DummyEvent() + result = await request_tool_call_approval( + config={"enable": False}, + ctx=ToolCallApprovalContext( + event=event, + tool_name="test_tool", + tool_args={}, + tool_call_id="call_1", + ), + ) + assert result.approved is True + assert result.reason == "approved" + assert len(event.sent_messages) == 0 + + +@pytest.mark.asyncio +async def test_dynamic_code_approval_passed(monkeypatch): + async def _mock_wait(*args, **kwargs): + return ToolCallApprovalResult(approved=True, reason="approved") + + monkeypatch.setattr( + "astrbot.core.agent.tool_call_approval._wait_for_code_input", + _mock_wait, + ) + + event = DummyEvent() + result = await request_tool_call_approval( + config={"enable": True, "strategy": "dynamic_code"}, + ctx=ToolCallApprovalContext( + event=event, + tool_name="test_tool", + tool_args={"query": "hello"}, + tool_call_id="call_2", + ), + ) + assert result.approved is True + assert result.reason == "approved" + assert len(event.sent_messages) == 1 + + +@pytest.mark.asyncio +async def test_dynamic_code_approval_rejected(monkeypatch): + async def _mock_wait(*args, **kwargs): + return ToolCallApprovalResult( + approved=False, + reason="rejected", + detail="not_code", + ) + + monkeypatch.setattr( + "astrbot.core.agent.tool_call_approval._wait_for_code_input", + _mock_wait, + ) + + event = DummyEvent() + result = await request_tool_call_approval( + config={"enable": True, "strategy": "dynamic_code"}, + ctx=ToolCallApprovalContext( + event=event, + tool_name="test_tool", + tool_args={}, + tool_call_id="call_3", + ), + ) + assert result.approved is False + assert result.reason == "rejected" + assert len(event.sent_messages) == 2 + + +@pytest.mark.asyncio +async def test_request_tool_call_approval_unsupported_strategy(): + event = DummyEvent() + result = await request_tool_call_approval( + config={"enable": True, "strategy": "unknown_strategy"}, + ctx=ToolCallApprovalContext( + event=event, + tool_name="test_tool", + tool_args={}, + tool_call_id="call_4", + ), + ) + assert result.approved is False + assert result.reason == "unsupported_strategy" + assert len(event.sent_messages) == 0