Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 69 additions & 25 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def _build_local_mode_prompt() -> str:
def _filter_skills_for_current_config(
skills: list[SkillInfo],
cfg: dict,
session_disabled: set[str] | None = None,
) -> list[SkillInfo]:
plugin_set = cfg.get("plugin_set", ["*"])
allowed_plugins = (
Expand All @@ -404,7 +405,12 @@ def _filter_skills_for_current_config(
plugin = plugin_by_root_dir.get(skill.plugin_name)
if not plugin or not plugin.activated:
continue
if plugin.reserved or allowed_plugins is None:
if plugin.reserved:
filtered.append(skill)
continue
if session_disabled and plugin.name in session_disabled:
continue
if allowed_plugins is None:
filtered.append(skill)
continue
if plugin.name is not None and plugin.name in allowed_plugins:
Expand All @@ -422,6 +428,19 @@ async def _ensure_persona_and_skills(
if not req.conversation:
return

from astrbot.core import sp

session_id = event.unified_msg_origin
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_disabled = set(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting shared helpers for session-disabled plugins and plugin enablement logic to remove duplication and keep the filtering rules in one place.

You can centralize the “session plugin disabling” concern and simplify the branching without changing behavior.

1. Extract session-level disabled plugins helper

Right now this logic is duplicated in _ensure_persona_and_skills and _plugin_tool_fix (and apparently in context_utils.call_event_hook).

Create a small helper and reuse it:

# somewhere shared, e.g. astrbot/core/session_plugins.py
from astrbot.core import sp

async def get_session_disabled_plugins(session_id: str) -> set[str]:
    session_plugin_config = await sp.get_async(
        scope="umo",
        scope_id=session_id,
        key="session_plugin_config",
        default={},
    )
    return set(session_plugin_config.get(session_id, {}).get("disabled_plugins", []))

Then in _ensure_persona_and_skills:

from astrbot.core.session_plugins import get_session_disabled_plugins

async def _ensure_persona_and_skills(...):
    ...
    session_id = event.unified_msg_origin
    session_disabled = await get_session_disabled_plugins(session_id)
    ...
    skills = _filter_skills_for_current_config(skills, cfg, session_disabled)

And in _plugin_tool_fix:

from astrbot.core.session_plugins import get_session_disabled_plugins

async def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
    if not req.func_tool:
        return

    session_id = event.unified_msg_origin
    session_disabled = await get_session_disabled_plugins(session_id)
    global_whitelist = event.plugins_name
    ...

You can mirror this in context_utils.call_event_hook as well.

2. Centralize plugin enablement rules

The rules (reserved → always allow, global whitelist, session disabled) are currently embedded separately in _filter_skills_for_current_config and _plugin_tool_fix. A tiny predicate keeps this linear and shared:

def is_plugin_enabled(
    plugin,
    *,
    global_whitelist: set[str] | None,
    session_disabled: set[str] | None,
) -> bool:
    if plugin.reserved:
        return True
    if session_disabled and plugin.name in session_disabled:
        return False
    if global_whitelist is None:
        return True
    return plugin.name in global_whitelist

Use it in _filter_skills_for_current_config:

def _filter_skills_for_current_config(
    skills: list[SkillInfo],
    cfg: dict,
    session_disabled: set[str] | None = None,
) -> list[SkillInfo]:
    plugin_set = cfg.get("plugin_set", ["*"])
    allowed_plugins = (
        None
        if not isinstance(plugin_set, list) or "*" in plugin_set
        else {str(name) for name in plugin_set}
    )

    plugin_by_root_dir = {
        metadata.root_dir_name: metadata
        for metadata in star_registry
    }

    filtered: list[SkillInfo] = []
    for skill in skills:
        if skill.source_type != "plugin":
            filtered.append(skill)
            continue

        plugin = plugin_by_root_dir.get(skill.plugin_name)
        if not plugin or not plugin.activated:
            continue

        if is_plugin_enabled(
            plugin,
            global_whitelist=allowed_plugins,
            session_disabled=session_disabled,
        ):
            filtered.append(skill)

    return filtered

And in _plugin_tool_fix:

async def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
    if not req.func_tool:
        return

    session_id = event.unified_msg_origin
    session_disabled = await get_session_disabled_plugins(session_id)
    global_whitelist = set(event.plugins_name) if event.plugins_name is not None else None

    new_tool_set = ToolSet()
    for tool in req.func_tool.tools:
        if isinstance(tool, MCPTool):
            new_tool_set.add_tool(tool)
            continue

        mp = tool.handler_module_path
        if not mp:
            new_tool_set.add_tool(tool)
            continue

        plugin = star_map.get(mp)
        if not plugin:
            new_tool_set.add_tool(tool)
            continue

        if is_plugin_enabled(
            plugin,
            global_whitelist=global_whitelist,
            session_disabled=session_disabled,
        ):
            new_tool_set.add_tool(tool)

    req.func_tool = new_tool_set

This keeps all current behavior (reserved always allowed, unknown plugin/tool preserved, MCP tools preserved, session and global filters respected) but removes the duplicated session config access and scattered condition sets, which should make future changes to the enablement rules safer and easier to reason about.

session_plugin_config.get(session_id, {}).get("disabled_plugins", [])
)

(
persona_id,
persona,
Expand Down Expand Up @@ -454,7 +473,7 @@ async def _ensure_persona_and_skills(
runtime = cfg.get("computer_use_runtime", "local")
skill_manager = SkillManager()
skills = skill_manager.list_skills(active_only=True, runtime=runtime)
skills = _filter_skills_for_current_config(skills, cfg)
skills = _filter_skills_for_current_config(skills, cfg, session_disabled)

if skills:
if persona and persona.get("skills") is not None:
Expand Down Expand Up @@ -898,33 +917,58 @@ async def _decorate_llm_request(
_apply_workspace_extra_prompt(event, req)


def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
async def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
"""根据事件中的插件设置,过滤请求中的工具列表。

注意:没有 handler_module_path 的工具(如 MCP 工具)会被保留,
因为它们不属于任何插件,不应被插件过滤逻辑影响。
"""
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
if isinstance(tool, MCPTool):
# 保留 MCP 工具
new_tool_set.add_tool(tool)
continue
mp = tool.handler_module_path
if not mp:
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*)
# 不应受到会话插件过滤影响。
new_tool_set.add_tool(tool)
continue
plugin = star_map.get(mp)
if not plugin:
# 无法解析插件归属时,保守保留工具,避免误过滤。
new_tool_set.add_tool(tool)
continue
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
if not req.func_tool:
return

from astrbot.core import sp

session_id = event.unified_msg_origin
session_plugin_config = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_plugin_config",
default={},
)
session_disabled = set(
session_plugin_config.get(session_id, {}).get("disabled_plugins", [])
)
Comment on lines +931 to +940
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for fetching session-level disabled plugins is identical to the one in _ensure_persona_and_skills (lines 433-442). Since both functions are called sequentially within build_main_agent, this results in redundant database queries. Consider fetching this configuration once in build_main_agent and passing it as an argument, or caching it within the event object to optimize performance.

References
  1. When implementing similar functionality for different cases, refactor the logic into a shared helper function to avoid code duplication.


global_whitelist = event.plugins_name # None 表示全部允许

new_tool_set = ToolSet()
for tool in req.func_tool.tools:
if isinstance(tool, MCPTool):
# 保留 MCP 工具
new_tool_set.add_tool(tool)
continue
mp = tool.handler_module_path
if not mp:
# 没有 plugin 归属信息的工具(如 subagent transfer_to_*)
# 不应受到会话插件过滤影响。
new_tool_set.add_tool(tool)
continue
plugin = star_map.get(mp)
if not plugin:
# 无法解析插件归属时,保守保留工具,避免误过滤。
new_tool_set.add_tool(tool)
continue
if plugin.reserved:
new_tool_set.add_tool(tool)
continue
# 全局白名单过滤
if global_whitelist is not None and plugin.name not in global_whitelist:
continue
# 会话级禁用过滤
if plugin.name in session_disabled:
continue
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set


async def _handle_webchat(
Expand Down Expand Up @@ -1372,7 +1416,7 @@ async def build_main_agent(
if not req.session_id:
req.session_id = event.unified_msg_origin

_plugin_tool_fix(event, req)
await _plugin_tool_fix(event, req)
await _apply_web_search_tools(event, req, plugin_context)

if config.llm_safety_mode:
Expand Down
3 changes: 3 additions & 0 deletions astrbot/core/pipeline/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ async def call_event_hook(
#

"""
from astrbot.core.star.session_plugin_manager import SessionPluginManager

handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type,
plugins_name=event.plugins_name,
)
handlers = await SessionPluginManager.filter_handlers_by_session(event, handlers)
for handler in handlers:
try:
assert inspect.iscoroutinefunction(handler.handler)
Expand Down
Loading