diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 3916215e5b..388d0edc6e 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -383,6 +383,7 @@ def _build_local_mode_prompt() -> str: def _filter_skills_for_current_config( skills: list[SkillInfo], cfg: dict, + session_disabled: set[str] | None = None, ) -> list[SkillInfo]: plugin_set = cfg.get("plugin_set", ["*"]) allowed_plugins = ( @@ -404,7 +405,12 @@ def _filter_skills_for_current_config( plugin = plugin_by_root_dir.get(skill.plugin_name) if not plugin or not plugin.activated: continue - if plugin.reserved or allowed_plugins is None: + if plugin.reserved: + filtered.append(skill) + continue + if session_disabled and plugin.name in session_disabled: + continue + if allowed_plugins is None: filtered.append(skill) continue if plugin.name is not None and plugin.name in allowed_plugins: @@ -422,6 +428,19 @@ async def _ensure_persona_and_skills( if not req.conversation: return + from astrbot.core import sp + + session_id = event.unified_msg_origin + session_plugin_config = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_plugin_config", + default={}, + ) + session_disabled = set( + session_plugin_config.get(session_id, {}).get("disabled_plugins", []) + ) + ( persona_id, persona, @@ -454,7 +473,7 @@ async def _ensure_persona_and_skills( runtime = cfg.get("computer_use_runtime", "local") skill_manager = SkillManager() skills = skill_manager.list_skills(active_only=True, runtime=runtime) - skills = _filter_skills_for_current_config(skills, cfg) + skills = _filter_skills_for_current_config(skills, cfg, session_disabled) if skills: if persona and persona.get("skills") is not None: @@ -898,33 +917,58 @@ async def _decorate_llm_request( _apply_workspace_extra_prompt(event, req) -def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: +async def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: """根据事件中的插件设置,过滤请求中的工具列表。 注意:没有 handler_module_path 的工具(如 MCP 工具)会被保留, 因为它们不属于任何插件,不应被插件过滤逻辑影响。 """ - if event.plugins_name is not None and req.func_tool: - new_tool_set = ToolSet() - for tool in req.func_tool.tools: - if isinstance(tool, MCPTool): - # 保留 MCP 工具 - new_tool_set.add_tool(tool) - continue - mp = tool.handler_module_path - if not mp: - # 没有 plugin 归属信息的工具(如 subagent transfer_to_*) - # 不应受到会话插件过滤影响。 - new_tool_set.add_tool(tool) - continue - plugin = star_map.get(mp) - if not plugin: - # 无法解析插件归属时,保守保留工具,避免误过滤。 - new_tool_set.add_tool(tool) - continue - if plugin.name in event.plugins_name or plugin.reserved: - new_tool_set.add_tool(tool) - req.func_tool = new_tool_set + if not req.func_tool: + return + + from astrbot.core import sp + + session_id = event.unified_msg_origin + session_plugin_config = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_plugin_config", + default={}, + ) + session_disabled = set( + session_plugin_config.get(session_id, {}).get("disabled_plugins", []) + ) + + global_whitelist = event.plugins_name # None 表示全部允许 + + new_tool_set = ToolSet() + for tool in req.func_tool.tools: + if isinstance(tool, MCPTool): + # 保留 MCP 工具 + new_tool_set.add_tool(tool) + continue + mp = tool.handler_module_path + if not mp: + # 没有 plugin 归属信息的工具(如 subagent transfer_to_*) + # 不应受到会话插件过滤影响。 + new_tool_set.add_tool(tool) + continue + plugin = star_map.get(mp) + if not plugin: + # 无法解析插件归属时,保守保留工具,避免误过滤。 + new_tool_set.add_tool(tool) + continue + if plugin.reserved: + new_tool_set.add_tool(tool) + continue + # 全局白名单过滤 + if global_whitelist is not None and plugin.name not in global_whitelist: + continue + # 会话级禁用过滤 + if plugin.name in session_disabled: + continue + new_tool_set.add_tool(tool) + req.func_tool = new_tool_set async def _handle_webchat( @@ -1372,7 +1416,7 @@ async def build_main_agent( if not req.session_id: req.session_id = event.unified_msg_origin - _plugin_tool_fix(event, req) + await _plugin_tool_fix(event, req) await _apply_web_search_tools(event, req, plugin_context) if config.llm_safety_mode: diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 9402ce3e62..28426be1ff 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -85,10 +85,13 @@ async def call_event_hook( # """ + from astrbot.core.star.session_plugin_manager import SessionPluginManager + handlers = star_handlers_registry.get_handlers_by_event_type( hook_type, plugins_name=event.plugins_name, ) + handlers = await SessionPluginManager.filter_handlers_by_session(event, handlers) for handler in handlers: try: assert inspect.iscoroutinefunction(handler.handler)