Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions astrbot/core/agent/run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class ContextWrapper(Generic[TContext]):
messages: list[Message] = Field(default_factory=list)
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
tool_call_timeout: int = 60 # Default tool call timeout in seconds
tool_call_approval: dict[str, Any] = Field(default_factory=dict)
"""Tool call approval runtime configuration."""


NoContext = ContextWrapper[None]
39 changes: 39 additions & 0 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
from ..run_context import ContextWrapper, TContext
from ..tool_call_approval import (
ToolCallApprovalContext,
request_tool_call_approval,
)
from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner

Expand Down Expand Up @@ -659,6 +663,41 @@ async def _handle_function_tools(
# 如果没有 handler(如 MCP 工具),使用所有参数
valid_params = func_tool_args

approval_cfg = self.run_context.tool_call_approval
if approval_cfg.get("enable", False):
event = getattr(self.run_context.context, "event", None)
if event is None:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=(
f"error: tool call approval is enabled, but event context is unavailable for `{func_tool_name}`."
),
),
)
continue
approval_result = await request_tool_call_approval(
config=approval_cfg,
ctx=ToolCallApprovalContext(
event=event,
tool_name=func_tool_name,
tool_args=valid_params,
tool_call_id=func_tool_id,
),
)
if not approval_result.approved:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=approval_result.to_tool_result_text(
func_tool_name
),
),
)
continue

try:
await self.agent_hooks.on_tool_start(
self.run_context,
Expand Down
248 changes: 248 additions & 0 deletions astrbot/core/agent/tool_call_approval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from __future__ import annotations

import secrets
import string
import typing as T
from abc import ABC, abstractmethod
from dataclasses import dataclass

from astrbot import logger
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.session_waiter import (
FILTERS,
DefaultSessionFilter,
SessionController,
SessionWaiter,
)

ApprovalReason = T.Literal[
"approved",
"rejected",
"timeout",
"unsupported_strategy",
"error",
]


@dataclass(slots=True)
class ToolCallApprovalContext:
event: AstrMessageEvent
tool_name: str
tool_args: dict[str, T.Any]
tool_call_id: str


@dataclass(slots=True)
class ToolCallApprovalResult:
approved: bool
reason: ApprovalReason
detail: str = ""

def to_tool_result_text(self, tool_name: str) -> str:
if self.approved:
return f"tool call approval passed: {tool_name}"
if self.reason == "timeout":
return (
f"error: tool call approval timed out for `{tool_name}`. "
"The tool call was cancelled."
)
if self.reason == "unsupported_strategy":
return (
f"error: tool call approval strategy is unsupported for `{tool_name}`. "
"The tool call was cancelled."
)
if self.reason == "error":
return (
f"error: tool call approval failed for `{tool_name}` ({self.detail}). "
"The tool call was cancelled."
)
return (
f"error: user rejected tool call approval for `{tool_name}`. "
"The tool call was cancelled."
)


class BaseToolCallApprovalStrategy(ABC):
@property
@abstractmethod
def name(self) -> str: ...

@abstractmethod
async def request(
self,
ctx: ToolCallApprovalContext,
config: dict[str, T.Any],
) -> ToolCallApprovalResult: ...


class DynamicCodeApprovalStrategy(BaseToolCallApprovalStrategy):
@property
def name(self) -> str:
return "dynamic_code"

async def request(
self,
ctx: ToolCallApprovalContext,
config: dict[str, T.Any],
) -> ToolCallApprovalResult:
timeout_seconds = _safe_int(config.get("timeout", 60), default=60, minimum=1)
dynamic_cfg = config.get("dynamic_code", {})
if not isinstance(dynamic_cfg, dict):
dynamic_cfg = {}
code_length = _safe_int(dynamic_cfg.get("code_length", 6), default=6, minimum=4)
case_sensitive = bool(dynamic_cfg.get("case_sensitive", False))

code = "".join(secrets.choice(string.digits) for _ in range(code_length))

await ctx.event.send(
MessageChain().message(
"Tool call needs your approval before execution.\n"
f"Tool: `{ctx.tool_name}`\n"
f"Approval code: `{code}`\n"
"Please send this code to continue. "
"Any other message will cancel this tool call."
)
)

try:
result = await _wait_for_code_input(
event=ctx.event,
expected_code=code,
timeout=timeout_seconds,
case_sensitive=case_sensitive,
)
except Exception as exc: # noqa: BLE001
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The except Exception as exc: # noqa: BLE001 is too broad. It's generally better to catch more specific exceptions to avoid masking unexpected errors. Consider catching more specific exceptions that _wait_for_code_input might raise, or at least asyncio.TimeoutError if that's the primary concern for register_wait.

logger.error(
"Tool call approval failed unexpectedly for %s: %s",
ctx.tool_name,
exc,
exc_info=True,
)
return ToolCallApprovalResult(
approved=False,
reason="error",
detail=str(exc),
)

if not result.approved:
if result.reason == "timeout":
await ctx.event.send(
MessageChain().message(
f"Tool call `{ctx.tool_name}` approval timed out. This call was cancelled."
)
)
else:
await ctx.event.send(
MessageChain().message(
f"Tool call `{ctx.tool_name}` was cancelled."
)
)
return result


_STRATEGY_REGISTRY: dict[str, BaseToolCallApprovalStrategy] = {}


def register_tool_call_approval_strategy(
strategy: BaseToolCallApprovalStrategy,
) -> None:
_STRATEGY_REGISTRY[strategy.name] = strategy


def _register_builtin_strategies() -> None:
register_tool_call_approval_strategy(DynamicCodeApprovalStrategy())


_register_builtin_strategies()


async def request_tool_call_approval(
*,
config: dict[str, T.Any] | None,
ctx: ToolCallApprovalContext,
) -> ToolCallApprovalResult:
if not config or not bool(config.get("enable", False)):
return ToolCallApprovalResult(approved=True, reason="approved")

strategy_name = (
str(config.get("strategy", "dynamic_code")).strip() or "dynamic_code"
)
strategy = _STRATEGY_REGISTRY.get(strategy_name)
if not strategy:
logger.warning("Unsupported tool call approval strategy: %s", strategy_name)
return ToolCallApprovalResult(
approved=False,
reason="unsupported_strategy",
detail=strategy_name,
)
return await strategy.request(ctx, config)


async def _wait_for_code_input(
*,
event: AstrMessageEvent,
expected_code: str,
timeout: int,
case_sensitive: bool,
) -> ToolCallApprovalResult:
session_filter = DefaultSessionFilter()
FILTERS.append(session_filter)
waiter = SessionWaiter(
session_filter=session_filter,
session_id=event.unified_msg_origin,
record_history_chains=False,
)

async def _handler(
controller: SessionController, incoming: AstrMessageEvent
) -> None:
raw_input = (incoming.message_str or "").strip()
if _is_code_match(
expected=expected_code,
actual=raw_input,
case_sensitive=case_sensitive,
):
if not controller.future.done():
controller.future.set_result(
ToolCallApprovalResult(approved=True, reason="approved"),
)
else:
if not controller.future.done():
controller.future.set_result(
ToolCallApprovalResult(
approved=False,
reason="rejected",
detail=raw_input,
)
)
controller.stop()
Comment on lines +197 to +219
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The tool call approval handler is vulnerable to Denial of Service (DoS) and Broken Access Control in group chat environments.

  1. Denial of Service: The handler automatically rejects the tool call if ANY message is received in the session that does not match the expected code. In a group chat, any message from any user will trigger this rejection, causing the tool call to be cancelled. This allows any member of the group to prevent the bot from executing tools.
  2. Broken Access Control: The mechanism does not verify the identity of the user providing the approval code. Since the code is sent to the group, any member can see it and send it back to approve the tool call, bypassing the intent of 'user approval'.

Remediation: Modify the _handler to verify that the sender of the message is the same user who initiated the tool call request. Messages from other users should be ignored.

    async def _handler(
        controller: SessionController, incoming: AstrMessageEvent
    ) -> None:
        if incoming.get_sender_id() != event.get_sender_id():
            return
        raw_input = (incoming.message_str or "").strip()
        if _is_code_match(
            expected=expected_code,
            actual=raw_input,
            case_sensitive=case_sensitive,
        ):
            if not controller.future.done():
                controller.future.set_result(
                    ToolCallApprovalResult(approved=True, reason="approved"),
                )
        else:
            if not controller.future.done():
                controller.future.set_result(
                    ToolCallApprovalResult(
                        approved=False,
                        reason="rejected",
                        detail=raw_input,
                    )
                )
        controller.stop()


try:
result = await waiter.register_wait(handler=_handler, timeout=timeout)
except TimeoutError:
return ToolCallApprovalResult(approved=False, reason="timeout")

if isinstance(result, ToolCallApprovalResult):
return result
return ToolCallApprovalResult(
approved=False,
reason="error",
detail=f"Invalid approval result type: {type(result).__name__}",
)


def _is_code_match(*, expected: str, actual: str, case_sensitive: bool) -> bool:
if case_sensitive:
return actual == expected
return actual.casefold() == expected.casefold()


def _safe_int(value: T.Any, *, default: int, minimum: int) -> int:
try:
parsed = int(value)
if parsed < minimum:
return minimum
return parsed
except Exception: # noqa: BLE001
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, catching a broad Exception can hide bugs. It's better to catch ValueError or TypeError if int() conversion is the only expected failure, or specify (ValueError, TypeError) if both are possible. If other exceptions are possible, they should be handled explicitly.

return default
3 changes: 3 additions & 0 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ class MainAgentBuildConfig:
timezone: str | None = None
max_quoted_fallback_images: int = 20
"""Maximum number of images injected from quoted-message fallback extraction."""
tool_call_approval: dict = field(default_factory=dict)
"""Tool call approval configuration."""


@dataclass(slots=True)
Expand Down Expand Up @@ -1118,6 +1120,7 @@ async def build_main_agent(
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=config.tool_call_timeout,
tool_call_approval=config.tool_call_approval,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
Expand Down
Loading