diff --git a/.env.template b/.env.template index e56b640ea..58f889295 100644 --- a/.env.template +++ b/.env.template @@ -29,6 +29,7 @@ YUXI_INSTANCE_ID= # # Servies # YUXI_SUPER_ADMIN_NAME= # YUXI_SUPER_ADMIN_PASSWORD= +# MCP_CREDENTIALS_MASTER_KEY= # # URL Whitelist (comma-separated domains/IPs, empty to disable URL parsing) # YUXI_URL_WHITELIST=github.com,docs.example.com,gitlab.example.com,127.0.0.1 @@ -73,4 +74,4 @@ YUXI_INSTANCE_ID= # SANDBOX_NODE_HOST=host.docker.internal # KUBECONFIG_PATH=/root/.kube/config # THREAD_PVC=yuxi-thread -# SKILLS_PVC=yuxi-skills # 当前代码会读取,但 Pod 挂载实际仍只使用 THREAD_PVC \ No newline at end of file +# SKILLS_PVC=yuxi-skills # 当前代码会读取,但 Pod 挂载实际仍只使用 THREAD_PVC diff --git a/.gitignore b/.gitignore index ec4c35cac..aa5dc6012 100644 --- a/.gitignore +++ b/.gitignore @@ -79,4 +79,7 @@ docs/vibe /models -.taskr/ \ No newline at end of file +.taskr/ + +.workbuddy +.worktrees/ diff --git a/backend/package/pyproject.toml b/backend/package/pyproject.toml index 9188d218f..7ef87fd41 100644 --- a/backend/package/pyproject.toml +++ b/backend/package/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "aiosqlite>=0.20.0", "argon2-cffi>=25.1.0", "asyncpg>=0.30.0", + "cachetools>=5.3.0", "chardet>=5.0.0", "colorlog>=6.9.0", "dashscope>=1.23.2", diff --git a/backend/package/yuxi/agents/__init__.py b/backend/package/yuxi/agents/__init__.py index dd7174a4f..8c654c59c 100644 --- a/backend/package/yuxi/agents/__init__.py +++ b/backend/package/yuxi/agents/__init__.py @@ -12,7 +12,7 @@ from yuxi.agents.toolkits.utils import get_tool_info # MCP - Agent 层统一入口(自动过滤 disabled_tools) -from yuxi.services.mcp_service import get_enabled_mcp_tools +from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools __all__ = [ # Base classes diff --git a/backend/package/yuxi/agents/buildin/chatbot/graph.py b/backend/package/yuxi/agents/buildin/chatbot/graph.py index d7dcdc2d2..db3d9b9ee 100644 --- a/backend/package/yuxi/agents/buildin/chatbot/graph.py +++ b/backend/package/yuxi/agents/buildin/chatbot/graph.py @@ -12,16 +12,19 @@ save_attachments_to_fs, ) from yuxi.agents.middlewares.knowledge_base_middleware import KnowledgeBaseMiddleware -from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware -from yuxi.services.mcp_service import get_tools_from_all_servers +from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware, collect_context_mcp_names_for_preload +from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers from yuxi.services.subagent_service import get_subagents_from_names +from yuxi.utils.logging_config import logger from .prompt import TODO_MID_PROMPT, build_prompt_with_context async def _build_middlewares(context): """构建中间件列表""" - all_mcp_tools = await get_tools_from_all_servers() # 因为异步加载,无法放在 RuntimeConfigMiddleware 的 __init__ 中 + preload_mcp_names = await collect_context_mcp_names_for_preload(context) + logger.info(f"ChatbotAgent MCP preload candidates: {preload_mcp_names}") + all_mcp_tools = await get_tools_from_all_servers(preload_mcp_names) # summary middleware # 主 Agent 上下文优化:90k tokens 触发压缩(128k context window 的 70%) diff --git a/backend/package/yuxi/agents/buildin/deep_agent/graph.py b/backend/package/yuxi/agents/buildin/deep_agent/graph.py index 706ec3b1f..4f100f0dc 100644 --- a/backend/package/yuxi/agents/buildin/deep_agent/graph.py +++ b/backend/package/yuxi/agents/buildin/deep_agent/graph.py @@ -15,9 +15,9 @@ save_attachments_to_fs, ) from yuxi.agents.middlewares.knowledge_base_middleware import KnowledgeBaseMiddleware -from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware +from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware, collect_context_mcp_names_for_preload from yuxi.agents.toolkits.buildin.tools import _create_tavily_search -from yuxi.services.mcp_service import get_tools_from_all_servers +from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers from yuxi.services.subagent_service import get_subagents_from_names from yuxi.utils import logger @@ -57,7 +57,9 @@ async def get_graph(self, context=None, **kwargs): model = load_chat_model(context.model) sub_model = load_chat_model(context.subagents_model) search_tools = await self.get_tools() - all_mcp_tools = await get_tools_from_all_servers() + preload_mcp_names = await collect_context_mcp_names_for_preload(context) + logger.info(f"DeepAgent MCP preload candidates: {preload_mcp_names}") + all_mcp_tools = await get_tools_from_all_servers(preload_mcp_names) # 合并搜索工具和 MCP 工具 # 从数据库加载 subagent specs(工具名称已解析) diff --git a/backend/package/yuxi/agents/context.py b/backend/package/yuxi/agents/context.py index 851780b69..9050b173d 100644 --- a/backend/package/yuxi/agents/context.py +++ b/backend/package/yuxi/agents/context.py @@ -33,6 +33,20 @@ def update(self, data: dict): metadata={"name": "用户ID", "configurable": False, "description": "用来唯一标识一个用户"}, ) + work_id: str | None = field( + default=None, + metadata={ + "name": "工号", + "configurable": False, + "description": "用来匹配个人 MCP 连接绑定范围的工号", + }, + ) + + department_id: str | None = field( + default=None, + metadata={"name": "部门ID", "configurable": False, "description": "用来标识当前用户所属部门"}, + ) + system_prompt: Annotated[str, {"__template_metadata__": {"kind": "prompt"}}] = field( default="You are a helpful assistant.", metadata={"name": "系统提示词", "description": "用来描述智能体的角色和行为"}, diff --git a/backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py b/backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py index 4f8c3aac5..5b0d47820 100644 --- a/backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py +++ b/backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py @@ -3,7 +3,7 @@ from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse -from yuxi.services.mcp_service import get_mcp_tools +from yuxi.services.mcp.tool_registry_service import get_mcp_tools from yuxi.utils import logger diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index a1b6e98f2..8a4ba6c92 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -4,14 +4,18 @@ from typing import Any from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse +from langchain.tools.tool_node import ToolCallRequest from langchain_core.messages import SystemMessage from yuxi.agents import load_chat_model from yuxi.agents.toolkits import get_all_tool_instances -from yuxi.services.mcp_service import get_enabled_mcp_tools +from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools +from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.utils.datetime_utils import shanghai_now from yuxi.utils.logging_config import logger +_RUNTIME_DYNAMIC_TOOLS_ATTR = "_runtime_config_dynamic_tools_by_name" + class RuntimeConfigMiddleware(AgentMiddleware): """运行时配置中间件 - 应用模型/工具/MCP/提示词配置 @@ -89,14 +93,19 @@ async def awrap_model_call( # 获取上下文配置的工具 enabled_tools = await self.get_tools_from_context(runtime_context) existing_tools = list(request.tools or []) - enabled_tool_names = {t.name for t in enabled_tools} managed_tool_names = {t.name for t in self.tools} merged_tools = [] for t_bind in existing_tools: - # (1) 已启用的工具保留 - # (2) 非本中间件管理的工具保留 - if t_bind.name in enabled_tool_names or t_bind.name not in managed_tool_names: + # 非本中间件管理的工具保留;本中间件管理的工具统一用本轮实时加载结果覆盖。 + if t_bind.name not in managed_tool_names: merged_tools.append(t_bind) + merged_tool_names = {t.name for t in merged_tools} + for tool in enabled_tools: + if tool.name in merged_tool_names: + continue + merged_tools.append(tool) + merged_tool_names.add(tool.name) + setattr(runtime_context, _RUNTIME_DYNAMIC_TOOLS_ATTR, {tool.name: tool for tool in enabled_tools}) overrides["tools"] = merged_tools logger.debug(f"RuntimeConfigMiddleware selected tools: {[t.name for t in merged_tools]}") @@ -116,6 +125,33 @@ async def awrap_model_call( return await handler(request) + async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], Any]): + """Allow ToolNode to execute runtime-auth MCP tools loaded during the last model call.""" + if request.tool is None: + runtime_context = getattr(request.runtime, "context", None) + dynamic_tools = getattr(runtime_context, _RUNTIME_DYNAMIC_TOOLS_ATTR, {}) or {} + tool = dynamic_tools.get(request.tool_call.get("name")) if isinstance(dynamic_tools, dict) else None + if tool is not None: + request = request.override(tool=tool) + + # NOTE: 注入当前的 AuthContext 以便于长连接拦截器 DynamicMCPTokenAuth 随时刷新 token + runtime_context = getattr(request.runtime, "context", None) + if runtime_context is not None: + user_id = getattr(runtime_context, "user_id", None) + work_id = getattr(runtime_context, "work_id", None) + dept_id = getattr(runtime_context, "department_id", None) + auth_context = AuthContext(user_id=user_id, department_id=dept_id, work_id=work_id) + + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + + token = mcp_auth_context_var.set(auth_context) + try: + return await handler(request) + finally: + mcp_auth_context_var.reset(token) + + return await handler(request) + async def get_tools_from_context(self, context) -> list: """从上下文配置中获取工具列表""" selected_tools = [] @@ -146,16 +182,41 @@ async def get_tools_from_context(self, context) -> list: all_mcp_names.append(server_name) selected_mcp_servers: set[str] = set() + selected_mcp_names: list[str] = [] + loaded_mcp_tools: dict[str, int] = {} + unavailable_mcp_servers: list[str] = [] + failed_mcp_servers: list[str] = [] for server_name in all_mcp_names: if server_name in selected_mcp_servers: continue selected_mcp_servers.add(server_name) + selected_mcp_names.append(server_name) try: - mcp_tools = await get_enabled_mcp_tools(server_name) + user_id = getattr(context, "user_id", None) + work_id = getattr(context, "work_id", None) + mcp_tools = await get_enabled_mcp_tools( + server_name, + auth_context=AuthContext( + user_id=user_id, + department_id=getattr(context, "department_id", None), + work_id=work_id, + ), + ) if not mcp_tools: - logger.warning(f"RuntimeConfigMiddleware: mcp dependency unavailable, skip: {server_name}") + unavailable_mcp_servers.append(server_name) + logger.debug(f"RuntimeConfigMiddleware: mcp dependency unavailable, skip: {server_name}") + else: + loaded_mcp_tools[server_name] = len(mcp_tools) selected_tools.extend(mcp_tools) except Exception as e: + failed_mcp_servers.append(server_name) logger.warning(f"RuntimeConfigMiddleware: failed to load mcp dependency '{server_name}': {e}") + if selected_mcp_names: + logger.info( + "RuntimeConfigMiddleware MCP runtime selection: " + f"selected={selected_mcp_names}, loaded={loaded_mcp_tools}, " + f"unavailable={unavailable_mcp_servers}, failed={failed_mcp_servers}" + ) + return selected_tools diff --git a/backend/package/yuxi/agents/middlewares/skills_middleware.py b/backend/package/yuxi/agents/middlewares/skills_middleware.py index 21a1d7543..7d90407b4 100644 --- a/backend/package/yuxi/agents/middlewares/skills_middleware.py +++ b/backend/package/yuxi/agents/middlewares/skills_middleware.py @@ -15,7 +15,8 @@ from yuxi.agents.toolkits import get_all_tool_instances from yuxi.repositories.skill_repository import SkillRepository -from yuxi.services.mcp_service import get_enabled_mcp_tools +from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools +from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.skill_service import _normalize_string_list, is_valid_skill_slug from yuxi.storage.postgres.manager import pg_manager from yuxi.utils.logging_config import logger @@ -79,6 +80,21 @@ async def get_dependency_map(db: AsyncSession | None = None) -> dict[str, SkillD return result +async def collect_context_mcp_names_for_preload(context, *, skills_context_name: str = "skills") -> list[str]: + """收集图构建阶段需要预注册的 MCP 名称。""" + names: list[str] = [] + names.extend(normalize_selected_skills(getattr(context, "mcps", None) or [])) + + dependency_map = await get_dependency_map() + configured_skills = normalize_selected_skills(getattr(context, skills_context_name, None) or []) + for slug in expand_skill_closure(configured_skills, dependency_map): + node = dependency_map.get(slug) + if node: + names.extend(node.get("mcps", [])) + + return normalize_selected_skills(names) + + def normalize_selected_skills(selected_skills: list[str] | None) -> list[str]: """规范化 skills 列表,去重并过滤无效值""" return _normalize_string_list(selected_skills) @@ -338,20 +354,42 @@ async def _get_mcp_tools_from_context( # 去重 unique_mcp_names = list(dict.fromkeys(all_mcp_names)) + loaded_mcp_tools: dict[str, int] = {} + unavailable_mcp_servers: list[str] = [] + failed_mcp_servers: list[str] = [] async def load_mcp_tools(server_name: str) -> list: """加载单个 MCP 服务器的工具""" try: - mcp_tools = await get_enabled_mcp_tools(server_name) + user_id = getattr(context, "user_id", None) + work_id = getattr(context, "work_id", None) + mcp_tools = await get_enabled_mcp_tools( + server_name, + auth_context=AuthContext( + user_id=user_id, + department_id=getattr(context, "department_id", None), + work_id=work_id, + ), + ) if not mcp_tools: - logger.warning(f"SkillsMiddleware: mcp dependency unavailable, skip: {server_name}") + unavailable_mcp_servers.append(server_name) + logger.debug(f"SkillsMiddleware: mcp dependency unavailable, skip: {server_name}") + else: + loaded_mcp_tools[server_name] = len(mcp_tools) return mcp_tools except Exception as e: + failed_mcp_servers.append(server_name) logger.warning(f"SkillsMiddleware: failed to load mcp dependency '{server_name}': {e}") return [] # 并行加载所有 MCP 工具 results = await asyncio.gather(*[load_mcp_tools(name) for name in unique_mcp_names]) + if unique_mcp_names: + logger.info( + "SkillsMiddleware MCP dependency selection: " + f"selected={unique_mcp_names}, loaded={loaded_mcp_tools}, " + f"unavailable={unavailable_mcp_servers}, failed={failed_mcp_servers}" + ) selected_tools = [] for tools in results: selected_tools.extend(tools) diff --git a/backend/package/yuxi/services/chat_service.py b/backend/package/yuxi/services/chat_service.py index 3430a3180..b91cc57e0 100644 --- a/backend/package/yuxi/services/chat_service.py +++ b/backend/package/yuxi/services/chat_service.py @@ -56,16 +56,47 @@ def _load_workspace_agents_prompt(thread_id: str, user_id: str) -> str: return prompt -async def _build_agent_input_context(agent_config: dict, *, thread_id: str, user_id: str) -> dict: +async def _build_agent_input_context( + agent_config: dict, + *, + thread_id: str, + current_user: User, +) -> dict: input_context = dict(agent_config or {}) - agents_prompt = await asyncio.to_thread(_load_workspace_agents_prompt, thread_id, user_id) + db_user_id = str(current_user.id) + work_id = current_user.user_id + department_id = current_user.department_id + agents_prompt = await asyncio.to_thread(_load_workspace_agents_prompt, thread_id, db_user_id) if agents_prompt: agents_section = f"用户工作区 agents/AGENTS.md 内容:\n{agents_prompt}" base_prompt = str(input_context.get("system_prompt") or "").rstrip() input_context["system_prompt"] = f"{base_prompt}\n\n{agents_section}" if base_prompt else agents_section - input_context.update({"user_id": user_id, "thread_id": thread_id}) + input_context.update( + { + "user_id": db_user_id, + "work_id": work_id, + "thread_id": thread_id, + "department_id": str(department_id) if department_id is not None else None, + } + ) + + # 将用户信息拼接到 system_prompt + user_info_parts = [] + if username := getattr(current_user, "username", None): + user_info_parts.append(f"姓名: {username}") + if role := getattr(current_user, "role", None): + user_info_parts.append(f"角色: {role}") + if work_id: + user_info_parts.append(f"工号: {work_id}") + + if user_info_parts: + user_info_block = "\n".join(user_info_parts) + current_prompt = str(input_context.get("system_prompt") or "").rstrip() + input_context["system_prompt"] = ( + f"{current_prompt}\n\n用户信息:\n{user_info_block}" if current_prompt else user_info_block + ) return input_context @@ -598,7 +629,11 @@ async def agent_chat( thread_id = str(uuid.uuid4()) logger.warning(f"No thread_id provided, generated new thread_id: {thread_id}") - input_context = await _build_agent_input_context(agent_config, thread_id=thread_id, user_id=user_id) + input_context = await _build_agent_input_context( + agent_config, + thread_id=thread_id, + current_user=current_user, + ) langfuse_run = _build_langfuse_run_context( current_user=current_user, thread_id=thread_id, @@ -814,7 +849,11 @@ def make_chunk(content=None, **kwargs): thread_id = str(uuid.uuid4()) logger.warning(f"No thread_id provided, generated new thread_id: {thread_id}") - input_context = await _build_agent_input_context(agent_config, thread_id=thread_id, user_id=user_id) + input_context = await _build_agent_input_context( + agent_config, + thread_id=thread_id, + current_user=current_user, + ) langfuse_run = _build_langfuse_run_context( current_user=current_user, thread_id=thread_id, @@ -1049,7 +1088,13 @@ def make_resume_chunk(content=None, **kwargs): return context = agent.context_schema() - context.update(await _build_agent_input_context(agent_config or {}, thread_id=thread_id, user_id=user_id)) + context.update( + await _build_agent_input_context( + agent_config or {}, + thread_id=thread_id, + current_user=current_user, + ) + ) graph = await agent.get_graph(context=context) langfuse_run = _build_langfuse_run_context( current_user=current_user, diff --git a/backend/package/yuxi/services/mcp/__init__.py b/backend/package/yuxi/services/mcp/__init__.py new file mode 100644 index 000000000..be19b5986 --- /dev/null +++ b/backend/package/yuxi/services/mcp/__init__.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from yuxi.services.mcp.cache_policy import ( + CachePolicyFactory, + DynamicProxyCachePolicy, + MCPCachePolicy, + StaticCachePolicy, + TokenInjectedCachePolicy, +) +from yuxi.services.mcp.client_pool import ( + MCPClientPool, + mcp_client_pool, +) +from yuxi.services.mcp.connection_service import ( + create_mcp_connection, + delete_mcp_connection, + get_mcp_connection, + list_mcp_connections, + reauthorize_mcp_connection, + set_mcp_connection_status, + test_mcp_connection, + update_mcp_connection, +) +from yuxi.services.mcp.server_service import ( + create_mcp_server, + delete_mcp_server, + ensure_builtin_mcp_servers_in_db, + get_all_mcp_servers, + get_enabled_mcp_server_config, + get_enabled_mcp_server_names, + get_mcp_server, + get_mcp_server_dependency_summary, + get_runtime_mcp_server_config, + get_servers_config, + set_server_enabled, + update_mcp_server, +) +from yuxi.services.mcp.tool_registry_service import ( + clear_mcp_cache, + clear_mcp_connection_tools_cache, + clear_mcp_server_tools_cache, + get_all_mcp_tools, + get_enabled_mcp_tools, + get_mcp_tools, + get_mcp_tools_stats, + get_tools_from_all_servers, + invalidate_mcp_connection_tools_cache, + invalidate_mcp_server_tools_cache, + to_camel_case, +) + +__all__ = [ + # 策略模式与对象池 + "MCPCachePolicy", + "StaticCachePolicy", + "TokenInjectedCachePolicy", + "DynamicProxyCachePolicy", + "CachePolicyFactory", + "mcp_client_pool", + "MCPClientPool", + # Server CRUD + "ensure_builtin_mcp_servers_in_db", + "get_enabled_mcp_server_config", + "get_runtime_mcp_server_config", + "get_enabled_mcp_server_names", + "get_mcp_server", + "get_all_mcp_servers", + "create_mcp_server", + "update_mcp_server", + "delete_mcp_server", + "get_mcp_server_dependency_summary", + "set_server_enabled", + "get_servers_config", + # Connection CRUD + "get_mcp_connection", + "list_mcp_connections", + "create_mcp_connection", + "update_mcp_connection", + "delete_mcp_connection", + "set_mcp_connection_status", + "reauthorize_mcp_connection", + "test_mcp_connection", + # Tool Registry + "to_camel_case", + "get_mcp_tools", + "get_tools_from_all_servers", + "clear_mcp_cache", + "clear_mcp_server_tools_cache", + "clear_mcp_connection_tools_cache", + "invalidate_mcp_server_tools_cache", + "invalidate_mcp_connection_tools_cache", + "get_mcp_tools_stats", + "get_enabled_mcp_tools", + "get_all_mcp_tools", +] diff --git a/backend/package/yuxi/services/mcp/cache_policy.py b/backend/package/yuxi/services/mcp/cache_policy.py new file mode 100644 index 000000000..7df1d4805 --- /dev/null +++ b/backend/package/yuxi/services/mcp/cache_policy.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from yuxi.services.mcp_auth.orchestrator import AuthContext + from yuxi.storage.postgres.models_business import MCPConnection + + +class MCPCachePolicy(ABC): + """MCP 缓存策略抽象基类""" + + @abstractmethod + def should_cache_tool_object(self) -> bool: + """是否在内存中缓存底层的 Tool 实例对象""" + pass + + @abstractmethod + def resolve_cache_partition( + self, + auth_context: AuthContext, + connection: MCPConnection | None, + ) -> tuple[str, bool]: + """ + 解析该连接应被划分到哪一个缓存分区中。 + + 返回: + tuple[partition_key, is_shared_across_users] + - partition_key: 用于区分 Redis 缓存或内存缓存隔离区段的 Key。 + - is_shared_across_users: 表明该分区下的缓存在不同用户间是否可以共享。 + """ + pass + + +class StaticCachePolicy(MCPCachePolicy): + """静态配置(无鉴权)缓存策略""" + + def should_cache_tool_object(self) -> bool: + # NOTE: 静态服务无任何鉴权和状态变化,完全可以缓存 Tool 对象以提升性能 + return True + + def resolve_cache_partition( + self, + auth_context: AuthContext, + connection: MCPConnection | None, + ) -> tuple[str, bool]: + # NOTE: 静态连接全局共享同一个分区 + return "global", True + + +class TokenInjectedCachePolicy(MCPCachePolicy): + """静态凭据/环境变量注入型缓存策略(例如 bound_secret, stdio_env)""" + + def should_cache_tool_object(self) -> bool: + # NOTE: 绑定了静态凭据或环境变量的连接,一旦 Connection 确定,其工具列表也是确定的,支持缓存 Tool 对象 + return True + + def resolve_cache_partition( + self, + auth_context: AuthContext, + connection: MCPConnection | None, + ) -> tuple[str, bool]: + if connection is None: + return "global", True + + # NOTE: 仅系统级别(system)是多用户共享的,部门和个人级别一律判定为独占 + is_shared = connection.scope_type == "system" + return f"connection:{connection.id}", is_shared + + +class DynamicProxyCachePolicy(MCPCachePolicy): + """动态 Token 鉴权代理缓存策略(例如 custom_http_token, authorization_code)""" + + def should_cache_tool_object(self) -> bool: + # NOTE: 动态 Token 具有时效性且可能因用户身份变化,为了安全性,禁止在内存中缓存带有具体 Token 的 Tool 实例 + return False + + def resolve_cache_partition( + self, + auth_context: AuthContext, + connection: MCPConnection | None, + ) -> tuple[str, bool]: + if connection is None: + return "global", True + + # NOTE: 仅系统级别(system)是多用户共享的,部门和个人级别一律判定为独占 + is_shared = connection.scope_type == "system" + return f"connection:{connection.id}", is_shared + + +class CachePolicyFactory: + """缓存策略工厂,根据 auth_provider 获取匹配的 CachePolicy 实例""" + + @staticmethod + def get_policy(provider: str | None) -> MCPCachePolicy: + if not provider or provider == "legacy_static": + return StaticCachePolicy() + elif provider in ("bound_secret", "stdio_env"): + return TokenInjectedCachePolicy() + else: + # 默认为动态代理鉴权策略(支持 custom_http_token, client_credentials, authorization_code 等) + return DynamicProxyCachePolicy() diff --git a/backend/package/yuxi/services/mcp/client_pool.py b/backend/package/yuxi/services/mcp/client_pool.py new file mode 100644 index 000000000..3b8f95dce --- /dev/null +++ b/backend/package/yuxi/services/mcp/client_pool.py @@ -0,0 +1,323 @@ +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any + +import httpx +from langchain_mcp_adapters.client import MultiServerMCPClient +from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + +if TYPE_CHECKING: + from mcp import ClientSession + +from cachetools import TTLCache + +# 缓存存储格式: (server_name, user_id, department_id) -> resolved_headers +_resolved_headers_cache: TTLCache = TTLCache(maxsize=1024, ttl=15.0) + + +def clear_resolved_headers_cache() -> None: + """清除解析后的 headers 缓存""" + _resolved_headers_cache.clear() + + +def clear_server_resolved_headers_cache(server_name: str) -> None: + """清除指定服务器的解析后 headers 缓存""" + stale_keys = [k for k in _resolved_headers_cache if k[0] == server_name] + for key in stale_keys: + _resolved_headers_cache.pop(key, None) + + +logger = logging.getLogger("yuxi.mcp.client_pool") + + +class DynamicMCPTokenAuth(httpx.Auth): + """动态 MCP Token 认证拦截器,每次 HTTP 请求前从 ContextVar 动态读取并注入 Authorization 头部""" + + def __init__(self, server_name: str): + self.server_name = server_name + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + auth_context = mcp_auth_context_var.get() + if auth_context: + try: + cache_key = (self.server_name, auth_context.user_id, auth_context.department_id) + cached_headers = _resolved_headers_cache.get(cache_key) + if cached_headers is not None: + for key, val in cached_headers.items(): + request.headers[key] = str(val) + yield request + return + + from yuxi.services.mcp_auth.proxy_service import INTERNAL_PROXY_TOKEN_HEADER, create_proxy_access_token + + if INTERNAL_PROXY_TOKEN_HEADER.lower() in request.headers: + # NOTE: 代理模式下,直接在本地生成新的代理 JWT,跳过 DB 事务 + new_token = create_proxy_access_token(self.server_name, auth_context) + request.headers[INTERNAL_PROXY_TOKEN_HEADER] = new_token + _resolved_headers_cache[cache_key] = dict(request.headers) + yield request + return + + # 导入数据库会话管理器以获取连接与 Token + from yuxi.storage.postgres.manager import pg_manager + + async with pg_manager.get_async_session_context() as session: + from yuxi.services.mcp.server_service import get_runtime_mcp_server_config + + # NOTE: 读取当前上下文对应的最新运行时配置(含 Token 自动刷新逻辑) + runtime_config = await get_runtime_mcp_server_config( + self.server_name, + auth_context=auth_context, + db=session, + ) + if runtime_config: + headers = dict(runtime_config.get("headers") or {}) + _resolved_headers_cache[cache_key] = headers + + for key, val in headers.items(): + request.headers[key] = str(val) + except Exception as exc: + logger.error(f"DynamicMCPTokenAuth failed to resolve token headers for '{self.server_name}': {exc}") + yield request + + +class LongLivedSession: + """长期存活的 MCP Client 及其 Session 生命周期管理器""" + + def __init__(self, client: MultiServerMCPClient, server_name: str): + self.client = client + self.server_name = server_name + self.session: ClientSession | None = None + self._running = False + self._loop_task: asyncio.Task | None = None + self._ready_event = asyncio.Event() + self._stop_event = asyncio.Event() + + async def start(self): + """在后台启动长连接 Session""" + if not hasattr(self.client, "session"): + self.session = self.client + self._ready_event.set() + return + + self._running = True + self._stop_event.clear() + self._ready_event.clear() + self._loop_task = asyncio.create_task(self._run_loop()) + # 等待 Session 成功连接并完成 initialize() + await self._ready_event.wait() + if not self.session: + raise RuntimeError(f"Failed to startup MCP ClientSession for {self.server_name}") + + async def _run_loop(self): + try: + # NOTE: 利用 client.session 会在退出上下文时自动释放底层的 Stdio 子进程或 HTTP Keep-Alive 连接 + async with self.client.session(self.server_name) as session: + self.session = session + self._ready_event.set() + # 挂起直到收到停止指令 + await self._stop_event.wait() + except Exception as exc: + if self.session is None: + logger.debug(f"Failed to start MCP session for {self.server_name}: {exc}", exc_info=True) + else: + logger.warning(f"MCP session loop stopped for {self.server_name}: {exc}") + logger.debug(f"Error in long-lived MCP session loop for {self.server_name}", exc_info=True) + finally: + self.session = None + self._running = False + self._ready_event.set() + + async def stop(self): + """停止长连接,回收子进程与 TCP 连接资源""" + self._stop_event.set() + if self._loop_task: + try: + await asyncio.wait_for(self._loop_task, timeout=5.0) + except TimeoutError: + logger.warning(f"Timeout waiting for long-lived session of {self.server_name} to stop.") + self._loop_task.cancel() + except Exception as exc: + logger.debug(f"Exception during long-lived session cleanup of {self.server_name}: {exc}") + self._loop_task = None + + +class MCPClientPool: + """MCP 客户端连接池实现""" + + def __init__(self): + # 缓存键格式: (server_name, partition_key) -> tuple[LongLivedSession, str] | asyncio.Future + self._sessions: dict[tuple[str, str], Any] = {} + self._dict_lock = asyncio.Lock() + + def _calculate_config_hash(self, config: dict[str, Any]) -> str: + """根据配置计算 Hash 用于比对配置是否脏变""" + clean_config = { + k: v + for k, v in config.items() + if k + not in { + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + "disabled_tools", + } + } + # 剔除 header 中可能随时变化的 token,以便准确比对静态配置 + from yuxi.services.mcp_auth.proxy_service import INTERNAL_PROXY_TOKEN_HEADER + + transient_header_names = {"authorization", INTERNAL_PROXY_TOKEN_HEADER.lower()} + headers = dict(clean_config.get("headers") or {}) + headers = {key: value for key, value in headers.items() if key.lower() not in transient_header_names} + if headers: + clean_config["headers"] = headers + elif "headers" in clean_config: + clean_config["headers"] = {} + + payload = json.dumps(clean_config, sort_keys=True, ensure_ascii=True, separators=(",", ":"), default=str) + return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16] + + async def _get_mcp_client(self, server_configs: dict[str, Any] | None = None) -> MultiServerMCPClient | None: + try: + client = MultiServerMCPClient(server_configs) # pyright: ignore[reportArgumentType] + logger.info(f"Initialized MCP client with servers: {list(server_configs.keys() or [])}") + return client + except Exception as e: + logger.error(f"Failed to initialize MCP client: {e}") + return None + + async def get_session( + self, + server_name: str, + partition_key: str, + runtime_config: dict[str, Any], + ) -> ClientSession: + """获取或重建匹配当前配置的 ClientSession""" + config_hash = self._calculate_config_hash(runtime_config) + cache_key = (server_name, partition_key) + + while True: + async with self._dict_lock: + existing = self._sessions.get(cache_key) + + if existing is not None: + if isinstance(existing, asyncio.Future): + future = existing + stale_session = None + else: + ll_session, cached_hash = existing + if cached_hash == config_hash and ll_session.session is not None: + return ll_session.session + + self._sessions.pop(cache_key, None) + stale_session = ll_session + future = None + else: + future = None + stale_session = None + + stale_keys = [k for k in self._sessions if k[0] == server_name and k != cache_key] + stale_other_sessions = [] + for stale_key in stale_keys: + stale_val = self._sessions.pop(stale_key) + if not isinstance(stale_val, asyncio.Future): + stale_other_sessions.append(stale_val[0]) + + init_future = asyncio.get_running_loop().create_future() + self._sessions[cache_key] = init_future + break + + if future is not None: + await future + continue + + if stale_session is not None: + logger.info(f"Destroying stale/disconnected MCP session for {cache_key}") + await stale_session.stop() + continue + + for s_session in stale_other_sessions: + logger.info("Evicting stale MCP session") + await s_session.stop() + + try: + client_config = dict(runtime_config) + for magic_k in ( + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + "disabled_tools", + ): + client_config.pop(magic_k, None) + + if client_config.get("transport") in ("sse", "http", "streamable_http", "streamable-http"): + client_config["auth"] = DynamicMCPTokenAuth(server_name) + + logger.info( + f"Creating new long-lived MCP session for {cache_key} (transport: {client_config.get('transport')})" + ) + client = await self._get_mcp_client({server_name: client_config}) + if client is None: + raise RuntimeError(f"Failed to initialize MCP client for {server_name}") + ll_session = LongLivedSession(client, server_name) + await ll_session.start() + + result = (ll_session, config_hash) + init_future.set_result(result) + async with self._dict_lock: + self._sessions[cache_key] = result + return ll_session.session + + except BaseException as exc: + if not init_future.done(): + init_future.set_exception(exc) + init_future.exception() + async with self._dict_lock: + if self._sessions.get(cache_key) is init_future: + self._sessions.pop(cache_key, None) + raise + + async def remove_session(self, server_name: str, partition_key: str): + """移除指定 key 的连接,强制下一次请求重新创建""" + cache_key = (server_name, partition_key) + async with self._dict_lock: + val = self._sessions.pop(cache_key, None) + if val is not None and not isinstance(val, asyncio.Future): + ll_session, _ = val + logger.info(f"Removing invalid session for {cache_key} from pool") + await ll_session.stop() + + async def ensure_prewarm( + self, + server_name: str, + partition_key: str, + runtime_config: dict[str, Any], + ): + """后台异步预热加载,减少首次访问时的冷启动卡顿""" + try: + await self.get_session(server_name, partition_key, runtime_config) + except Exception as exc: + logger.warning(f"Failed to pre-warm MCP server '{server_name}': {exc}") + + async def shutdown(self): + """关闭并回收连接池中的所有连接""" + async with self._dict_lock: + sessions_to_stop = [] + for cache_key, val in list(self._sessions.items()): + if isinstance(val, asyncio.Future): + val.cancel() + else: + ll_session, _ = val + sessions_to_stop.append((cache_key, ll_session)) + self._sessions.clear() + + for cache_key, ll_session in sessions_to_stop: + logger.info(f"Stopping MCP session for {cache_key} during shutdown") + await ll_session.stop() + + +# 全局单例连接池 +mcp_client_pool = MCPClientPool() diff --git a/backend/package/yuxi/services/mcp/connection_service.py b/backend/package/yuxi/services/mcp/connection_service.py new file mode 100644 index 000000000..2a11a062b --- /dev/null +++ b/backend/package/yuxi/services/mcp/connection_service.py @@ -0,0 +1,570 @@ +from __future__ import annotations + +import logging +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import String, and_, cast, func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.crypto import encrypt_credential_blob +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.storage.postgres.models_business import Department, MCPConnection, User + +logger = logging.getLogger("yuxi.mcp.connection_service") + +_UNSET = object() +_VALID_MCP_CONNECTION_SCOPE_TYPES = {"system", "department", "user"} +_VALID_MCP_CONNECTION_STATUSES = {"active", "disabled", "reauth_required", "invalid"} +_MCP_CONNECTION_SCOPE_LABELS = { + "system": "全局共享", + "department": "部门共享", + "user": "个人专用", +} +_MCP_CONNECTION_HEALTH_FILTERS = {"all", "active", "attention", "disabled"} + + +def _resolve_scope_id(binding_scope: str, auth_context: AuthContext | None) -> str | None: + """依据 Scope 类别从 AuthContext 中解出匹配的 ID""" + if binding_scope == "inline": + return None + if binding_scope == "system": + return "global" + if auth_context is None: + raise ValueError(f"auth_context is required for MCP binding scope '{binding_scope}'") + if binding_scope == "department": + if not auth_context.department_id: + raise ValueError("department_id is required for department-scoped MCP auth") + return str(auth_context.department_id) + if binding_scope == "user": + if not auth_context.user_id: + raise ValueError("user_id is required for user-scoped MCP auth") + return str(auth_context.user_id) + raise ValueError(f"Unsupported MCP binding scope: {binding_scope}") + + +def requires_bound_mcp_connection(auth_config: MCPAuthConfig) -> bool: + """判断当前鉴权配置是否必须存在 active MCPConnection。""" + return auth_config.binding_scope != "inline" and bool(auth_config.get_secret_fields()) + + +def _normalize_mcp_connection_scope(scope_type: str, scope_id: str | None) -> tuple[str, str]: + normalized_scope_type = str(scope_type or "").strip().lower() + if normalized_scope_type not in _VALID_MCP_CONNECTION_SCOPE_TYPES: + raise ValueError("scope_type must be one of: system, department, user") + + normalized_scope_id = str(scope_id or "").strip() + if normalized_scope_type == "system": + return normalized_scope_type, "global" + if not normalized_scope_id: + raise ValueError(f"scope_id is required for {normalized_scope_type}-scoped MCP connections") + return normalized_scope_type, normalized_scope_id + + +def _normalize_mcp_connection_status(status: str) -> str: + normalized_status = str(status or "").strip().lower() + if normalized_status not in _VALID_MCP_CONNECTION_STATUSES: + raise ValueError("status must be one of: active, disabled, reauth_required, invalid") + return normalized_status + + +def _format_duplicate_connection_message(server_name: str, scope_type: str) -> str: + scope_label = _MCP_CONNECTION_SCOPE_LABELS.get(scope_type, "绑定") + return f'MCP "{server_name}" 的{scope_label}连接已存在,请直接编辑现有连接。' + + +def _configured_connection_scope(server) -> str | None: + payload = getattr(server, "auth_config_json", None) + if not payload: + return None + try: + auth_config = MCPAuthConfig.model_validate(payload) + except Exception: + return None + if auth_config.binding_scope in _VALID_MCP_CONNECTION_SCOPE_TYPES: + return auth_config.binding_scope + return None + + +def _ensure_connection_scope_matches_server(server, scope_type: str) -> None: + configured_scope = _configured_connection_scope(server) + if not configured_scope or scope_type == configured_scope: + return + scope_label = _MCP_CONNECTION_SCOPE_LABELS.get(configured_scope, "当前绑定类型") + server_name = getattr(server, "name", "") + raise ValueError(f'MCP "{server_name}" 当前绑定类型是{scope_label},只能使用{scope_label}连接。') + + +async def get_mcp_connection(db: AsyncSession, connection_id: int) -> MCPConnection | None: + """获取单个 Connection 记录""" + result = await db.execute(select(MCPConnection).where(MCPConnection.id == connection_id)) + return result.scalar_one_or_none() + + +def _auth_context_from_connection(connection: MCPConnection) -> AuthContext: + """基于连接绑定生成对应的 AuthContext 用于模拟联调与测试""" + if connection.scope_type == "department": + return AuthContext(department_id=connection.scope_id) + if connection.scope_type == "user": + return AuthContext(user_id=connection.scope_id, work_id=connection.scope_id) + return AuthContext() + + +async def list_mcp_connections( + db: AsyncSession, + *, + server_name: str | None = None, + scope_type: str | None = None, + scope_id: str | None = None, +) -> list[MCPConnection]: + """多条件列表查询 Connection""" + stmt = select(MCPConnection) + if server_name is not None: + stmt = stmt.where(MCPConnection.server_name == server_name) + if scope_type is not None: + stmt = stmt.where(MCPConnection.scope_type == scope_type) + if scope_id is not None: + stmt = stmt.where(MCPConnection.scope_id == scope_id) + stmt = stmt.order_by(MCPConnection.id.asc()) + result = await db.execute(stmt) + return list(result.scalars().all()) + + +def _connection_has_credential_condition(): + return and_(MCPConnection.credential_blob.is_not(None), MCPConnection.credential_blob != "") + + +def _connection_missing_credential_condition(): + return or_(MCPConnection.credential_blob.is_(None), MCPConnection.credential_blob == "") + + +def _connection_search_condition(search: str): + keyword = str(search or "").strip() + like_keyword = f"%{keyword}%" + lowered_keyword = keyword.lower() + conditions = [ + MCPConnection.display_name.ilike(like_keyword), + MCPConnection.external_subject.ilike(like_keyword), + MCPConnection.scope_id.ilike(like_keyword), + MCPConnection.created_by.ilike(like_keyword), + MCPConnection.updated_by.ilike(like_keyword), + and_( + MCPConnection.scope_type == "department", + select(Department.id) + .where( + cast(Department.id, String) == MCPConnection.scope_id, + Department.name.ilike(like_keyword), + ) + .exists(), + ), + and_( + MCPConnection.scope_type == "user", + select(User.id) + .where( + or_( + cast(User.id, String) == MCPConnection.scope_id, + User.user_id == MCPConnection.scope_id, + ), + or_(User.username.ilike(like_keyword), User.user_id.ilike(like_keyword)), + ) + .exists(), + ), + ] + if any(token in lowered_keyword for token in ("system", "global", "全局", "共享", "全部")): + conditions.append(MCPConnection.scope_type == "system") + if any(token in lowered_keyword for token in ("department", "dept", "部门")): + conditions.append(MCPConnection.scope_type == "department") + if any(token in lowered_keyword for token in ("user", "个人", "用户")): + conditions.append(MCPConnection.scope_type == "user") + return or_(*conditions) + + +def _connection_health_condition( + status_filter: str, + *, + effective_scope_type: str | None, + credentials_required: bool, +): + normalized_filter = str(status_filter or "all").strip().lower() + if normalized_filter not in _MCP_CONNECTION_HEALTH_FILTERS: + raise ValueError("status filter must be one of: all, active, attention, disabled") + if normalized_filter == "all": + return None + if normalized_filter == "disabled": + return MCPConnection.status == "disabled" + + conditions = [] + if normalized_filter == "active": + conditions.append(MCPConnection.status == "active") + if effective_scope_type: + conditions.append(MCPConnection.scope_type == effective_scope_type) + if credentials_required: + conditions.append(_connection_has_credential_condition()) + return and_(*conditions) + + conditions.append(MCPConnection.status.in_(("reauth_required", "invalid"))) + if effective_scope_type: + conditions.append(MCPConnection.scope_type != effective_scope_type) + if credentials_required: + conditions.append(_connection_missing_credential_condition()) + return or_(*conditions) + + +def _build_mcp_connections_query( + *, + server_name: str | None = None, + scope_type: str | None = None, + scope_id: str | None = None, + status_filter: str = "all", + effective_scope_type: str | None = None, + credentials_required: bool = False, + search: str | None = None, +): + conditions = [] + if server_name is not None: + conditions.append(MCPConnection.server_name == server_name) + if scope_type is not None: + conditions.append(MCPConnection.scope_type == scope_type) + if scope_id is not None: + conditions.append(MCPConnection.scope_id == scope_id) + + health_condition = _connection_health_condition( + status_filter, + effective_scope_type=effective_scope_type, + credentials_required=credentials_required, + ) + if health_condition is not None: + conditions.append(health_condition) + + if search and str(search).strip(): + conditions.append(_connection_search_condition(str(search).strip())) + + return conditions + + +async def count_mcp_connections( + db: AsyncSession, + *, + server_name: str | None = None, + scope_type: str | None = None, + scope_id: str | None = None, + status_filter: str = "all", + effective_scope_type: str | None = None, + credentials_required: bool = False, + search: str | None = None, +) -> int: + """统计符合筛选条件的连接数量。""" + stmt = select(func.count()).select_from(MCPConnection) + for condition in _build_mcp_connections_query( + server_name=server_name, + scope_type=scope_type, + scope_id=scope_id, + status_filter=status_filter, + effective_scope_type=effective_scope_type, + credentials_required=credentials_required, + search=search, + ): + stmt = stmt.where(condition) + result = await db.execute(stmt) + return int(result.scalar_one() or 0) + + +async def list_mcp_connections_page( + db: AsyncSession, + *, + server_name: str | None = None, + scope_type: str | None = None, + scope_id: str | None = None, + status_filter: str = "all", + effective_scope_type: str | None = None, + credentials_required: bool = False, + search: str | None = None, + page: int = 1, + page_size: int = 12, +) -> tuple[list[MCPConnection], int]: + """分页查询连接列表,供管理员连接页使用。""" + normalized_page = max(1, int(page or 1)) + normalized_page_size = min(max(1, int(page_size or 12)), 100) + conditions = _build_mcp_connections_query( + server_name=server_name, + scope_type=scope_type, + scope_id=scope_id, + status_filter=status_filter, + effective_scope_type=effective_scope_type, + credentials_required=credentials_required, + search=search, + ) + stmt = select(MCPConnection).order_by(MCPConnection.id.asc()) + count_stmt = select(func.count()).select_from(MCPConnection) + for condition in conditions: + stmt = stmt.where(condition) + count_stmt = count_stmt.where(condition) + stmt = stmt.limit(normalized_page_size).offset((normalized_page - 1) * normalized_page_size) + + total_result = await db.execute(count_stmt) + result = await db.execute(stmt) + return list(result.scalars().all()), int(total_result.scalar_one() or 0) + + +async def create_mcp_connection( + db: AsyncSession, + *, + server_name: str, + scope_type: str, + scope_id: str, + display_name: str | None = None, + external_subject: str | None = None, + status: str = "active", + credential_blob: str | None = None, + meta_json: dict[str, Any] | None = None, + created_by: str | None = None, +) -> MCPConnection: + """创建 MCP 绑定连接""" + from yuxi.services.mcp.server_service import get_mcp_server + + server = await get_mcp_server(db, server_name) + if server is None: + raise ValueError(f"Server '{server_name}' does not exist") + normalized_scope_type, normalized_scope_id = _normalize_mcp_connection_scope(scope_type, scope_id) + normalized_status = _normalize_mcp_connection_status(status) + _ensure_connection_scope_matches_server(server, normalized_scope_type) + + encrypted_credential_blob = ( + encrypt_credential_blob(credential_blob) + if isinstance(credential_blob, str) and credential_blob.strip() + else credential_blob + ) + + connection = MCPConnection( + server_name=server_name, + scope_type=normalized_scope_type, + scope_id=normalized_scope_id, + display_name=display_name, + external_subject=external_subject, + status=normalized_status, + credential_blob=encrypted_credential_blob, + meta_json=meta_json or {}, + created_by=created_by, + updated_by=created_by, + ) + db.add(connection) + from sqlalchemy.exc import IntegrityError + + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise ValueError(_format_duplicate_connection_message(server_name, normalized_scope_type)) + await db.refresh(connection) + return connection + + +async def update_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + display_name: str | None = None, + external_subject: str | None = None, + credential_blob: Any = _UNSET, + meta_json: dict[str, Any] | None = None, + status: str | None = None, + updated_by: str | None = None, +) -> MCPConnection: + """更新 MCP 绑定连接""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + should_clear_runtime_auth_cache = False + if display_name is not None: + connection.display_name = display_name + if external_subject is not None: + connection.external_subject = external_subject + if credential_blob is not _UNSET: + if isinstance(credential_blob, str) and credential_blob.strip(): + connection.credential_blob = encrypt_credential_blob(credential_blob) + else: + connection.credential_blob = credential_blob + should_clear_runtime_auth_cache = True + if meta_json is not None: + connection.meta_json = meta_json + if status is not None: + normalized_status = _normalize_mcp_connection_status(status) + if normalized_status == "active": + from yuxi.services.mcp.server_service import get_mcp_server + + server = await get_mcp_server(db, connection.server_name) + if server is None: + raise ValueError(f"Server '{connection.server_name}' does not exist") + _ensure_connection_scope_matches_server(server, connection.scope_type) + connection.status = normalized_status + should_clear_runtime_auth_cache = True + if updated_by is not None: + connection.updated_by = updated_by + + await db.commit() + await db.refresh(connection) + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_connection_runtime_auth_cache, + _invalidate_mcp_tools_cache_for_connection, + ) + + if should_clear_runtime_auth_cache: + await _clear_mcp_connection_runtime_auth_cache(connection.id) + await _invalidate_mcp_tools_cache_for_connection(connection) + return connection + + +async def delete_mcp_connection(db: AsyncSession, connection_id: int) -> bool: + """删除 MCP 绑定连接""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + return False + deleted_connection_id = connection.id + deleted_server_name = connection.server_name + deleted_scope_type = connection.scope_type + await db.delete(connection) + await db.commit() + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_connection_runtime_auth_cache, + invalidate_mcp_connection_tools_cache, + invalidate_mcp_server_tools_cache, + ) + + await _clear_mcp_connection_runtime_auth_cache(deleted_connection_id) + if deleted_scope_type == "system": + await invalidate_mcp_server_tools_cache(deleted_server_name) + else: + await invalidate_mcp_connection_tools_cache(deleted_server_name, deleted_connection_id) + return True + + +async def set_mcp_connection_status( + db: AsyncSession, + connection_id: int, + *, + status: str, + updated_by: str | None = None, +) -> MCPConnection: + """设置 MCP 绑定状态""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + normalized_status = _normalize_mcp_connection_status(status) + if normalized_status == "active": + from yuxi.services.mcp.server_service import get_mcp_server + + server = await get_mcp_server(db, connection.server_name) + if server is None: + raise ValueError(f"Server '{connection.server_name}' does not exist") + _ensure_connection_scope_matches_server(server, connection.scope_type) + + connection.status = normalized_status + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_connection_runtime_auth_cache, + _invalidate_mcp_tools_cache_for_connection, + ) + + await _clear_mcp_connection_runtime_auth_cache(connection.id) + await _invalidate_mcp_tools_cache_for_connection(connection) + return connection + + +async def reauthorize_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + updated_by: str | None = None, +) -> MCPConnection: + """重置授权连接凭据缓存并重新开启连接""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + from yuxi.services.mcp.server_service import get_mcp_server + + server = await get_mcp_server(db, connection.server_name) + if server is None: + raise ValueError(f"Server '{connection.server_name}' does not exist") + _ensure_connection_scope_matches_server(server, connection.scope_type) + + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + + cache = RedisTokenCache() + if getattr(connection, "id", None) is not None: + try: + await cache.delete_access_token(connection.id) + except Exception as exc: + logger.warning(f"Failed to clear MCP token cache for connection {connection.id}: {exc}") + try: + await cache.release_refresh_lock(connection.id) + except Exception as exc: + logger.warning(f"Failed to clear MCP refresh lock for connection {connection.id}: {exc}") + + from yuxi.services.mcp.tool_registry_service import _invalidate_mcp_tools_cache_for_connection + + await _invalidate_mcp_tools_cache_for_connection(connection) + + connection.status = "active" + meta_json = dict(connection.meta_json or {}) + meta_json.pop("last_error", None) + connection.meta_json = meta_json + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + return connection + + +async def test_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + updated_by: str | None = None, +) -> dict[str, Any]: + """测试连接联调可用性,获取可用的工具列表""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + from yuxi.services.mcp.server_service import get_mcp_server + + server = await get_mcp_server(db, connection.server_name) + if server is None: + raise ValueError(f"Server '{connection.server_name}' does not exist") + _ensure_connection_scope_matches_server(server, connection.scope_type) + + auth_context = _auth_context_from_connection(connection) + from yuxi.services.mcp.server_service import get_runtime_mcp_server_config + from yuxi.services.mcp.tool_registry_service import get_mcp_tools + + config = await get_runtime_mcp_server_config(server.name, auth_context=auth_context, db=db) + if config is None: + raise ValueError(f"MCP server '{server.name}' runtime config unavailable") + + tools = await get_mcp_tools( + server.name, + additional_servers={server.name: config}, + disabled_tools=[], + cache=False, + force_refresh=True, + ) + + meta_json = dict(connection.meta_json or {}) + meta_json["last_success_at"] = datetime.now(tz=UTC).isoformat() + meta_json.pop("last_error", None) + connection.meta_json = meta_json + connection.status = "active" + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + return {"tool_count": len(tools), "connection": connection} diff --git a/backend/package/yuxi/services/mcp/server_service.py b/backend/package/yuxi/services/mcp/server_service.py new file mode 100644 index 000000000..abc7cf15b --- /dev/null +++ b/backend/package/yuxi/services/mcp/server_service.py @@ -0,0 +1,486 @@ +from __future__ import annotations + +import logging +import os +import traceback +from typing import Any + +import httpx +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config +from yuxi.services.mcp_auth.proxy_service import ( + build_proxy_runtime_config, + should_use_internal_proxy, +) +from yuxi.storage.postgres.models_business import AgentConfig, MCPConnection, MCPServer, Skill + +logger = logging.getLogger("yuxi.mcp.server_service") + +_UNSET = object() +_MCP_PROXY_BASE_URL_ENV = "YUXI_INTERNAL_MCP_PROXY_BASE_URL" + +_DEFAULT_MCP_SERVERS = { + "sequentialthinking": { + "url": "https://remote.mcpservers.org/sequentialthinking/mcp", + "transport": "streamable_http", + "description": "顺序思考工具,帮助 AI 将复杂问题分解为多个步骤", + "icon": "🧠", + "tags": ["内置", "AI"], + }, + "mcp-server-chart": { + "command": "npx", + "args": ["-y", "@antv/mcp-server-chart"], + "transport": "stdio", + "description": "图表生成工具,支持生成各类图表(柱状图、折线图、饼图等)", + "icon": "📊", + "tags": ["内置", "图表"], + }, +} + +_SYNCED_MCP_FIELDS = ( + "description", + "transport", + "url", + "command", + "args", + "env", + "headers", + "timeout", + "sse_read_timeout", + "tags", + "icon", +) + + +async def ensure_builtin_mcp_servers_in_db() -> None: + """同步代码预置的内置 MCP 服务器至数据库中""" + from yuxi.storage.postgres.manager import pg_manager + + try: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(func.count(MCPServer.name))) + count = result.scalar() + + if count == 0: + logger.info("No MCP servers in database, importing default configurations...") + for name, config in _DEFAULT_MCP_SERVERS.items(): + server = MCPServer( + name=name, + description=config.get("description"), + transport=config["transport"], + url=config.get("url"), + command=config.get("command"), + args=config.get("args"), + env=config.get("env"), + headers=config.get("headers"), + timeout=config.get("timeout"), + sse_read_timeout=config.get("sse_read_timeout"), + tags=config.get("tags"), + icon=config.get("icon"), + enabled=0, + created_by="system", + updated_by="system", + ) + session.add(server) + await session.commit() + logger.info(f"Imported {len(_DEFAULT_MCP_SERVERS)} default MCP servers to database") + else: + for name, config in _DEFAULT_MCP_SERVERS.items(): + result = await session.execute(select(MCPServer).filter(MCPServer.name == name)) + existing = result.scalar_one_or_none() + if not existing: + server = MCPServer( + name=name, + description=config.get("description"), + transport=config["transport"], + url=config.get("url"), + command=config.get("command"), + args=config.get("args"), + env=config.get("env"), + headers=config.get("headers"), + timeout=config.get("timeout"), + sse_read_timeout=config.get("sse_read_timeout"), + tags=config.get("tags"), + icon=config.get("icon"), + enabled=0, + created_by="system", + updated_by="system", + ) + session.add(server) + logger.info(f"Added built-in MCP server '{name}' to database") + else: + changed = False + for field in _SYNCED_MCP_FIELDS: + next_value = config.get(field) + if getattr(existing, field) != next_value: + setattr(existing, field, next_value) + changed = True + if changed: + existing.updated_by = "system" + if session.new: + await session.commit() + elif session.dirty: + await session.commit() + + except Exception as e: + logger.error(f"Failed to ensure builtin MCP servers in database: {e}, traceback: {traceback.format_exc()}") + + +async def _load_enabled_mcp_server_configs( + *, + names: list[str] | None = None, + db: AsyncSession | None = None, +) -> dict[str, dict[str, Any]]: + """从数据库中加载已启用的服务器 MCP 配置""" + if db is not None: + stmt = select(MCPServer).where(MCPServer.enabled == 1) + if names: + stmt = stmt.where(MCPServer.name.in_(names)) + result = await db.execute(stmt) + servers = result.scalars().all() + return {server.name: server.to_mcp_config() for server in servers} + + from yuxi.storage.postgres.manager import pg_manager + + async with pg_manager.get_async_session_context() as session: + return await _load_enabled_mcp_server_configs(names=names, db=session) + + +async def get_enabled_mcp_server_config(server_name: str, *, db: AsyncSession | None = None) -> dict[str, Any] | None: + """获取最新启用的指定服务器的 MCP 配置""" + configs = await _load_enabled_mcp_server_configs(names=[server_name], db=db) + return configs.get(server_name) + + +def _get_internal_mcp_proxy_base_url() -> str | None: + value = os.getenv(_MCP_PROXY_BASE_URL_ENV, "").strip() + return value or None + + +async def _get_enabled_mcp_server_record(server_name: str, *, db: AsyncSession) -> MCPServer | None: + result = await db.execute( + select(MCPServer).where( + MCPServer.enabled == 1, + MCPServer.name == server_name, + ) + ) + return result.scalar_one_or_none() + + +def _apply_runtime_tool_cache_policy( + config: dict[str, Any], + *, + auth_config: MCPAuthConfig, + auth_context: AuthContext | None, + connection: MCPConnection | None, +) -> dict[str, Any]: + """利用 CachePolicy 模式获取缓存 key 的隔离区划并应用""" + from yuxi.services.mcp.cache_policy import CachePolicyFactory + + policy = CachePolicyFactory.get_policy(auth_config.provider) + partition, is_shared = policy.resolve_cache_partition( + auth_context or AuthContext(), + connection, + ) + config["__yuxi_cache_partition"] = partition + config["__yuxi_allow_global_cache"] = is_shared + return config + + +async def get_runtime_mcp_server_config( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, +) -> dict[str, Any] | None: + """解析获取附带运行时鉴权与租户范围的 MCP 服务配置""" + if db is None and auth_context is None: + return await get_enabled_mcp_server_config(server_name) + + if db is not None: + server = await _get_enabled_mcp_server_record(server_name, db=db) + if server is None: + return None + if not server.auth_config_json: + return server.to_mcp_config() + + auth_config = MCPAuthConfig.model_validate(server.auth_config_json) + from yuxi.services.mcp.connection_service import _resolve_scope_id, requires_bound_mcp_connection + + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) + if scope_id is None: + return server.to_mcp_config() + + result = await db.execute( + select(MCPConnection).where( + MCPConnection.server_name == server_name, + MCPConnection.scope_type == auth_config.binding_scope, + MCPConnection.scope_id == scope_id, + MCPConnection.status == "active", + ) + ) + connection = result.scalar_one_or_none() + if connection is None: + if requires_bound_mcp_connection(auth_config): + raise ValueError( + f"Active MCP connection not found for server '{server_name}' and scope " + f"{auth_config.binding_scope}:{scope_id}" + ) + # 无需长期密钥的鉴权机制无需强制绑定连接即可生成运行时配置 + proxy_base_url = _get_internal_mcp_proxy_base_url() + if should_use_internal_proxy(server, auth_config, proxy_base_url): + config = build_proxy_runtime_config( + server, + auth_context=auth_context or AuthContext(), + proxy_base_url=proxy_base_url or "", + ) + else: + config = await resolve_runtime_mcp_config( + server, + auth_context=auth_context or AuthContext(), + connection=connection, + http_client=http_client, + ) + return _apply_runtime_tool_cache_policy( + config, + auth_config=auth_config, + auth_context=auth_context, + connection=connection, + ) + + from yuxi.storage.postgres.manager import pg_manager + + async with pg_manager.get_async_session_context() as session: + return await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=session, + http_client=http_client, + ) + + +async def get_enabled_mcp_server_names(*, db: AsyncSession | None = None) -> list[str]: + """获取所有已启用的服务器名称""" + configs = await _load_enabled_mcp_server_configs(db=db) + return list(configs.keys()) + + +async def get_mcp_server(db: AsyncSession, name: str) -> MCPServer | None: + """获取单个服务器对象记录""" + result = await db.execute(select(MCPServer).filter(MCPServer.name == name)) + return result.scalar_one_or_none() + + +async def get_all_mcp_servers(db: AsyncSession) -> list[MCPServer]: + """获取所有配置的服务器对象列表""" + result = await db.execute(select(MCPServer)) + return list(result.scalars().all()) + + +async def create_mcp_server( + db: AsyncSession, + name: str, + transport: str, + url: str = None, + command: str = None, + args: list = None, + env: dict = None, + description: str = None, + headers: dict = None, + timeout: int = None, + sse_read_timeout: int = None, + tags: list = None, + icon: str = None, + auth_config: dict | None = None, + created_by: str = None, +) -> MCPServer: + """创建 MCP 服务器配置""" + existing = await get_mcp_server(db, name) + if existing: + raise ValueError(f"Server name '{name}' already exists") + + server = MCPServer( + name=name, + description=description, + transport=transport, + url=url, + command=command, + args=args, + env=env, + headers=headers, + auth_config_json=auth_config, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + tags=tags, + icon=icon, + enabled=1, + created_by=created_by, + updated_by=created_by, + ) + db.add(server) + await db.commit() + await db.refresh(server) + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_server_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + ) + + await _clear_mcp_server_runtime_auth_cache(db, name) + await invalidate_mcp_server_tools_cache(name) + + logger.info(f"Created MCP server '{name}'") + return server + + +async def update_mcp_server( + db: AsyncSession, + name: str, + description: str = None, + transport: str = None, + url: str = None, + command: str = None, + args: list = None, + env: Any = _UNSET, + headers: dict = None, + timeout: int = None, + sse_read_timeout: int = None, + tags: list = None, + icon: str = None, + auth_config: Any = _UNSET, + updated_by: str = None, +) -> MCPServer: + """更新服务器配置""" + server = await get_mcp_server(db, name) + if not server: + raise ValueError(f"Server '{name}' does not exist") + + if description is not None: + server.description = description + if transport is not None: + server.transport = transport + if url is not None: + server.url = url + if command is not None: + server.command = command + if args is not None: + server.args = args + if env is not _UNSET: + server.env = env + if headers is not None: + server.headers = headers + if auth_config is not _UNSET: + server.auth_config_json = auth_config + if timeout is not None: + server.timeout = timeout + if sse_read_timeout is not None: + server.sse_read_timeout = sse_read_timeout + if tags is not None: + server.tags = tags + if icon is not None: + server.icon = icon + if updated_by is not None: + server.updated_by = updated_by + + await db.commit() + await db.refresh(server) + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_server_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + ) + + if auth_config is not _UNSET: + await _clear_mcp_server_runtime_auth_cache(db, name) + await invalidate_mcp_server_tools_cache(name) + + logger.info(f"Updated MCP server '{name}'") + return server + + +async def delete_mcp_server(db: AsyncSession, name: str) -> bool: + """删除服务器""" + server = await get_mcp_server(db, name) + if not server: + return False + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_server_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + ) + + # NOTE: 必须在级联删除前执行 Redis 缓存清理,否则关联的 connection 行被删除后将无法提取 ID + await _clear_mcp_server_runtime_auth_cache(db, name) + + await db.delete(server) + await db.commit() + + await invalidate_mcp_server_tools_cache(name) + + logger.info(f"Deleted MCP server '{name}'") + return True + + +async def get_mcp_server_dependency_summary(db: AsyncSession, name: str) -> dict[str, Any]: + """获取依赖于该 MCP 服务器的智能体、技能和连接概要""" + from yuxi.services.mcp.connection_service import list_mcp_connections + + connections = await list_mcp_connections(db, server_name=name) + + skill_rows = (await db.execute(select(Skill))).scalars().all() + matched_skills = [ + {"slug": item.slug, "name": item.name} for item in skill_rows if name in (item.mcp_dependencies or []) + ] + + agent_config_rows = (await db.execute(select(AgentConfig))).scalars().all() + matched_agent_configs = [] + for item in agent_config_rows: + config_json = item.config_json or {} + if name in (config_json.get("mcps") or []): + matched_agent_configs.append({"id": item.id, "name": item.name, "agent_id": item.agent_id}) + + connection_refs = [ + {"scope_type": item.scope_type, "scope_id": item.scope_id, "status": item.status} for item in connections + ] + + return { + "has_references": bool(connection_refs or matched_skills or matched_agent_configs), + "connections": connection_refs, + "skills": matched_skills, + "agent_configs": matched_agent_configs, + } + + +async def set_server_enabled( + db: AsyncSession, name: str, enabled: bool, updated_by: str = None +) -> tuple[bool, MCPServer]: + """设置服务器的启用状态""" + server = await get_mcp_server(db, name) + if not server: + raise ValueError(f"Server '{name}' does not exist") + + server.enabled = 1 if enabled else 0 + if updated_by is not None: + server.updated_by = updated_by + await db.commit() + + is_enabled = bool(server.enabled) + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_server_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + ) + + if not is_enabled: + await _clear_mcp_server_runtime_auth_cache(db, name) + await invalidate_mcp_server_tools_cache(name) + + logger.info(f"Set MCP server '{name}' enabled={is_enabled}") + return is_enabled, server + + +async def get_servers_config(names: list[str]) -> dict[str, dict[str, Any]]: + """批量获取服务器配置""" + return await _load_enabled_mcp_server_configs(names=names) diff --git a/backend/package/yuxi/services/mcp/tool_registry_service.py b/backend/package/yuxi/services/mcp/tool_registry_service.py new file mode 100644 index 000000000..0bb859c80 --- /dev/null +++ b/backend/package/yuxi/services/mcp/tool_registry_service.py @@ -0,0 +1,628 @@ +from __future__ import annotations + +import asyncio +import hashlib +import json +import logging +import os +import time +from collections.abc import Callable +from types import SimpleNamespace +from typing import Any, cast + +import httpx +from cachetools import LRUCache +from sqlalchemy.ext.asyncio import AsyncSession +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.services.mcp_auth.proxy_service import ( + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + INTERNAL_PROXY_TOKEN_HEADER, +) +from yuxi.services.mcp_tool_cache import RedisMcpToolCache +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + +logger = logging.getLogger("yuxi.mcp.tool_registry_service") + +# 全局共享状态(直接在本模块维护,供外部和测试使用) +_mcp_tools_cache: LRUCache = LRUCache(maxsize=128) +_mcp_tools_stats: LRUCache = LRUCache(maxsize=128) +_mcp_tools_failure_cache: LRUCache = LRUCache(maxsize=256) +_mcp_tool_cache_store = RedisMcpToolCache() +_mcp_lock = asyncio.Lock() +_MCP_TOOL_FAILURE_COOLDOWN_SECONDS = float(os.getenv("YUXI_MCP_TOOL_FAILURE_COOLDOWN_SECONDS", "30")) + + +def to_camel_case(s: str) -> str: + """转换字符串为 lowerCamelCase 命名格式""" + import re + + s = re.sub(r"[-_]+(.)", lambda m: m.group(1).upper(), s) + if len(s) > 0: + s = s[0].lower() + s[1:] + return s + + +def _extract_cache_identity(server_config: dict[str, Any]) -> tuple[dict[str, Any], str, bool]: + """提取用于缓存 key 比较的标识配置""" + cache_partition = str(server_config.get("__yuxi_cache_partition") or "server") + allow_global_cache = bool(server_config.get("__yuxi_allow_global_cache", True)) + + cache_identity = { + key: value + for key, value in server_config.items() + if key + not in { + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + "disabled_tools", + } + } + + headers = dict(cache_identity.get("headers") or {}) + headers.pop(INTERNAL_PROXY_TOKEN_HEADER, None) + if headers: + cache_identity["headers"] = headers + elif "headers" in cache_identity: + cache_identity["headers"] = {} + return cache_identity, cache_partition, allow_global_cache + + +async def _build_mcp_tool_cache_descriptor(server_name: str, server_config: dict[str, Any]) -> dict[str, Any]: + """生成缓存 Key 描述信息字典""" + cache_identity, cache_partition, allow_global_cache = _extract_cache_identity(server_config) + config_payload = json.dumps(cache_identity, sort_keys=True, ensure_ascii=True, separators=(",", ":")) + config_hash = hashlib.sha256(config_payload.encode("utf-8")).hexdigest()[:16] + + server_revision = await _mcp_tool_cache_store.get_server_revision(server_name) + partition_revision = 0 + if not allow_global_cache: + partition_revision = await _mcp_tool_cache_store.get_partition_revision(server_name, cache_partition) + revision_token = f"s{server_revision}:p{partition_revision}" + cache_prefix = f"{server_name}:{cache_partition}:{revision_token}:" + + return { + "cache_identity": cache_identity, + "cache_partition": cache_partition, + "allow_global_cache": allow_global_cache, + "config_hash": config_hash, + "cache_prefix": cache_prefix, + "cache_key": f"{cache_prefix}{config_hash}", + "server_revision": server_revision, + "partition_revision": partition_revision, + } + + +def _serialize_mcp_tools_manifest( + *, + server_name: str, + cache_partition: str, + cache_key: str, + tools: list[Callable[..., Any]], +) -> dict[str, Any]: + """将 Langchain 运行态 Tool 转换为 Manifest 字典以缓存到 Redis 中""" + entries = [] + for tool in tools: + if hasattr(tool, "args_schema") and tool.args_schema: + schema = tool.args_schema.schema() if hasattr(tool.args_schema, "schema") else {} + parameters = schema.get("properties", {}) + required = schema.get("required", []) + else: + parameters = {} + required = [] + metadata = dict(getattr(tool, "metadata", {}) or {}) + entries.append( + { + "name": tool.name, + "id": metadata.get("id") or tool.name, + "description": getattr(tool, "description", ""), + "parameters": parameters, + "required": required, + } + ) + return { + "server_name": server_name, + "cache_partition": cache_partition, + "cache_key": cache_key, + "tools": entries, + } + + +def _deserialize_mcp_tool_manifest(manifest: dict[str, Any]) -> list[Callable[..., Any]]: + """反序列化 Redis 中的 Manifest 字典还原为本地 Tool 对象结构""" + tools: list[Callable[..., Any]] = [] + for entry in manifest.get("tools", []): + args_schema = None + parameters = entry.get("parameters") or {} + required = entry.get("required") or [] + if parameters or required: + args_schema = SimpleNamespace( + schema=lambda parameters=parameters, required=required: { + "properties": parameters, + "required": required, + } + ) + tools.append( + SimpleNamespace( + name=entry.get("name") or "", + description=entry.get("description") or "", + metadata={"id": entry.get("id") or entry.get("name") or ""}, + args_schema=args_schema, + ) + ) + return tools + + +def _get_mcp_auth_config(server_config: dict[str, Any]) -> MCPAuthConfig | None: + auth_payload = server_config.get("auth_config") or {} + if not auth_payload: + return None + try: + return MCPAuthConfig.model_validate(auth_payload) + except Exception as exc: + logger.warning(f"Invalid MCP auth config while resolving tool preload strategy: {exc}") + return None + + +def _can_preload_mcp_server_tools_without_runtime_auth(server_config: dict[str, Any]) -> bool: + if not (server_config.get("auth_config") or {}): + return True + auth_config = _get_mcp_auth_config(server_config) + if auth_config is None: + return False + return auth_config.provider == "legacy_static" + + +def _get_cached_mcp_tool_failure(cache_key: str) -> dict[str, Any] | None: + entry = _mcp_tools_failure_cache.get(cache_key) + if not entry: + return None + retry_at = float(entry.get("retry_at") or 0) + if retry_at <= time.monotonic(): + _mcp_tools_failure_cache.pop(cache_key, None) + return None + return entry + + +def _record_mcp_tool_failure(cache_key: str, exc: BaseException) -> None: + if _MCP_TOOL_FAILURE_COOLDOWN_SECONDS <= 0: + return + _mcp_tools_failure_cache[cache_key] = { + "retry_at": time.monotonic() + _MCP_TOOL_FAILURE_COOLDOWN_SECONDS, + "message": str(exc) or exc.__class__.__name__, + } + + +def _clear_mcp_tool_failure(cache_key: str) -> None: + _mcp_tools_failure_cache.pop(cache_key, None) + + +def _clear_mcp_tool_failure_cache_for_server(server_name: str) -> None: + prefix = f"{server_name}:" + stale_keys = [key for key in _mcp_tools_failure_cache if key.startswith(prefix)] + for key in stale_keys: + _mcp_tools_failure_cache.pop(key, None) + + +async def get_mcp_tools( + server_name: str, + additional_servers: dict[str, dict[str, Any]] | None = None, + disabled_tools: list[str] = None, + cache: bool = True, + force_refresh: bool = False, +) -> list[Callable[..., Any]]: + """ + 获取指定 MCP 服务器的工具列表。 + + 优化生命周期: + - 集成缓存策略模式 (CachePolicy),动态决策是否在进程内容许缓存 Tool 对象。 + - 集成客户端连接池 (MCPClientPool),复用 Stdio 长期子进程及 HTTP Keep-Alive 连接。 + """ + if additional_servers and server_name in additional_servers: + server_config = additional_servers[server_name] + else: + from yuxi.services.mcp.server_service import get_enabled_mcp_server_config + + server_config = await get_enabled_mcp_server_config(server_name) + + if server_config is None: + logger.warning(f"MCP server '{server_name}' not found in database or disabled") + return [] + + cache_descriptor = await _build_mcp_tool_cache_descriptor(server_name, server_config) + cache_partition = cache_descriptor["cache_partition"] + cache_prefix = cache_descriptor["cache_prefix"] + cache_key = cache_descriptor["cache_key"] + + # 策略模式:根据 AuthProvider 确认是否容许内存缓存 Tool 实例对象 + from yuxi.services.mcp.cache_policy import CachePolicyFactory + + auth_config = _get_mcp_auth_config(server_config) + policy = CachePolicyFactory.get_policy(auth_config.provider if auth_config else None) + use_tool_object_cache = ( + cache + and policy.should_cache_tool_object() + and not bool(server_config.get(INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY)) + ) + + all_processed_tools: list[Callable[..., Any]] = [] + + async with _mcp_lock: + if not force_refresh and use_tool_object_cache and cache_key in _mcp_tools_cache: + all_processed_tools = _mcp_tools_cache[cache_key] + + if not all_processed_tools: + if not force_refresh: + failure_entry = _get_cached_mcp_tool_failure(cache_key) + if failure_entry is not None: + retry_in = max(0.0, float(failure_entry.get("retry_at") or 0) - time.monotonic()) + logger.debug( + f"Skip loading MCP tools for '{server_name}' during failure cooldown " + f"({retry_in:.1f}s left): {failure_entry.get('message')}" + ) + return [] + + try: + client_config = { + k: v + for k, v in server_config.items() + if k + not in ( + "disabled_tools", + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + ) + } + + # NOTE: 从长连接池中提取 ClientSession 实例 + # (对 Stdio 而言子进程被挂起复用,避免频繁启停;HTTP 协议亦保持 Keep-Alive) + from yuxi.services.mcp.client_pool import mcp_client_pool + + session = await mcp_client_pool.get_session( + server_name, + partition_key=f"{cache_partition}:s{cache_descriptor['server_revision']}:p{cache_descriptor['partition_revision']}", + runtime_config=client_config, + ) + + # 如果 session 是 Fake Client (有 get_tools 方法),我们直接调用它获取工具列表,避免 load_mcp_tools 报错 + if hasattr(session, "get_tools"): + raw_tools = cast(list[Any], await session.get_tools()) + else: + # 调用 langchain 官方加载工具,直接传入已预备并建立好的 session + from langchain_mcp_adapters.tools import load_mcp_tools + + raw_tools = cast(list[Any], await load_mcp_tools(session, server_name=server_name)) + + server_cc = to_camel_case(server_name) + for tool in raw_tools: + original_name = tool.name + tool_cc = to_camel_case(original_name) + unique_id = f"mcp__{server_cc}__{tool_cc}" + + if tool.metadata is None: + tool.metadata = {} + tool.metadata["id"] = unique_id + tool.handle_tool_error = True + all_processed_tools.append(tool) + + if cache: + if use_tool_object_cache: + async with _mcp_lock: + stale_keys = [ + key for key in _mcp_tools_cache if key.startswith(cache_prefix) and key != cache_key + ] + for stale_key in stale_keys: + _mcp_tools_cache.pop(stale_key, None) + _mcp_tools_cache[cache_key] = all_processed_tools + + await _mcp_tool_cache_store.set_manifest( + cache_key, + _serialize_mcp_tools_manifest( + server_name=server_name, + cache_partition=cache_partition, + cache_key=cache_key, + tools=all_processed_tools, + ), + ) + + global_config_disabled = server_config.get("disabled_tools") or [] + enabled_count = len([t for t in all_processed_tools if t.name not in global_config_disabled]) + _mcp_tools_stats[server_name] = { + "total": len(all_processed_tools), + "enabled": enabled_count, + "disabled": len(all_processed_tools) - enabled_count, + } + + logger.info( + f"Refreshed MCP tools cache for '{server_name}' with key '{cache_key}': " + f"{len(all_processed_tools)} tools loaded." + ) + + _clear_mcp_tool_failure(cache_key) + + except Exception as e: + _record_mcp_tool_failure(cache_key, e) + logger.warning( + f"MCP server '{server_name}' temporarily unavailable; " + f"suppress retries for {_MCP_TOOL_FAILURE_COOLDOWN_SECONDS:.0f}s: {e}" + ) + logger.debug(f"Failed to load tools from MCP server '{server_name}'", exc_info=True) + try: + partition_key = ( + f"{cache_partition}:s{cache_descriptor['server_revision']}:" + f"p{cache_descriptor['partition_revision']}" + ) + from yuxi.services.mcp.client_pool import mcp_client_pool + + await mcp_client_pool.remove_session(server_name, partition_key) + except Exception as pool_err: + logger.warning(f"Failed to remove stale session for {server_name}: {pool_err}") + return [] + + if disabled_tools: + filtered_tools = [t for t in all_processed_tools if t.name not in disabled_tools] + return filtered_tools + + return all_processed_tools + + +async def get_tools_from_all_servers(server_names: list[str] | None = None) -> list[Callable[..., Any]]: + """批量载入指定或所有可用服务的工具(用于系统初始化及预热)""" + from yuxi.services.mcp.server_service import _load_enabled_mcp_server_configs + + names: list[str] | None = None + if server_names is not None: + names = [] + seen: set[str] = set() + for value in server_names: + if not isinstance(value, str): + continue + name = value.strip() + if not name or name in seen: + continue + seen.add(name) + names.append(name) + if not names: + return [] + + server_configs = await _load_enabled_mcp_server_configs(names=names) + all_tools = [] + for server_name, server_config in server_configs.items(): + if not _can_preload_mcp_server_tools_without_runtime_auth(server_config): + logger.info(f"Skip MCP tool preload for '{server_name}' because runtime auth context is required") + continue + tools = await get_mcp_tools(server_name, additional_servers={server_name: server_config}) + all_tools.extend(tools) + return all_tools + + +async def clear_mcp_cache() -> None: + """清空本地内存工具缓存""" + global _mcp_tools_cache, _mcp_tools_failure_cache + _mcp_tools_cache = LRUCache(maxsize=128) + _mcp_tools_failure_cache = LRUCache(maxsize=256) + + try: + from yuxi.services.mcp.client_pool import clear_resolved_headers_cache, mcp_client_pool + + await mcp_client_pool.shutdown() + clear_resolved_headers_cache() + except Exception: + pass + + +def clear_mcp_server_tools_cache(server_name: str) -> None: + """清空指定服务器下的所有本地缓存""" + global _mcp_tools_cache + prefix = f"{server_name}:" + stale_keys = [k for k in _mcp_tools_cache if k.startswith(prefix)] + for key in stale_keys: + _mcp_tools_cache.pop(key, None) + _clear_mcp_tool_failure_cache_for_server(server_name) + + try: + from yuxi.services.mcp.client_pool import clear_server_resolved_headers_cache + + clear_server_resolved_headers_cache(server_name) + except Exception: + pass + + +def clear_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: + """清空指定连接下的本地内存缓存""" + if connection_id is None: + return + global _mcp_tools_cache + suffix = f":connection:{connection_id}:" + stale_keys = [k for k in _mcp_tools_cache if suffix in k and k.startswith(f"{server_name}:")] + for key in stale_keys: + _mcp_tools_cache.pop(key, None) + stale_failure_keys = [ + key for key in _mcp_tools_failure_cache if suffix in key and key.startswith(f"{server_name}:") + ] + for key in stale_failure_keys: + _mcp_tools_failure_cache.pop(key, None) + + try: + from yuxi.services.mcp.client_pool import clear_server_resolved_headers_cache + + clear_server_resolved_headers_cache(server_name) + except Exception: + pass + + +async def invalidate_mcp_server_tools_cache(server_name: str) -> None: + """全局失效指定服务器的全部二级缓存""" + clear_mcp_server_tools_cache(server_name) + await _mcp_tool_cache_store.bump_server_revision(server_name) + + +async def invalidate_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: + """失效指定连接下的二级缓存区划""" + if connection_id is None: + return + clear_mcp_connection_tools_cache(server_name, connection_id) + await _mcp_tool_cache_store.bump_partition_revision(server_name, f"connection:{connection_id}") + + +async def _invalidate_mcp_tools_cache_for_connection(connection: MCPConnection) -> None: + """依据 Scope 类别自动刷新并失效缓存""" + if connection.scope_type == "system": + await invalidate_mcp_server_tools_cache(connection.server_name) + else: + await invalidate_mcp_connection_tools_cache(connection.server_name, connection.id) + + +async def _clear_mcp_connection_runtime_auth_cache(connection_id: int | None) -> None: + """清理 Redis 中缓存的 Access Token 与锁状态""" + if connection_id is None: + return + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + + cache = RedisTokenCache() + try: + await cache.delete_access_token(connection_id) + except Exception as exc: + logger.warning(f"Failed to clear MCP token cache for connection {connection_id}: {exc}") + try: + await cache.release_refresh_lock(connection_id) + except Exception as exc: + logger.warning(f"Failed to clear MCP refresh lock for connection {connection_id}: {exc}") + + +async def _clear_mcp_server_runtime_auth_cache(db: AsyncSession, server_name: str) -> None: + """清理服务器下所有关联连接的 Token 缓存""" + from yuxi.services.mcp.connection_service import list_mcp_connections + + connections = await list_mcp_connections(db, server_name=server_name) + for connection in connections: + await _clear_mcp_connection_runtime_auth_cache(getattr(connection, "id", None)) + + +def get_mcp_tools_stats(server_name: str) -> dict[str, int] | None: + return _mcp_tools_stats.get(server_name) + + +async def get_enabled_mcp_tools( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, +) -> list: + from yuxi.services.mcp.server_service import get_runtime_mcp_server_config + + token = None + if auth_context: + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + + token = mcp_auth_context_var.set(auth_context) + + try: + config = await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=db, + http_client=http_client, + ) + if config is None: + logger.warning(f"MCP server '{server_name}' not found in database or disabled") + return [] + + disabled_tools = config.get("disabled_tools") or [] + return await get_mcp_tools( + server_name, + additional_servers={server_name: config}, + disabled_tools=disabled_tools, + ) + finally: + if token: + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + + mcp_auth_context_var.reset(token) + + +async def get_all_mcp_tools( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, + force_refresh: bool = False, +) -> list: + from yuxi.services.mcp.server_service import get_enabled_mcp_server_config, get_runtime_mcp_server_config + + token = None + if auth_context: + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + + token = mcp_auth_context_var.set(auth_context) + + try: + if auth_context is None and db is None: + config = await get_enabled_mcp_server_config(server_name) + else: + config = await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=db, + http_client=http_client, + ) + if config is None: + logger.warning(f"MCP server '{server_name}' not found in database or disabled") + return [] + + if not force_refresh: + cache_descriptor = await _build_mcp_tool_cache_descriptor(server_name, config) + manifest = await _mcp_tool_cache_store.get_manifest(cache_descriptor["cache_key"]) + if manifest is not None: + return _deserialize_mcp_tool_manifest(manifest) + + return await get_mcp_tools( + server_name, + additional_servers={server_name: config}, + disabled_tools=[], + cache=True, + force_refresh=force_refresh, + ) + finally: + if token: + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + + mcp_auth_context_var.reset(token) + + +async def toggle_tool_enabled( + db: AsyncSession, + server_name: str, + tool_name: str, + updated_by: str | None = None, +) -> tuple[bool, MCPServer]: + """切换单个工具的启用状态""" + from yuxi.services.mcp.server_service import get_mcp_server + + server = await get_mcp_server(db, server_name) + if not server: + raise ValueError(f"Server '{server_name}' does not exist") + + disabled_tools = list(server.disabled_tools or []) + + if tool_name in disabled_tools: + disabled_tools.remove(tool_name) + enabled = True + else: + disabled_tools.append(tool_name) + enabled = False + + server.disabled_tools = disabled_tools + if updated_by is not None: + server.updated_by = updated_by + await db.commit() + + # 清除内存工具缓存 + clear_mcp_server_tools_cache(server_name) + + logger.info(f"Toggled tool '{tool_name}' for server '{server_name}' enabled={enabled}") + return enabled, server diff --git a/backend/package/yuxi/services/mcp_auth/__init__.py b/backend/package/yuxi/services/mcp_auth/__init__.py new file mode 100644 index 000000000..ef382794e --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/__init__.py @@ -0,0 +1,14 @@ +"""MCP auth helpers.""" + +from .config_models import MCPAuthConfig +from .crypto import decrypt_credential_blob, encrypt_credential_blob, is_encrypted_credential_blob +from .template_resolver import TemplateResolutionError, resolve_template_value + +__all__ = [ + "MCPAuthConfig", + "TemplateResolutionError", + "decrypt_credential_blob", + "encrypt_credential_blob", + "is_encrypted_credential_blob", + "resolve_template_value", +] diff --git a/backend/package/yuxi/services/mcp_auth/config_models.py b/backend/package/yuxi/services/mcp_auth/config_models.py new file mode 100644 index 000000000..e9ff7fa4a --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/config_models.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class InjectEntry(BaseModel): + name: str + value_template: str + + +class InjectConfig(BaseModel): + target: Literal["headers", "env"] + entries: list[InjectEntry] = Field(default_factory=list) + + +class RefreshPolicy(BaseModel): + pre_refresh_seconds: int = 0 + retry_once_on_401: bool = False + + +class MCPAuthConfig(BaseModel): + model_config = ConfigDict(extra="ignore") + + version: int = 1 + provider: Literal[ + "legacy_static", + "bound_secret", + "client_credentials", + "custom_http_token", + "authorization_code", + "stdio_env", + ] + binding_scope: Literal["inline", "system", "department", "user"] | None = None + manifest_scope: Literal["server", "binding"] | None = None + inject: InjectConfig + refresh_policy: RefreshPolicy = Field(default_factory=RefreshPolicy) + token_request: dict[str, Any] | None = None + + @model_validator(mode="after") + def apply_defaults_and_validate(self) -> MCPAuthConfig: + if self.binding_scope is None: + self.binding_scope = "inline" if self.provider == "legacy_static" else "system" + if self.manifest_scope is None: + self.manifest_scope = "server" + if ( + self.provider in {"custom_http_token", "client_credentials", "authorization_code"} + and not self.token_request + ): + raise ValueError("token_request is required for dynamic auth providers") + return self + + def get_secret_fields(self) -> list[str]: + """Extract all secret fields referenced in the configuration templates.""" + import json + import re + + pattern = re.compile(r"\$\{secret\.([^\}]+)\}") + dumped = json.dumps(self.model_dump(mode="json")) + matches = pattern.findall(dumped) + # Deduplicate while preserving order + return list(dict.fromkeys(matches)) diff --git a/backend/package/yuxi/services/mcp_auth/crypto.py b/backend/package/yuxi/services/mcp_auth/crypto.py new file mode 100644 index 000000000..5edff8e65 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/crypto.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import base64 +import hashlib +import json +import os +from typing import Any + +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF + +MASTER_KEY_ENV = "MCP_CREDENTIALS_MASTER_KEY" +ENVELOPE_VERSION = 2 +ENVELOPE_KEY_ID = "local" +_AAD = b"yuxi:mcp_credentials:v1" + + +def _get_master_key() -> str: + value = os.getenv(MASTER_KEY_ENV, "").strip() + if not value: + raise ValueError(f"{MASTER_KEY_ENV} is required when storing encrypted MCP credentials") + return value + + +def _derive_aes_key_v1(master_key: str) -> bytes: + # legacy v1 key derivation (raw sha256) + return hashlib.sha256(master_key.encode("utf-8")).digest() + + +def _derive_aes_key_v2(master_key: str, salt: bytes) -> bytes: + # v2 key derivation using HKDF + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + info=b"mcp-credentials-v2", + ) + return hkdf.derive(master_key.encode("utf-8")) + + +def _b64encode(value: bytes) -> str: + return base64.urlsafe_b64encode(value).decode("ascii") + + +def _b64decode(value: str) -> bytes: + return base64.urlsafe_b64decode(value.encode("ascii")) + + +def _parse_envelope(blob: str) -> dict[str, Any] | None: + try: + payload = json.loads(blob) + except (TypeError, json.JSONDecodeError): + return None + if not isinstance(payload, dict): + return None + required_keys = {"v", "kid", "nonce", "ciphertext"} + if not required_keys.issubset(payload.keys()): + return None + v = payload.get("v") + if v not in (1, 2): + return None + if v == 2 and "salt" not in payload: + return None + return payload + + +def is_encrypted_credential_blob(blob: str | None) -> bool: + if not blob or not isinstance(blob, str): + return False + return _parse_envelope(blob) is not None + + +def encrypt_credential_blob(plaintext: str) -> str: + if not plaintext: + return plaintext + if is_encrypted_credential_blob(plaintext): + return plaintext + + master_key = _get_master_key() + salt = os.urandom(16) + aesgcm = AESGCM(_derive_aes_key_v2(master_key, salt)) + nonce = os.urandom(12) + ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), _AAD) + return json.dumps( + { + "v": ENVELOPE_VERSION, + "kid": ENVELOPE_KEY_ID, + "salt": _b64encode(salt), + "nonce": _b64encode(nonce), + "ciphertext": _b64encode(ciphertext), + }, + ensure_ascii=True, + separators=(",", ":"), + ) + + +def decrypt_credential_blob(blob: str | None) -> str | None: + if blob is None or not isinstance(blob, str): + return blob + + payload = _parse_envelope(blob) + if payload is None: + return blob + + master_key = _get_master_key() + v = payload.get("v") + if v == 1: + key = _derive_aes_key_v1(master_key) + elif v == 2: + salt = _b64decode(payload["salt"]) + key = _derive_aes_key_v2(master_key, salt) + else: + return blob + + aesgcm = AESGCM(key) + plaintext = aesgcm.decrypt( + _b64decode(payload["nonce"]), + _b64decode(payload["ciphertext"]), + _AAD, + ) + return plaintext.decode("utf-8") diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py b/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py new file mode 100644 index 000000000..ea8ae30aa --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from yuxi.services.mcp_auth.fetchers.base import BaseTokenFetcher, ITokenFetcher +from yuxi.services.mcp_auth.fetchers.factory import TokenFetcherFactory +from yuxi.services.mcp_auth.fetchers.http_fetcher import ClientCredentialsFetcher, CustomHttpTokenFetcher +from yuxi.services.mcp_auth.fetchers.oauth_fetcher import AuthorizationCodeFetcher + +__all__ = [ + "ITokenFetcher", + "BaseTokenFetcher", + "CustomHttpTokenFetcher", + "ClientCredentialsFetcher", + "AuthorizationCodeFetcher", + "TokenFetcherFactory", +] diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/base.py b/backend/package/yuxi/services/mcp_auth/fetchers/base.py new file mode 100644 index 000000000..b8c89cf87 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/base.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import httpx +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.template_resolver import resolve_template_value + +# 注释必须使用简体中文,符合 RULE[user_global] +# NOTE: 所有获取 Token 的具体策略需要继承 ITokenFetcher 并实现 fetch_token 方法。 + +_DEFAULT_TOKEN_RESPONSE_MAP = { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + "expires_at": "expires_at", + "scope": "scope", + "token_type": "token_type", +} + + +def extract_path(payload: dict[str, Any], path: str) -> Any: + """从 payload 中根据点分路径提取字段值""" + current: Any = payload + for segment in path.split("."): + if isinstance(current, dict): + current = current[segment] + continue + raise KeyError(path) + return current + + +async def fetch_custom_http_token( + request_config: dict[str, Any], + *, + response_map: dict[str, str] | None, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, +) -> dict[str, Any]: + """执行自定义 HTTP 请求获取 Token""" + from yuxi.services.mcp_auth.orchestrator import _normalize_token_payload + + response_map = response_map or dict(_DEFAULT_TOKEN_RESPONSE_MAP) + if http_client is None: + http_client = httpx.AsyncClient(timeout=httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0)) + should_close = True + else: + should_close = False + + try: + headers = resolve_template_value( + request_config.get("headers") or {}, + context=context_payload, + secret=secret_values, + token=token_values, + access_token=token_values.get("access_token"), + ) + body = resolve_template_value( + request_config.get("body_template") or {}, + context=context_payload, + secret=secret_values, + token=token_values, + access_token=token_values.get("access_token"), + ) + body_type = request_config.get("body_type", "json") + request_kwargs: dict[str, Any] = { + "method": (request_config.get("method") or "POST").upper(), + "url": request_config["url"], + "headers": headers, + } + if body_type == "json": + request_kwargs["json"] = body + else: + request_kwargs["data"] = body + + request_kwargs["timeout"] = httpx.Timeout(10.0, read=30.0) + + response = await http_client.request(**request_kwargs) + response.raise_for_status() + payload = response.json() + resolved = {} + for field_name, path in response_map.items(): + try: + resolved[field_name] = extract_path(payload, path) + except KeyError: + continue + return _normalize_token_payload(resolved) + except Exception as exc: + import traceback + + from yuxi.utils import logger + + logger.error(f"fetch_custom_http_token failure: {exc}, traceback: {traceback.format_exc()}") + raise + finally: + if should_close: + await http_client.aclose() + + +class ITokenFetcher(ABC): + """Token 获取接口""" + + @abstractmethod + async def fetch_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + """ + 获取或刷新 Access Token + """ + pass + + +class BaseTokenFetcher(ITokenFetcher, ABC): + """带自动 Refresh 逻辑的 Token 获取基类""" + + async def fetch_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + # NOTE: 优先检查是否有可用 refresh token,并进行刷新 + token_request = auth_config.token_request or {} + refresh_request = token_request.get("refresh") + if ( + token_values + and refresh_request + and (token_values.get("refresh_token") or credential_payload.get("refresh_token")) + ): + refresh_token_values = dict(token_values) + if not refresh_token_values.get("refresh_token") and credential_payload.get("refresh_token"): + refresh_token_values["refresh_token"] = credential_payload["refresh_token"] + + refreshed = await fetch_custom_http_token( + refresh_request, + response_map=(refresh_request.get("response_map") or token_request.get("response_map")), + context_payload=context_payload, + secret_values=secret_values, + token_values=refresh_token_values, + http_client=http_client, + ) + if not refreshed.get("refresh_token") and refresh_token_values.get("refresh_token"): + refreshed["refresh_token"] = refresh_token_values["refresh_token"] + return refreshed + + # NOTE: 如果不满足刷新条件,则获取全新 Token + return await self._fetch_new_token( + auth_config, + context_payload=context_payload, + secret_values=secret_values, + credential_payload=credential_payload, + token_values=token_values, + http_client=http_client, + ) + + @abstractmethod + async def _fetch_new_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + """获取全新的 Token""" + pass diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/factory.py b/backend/package/yuxi/services/mcp_auth/fetchers/factory.py new file mode 100644 index 000000000..f02e69ac1 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/factory.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from yuxi.services.mcp_auth.fetchers.base import ITokenFetcher +from yuxi.services.mcp_auth.fetchers.http_fetcher import ClientCredentialsFetcher, CustomHttpTokenFetcher +from yuxi.services.mcp_auth.fetchers.oauth_fetcher import AuthorizationCodeFetcher + + +class TokenFetcherFactory: + """TokenFetcher 工厂""" + + @staticmethod + def get_fetcher(provider: str) -> ITokenFetcher: + if provider == "custom_http_token": + return CustomHttpTokenFetcher() + elif provider == "client_credentials": + return ClientCredentialsFetcher() + elif provider == "authorization_code": + return AuthorizationCodeFetcher() + else: + raise ValueError(f"Unsupported MCP auth provider for dynamic token: {provider}") diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py b/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py new file mode 100644 index 000000000..7ef45f431 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any + +import httpx +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.fetchers.base import BaseTokenFetcher, fetch_custom_http_token + + +class CustomHttpTokenFetcher(BaseTokenFetcher): + """自定义 HTTP 方式获取 Token""" + + async def _fetch_new_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + token_request = auth_config.token_request or {} + resolved = await fetch_custom_http_token( + token_request, + response_map=token_request.get("response_map"), + context_payload=context_payload, + secret_values=secret_values, + token_values=token_values, + http_client=http_client, + ) + if not resolved.get("refresh_token") and credential_payload.get("refresh_token"): + resolved["refresh_token"] = credential_payload["refresh_token"] + return resolved + + +class ClientCredentialsFetcher(CustomHttpTokenFetcher): + """客户端凭证 (Client Credentials) 方式获取 Token""" + + # NOTE: 当前其底层获取逻辑与 CustomHttpTokenFetcher 相同 + pass diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py b/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py new file mode 100644 index 000000000..fe7b8eaa0 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import Any + +import httpx +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.fetchers.base import _DEFAULT_TOKEN_RESPONSE_MAP, ITokenFetcher, fetch_custom_http_token + + +class AuthorizationCodeFetcher(ITokenFetcher): + """授权码 (Authorization Code) 模式下的后台 Token 刷新获取""" + + async def _resolve_token_request_config( + self, + token_request: dict[str, Any], + secret_values: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient, + ) -> tuple[dict[str, Any], dict[str, str]]: + issuer_url = ( + token_request.get("issuer_url") or secret_values.get("issuer_url") or token_values.get("issuer_url") + ) + if not issuer_url: + raise ValueError("authorization_code provider requires token_request.issuer_url") + discovery_url = f"{str(issuer_url).rstrip('/')}/.well-known/openid-configuration" + response = await http_client.get(discovery_url) + response.raise_for_status() + payload = response.json() + token_endpoint = payload.get("token_endpoint") + if not token_endpoint: + raise ValueError("authorization_code provider discovery missing token_endpoint") + + return { + "url": token_endpoint, + "method": "POST", + "body_type": "form", + "headers": { + "Content-Type": "application/x-www-form-urlencoded", + }, + "body_template": { + "grant_type": "refresh_token", + "refresh_token": "${token.refresh_token}", + "client_id": token_request.get("client_id", "${secret.client_id}"), + "client_secret": token_request.get("client_secret", "${secret.client_secret}"), + }, + }, dict(_DEFAULT_TOKEN_RESPONSE_MAP) + + async def fetch_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + if http_client is None: + http_client = httpx.AsyncClient() + should_close = True + else: + should_close = False + + try: + token_request = auth_config.token_request or {} + authorization_request, response_map = await self._resolve_token_request_config( + token_request=token_request, + secret_values=secret_values, + token_values=token_values or credential_payload, + http_client=http_client, + ) + authorization_token_values = dict(token_values or credential_payload) + if not authorization_token_values.get("refresh_token") and credential_payload.get("refresh_token"): + authorization_token_values["refresh_token"] = credential_payload["refresh_token"] + + resolved = await fetch_custom_http_token( + authorization_request, + response_map=response_map, + context_payload=context_payload, + secret_values=secret_values, + token_values=authorization_token_values, + http_client=http_client, + ) + if not resolved.get("refresh_token") and authorization_token_values.get("refresh_token"): + resolved["refresh_token"] = authorization_token_values["refresh_token"] + return resolved + finally: + if should_close: + await http_client.aclose() diff --git a/backend/package/yuxi/services/mcp_auth/orchestrator.py b/backend/package/yuxi/services/mcp_auth/orchestrator.py new file mode 100644 index 000000000..d04d1154a --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/orchestrator.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +import asyncio +import contextvars +import json +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import Any + +import httpx +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.crypto import decrypt_credential_blob +from yuxi.services.mcp_auth.template_resolver import resolve_template_value +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer +from yuxi.utils import logger + + +@dataclass(slots=True) +class AuthContext: + user_id: str | None = None + department_id: str | None = None + work_id: str | None = None + + +mcp_auth_context_var: contextvars.ContextVar[AuthContext | None] = contextvars.ContextVar( + "mcp_auth_context_var", default=None +) + +_DEFAULT_TOKEN_RESPONSE_MAP = { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + "expires_at": "expires_at", + "scope": "scope", + "token_type": "token_type", +} +_REFRESH_LOCK_WAIT_SECONDS = 1.0 +_REFRESH_LOCK_POLL_INTERVAL_SECONDS = 0.05 + + +def _parse_credential_blob(connection: MCPConnection | None) -> dict[str, Any]: + if connection is None or not connection.credential_blob: + return {} + if isinstance(connection.credential_blob, dict): + return dict(connection.credential_blob) + decrypted = decrypt_credential_blob(connection.credential_blob) + if not decrypted: + return {} + try: + return json.loads(decrypted) + except json.JSONDecodeError: + return { + "access_token": decrypted, + "secrets": {"access_token": decrypted}, + } + + +def _extract_path(payload: dict[str, Any], path: str) -> Any: + current: Any = payload + for segment in path.split("."): + if isinstance(current, dict): + current = current[segment] + continue + raise KeyError(path) + return current + + +def _base_server_config(server: MCPServer) -> dict[str, Any]: + config = server.to_mcp_config() + config.pop("auth_config", None) + return config + + +def _context_payload(context: AuthContext) -> dict[str, Any]: + return { + "user_id": context.user_id, + "department_id": context.department_id, + "work_id": context.work_id, + } + + +def _parse_datetime(value: str | None) -> datetime | None: + if not value or not isinstance(value, str): + return None + try: + parsed = datetime.fromisoformat(value) + except ValueError: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=UTC) + return parsed.astimezone(UTC) + + +def _normalize_token_payload(token_values: dict[str, Any]) -> dict[str, Any]: + normalized = dict(token_values) + expires_at = normalized.get("expires_at") + if isinstance(expires_at, datetime): + if expires_at.tzinfo is None: + # 若无时区信息则默认为 UTC,避免 astimezone() 将其视作本地时区转换 + expires_at = expires_at.replace(tzinfo=UTC) + normalized["expires_at"] = expires_at.astimezone(UTC).isoformat() + return normalized + if isinstance(expires_at, str): + parsed = _parse_datetime(expires_at) + if parsed is not None: + normalized["expires_at"] = parsed.isoformat() + return normalized + expires_in = normalized.get("expires_in") + if isinstance(expires_in, str) and expires_in.isdigit(): + expires_in = int(expires_in) + normalized["expires_in"] = expires_in + if isinstance(expires_in, (int, float)): + normalized["expires_at"] = (datetime.now(tz=UTC) + timedelta(seconds=int(expires_in))).isoformat() + return normalized + + +def _is_token_expiring_soon(token_values: dict[str, Any], *, pre_refresh_seconds: int) -> bool: + expires_at = _parse_datetime(token_values.get("expires_at")) + if expires_at is None: + return False + return expires_at <= datetime.now(tz=UTC) + timedelta(seconds=max(pre_refresh_seconds, 0)) + + +def _merge_injected_entries( + config: dict[str, Any], + *, + inject_target: str, + inject_entries: list[dict[str, str]], + context: AuthContext, + secret_values: dict[str, Any], + token_values: dict[str, Any], + access_token: str | None, +) -> dict[str, Any]: + target_values = dict(config.get(inject_target) or {}) + for entry in inject_entries: + target_values[entry["name"]] = resolve_template_value( + entry["value_template"], + context={ + "user_id": context.user_id, + "department_id": context.department_id, + "work_id": context.work_id, + }, + secret=secret_values, + token=token_values, + access_token=access_token, + ) + config[inject_target] = target_values + return config + + +async def _load_cached_token( + *, + token_cache: Any | None, + connection_id: int | None, +) -> dict[str, Any] | None: + if token_cache is None or connection_id is None: + return None + try: + cached = await token_cache.get_access_token(connection_id) + except Exception as exc: + logger.warning(f"Failed to load MCP access token cache for connection {connection_id}: {exc}") + return None + if not cached: + return None + return _normalize_token_payload(cached) + + +async def _store_cached_token( + *, + token_cache: Any | None, + connection_id: int | None, + token_payload: dict[str, Any], +) -> None: + if token_cache is None or connection_id is None: + return + try: + await token_cache.set_access_token(connection_id, token_payload) + except Exception as exc: + logger.warning(f"Failed to persist MCP access token cache for connection {connection_id}: {exc}") + + +async def _acquire_refresh_lock( + *, + token_cache: Any | None, + connection_id: int | None, +) -> bool: + if token_cache is None or connection_id is None: + return True + acquire_method = getattr(token_cache, "acquire_refresh_lock", None) + if acquire_method is None: + return True + try: + return bool(await acquire_method(connection_id)) + except Exception as exc: + logger.warning(f"Failed to acquire MCP refresh lock for connection {connection_id}: {exc}") + return True + + +async def _release_refresh_lock( + *, + token_cache: Any | None, + connection_id: int | None, + acquired: bool, +) -> None: + if not acquired or token_cache is None or connection_id is None: + return + release_method = getattr(token_cache, "release_refresh_lock", None) + if release_method is None: + return + try: + await release_method(connection_id) + except Exception as exc: + logger.warning(f"Failed to release MCP refresh lock for connection {connection_id}: {exc}") + + +async def _wait_for_refreshed_token( + *, + token_cache: Any | None, + connection_id: int | None, + pre_refresh_seconds: int, +) -> dict[str, Any] | None: + if token_cache is None or connection_id is None: + return None + + remaining = _REFRESH_LOCK_WAIT_SECONDS + while remaining > 0: + await asyncio.sleep(_REFRESH_LOCK_POLL_INTERVAL_SECONDS) + cached_token = await _load_cached_token(token_cache=token_cache, connection_id=connection_id) + if cached_token and not _is_token_expiring_soon( + cached_token, + pre_refresh_seconds=pre_refresh_seconds, + ): + return cached_token + remaining -= _REFRESH_LOCK_POLL_INTERVAL_SECONDS + return None + + +async def _request_dynamic_token_values( + auth_config: MCPAuthConfig, + *, + context: AuthContext, + connection: MCPConnection | None, + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + http_client: httpx.AsyncClient | None, + token_cache: Any | None, + token_values: dict[str, Any], +) -> dict[str, Any]: + from yuxi.services.mcp_auth.fetchers.factory import TokenFetcherFactory + + fetcher = TokenFetcherFactory.get_fetcher(auth_config.provider) + resolved = await fetcher.fetch_token( + auth_config, + context_payload={ + "user_id": context.user_id, + "department_id": context.department_id, + "work_id": context.work_id, + }, + secret_values=secret_values, + credential_payload=credential_payload, + token_values=token_values, + http_client=http_client, + ) + + await _store_cached_token( + token_cache=token_cache, + connection_id=getattr(connection, "id", None), + token_payload=resolved, + ) + return resolved + + +async def _resolve_dynamic_token_values( + auth_config: MCPAuthConfig, + *, + context: AuthContext, + connection: MCPConnection | None, + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + http_client: httpx.AsyncClient | None, + token_cache: Any | None, +) -> dict[str, Any]: + if token_cache is None and connection is not None: + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + + token_cache = RedisTokenCache() + + cached_token = await _load_cached_token( + token_cache=token_cache, + connection_id=getattr(connection, "id", None), + ) + pre_refresh_seconds = auth_config.refresh_policy.pre_refresh_seconds + if cached_token and not _is_token_expiring_soon(cached_token, pre_refresh_seconds=pre_refresh_seconds): + return cached_token + + token_values = dict(cached_token or {}) + if not token_values: + token_values.update( + { + key: value + for key, value in credential_payload.items() + if key in {"access_token", "refresh_token", "expires_in", "expires_at", "scope", "token_type"} + } + ) + token_values = _normalize_token_payload(token_values) + if token_values.get("access_token") and not _is_token_expiring_soon( + token_values, + pre_refresh_seconds=pre_refresh_seconds, + ): + return token_values + connection_id = getattr(connection, "id", None) + lock_acquired = await _acquire_refresh_lock(token_cache=token_cache, connection_id=connection_id) + if not lock_acquired: + refreshed_from_cache = await _wait_for_refreshed_token( + token_cache=token_cache, + connection_id=connection_id, + pre_refresh_seconds=pre_refresh_seconds, + ) + if refreshed_from_cache: + return refreshed_from_cache + + try: + return await _request_dynamic_token_values( + auth_config, + context=context, + connection=connection, + secret_values=secret_values, + credential_payload=credential_payload, + http_client=http_client, + token_cache=token_cache, + token_values=token_values, + ) + finally: + await _release_refresh_lock( + token_cache=token_cache, + connection_id=connection_id, + acquired=lock_acquired, + ) + + +async def resolve_runtime_mcp_config( + server: MCPServer, + *, + auth_context: AuthContext, + connection: MCPConnection | None = None, + http_client: httpx.AsyncClient | None = None, + token_cache: Any | None = None, +) -> dict[str, Any]: + config = _base_server_config(server) + auth_payload = server.auth_config_json or {} + if not auth_payload: + return config + + auth_config = MCPAuthConfig.model_validate(auth_payload) + inject_entries = [entry.model_dump() for entry in auth_config.inject.entries] + credential_payload = _parse_credential_blob(connection) + secret_values = credential_payload.get("secrets") or {} + + if auth_config.provider == "legacy_static": + return config + + if auth_config.provider in {"bound_secret", "stdio_env"}: + return _merge_injected_entries( + config, + inject_target=auth_config.inject.target, + inject_entries=inject_entries, + context=auth_context, + secret_values=secret_values, + token_values=credential_payload, + access_token=None, + ) + + if auth_config.provider in {"custom_http_token", "client_credentials", "authorization_code"}: + token_values = await _resolve_dynamic_token_values( + auth_config, + context=auth_context, + connection=connection, + secret_values=secret_values, + credential_payload=credential_payload, + http_client=http_client, + token_cache=token_cache, + ) + return _merge_injected_entries( + config, + inject_target=auth_config.inject.target, + inject_entries=inject_entries, + context=auth_context, + secret_values=secret_values, + token_values=token_values, + access_token=token_values.get("access_token"), + ) + + raise ValueError(f"Unsupported MCP auth provider: {auth_config.provider}") diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py new file mode 100644 index 000000000..448d6990f --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import Any +from urllib.parse import urlencode + +import httpx +from fastapi import HTTPException, Request, Response +from fastapi.responses import StreamingResponse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + +from server.utils.auth_utils import AuthUtils + +_proxy_http_client: httpx.AsyncClient | None = None + + +def get_shared_proxy_client() -> httpx.AsyncClient: + global _proxy_http_client + if _proxy_http_client is None: + _proxy_http_client = httpx.AsyncClient(timeout=httpx.Timeout(connect=30.0, pool=120.0, read=120.0, write=30.0)) + return _proxy_http_client + + +async def close_shared_proxy_client() -> None: + global _proxy_http_client + if _proxy_http_client is not None: + await _proxy_http_client.aclose() + _proxy_http_client = None + + +INTERNAL_PROXY_TOKEN_HEADER = "X-Yuxi-MCP-Proxy-Token" +INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY = "__yuxi_disable_tool_object_cache" +_PROXY_TOKEN_TYPE = "mcp_proxy" +_DYNAMIC_HTTP_PROVIDERS = {"custom_http_token", "client_credentials", "authorization_code"} +_HTTP_TRANSPORTS = {"streamable_http", "sse"} +_HOP_BY_HOP_HEADERS = { + "connection", + "content-length", + "host", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", +} + + +def should_use_internal_proxy(server: MCPServer, auth_config: MCPAuthConfig, proxy_base_url: str | None) -> bool: + return bool( + proxy_base_url and server.transport in _HTTP_TRANSPORTS and auth_config.provider in _DYNAMIC_HTTP_PROVIDERS + ) + + +def create_proxy_access_token(server_name: str, auth_context: AuthContext) -> str: + return AuthUtils.create_access_token( + { + "sub": f"mcp-proxy:{server_name}", + "token_type": _PROXY_TOKEN_TYPE, + "server_name": server_name, + "user_id": auth_context.user_id, + "department_id": auth_context.department_id, + "work_id": auth_context.work_id, + }, + expires_delta=timedelta(minutes=15), + ) + + +def decode_proxy_access_token(token: str, *, server_name: str) -> AuthContext: + payload = AuthUtils.decode_token(token) + if not payload: + raise ValueError("invalid internal proxy token") + if payload.get("token_type") != _PROXY_TOKEN_TYPE: + raise ValueError("invalid internal proxy token type") + if payload.get("server_name") != server_name: + raise ValueError("internal proxy token server mismatch") + return AuthContext( + user_id=payload.get("user_id"), + department_id=payload.get("department_id"), + work_id=payload.get("work_id"), + ) + + +def build_internal_proxy_url(proxy_base_url: str, server_name: str) -> str: + return f"{proxy_base_url.rstrip('/')}/api/internal/mcp-proxy/{server_name}" + + +def build_proxy_runtime_config( + server: MCPServer, + *, + auth_context: AuthContext, + proxy_base_url: str, +) -> dict[str, Any]: + config = server.to_mcp_config() + config.pop("auth_config", None) + headers = dict(config.get("headers") or {}) + headers[INTERNAL_PROXY_TOKEN_HEADER] = create_proxy_access_token(server.name, auth_context) + config["headers"] = headers + config["url"] = build_internal_proxy_url(proxy_base_url, server.name) + config[INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY] = True + return config + + +def _merge_upstream_headers( + base_headers: dict[str, Any], + request_headers: dict[str, str] | None, +) -> dict[str, Any]: + merged = dict(base_headers or {}) + _PROTECTED_HEADERS = { + INTERNAL_PROXY_TOKEN_HEADER.lower(), + "authorization", + } + for key, value in (request_headers or {}).items(): + if key.lower() in _HOP_BY_HOP_HEADERS or key.lower() in _PROTECTED_HEADERS: + continue + merged[key] = value + return merged + + +def _build_target_url(base_url: str, path: str = "", query_params: dict[str, Any] | None = None) -> str: + if not path: + target = base_url + else: + target = f"{base_url.rstrip('/')}/{path.lstrip('/')}" + if query_params: + return f"{target}?{urlencode(query_params, doseq=True)}" + return target + + +def _mark_reauth_required(connection: MCPConnection | None, message: str) -> None: + if connection is None: + return + connection.status = "reauth_required" + meta_json = dict(connection.meta_json or {}) + meta_json["last_error"] = { + "code": "unauthorized", + "message": message, + } + connection.meta_json = meta_json + + +def _record_scope_error(connection: MCPConnection | None, message: str) -> None: + if connection is None: + return + meta_json = dict(connection.meta_json or {}) + meta_json["last_error"] = { + "code": "insufficient_scope", + "message": message, + } + connection.meta_json = meta_json + + +async def handle_mcp_proxy_request( + server_name: str, + request: Request, + path: str, + internal_token: str, + db: AsyncSession, +) -> Response: + """内部网关主入口:鉴权解析、查库拦截与流式代理""" + from yuxi.services.mcp.server_service import get_mcp_server + + try: + auth_context = decode_proxy_access_token(internal_token, server_name=server_name) + except ValueError as exc: + raise HTTPException(status_code=401, detail=str(exc)) from exc + + server = await get_mcp_server(db, server_name) + if server is None: + raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在") + if not bool(getattr(server, "enabled", True)): + raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在或已停用") + + auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) + + from yuxi.services.mcp.connection_service import _resolve_scope_id, requires_bound_mcp_connection + + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) + connection = None + if scope_id is not None: + result = await db.execute( + select(MCPConnection).where( + MCPConnection.server_name == server.name, + MCPConnection.scope_type == auth_config.binding_scope, + MCPConnection.scope_id == scope_id, + MCPConnection.status == "active", + ) + ) + connection = result.scalar_one_or_none() + + if connection is None and requires_bound_mcp_connection(auth_config): + raise HTTPException(status_code=403, detail="当前用户没有该 MCP 的有效连接") + + # 注意:我们读取整个 request body,因为 MCP 请求参数通常极小, + # 但由于可能有 401 重试,我们需要保存下 body 来实现背压重发。 + body = await request.body() + return await _proxy_mcp_request_stream( + server=server, + connection=connection, + auth_context=auth_context, + request=request, + body=body, + path=path, + db=db, + ) + + +async def _proxy_mcp_request_stream( + server: MCPServer, + *, + connection: MCPConnection | None, + auth_context: AuthContext, + request: Request, + body: bytes, + path: str = "", + db: AsyncSession, + _http_client: httpx.AsyncClient | None = None, + _token_cache: Any | None = None, +) -> Response: + """底层流式转发逻辑:处理 HTTPX 透传、SSE 和 401 重试闭环事务""" + auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) + if server.transport not in _HTTP_TRANSPORTS: + raise HTTPException( + status_code=400, detail=f"Internal proxy only supports HTTP MCP transports, got: {server.transport}" + ) + + connect_timeout = server.timeout or 60.0 + read_timeout = server.sse_read_timeout or connect_timeout + request_timeout = httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=connect_timeout, + pool=connect_timeout, + ) + http_client = _http_client or get_shared_proxy_client() + + if _token_cache is not None: + token_cache = _token_cache + else: + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + + token_cache = RedisTokenCache() + + max_attempts = 2 if auth_config.refresh_policy.retry_once_on_401 else 1 + + for attempt in range(max_attempts): + runtime_config = await resolve_runtime_mcp_config( + server, + auth_context=auth_context, + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + target_url = _build_target_url(runtime_config["url"], path=path, query_params=dict(request.query_params)) + upstream_headers = _merge_upstream_headers(runtime_config.get("headers") or {}, dict(request.headers)) + + request_obj = http_client.build_request( + method=request.method.upper(), + url=target_url, + headers=upstream_headers, + content=body, + timeout=request_timeout, + ) + + # 使用 send(stream=True) 获取异步可迭代响应而不会阻塞 SSE 长链接 + response = await http_client.send(request_obj, stream=True) + + if response.status_code == 403: + await response.aclose() + _record_scope_error(connection, "MCP upstream rejected request due to insufficient scope") + if connection is not None: + await db.commit() + return Response( + content='{"error": "insufficient_scope", "message": "当前授权范围不足"}', + status_code=403, + media_type="application/json", + background=None, + ) + + if response.status_code != 401: + # 正常响应,此时直接闭环提交事务,防止污染外层 + if connection is not None and hasattr(db, "commit"): + await db.commit() + + async def proxy_stream_generator(): + try: + async for chunk in response.aiter_raw(): + yield chunk + finally: + await response.aclose() + + resp_headers = {} + for k, v in response.headers.items(): + if k.lower() not in _HOP_BY_HOP_HEADERS and k.lower() not in ("content-encoding", "content-length"): + resp_headers[k] = v + + return StreamingResponse( + proxy_stream_generator(), status_code=response.status_code, headers=resp_headers, background=None + ) + + # 如果是 401,回收流连接并准备重试 + await response.aclose() + if attempt + 1 >= max_attempts: + break + if connection is not None and getattr(connection, "id", None) is not None: + await token_cache.delete_access_token(connection.id) + + _mark_reauth_required(connection, "MCP upstream returned 401 after retry") + if connection is not None: + await db.commit() + return Response( + content='{"error": "reauth_required", "message": "连接失效,请重新连接"}', + status_code=424, + media_type="application/json", + background=None, + ) diff --git a/backend/package/yuxi/services/mcp_auth/redis_token_cache.py b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py new file mode 100644 index 000000000..4e96a4016 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import json +from collections.abc import Awaitable, Callable +from datetime import UTC, datetime +from typing import Any + +from yuxi.services.run_queue_service import get_redis_client +from yuxi.utils import logger + +ACCESS_TOKEN_KEY_PREFIX = "yuxi:mcp:access_token:v1" +REFRESH_LOCK_KEY_PREFIX = "yuxi:mcp:refresh_lock:v1" +DEFAULT_TOKEN_TTL_SECONDS = 300 +DEFAULT_LOCK_TTL_SECONDS = 30 + + +def _compute_token_ttl_seconds(token_payload: dict[str, Any]) -> int: + expires_at = token_payload.get("expires_at") + if isinstance(expires_at, str): + try: + expires_at_dt = datetime.fromisoformat(expires_at) + if expires_at_dt.tzinfo is None: + expires_at_dt = expires_at_dt.replace(tzinfo=UTC) + ttl = int((expires_at_dt - datetime.now(tz=UTC)).total_seconds()) + return max(ttl, 1) + except ValueError: + logger.warning(f"Invalid expires_at in MCP token payload: {expires_at}") + expires_in = token_payload.get("expires_in") + if isinstance(expires_in, (int, float)) and int(expires_in) > 0: + return int(expires_in) + return DEFAULT_TOKEN_TTL_SECONDS + + +class RedisTokenCache: + def __init__( + self, + redis_client_factory: Callable[[], Awaitable[Any]] | None = None, + key_prefix: str | None = None, + ): + self._redis_client_factory = redis_client_factory or get_redis_client + self._key_prefix = key_prefix + + def _access_token_key(self, connection_id: int) -> str: + key = f"{ACCESS_TOKEN_KEY_PREFIX}:{connection_id}" + if self._key_prefix: + return f"{self._key_prefix}:{key}" + return key + + def _refresh_lock_key(self, connection_id: int) -> str: + key = f"{REFRESH_LOCK_KEY_PREFIX}:{connection_id}" + if self._key_prefix: + return f"{self._key_prefix}:{key}" + return key + + async def _get_redis(self): + return await self._redis_client_factory() + + async def get_access_token(self, connection_id: int) -> dict[str, Any] | None: + redis = await self._get_redis() + raw = await redis.get(self._access_token_key(connection_id)) + if not raw: + return None + if isinstance(raw, dict): + return raw + return json.loads(raw) + + async def set_access_token(self, connection_id: int, token_payload: dict[str, Any]) -> None: + redis = await self._get_redis() + ttl_seconds = _compute_token_ttl_seconds(token_payload) + await redis.set( + self._access_token_key(connection_id), + json.dumps(token_payload, ensure_ascii=False, separators=(",", ":")), + ex=ttl_seconds, + ) + + async def delete_access_token(self, connection_id: int) -> None: + redis = await self._get_redis() + await redis.delete(self._access_token_key(connection_id)) + + async def acquire_refresh_lock(self, connection_id: int, *, ttl_seconds: int = DEFAULT_LOCK_TTL_SECONDS) -> bool: + redis = await self._get_redis() + acquired = await redis.set(self._refresh_lock_key(connection_id), "1", ex=ttl_seconds, nx=True) + return bool(acquired) + + async def release_refresh_lock(self, connection_id: int) -> None: + redis = await self._get_redis() + await redis.delete(self._refresh_lock_key(connection_id)) diff --git a/backend/package/yuxi/services/mcp_auth/template_resolver.py b/backend/package/yuxi/services/mcp_auth/template_resolver.py new file mode 100644 index 000000000..446dbab16 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/template_resolver.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import re +from collections.abc import Mapping +from typing import Any + +_PLACEHOLDER_PATTERN = re.compile(r"\$\{([^}]+)\}") + + +class TemplateResolutionError(ValueError): + """Raised when a template placeholder cannot be resolved.""" + + +def _lookup_path(root: Any, path: str, *, full_expression: str) -> Any: + current = root + for segment in path.split("."): + if isinstance(current, Mapping) and segment in current: + current = current[segment] + continue + raise TemplateResolutionError(f"Unknown template placeholder: {full_expression}") + return current + + +def _resolve_placeholder( + expression: str, + *, + context: Mapping[str, Any], + secret: Mapping[str, Any], + token: Mapping[str, Any], + access_token: str | None, +) -> Any: + if expression == "access_token": + if access_token is None: + raise TemplateResolutionError("Unknown template placeholder: access_token") + return access_token + + if "." not in expression: + raise TemplateResolutionError(f"Unknown template placeholder: {expression}") + + root_name, path = expression.split(".", 1) + roots = { + "context": context, + "secret": secret, + "token": token, + } + if root_name not in roots: + raise TemplateResolutionError(f"Unknown template placeholder: {expression}") + return _lookup_path(roots[root_name], path, full_expression=expression) + + +def resolve_template_value( + value: Any, + *, + context: Mapping[str, Any], + secret: Mapping[str, Any], + token: Mapping[str, Any], + access_token: str | None, +) -> Any: + if isinstance(value, Mapping): + return { + key: resolve_template_value( + item, + context=context, + secret=secret, + token=token, + access_token=access_token, + ) + for key, item in value.items() + } + + if isinstance(value, list): + return [ + resolve_template_value( + item, + context=context, + secret=secret, + token=token, + access_token=access_token, + ) + for item in value + ] + + if not isinstance(value, str): + return value + + matches = list(_PLACEHOLDER_PATTERN.finditer(value)) + if not matches: + return value + + if len(matches) == 1 and matches[0].span() == (0, len(value)): + return _resolve_placeholder( + matches[0].group(1), + context=context, + secret=secret, + token=token, + access_token=access_token, + ) + + parts: list[str] = [] + cursor = 0 + for match in matches: + start, end = match.span() + if start > cursor: + parts.append(value[cursor:start]) + resolved = _resolve_placeholder( + match.group(1), + context=context, + secret=secret, + token=token, + access_token=access_token, + ) + parts.append(str(resolved)) + cursor = end + if cursor < len(value): + parts.append(value[cursor:]) + return "".join(parts) diff --git a/backend/package/yuxi/services/mcp_service.py b/backend/package/yuxi/services/mcp_service.py deleted file mode 100644 index ebec9f9db..000000000 --- a/backend/package/yuxi/services/mcp_service.py +++ /dev/null @@ -1,626 +0,0 @@ -"""MCP Service - Unified business logic and state management for MCP. - -Responsibilities: -- Server configuration CRUD operations -- Built-in configuration synchronization (Code <-> Database) -- Unified entry point for Agent tool retrieval (auto-filtering disabled_tools) -- MCP Client and Tools management (formerly in agents/common/mcp.py) -""" - -import asyncio -import hashlib -import json -import re -import traceback -from collections.abc import Callable -from typing import Any, cast - -from langchain_mcp_adapters.client import MultiServerMCPClient -from sqlalchemy import func, select -from sqlalchemy.ext.asyncio import AsyncSession -from yuxi.storage.postgres.models_business import MCPServer -from yuxi.utils import logger - -# ============================================================================= -# === Global Cache & State === -# ============================================================================= - -# Global Lock for MCP state -_mcp_lock = asyncio.Lock() - -# 本地仅缓存工具对象。配置始终以数据库为准,每次按 server_name 现查。 -# cache key 使用 server_name:config_hash,当配置变化时会自然失效。 -_mcp_tools_cache: dict[str, list[Callable[..., Any]]] = {} - -# MCP tools statistics (for reporting enabled/disabled counts) -_mcp_tools_stats: dict[str, dict[str, int]] = {} -_UNSET = object() - -# Default MCP Server configurations (Imported to DB on first run) -_DEFAULT_MCP_SERVERS = { - "sequentialthinking": { - "url": "https://remote.mcpservers.org/sequentialthinking/mcp", - "transport": "streamable_http", - "description": "顺序思考工具,帮助 AI 将复杂问题分解为多个步骤", - "icon": "🧠", - "tags": ["内置", "AI"], - }, - "mcp-server-chart": { - "command": "npx", - "args": ["-y", "@antv/mcp-server-chart"], - "transport": "stdio", - "description": "图表生成工具,支持生成各类图表(柱状图、折线图、饼图等)", - "icon": "📊", - "tags": ["内置", "图表"], - }, -} - -_SYNCED_MCP_FIELDS = ( - "description", - "transport", - "url", - "command", - "args", - "env", - "headers", - "timeout", - "sse_read_timeout", - "tags", - "icon", -) - -# ============================================================================= -# === Core Logic (Moved from agents/common/mcp.py) === -# ============================================================================= - - -async def ensure_builtin_mcp_servers_in_db() -> None: - """Ensure built-in MCP server definitions exist in the database. - - This function only synchronizes code-defined built-ins to the database. - It does not preload runtime configuration into memory. - """ - # Delayed import to avoid circular references - from yuxi.storage.postgres.manager import pg_manager - - try: - async with pg_manager.get_async_session_context() as session: - # Check if database has MCP configurations - result = await session.execute(select(func.count(MCPServer.name))) - count = result.scalar() - - if count == 0: - # Database is empty, import default configurations - logger.info("No MCP servers in database, importing default configurations...") - for name, config in _DEFAULT_MCP_SERVERS.items(): - server = MCPServer( - name=name, - description=config.get("description"), - transport=config["transport"], - url=config.get("url"), - command=config.get("command"), - args=config.get("args"), - env=config.get("env"), - headers=config.get("headers"), - timeout=config.get("timeout"), - sse_read_timeout=config.get("sse_read_timeout"), - tags=config.get("tags"), - icon=config.get("icon"), - enabled=0, - created_by="system", - updated_by="system", - ) - session.add(server) - await session.commit() - logger.info(f"Imported {len(_DEFAULT_MCP_SERVERS)} default MCP servers to database") - else: - # Ensure all built-in MCP servers exist in database - for name, config in _DEFAULT_MCP_SERVERS.items(): - result = await session.execute(select(MCPServer).filter(MCPServer.name == name)) - existing = result.scalar_one_or_none() - if not existing: - server = MCPServer( - name=name, - description=config.get("description"), - transport=config["transport"], - url=config.get("url"), - command=config.get("command"), - args=config.get("args"), - env=config.get("env"), - headers=config.get("headers"), - timeout=config.get("timeout"), - sse_read_timeout=config.get("sse_read_timeout"), - tags=config.get("tags"), - icon=config.get("icon"), - enabled=0, - created_by="system", - updated_by="system", - ) - session.add(server) - logger.info(f"Added built-in MCP server '{name}' to database") - else: - changed = False - for field in _SYNCED_MCP_FIELDS: - next_value = config.get(field) - if getattr(existing, field) != next_value: - setattr(existing, field, next_value) - changed = True - if changed: - existing.updated_by = "system" - # Commit if any new servers were added (check session state) - if session.new: - await session.commit() - elif session.dirty: - await session.commit() - - except Exception as e: - logger.error(f"Failed to ensure builtin MCP servers in database: {e}, traceback: {traceback.format_exc()}") - - -async def get_mcp_client( - server_configs: dict[str, Any] | None = None, -) -> MultiServerMCPClient | None: - """Initializes an MCP client with the given server configurations.""" - try: - client = MultiServerMCPClient(server_configs) # pyright: ignore[reportArgumentType] - logger.info(f"Initialized MCP client with servers: {list(server_configs.keys())}") - return client - except Exception as e: - logger.error("Failed to initialize MCP client: {}", e) - return None - - -def to_camel_case(s: str) -> str: - """Convert string to lowerCamelCase.""" - - # Handle - and _ - s = re.sub(r"[-_]+(.)", lambda m: m.group(1).upper(), s) - # Lowercase first letter - if len(s) > 0: - s = s[0].lower() + s[1:] - return s - - -async def _load_enabled_mcp_server_configs( - *, - names: list[str] | None = None, - db: AsyncSession | None = None, -) -> dict[str, dict[str, Any]]: - """Load enabled MCP server configs directly from the database.""" - if db is not None: - stmt = select(MCPServer).where(MCPServer.enabled == 1) - if names: - stmt = stmt.where(MCPServer.name.in_(names)) - result = await db.execute(stmt) - servers = result.scalars().all() - return {server.name: server.to_mcp_config() for server in servers} - - from yuxi.storage.postgres.manager import pg_manager - - async with pg_manager.get_async_session_context() as session: - return await _load_enabled_mcp_server_configs(names=names, db=session) - - -async def get_enabled_mcp_server_config(server_name: str, *, db: AsyncSession | None = None) -> dict[str, Any] | None: - """Get the latest enabled MCP server config from the database.""" - configs = await _load_enabled_mcp_server_configs(names=[server_name], db=db) - return configs.get(server_name) - - -async def get_enabled_mcp_server_names(*, db: AsyncSession | None = None) -> list[str]: - """Get enabled MCP server names from the database.""" - configs = await _load_enabled_mcp_server_configs(db=db) - return list(configs.keys()) - - -async def get_mcp_tools( - server_name: str, - additional_servers: dict[str, dict[str, Any]] | None = None, - disabled_tools: list[str] = None, - cache: bool = True, - force_refresh: bool = False, -) -> list[Callable[..., Any]]: - """Get MCP tools for a specific server. - - Architecture: - 1. Fetching: Connects to MCP server to get ALL tools. - 2. Caching: Stores the FULL, UNFILTERED list of tools in `_mcp_tools_cache`. - 3. Filtering: Filters the return value based on `disabled_tools` argument. - - Args: - server_name: Server name - additional_servers: Additional server configurations - disabled_tools: List of tool names to filter out from the RETURN value (does not affect cache) - cache: Whether to use/update the cache (default: True) - force_refresh: Whether to force a refresh from the server (default: False) - """ - if additional_servers and server_name in additional_servers: - server_config = additional_servers[server_name] - else: - server_config = await get_enabled_mcp_server_config(server_name) - - if server_config is None: - logger.warning(f"MCP server '{server_name}' not found in database or disabled") - return [] - - # 配置 hash 直接基于完整配置生成。只要数据库中的配置发生变化, - # 本地工具缓存 key 就会变化,从而自然触发重建。 - config_payload = json.dumps(server_config, sort_keys=True, ensure_ascii=True, separators=(",", ":")) - config_hash = hashlib.sha256(config_payload.encode("utf-8")).hexdigest()[:16] - cache_key = f"{server_name}:{config_hash}" - - all_processed_tools: list[Callable[..., Any]] = [] - - async with _mcp_lock: - if not force_refresh and cache and cache_key in _mcp_tools_cache: - all_processed_tools = _mcp_tools_cache[cache_key] - - if not all_processed_tools: - try: - # disabled_tools 只影响返回值过滤,不参与 MCP client 建连参数。 - client_config = {k: v for k, v in server_config.items() if k not in ("disabled_tools",)} - - client = await get_mcp_client({server_name: client_config}) - if client is None: - return [] - - raw_tools = cast(list[Any], await client.get_tools()) - - server_cc = to_camel_case(server_name) - for tool in raw_tools: - original_name = tool.name - tool_cc = to_camel_case(original_name) - unique_id = f"mcp__{server_cc}__{tool_cc}" - - if tool.metadata is None: - tool.metadata = {} - tool.metadata["id"] = unique_id - # 开启错误处理,防止工具调用抛出 ToolException 时击穿服务 - tool.handle_tool_error = True - all_processed_tools.append(tool) - - if cache: - async with _mcp_lock: - stale_keys = [ - key for key in _mcp_tools_cache if key.startswith(f"{server_name}:") and key != cache_key - ] - for stale_key in stale_keys: - _mcp_tools_cache.pop(stale_key, None) - _mcp_tools_cache[cache_key] = all_processed_tools - - global_config_disabled = server_config.get("disabled_tools") or [] - enabled_count = len([t for t in all_processed_tools if t.name not in global_config_disabled]) - _mcp_tools_stats[server_name] = { - "total": len(all_processed_tools), - "enabled": enabled_count, - "disabled": len(all_processed_tools) - enabled_count, - } - - logger.info( - f"Refreshed MCP tools cache for '{server_name}' with key '{cache_key}': " - f"{len(all_processed_tools)} tools loaded." - ) - - except Exception as e: - logger.error( - f"Failed to load tools from MCP server '{server_name}': {e}, traceback: {traceback.format_exc()}" - ) - return [] - - # 3. Filtering (Apply to Return Value Only) - if disabled_tools: - filtered_tools = [t for t in all_processed_tools if t.name not in disabled_tools] - logger.debug( - f"Returning {len(filtered_tools)}/{len(all_processed_tools)} tools for '{server_name}' " - f"(filtered {len(disabled_tools)} by argument)" - ) - return filtered_tools - - return all_processed_tools - - -async def get_tools_from_all_servers() -> list[Callable[..., Any]]: - """Get all tools from all configured MCP servers.""" - server_configs = await _load_enabled_mcp_server_configs() - all_tools = [] - for server_name in server_configs: - tools = await get_mcp_tools(server_name, additional_servers=server_configs) - all_tools.extend(tools) - return all_tools - - -def clear_mcp_cache() -> None: - """Clear the MCP tools cache (useful for testing).""" - global _mcp_tools_cache, _mcp_tools_stats - _mcp_tools_cache = {} - _mcp_tools_stats = {} - - -def clear_mcp_server_tools_cache(server_name: str) -> None: - """Clear the tools cache for a specific MCP server.""" - global _mcp_tools_cache, _mcp_tools_stats - server_prefix = f"{server_name}:" - stale_keys = [key for key in _mcp_tools_cache if key.startswith(server_prefix)] - for stale_key in stale_keys: - _mcp_tools_cache.pop(stale_key, None) - _mcp_tools_stats.pop(server_name, None) - logger.info(f"Cleared tools cache for MCP server '{server_name}'") - - -def get_mcp_tools_stats(server_name: str) -> dict[str, int] | None: - """Get tools statistics for a MCP server. - - Returns: - dict with 'total', 'enabled', 'disabled' counts, or None if not available - """ - return _mcp_tools_stats.get(server_name) - - -# ============================================================================= -# === Server Config CRUD (Existing in mcp_service.py) === -# ============================================================================= - - -async def get_mcp_server(db: AsyncSession, name: str) -> MCPServer | None: - """Get single server configuration.""" - result = await db.execute(select(MCPServer).filter(MCPServer.name == name)) - return result.scalar_one_or_none() - - -async def get_all_mcp_servers(db: AsyncSession) -> list[MCPServer]: - """Get all server configurations.""" - result = await db.execute(select(MCPServer)) - return list(result.scalars().all()) - - -async def create_mcp_server( - db: AsyncSession, - name: str, - transport: str, - url: str = None, - command: str = None, - args: list = None, - env: dict = None, - description: str = None, - headers: dict = None, - timeout: int = None, - sse_read_timeout: int = None, - tags: list = None, - icon: str = None, - created_by: str = None, -) -> MCPServer: - """Create server.""" - # Check if name exists - existing = await get_mcp_server(db, name) - if existing: - raise ValueError(f"Server name '{name}' already exists") - - server = MCPServer( - name=name, - description=description, - transport=transport, - url=url, - command=command, - args=args, - env=env, - headers=headers, - timeout=timeout, - sse_read_timeout=sse_read_timeout, - tags=tags, - icon=icon, - enabled=1, - created_by=created_by, - updated_by=created_by, - ) - db.add(server) - await db.commit() - await db.refresh(server) - - clear_mcp_server_tools_cache(name) - - logger.info(f"Created MCP server '{name}'") - return server - - -async def update_mcp_server( - db: AsyncSession, - name: str, - description: str = None, - transport: str = None, - url: str = None, - command: str = None, - args: list = None, - env: Any = _UNSET, - headers: dict = None, - timeout: int = None, - sse_read_timeout: int = None, - tags: list = None, - icon: str = None, - updated_by: str = None, -) -> MCPServer: - """Update server configuration.""" - server = await get_mcp_server(db, name) - if not server: - raise ValueError(f"Server '{name}' does not exist") - - if description is not None: - server.description = description - if transport is not None: - server.transport = transport - if url is not None: - server.url = url - if command is not None: - server.command = command - if args is not None: - server.args = args - if env is not _UNSET: - server.env = env - if headers is not None: - server.headers = headers - if timeout is not None: - server.timeout = timeout - if sse_read_timeout is not None: - server.sse_read_timeout = sse_read_timeout - if tags is not None: - server.tags = tags - if icon is not None: - server.icon = icon - if updated_by is not None: - server.updated_by = updated_by - - await db.commit() - await db.refresh(server) - - clear_mcp_server_tools_cache(name) - - logger.info(f"Updated MCP server '{name}'") - return server - - -async def delete_mcp_server(db: AsyncSession, name: str) -> bool: - """Delete server.""" - server = await get_mcp_server(db, name) - if not server: - return False - - await db.delete(server) - await db.commit() - - clear_mcp_server_tools_cache(name) - - logger.info(f"Deleted MCP server '{name}'") - return True - - -# ============================================================================= -# === Tool Management === -# ============================================================================= - - -async def set_server_enabled( - db: AsyncSession, name: str, enabled: bool, updated_by: str = None -) -> tuple[bool, MCPServer]: - """Set server enabled status.""" - server = await get_mcp_server(db, name) - if not server: - raise ValueError(f"Server '{name}' does not exist") - - server.enabled = 1 if enabled else 0 - if updated_by is not None: - server.updated_by = updated_by - await db.commit() - - is_enabled = bool(server.enabled) - clear_mcp_server_tools_cache(name) - - logger.info(f"Set MCP server '{name}' enabled={is_enabled}") - return is_enabled, server - - -async def toggle_tool_enabled( - db: AsyncSession, - server_name: str, - tool_name: str, - updated_by: str = None, -) -> tuple[bool, MCPServer]: - """Toggle single tool enabled status. - - Args: - db: Database session - server_name: Server name - tool_name: Tool name - updated_by: Updater - - Returns: - (enabled, server): Tool enabled status and updated server object - """ - server = await get_mcp_server(db, server_name) - if not server: - raise ValueError(f"Server '{server_name}' does not exist") - - disabled_tools = list(server.disabled_tools or []) - - if tool_name in disabled_tools: - disabled_tools.remove(tool_name) - enabled = True - else: - disabled_tools.append(tool_name) - enabled = False - - server.disabled_tools = disabled_tools - if updated_by is not None: - server.updated_by = updated_by - await db.commit() - - # Clear tool cache (re-filtered on next fetch) - clear_mcp_server_tools_cache(server_name) - - logger.info(f"Toggled tool '{tool_name}' for server '{server_name}' enabled={enabled}") - return enabled, server - - -# ============================================================================= -# === Unified Entry Points (Wrappers) === -# ============================================================================= - - -async def get_enabled_mcp_tools(server_name: str) -> list: - """Get MCP server tools (auto-filtering disabled_tools). - - Unified entry point for Agents, automatically: - 1. Gets the latest server config from database - 2. Gets all tools - 3. Filters out disabled_tools - - Args: - server_name: Server name - - Returns: - List of enabled tools - """ - config = await get_enabled_mcp_server_config(server_name) - if config is None: - logger.warning(f"MCP server '{server_name}' not found in database or disabled") - return [] - - disabled_tools = config.get("disabled_tools") or [] - return await get_mcp_tools(server_name, additional_servers={server_name: config}, disabled_tools=disabled_tools) - - -async def get_servers_config(names: list[str]) -> dict[str, dict[str, Any]]: - """Batch get server configurations. - - Args: - names: List of server names - - Returns: - {name: config} dictionary, containing only found servers - """ - return await _load_enabled_mcp_server_configs(names=names) - - -async def get_all_mcp_tools(server_name: str) -> list: - """Get all tools of an MCP server (no filtering). - - For management UI to display tool list, supports viewing all tools and their enabled status. - Does NOT update the global tools cache to avoid polluting agent's filtered view. - - Args: - server_name: Server name - - Returns: - List of all tools (unfiltered) - """ - config = await get_enabled_mcp_server_config(server_name) - if config is None: - logger.warning(f"MCP server '{server_name}' not found in database or disabled") - return [] - - # Get all tools (no filtering, force refresh, no cache update) - return await get_mcp_tools( - server_name, - additional_servers={server_name: config}, - disabled_tools=[], - cache=False, - force_refresh=True, - ) diff --git a/backend/package/yuxi/services/mcp_tool_cache.py b/backend/package/yuxi/services/mcp_tool_cache.py new file mode 100644 index 000000000..f0d9cc115 --- /dev/null +++ b/backend/package/yuxi/services/mcp_tool_cache.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import json +import os +from collections.abc import Awaitable, Callable +from typing import Any + +from yuxi.services.run_queue_service import get_redis_client +from yuxi.utils import logger + +SERVER_REVISION_KEY_PREFIX = "yuxi:mcp:tool_cache:server_revision:v1" +PARTITION_REVISION_KEY_PREFIX = "yuxi:mcp:tool_cache:partition_revision:v1" +MANIFEST_KEY_PREFIX = "yuxi:mcp:tool_cache:manifest:v1" +MANIFEST_TTL_SECONDS = int(os.getenv("YUXI_MCP_TOOL_MANIFEST_TTL_SECONDS", "3600")) + + +def _server_revision_key(server_name: str) -> str: + return f"{SERVER_REVISION_KEY_PREFIX}:{server_name}" + + +def _partition_revision_key(server_name: str, cache_partition: str) -> str: + return f"{PARTITION_REVISION_KEY_PREFIX}:{server_name}:{cache_partition}" + + +def _manifest_key(cache_key: str) -> str: + return f"{MANIFEST_KEY_PREFIX}:{cache_key}" + + +class RedisMcpToolCache: + def __init__(self, redis_client_factory: Callable[[], Awaitable[Any]] | None = None): + self._redis_client_factory = redis_client_factory or get_redis_client + + async def _get_redis(self): + return await self._redis_client_factory() + + async def get_server_revision(self, server_name: str) -> int: + return await self._get_revision(_server_revision_key(server_name)) + + async def get_partition_revision(self, server_name: str, cache_partition: str) -> int: + return await self._get_revision(_partition_revision_key(server_name, cache_partition)) + + async def bump_server_revision(self, server_name: str) -> int: + return await self._bump_revision(_server_revision_key(server_name)) + + async def bump_partition_revision(self, server_name: str, cache_partition: str) -> int: + return await self._bump_revision(_partition_revision_key(server_name, cache_partition)) + + async def get_manifest(self, cache_key: str) -> dict[str, Any] | None: + try: + redis = await self._get_redis() + raw = await redis.get(_manifest_key(cache_key)) + except Exception as exc: + logger.warning(f"Failed to read MCP tool manifest cache for '{cache_key}': {exc}") + return None + if not raw: + return None + if isinstance(raw, dict): + return raw + try: + return json.loads(raw) + except Exception as exc: + logger.warning(f"Failed to decode MCP tool manifest cache for '{cache_key}': {exc}") + return None + + async def set_manifest(self, cache_key: str, manifest: dict[str, Any]) -> None: + try: + redis = await self._get_redis() + await redis.set( + _manifest_key(cache_key), + json.dumps(manifest, ensure_ascii=False, separators=(",", ":")), + ex=MANIFEST_TTL_SECONDS, + ) + except Exception as exc: + logger.warning(f"Failed to write MCP tool manifest cache for '{cache_key}': {exc}") + + async def _get_revision(self, key: str) -> int: + try: + redis = await self._get_redis() + raw = await redis.get(key) + except Exception as exc: + logger.warning(f"Failed to read MCP tool revision cache for '{key}': {exc}") + return 0 + if raw is None: + return 0 + try: + return int(raw) + except (TypeError, ValueError): + logger.warning(f"Invalid MCP tool revision cache value for '{key}': {raw}") + return 0 + + async def _bump_revision(self, key: str) -> int: + try: + redis = await self._get_redis() + return int(await redis.incr(key)) + except Exception as exc: + logger.warning(f"Failed to bump MCP tool revision cache for '{key}': {exc}") + return 0 diff --git a/backend/package/yuxi/services/run_queue_service.py b/backend/package/yuxi/services/run_queue_service.py index ec0422b33..04079a8d7 100644 --- a/backend/package/yuxi/services/run_queue_service.py +++ b/backend/package/yuxi/services/run_queue_service.py @@ -190,7 +190,7 @@ async def list_run_stream_events( ) -> list[dict]: redis = await get_redis_client() key = _event_stream_key(run_id) - start = "-" if after_seq in {"0", "0-0", ""} else f"{after_seq}" + start = "-" if after_seq in {"0", "0-0", ""} else f"({after_seq}" rows = await redis.xrange(key, min=start, max="+", count=limit) events = [] diff --git a/backend/package/yuxi/services/run_worker.py b/backend/package/yuxi/services/run_worker.py index a6ef9e14c..12b606dd3 100644 --- a/backend/package/yuxi/services/run_worker.py +++ b/backend/package/yuxi/services/run_worker.py @@ -12,7 +12,7 @@ from sqlalchemy.exc import OperationalError from yuxi.repositories.agent_run_repository import TERMINAL_RUN_STATUSES, AgentRunRepository from yuxi.services.chat_service import stream_agent_chat -from yuxi.services.mcp_service import ensure_builtin_mcp_servers_in_db +from yuxi.services.mcp.server_service import ensure_builtin_mcp_servers_in_db from yuxi.services.run_queue_service import ( append_run_stream_event, clear_cancel_signal, diff --git a/backend/package/yuxi/services/skill_service.py b/backend/package/yuxi/services/skill_service.py index a6663a637..ce094e720 100644 --- a/backend/package/yuxi/services/skill_service.py +++ b/backend/package/yuxi/services/skill_service.py @@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from yuxi import config as sys_config from yuxi.repositories.skill_repository import SkillRepository -from yuxi.services.mcp_service import get_enabled_mcp_server_names +from yuxi.services.mcp.server_service import get_enabled_mcp_server_names from yuxi.storage.postgres.models_business import Skill from yuxi.utils.logging_config import logger diff --git a/backend/package/yuxi/storage/postgres/manager.py b/backend/package/yuxi/storage/postgres/manager.py index c4d1acbc4..d87498f9d 100644 --- a/backend/package/yuxi/storage/postgres/manager.py +++ b/backend/package/yuxi/storage/postgres/manager.py @@ -196,6 +196,29 @@ async def ensure_business_schema(self): "ALTER TABLE IF EXISTS subagents ADD COLUMN IF NOT EXISTS enabled BOOLEAN NOT NULL DEFAULT TRUE", "ALTER TABLE IF EXISTS conversations ADD COLUMN IF NOT EXISTS is_pinned BOOLEAN NOT NULL DEFAULT FALSE", "ALTER TABLE IF EXISTS mcp_servers ADD COLUMN IF NOT EXISTS env JSONB", + "ALTER TABLE IF EXISTS mcp_servers ADD COLUMN IF NOT EXISTS auth_config_json JSONB", + """ + CREATE TABLE IF NOT EXISTS mcp_connections ( + id SERIAL PRIMARY KEY, + server_name VARCHAR(100) NOT NULL REFERENCES mcp_servers(name) ON DELETE CASCADE, + scope_type VARCHAR(16) NOT NULL, + scope_id VARCHAR(64) NOT NULL, + display_name VARCHAR(128), + external_subject VARCHAR(255), + status VARCHAR(32) NOT NULL DEFAULT 'active', + credential_blob TEXT, + meta_json JSONB NOT NULL DEFAULT '{}'::jsonb, + created_by VARCHAR(64), + updated_by VARCHAR(64), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT ck_mcp_connections_scope_type CHECK (scope_type IN ('system', 'department', 'user')), + CONSTRAINT ck_mcp_connections_status CHECK ( + status IN ('active', 'disabled', 'reauth_required', 'invalid') + ), + CONSTRAINT uq_mcp_connections_server_scope UNIQUE (server_name, scope_type, scope_id) + ) + """, """ CREATE TABLE IF NOT EXISTS model_providers ( id SERIAL PRIMARY KEY, @@ -244,6 +267,8 @@ async def ensure_business_schema(self): "CREATE INDEX IF NOT EXISTS idx_agent_runs_thread_created ON agent_runs(thread_id, created_at DESC)", "CREATE INDEX IF NOT EXISTS idx_agent_runs_status_updated ON agent_runs(status, updated_at)", "CREATE INDEX IF NOT EXISTS ix_conversations_is_pinned ON conversations(is_pinned)", + "CREATE INDEX IF NOT EXISTS idx_mcp_connections_status ON mcp_connections(status)", + "CREATE INDEX IF NOT EXISTS idx_mcp_connections_subject ON mcp_connections(external_subject)", "CREATE UNIQUE INDEX IF NOT EXISTS ix_model_providers_provider_id ON model_providers(provider_id)", "CREATE INDEX IF NOT EXISTS ix_model_providers_is_enabled ON model_providers(is_enabled)", # Undo/Fork 递归 CTE 性能索引 diff --git a/backend/package/yuxi/storage/postgres/models_business.py b/backend/package/yuxi/storage/postgres/models_business.py index ef3299653..86c4fd2d8 100644 --- a/backend/package/yuxi/storage/postgres/models_business.py +++ b/backend/package/yuxi/storage/postgres/models_business.py @@ -440,6 +440,7 @@ class MCPServer(Base): args = Column(JSON, nullable=True, comment="命令参数数组(stdio)") env = Column(JSON, nullable=True, comment="环境变量(stdio)") headers = Column(JSON, nullable=True, comment="HTTP 请求头") + auth_config_json = Column(JSON, nullable=True, comment="MCP 认证配置") timeout = Column(Integer, nullable=True, comment="HTTP 超时时间(秒)") sse_read_timeout = Column(Integer, nullable=True, comment="SSE 读取超时(秒)") @@ -469,6 +470,7 @@ def to_dict(self) -> dict[str, Any]: "args": self.args or [], "env": self.env or {}, "headers": self.headers or {}, + "auth_config": self.auth_config_json or {}, "timeout": self.timeout, "sse_read_timeout": self.sse_read_timeout, "tags": self.tags or [], @@ -516,6 +518,14 @@ def to_mcp_config(self) -> dict[str, Any]: config["headers"] = json.loads(self.headers) except json.JSONDecodeError: pass + if self.auth_config_json: + if isinstance(self.auth_config_json, dict): + config["auth_config"] = self.auth_config_json + elif isinstance(self.auth_config_json, str): + try: + config["auth_config"] = json.loads(self.auth_config_json) + except json.JSONDecodeError: + pass if self.timeout is not None: config["timeout"] = self.timeout if self.sse_read_timeout is not None: @@ -525,6 +535,53 @@ def to_mcp_config(self) -> dict[str, Any]: return config +class MCPConnection(Base): + """MCP 长期连接与凭据绑定模型""" + + __tablename__ = "mcp_connections" + __table_args__ = ( + UniqueConstraint("server_name", "scope_type", "scope_id", name="uq_mcp_connections_server_scope"), + Index("idx_mcp_connections_status", "status"), + Index("idx_mcp_connections_subject", "external_subject"), + ) + + id = Column(Integer, primary_key=True, autoincrement=True) + server_name = Column(String(100), ForeignKey("mcp_servers.name", ondelete="CASCADE"), nullable=False) + scope_type = Column(String(16), nullable=False, comment="system/department/user") + scope_id = Column(String(64), nullable=False, comment="绑定范围标识") + display_name = Column(String(128), nullable=True, comment="展示名称") + external_subject = Column(String(255), nullable=True, comment="外部系统主体标识") + status = Column(String(32), nullable=False, default="active", comment="连接状态") + credential_blob = Column(Text, nullable=True, comment="加密后的长期敏感凭据") + meta_json = Column(JSON, nullable=False, default=dict, comment="非敏感元数据") + created_by = Column(String(64), nullable=True) + updated_by = Column(String(64), nullable=True) + created_at = Column(DateTime, default=utc_now_naive, comment="创建时间") + updated_at = Column(DateTime, default=utc_now_naive, onupdate=utc_now_naive, comment="更新时间") + + server = relationship("MCPServer") + + def to_dict(self, *, include_credentials: bool = False) -> dict[str, Any]: + payload = { + "id": self.id, + "server_name": self.server_name, + "scope_type": self.scope_type, + "scope_id": self.scope_id, + "display_name": self.display_name, + "external_subject": self.external_subject, + "status": self.status, + "meta_json": self.meta_json or {}, + "has_credentials": bool(self.credential_blob), + "created_by": self.created_by, + "updated_by": self.updated_by, + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), + } + if include_credentials: + payload["credential_blob"] = self.credential_blob + return payload + + class ModelProvider(Base): """模型供应商配置,存储 provider 基础信息、模型端点和可用模型。""" diff --git a/backend/package/yuxi/utils/logging_config.py b/backend/package/yuxi/utils/logging_config.py index d76f9c996..1235e2208 100644 --- a/backend/package/yuxi/utils/logging_config.py +++ b/backend/package/yuxi/utils/logging_config.py @@ -44,6 +44,12 @@ def _setup_logging_bridge(): lightrag_logger.setLevel(logging.DEBUG) lightrag_logger.propagate = False # 避免重复 + # 桥接 MCP 服务层日志,便于在 agent 运行时直接观察 MCP 选择、加载和失败冷却。 + mcp_logger = logging.getLogger("yuxi.mcp") + mcp_logger.addHandler(loguru_handler) + mcp_logger.setLevel(logging.INFO) + mcp_logger.propagate = False + # 桥接其他常见第三方库(降低级别减少噪音) for lib in ["httpx", "openai", "neo4j", "urllib3"]: lib_logger = logging.getLogger(lib) diff --git a/backend/server/routers/__init__.py b/backend/server/routers/__init__.py index bf998c296..b8528f826 100644 --- a/backend/server/routers/__init__.py +++ b/backend/server/routers/__init__.py @@ -6,6 +6,7 @@ from server.routers.chat_router import chat from server.routers.dashboard_router import dashboard from server.routers.auth_dept_router import department +from server.routers.mcp_internal_router import mcp_internal from server.routers.mcp_router import mcp from server.routers.model_provider_router import model_providers from server.routers.skill_router import skills @@ -32,6 +33,7 @@ router.include_router(department) # /api/departments/* 部门与权限相关数据 router.include_router(tasks) # /api/tasks/* 后台任务查询与管理 router.include_router(mcp) # /api/system/mcp-servers/* MCP 服务管理 +router.include_router(mcp_internal) # /api/internal/mcp-proxy/* 动态 MCP 内部代理 router.include_router(model_providers) # /api/system/model-providers/* 独立模型配置 router.include_router(skills) # /api/system/skills/* Skills 管理 router.include_router(subagents_router) # /api/system/subagents/* 子智能体管理 diff --git a/backend/server/routers/mcp_internal_router.py b/backend/server/routers/mcp_internal_router.py new file mode 100644 index 000000000..2de6fcf26 --- /dev/null +++ b/backend/server/routers/mcp_internal_router.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Header, Request, Response +from sqlalchemy.ext.asyncio import AsyncSession + +from server.utils.auth_middleware import get_db +from yuxi.services.mcp_auth.proxy_service import ( + INTERNAL_PROXY_TOKEN_HEADER, + handle_mcp_proxy_request, +) + +mcp_internal = APIRouter(prefix="/internal/mcp-proxy", tags=["mcp-internal"]) + + +@mcp_internal.api_route( + "/{server_name}{path:path}", + methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], +) +async def proxy_mcp_server_request( + server_name: str, + request: Request, + path: str = "", + internal_token: str | None = Header(None, alias=INTERNAL_PROXY_TOKEN_HEADER), + db: AsyncSession = Depends(get_db), +) -> Response: + """代理路由(纯路由层):业务鉴权、DB操作及背压透传已全部下沉到 proxy_service 领域服务处理""" + # 去除前导斜杠,以兼容不带 path 和带 path 两种情况 + path = path.lstrip("/") + + return await handle_mcp_proxy_request( + server_name=server_name, + request=request, + path=path, + internal_token=internal_token or "", + db=db, + ) diff --git a/backend/server/routers/mcp_router.py b/backend/server/routers/mcp_router.py index 76fdbd96a..ad88989d4 100644 --- a/backend/server/routers/mcp_router.py +++ b/backend/server/routers/mcp_router.py @@ -1,20 +1,40 @@ """MCP 服务器管理路由""" -from fastapi import APIRouter, Depends, HTTPException +import json + +from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession -from yuxi.services.mcp_service import ( +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp.server_service import ( create_mcp_server, - get_mcp_tools_stats, delete_mcp_server, get_all_mcp_servers, - get_all_mcp_tools, get_mcp_server, + get_mcp_server_dependency_summary, set_server_enabled, - toggle_tool_enabled, update_mcp_server, ) +from yuxi.services.mcp.connection_service import ( + count_mcp_connections, + create_mcp_connection, + delete_mcp_connection, + get_mcp_connection, + list_mcp_connections_page, + list_mcp_connections, + reauthorize_mcp_connection, + requires_bound_mcp_connection, + set_mcp_connection_status, + test_mcp_connection, + update_mcp_connection, +) +from yuxi.services.mcp.tool_registry_service import ( + get_all_mcp_tools, + get_mcp_tools_stats, + toggle_tool_enabled, +) from yuxi.storage.postgres.models_business import User from yuxi.utils import logger from server.utils.auth_middleware import get_admin_user, get_db, get_required_user @@ -40,6 +60,7 @@ class CreateMcpServerRequest(BaseModel): sse_read_timeout: int | None = Field(None, description="SSE 读取超时(秒)") tags: list | None = Field(None, description="标签数组") icon: str | None = Field(None, description="图标(emoji)") + auth_config: dict | None = Field(None, description="MCP 鉴权配置") class UpdateMcpServerRequest(BaseModel): @@ -54,12 +75,35 @@ class UpdateMcpServerRequest(BaseModel): sse_read_timeout: int | None = Field(None, description="SSE 读取超时(秒)") tags: list | None = Field(None, description="标签数组") icon: str | None = Field(None, description="图标(emoji)") + auth_config: dict | None = Field(None, description="MCP 鉴权配置") class UpdateMcpServerStatusRequest(BaseModel): enabled: bool = Field(..., description="是否启用") +class CreateMcpConnectionRequest(BaseModel): + scope_type: str = Field(..., description="连接范围:system/department/user") + scope_id: str | None = Field(None, description="范围标识") + display_name: str | None = Field(None, description="展示名称") + external_subject: str | None = Field(None, description="外部系统主体标识") + credential: dict | str | None = Field(None, description="长期凭据") + meta_json: dict | None = Field(None, description="非敏感元数据") + status: str = Field("active", description="连接状态") + + +class UpdateMcpConnectionStatusRequest(BaseModel): + status: str = Field(..., description="连接状态") + + +class UpdateMcpConnectionRequest(BaseModel): + display_name: str | None = Field(None, description="展示名称") + external_subject: str | None = Field(None, description="外部系统主体标识") + credential: dict | str | None = Field(None, description="长期凭据") + meta_json: dict | None = Field(None, description="非敏感元数据") + status: str | None = Field(None, description="连接状态") + + # ============================================================================= # === Helpers === # ============================================================================= @@ -73,6 +117,122 @@ async def get_server_or_404(db: AsyncSession, name: str): return server +def _is_admin_user(current_user: User) -> bool: + return current_user.role in ["admin", "superadmin"] + + +def _current_user_scope_id(current_user: User) -> str: + db_id = getattr(current_user, "id", None) + login_id = getattr(current_user, "user_id", None) + resolved_user_id = db_id if db_id is not None else login_id + if resolved_user_id is None: + raise HTTPException(status_code=400, detail="当前用户缺少可用的用户标识") + return str(resolved_user_id) + + +def _public_auth_config(payload: dict | None) -> dict: + if not payload: + return {} + try: + auth_config = MCPAuthConfig.model_validate(payload) + except Exception: + return {} + return { + "version": auth_config.version, + "provider": auth_config.provider, + "binding_scope": auth_config.binding_scope, + "manifest_scope": auth_config.manifest_scope, + "secret_fields": auth_config.get_secret_fields(), + } + + +def _public_mcp_server_detail(server) -> dict: + return { + "name": getattr(server, "name", ""), + "description": getattr(server, "description", None), + "transport": getattr(server, "transport", None), + "auth_config": _public_auth_config(getattr(server, "auth_config_json", None)), + "tags": getattr(server, "tags", None) or [], + "icon": getattr(server, "icon", None), + "enabled": bool(getattr(server, "enabled", True)), + } + + +def _ensure_mcp_server_visible_to_user(server, current_user: User) -> None: + if _is_admin_user(current_user): + return + if not bool(getattr(server, "enabled", True)): + raise HTTPException(status_code=404, detail=f"服务器 '{getattr(server, 'name', '')}' 不存在") + + +def _ensure_personal_connection_server( + server, + current_user: User, + *, + include_admin: bool = False, +) -> None: + _ensure_mcp_server_visible_to_user(server, current_user) + if _is_admin_user(current_user) and not include_admin: + return + try: + auth_config = MCPAuthConfig.model_validate(getattr(server, "auth_config_json", None) or {}) + except Exception as exc: + raise HTTPException(status_code=400, detail="当前 MCP 未配置可用的动态鉴权策略") from exc + if auth_config.binding_scope != "user": + raise HTTPException(status_code=403, detail="仅管理员可以维护非个人范围的 MCP 连接") + + +def _ensure_connection_accessible_to_user(connection, current_user: User) -> None: + if _is_admin_user(current_user): + return + if connection.scope_type != "user" or connection.scope_id != _current_user_scope_id(current_user): + raise HTTPException(status_code=404, detail=f"连接 '{connection.id}' 不存在") + + +async def get_connection_for_server_or_404( + db: AsyncSession, + server_name: str, + connection_id: int, + current_user: User | None = None, +): + connection = await get_mcp_connection(db, connection_id) + if connection is None or connection.server_name != server_name: + raise HTTPException(status_code=404, detail=f"连接 '{connection_id}' 不存在") + if current_user is not None: + _ensure_connection_accessible_to_user(connection, current_user) + return connection + + +def _normalize_credential_blob(value: dict | str | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return json.dumps(value, ensure_ascii=False) + + +def _validate_auth_config_or_400(payload: dict | None) -> dict | None: + if not payload: + return None + try: + return MCPAuthConfig.model_validate(payload).model_dump(mode="json") + except Exception as exc: + raise HTTPException(status_code=400, detail=f"auth_config 配置无效: {exc}") from exc + + +def _auth_context_from_user(current_user: User) -> AuthContext: + department_id = getattr(current_user, "department_id", None) + db_id = getattr(current_user, "id", None) + login_id = getattr(current_user, "user_id", None) + resolved_user_id = db_id if db_id is not None else login_id + return AuthContext( + user_id=str(resolved_user_id) if resolved_user_id is not None else None, + department_id=str(department_id) if department_id is not None else None, + work_id=str(login_id) if login_id is not None else None, + ) + + + # ============================================================================= # === MCP 服务器 CRUD === # ============================================================================= @@ -93,6 +253,8 @@ async def get_mcp_servers( # 仿真对象和历史数据,避免未来新增敏感字段或审计信息越权泄露 data = [] for s in servers: + if not bool(getattr(s, "enabled", True)): + continue data.append( { "name": getattr(s, "name", ""), @@ -127,6 +289,7 @@ async def create_mcp_server_route( raise HTTPException(status_code=400, detail="传输类型为 stdio 时,command 必填") try: + auth_config = _validate_auth_config_or_400(request.auth_config) server = await create_mcp_server( db, name=request.name, @@ -141,9 +304,12 @@ async def create_mcp_server_route( sse_read_timeout=request.sse_read_timeout, tags=request.tags, icon=request.icon, + auth_config=auth_config, created_by=current_user.username, ) return {"success": True, "data": server.to_dict()} + except HTTPException: + raise except ValueError as ve: raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: @@ -154,13 +320,16 @@ async def create_mcp_server_route( @mcp.get("/{name}") async def get_mcp_server_route( name: str, - current_user: User = Depends(get_admin_user), + current_user: User = Depends(get_required_user), db: AsyncSession = Depends(get_db), ): - """获取单个 MCP 服务器配置""" + """获取单个 MCP 服务器配置(普通用户仅获取脱敏的基础信息)""" try: server = await get_server_or_404(db, name) - return {"success": True, "data": server.to_dict()} + _ensure_mcp_server_visible_to_user(server, current_user) + if _is_admin_user(current_user): + return {"success": True, "data": server.to_dict()} + return {"success": True, "data": _public_mcp_server_detail(server)} except HTTPException: raise except Exception as e: @@ -182,10 +351,12 @@ async def update_mcp_server_route( raise HTTPException(status_code=400, detail=f"传输类型必须是 {', '.join(valid_transports)} 之一") try: - fields_set = getattr(request, "model_fields_set", getattr(request, "__fields_set__", set())) + fields_set = request.model_fields_set update_kwargs = {} if "env" in fields_set: update_kwargs["env"] = request.env + if "auth_config" in fields_set: + update_kwargs["auth_config"] = _validate_auth_config_or_400(request.auth_config) server = await update_mcp_server( db, @@ -204,6 +375,8 @@ async def update_mcp_server_route( **update_kwargs, ) return {"success": True, "data": server.to_dict()} + except HTTPException: + raise except ValueError as ve: raise HTTPException(status_code=404, detail=str(ve)) except Exception as e: @@ -214,6 +387,7 @@ async def update_mcp_server_route( @mcp.delete("/{name}") async def delete_mcp_server_route( name: str, + hard: bool = False, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db), ): @@ -221,13 +395,26 @@ async def delete_mcp_server_route( try: # 检查是否为系统内置服务器 server = await get_mcp_server(db, name) - if server and server.created_by == "system": + if not server: + raise HTTPException(status_code=404, detail=f"服务器 '{name}' 不存在") + if server.created_by == "system": raise HTTPException(status_code=403, detail="系统内置的 MCP 服务器无法删除") + if not hard: + await set_server_enabled(db, name, False, current_user.username) + return {"success": True, "message": f"服务器 '{name}' 已退役"} + + if bool(server.enabled): + raise HTTPException(status_code=409, detail="请先退役服务器,再执行硬删除") + + dependency_summary = await get_mcp_server_dependency_summary(db, name) + if dependency_summary["has_references"]: + raise HTTPException(status_code=409, detail=dependency_summary) + deleted = await delete_mcp_server(db, name) if not deleted: raise HTTPException(status_code=404, detail=f"服务器 '{name}' 不存在") - return {"success": True, "message": f"服务器 '{name}' 已删除"} + return {"success": True, "message": f"服务器 '{name}' 已彻底删除"} except HTTPException: raise except Exception as e: @@ -249,16 +436,25 @@ async def test_mcp_server( """测试 MCP 服务器连接""" try: await get_server_or_404(db, name) - try: - tools = await get_all_mcp_tools(name) + auth_context = _auth_context_from_user(current_user) + tools = await get_all_mcp_tools(name, auth_context=auth_context, db=db) return { "success": True, "message": f"连接成功,共发现 {len(tools)} 个工具", "tool_count": len(tools), } + except ValueError as val_err: + err_msg = str(val_err) + if "Active MCP connection not found" in err_msg: + raise HTTPException( + status_code=400, + detail="该 MCP 需要绑定连接(需要绑定长期密钥,请在连接页创建对应连接后进行测试)" + ) + raise HTTPException(status_code=500, detail=f"连接失败: {err_msg}") except Exception as test_error: raise HTTPException(status_code=500, detail=f"连接失败: {str(test_error)}") + except HTTPException: raise except Exception as e: @@ -289,6 +485,271 @@ async def update_mcp_server_status_route( raise HTTPException(status_code=500, detail=str(e)) +# ============================================================================= +# === MCP 连接管理 === +# ============================================================================= + + +@mcp.get("/{name}/connections") +async def get_mcp_connections( + name: str, + mine: bool = False, + paginated: bool = False, + status: str = Query("all", description="健康筛选:all/active/attention/disabled"), + search: str | None = Query(None, description="连接名、绑定对象或主体搜索"), + page: int = Query(1, ge=1, description="页码"), + page_size: int = Query(12, ge=1, le=100, description="每页数量"), + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + try: + server = await get_server_or_404(db, name) + _ensure_mcp_server_visible_to_user(server, current_user) + list_kwargs = {"server_name": name} + if mine or not _is_admin_user(current_user): + _ensure_personal_connection_server(server, current_user, include_admin=mine) + list_kwargs.update({"scope_type": "user", "scope_id": _current_user_scope_id(current_user)}) + if paginated or status != "all" or search: + effective_scope_type = None + credentials_required = False + try: + auth_config = MCPAuthConfig.model_validate(getattr(server, "auth_config_json", None) or {}) + if auth_config.binding_scope in {"system", "department", "user"}: + effective_scope_type = auth_config.binding_scope + credentials_required = requires_bound_mcp_connection(auth_config) + except Exception: + effective_scope_type = None + credentials_required = False + + connections, total = await list_mcp_connections_page( + db, + **list_kwargs, + status_filter=status, + effective_scope_type=effective_scope_type, + credentials_required=credentials_required, + search=search, + page=page, + page_size=page_size, + ) + summary = { + "total": await count_mcp_connections(db, **list_kwargs), + "active": await count_mcp_connections( + db, + **list_kwargs, + status_filter="active", + effective_scope_type=effective_scope_type, + credentials_required=credentials_required, + ), + "attention": await count_mcp_connections( + db, + **list_kwargs, + status_filter="attention", + effective_scope_type=effective_scope_type, + credentials_required=credentials_required, + ), + "disabled": await count_mcp_connections(db, **list_kwargs, status_filter="disabled"), + } + return { + "success": True, + "data": { + "items": [item.to_dict() for item in connections], + "total": total, + "page": page, + "page_size": page_size, + "summary": summary, + }, + } + connections = await list_mcp_connections(db, **list_kwargs) + return {"success": True, "data": [item.to_dict() for item in connections]} + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to list MCP connections: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.post("/{name}/connections") +async def create_mcp_connection_route( + name: str, + request: CreateMcpConnectionRequest, + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + try: + server = await get_server_or_404(db, name) + _ensure_personal_connection_server(server, current_user) + scope_type = request.scope_type + scope_id = request.scope_id + if scope_type == "user" and not scope_id: + scope_id = _current_user_scope_id(current_user) + if not _is_admin_user(current_user): + if request.scope_type != "user": + raise HTTPException(status_code=403, detail="普通用户只能维护个人专用 MCP 连接") + scope_type = "user" + scope_id = _current_user_scope_id(current_user) + connection = await create_mcp_connection( + db, + server_name=name, + scope_type=scope_type, + scope_id=scope_id, + display_name=request.display_name, + external_subject=request.external_subject, + status=request.status, + credential_blob=_normalize_credential_blob(request.credential), + meta_json=request.meta_json, + created_by=current_user.username, + ) + return {"success": True, "data": connection.to_dict()} + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to create MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.put("/{name}/connections/{connection_id}/status") +async def update_mcp_connection_status_route( + name: str, + connection_id: int, + request: UpdateMcpConnectionStatusRequest, + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + try: + server = await get_server_or_404(db, name) + _ensure_personal_connection_server(server, current_user) + await get_connection_for_server_or_404(db, name, connection_id, current_user) + connection = await set_mcp_connection_status( + db, + connection_id, + status=request.status, + updated_by=current_user.username, + ) + return {"success": True, "data": connection.to_dict(), "message": "连接状态已更新"} + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to update MCP connection status: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.put("/{name}/connections/{connection_id}") +async def update_mcp_connection_route( + name: str, + connection_id: int, + request: UpdateMcpConnectionRequest, + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + try: + server = await get_server_or_404(db, name) + _ensure_personal_connection_server(server, current_user) + await get_connection_for_server_or_404(db, name, connection_id, current_user) + fields_set = request.model_fields_set + update_kwargs = {} + if "credential" in fields_set: + update_kwargs["credential_blob"] = _normalize_credential_blob(request.credential) + if "display_name" in fields_set: + update_kwargs["display_name"] = request.display_name + if "external_subject" in fields_set: + update_kwargs["external_subject"] = request.external_subject + if "meta_json" in fields_set: + update_kwargs["meta_json"] = request.meta_json + if "status" in fields_set: + update_kwargs["status"] = request.status + + connection = await update_mcp_connection( + db, + connection_id, + updated_by=current_user.username, + **update_kwargs, + ) + return {"success": True, "data": connection.to_dict(), "message": "连接已更新"} + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to update MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.delete("/{name}/connections/{connection_id}") +async def delete_mcp_connection_route( + name: str, + connection_id: int, + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + try: + server = await get_server_or_404(db, name) + _ensure_personal_connection_server(server, current_user) + await get_connection_for_server_or_404(db, name, connection_id, current_user) + deleted = await delete_mcp_connection(db, connection_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"连接 '{connection_id}' 不存在") + return {"success": True, "message": "连接已删除"} + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.post("/{name}/connections/{connection_id}/test") +async def test_mcp_connection_route( + name: str, + connection_id: int, + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + try: + server = await get_server_or_404(db, name) + _ensure_personal_connection_server(server, current_user) + await get_connection_for_server_or_404(db, name, connection_id, current_user) + result = await test_mcp_connection(db, connection_id, updated_by=current_user.username) + return { + "success": True, + "tool_count": result["tool_count"], + "message": f"连接成功,共发现 {result['tool_count']} 个工具", + } + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to test MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.post("/{name}/connections/{connection_id}/reauth") +async def reauthorize_mcp_connection_route( + name: str, + connection_id: int, + current_user: User = Depends(get_required_user), + db: AsyncSession = Depends(get_db), +): + try: + server = await get_server_or_404(db, name) + _ensure_personal_connection_server(server, current_user) + await get_connection_for_server_or_404(db, name, connection_id, current_user) + connection = await reauthorize_mcp_connection(db, connection_id, updated_by=current_user.username) + return {"success": True, "data": connection.to_dict(), "message": "连接已重置并重新激活"} + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to reauthorize MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # ============================================================================= # === MCP 工具管理 === # ============================================================================= @@ -304,10 +765,11 @@ async def get_mcp_server_tools( try: server = await get_server_or_404(db, name) disabled_tools = server.disabled_tools or [] + auth_context = _auth_context_from_user(current_user) try: # 获取所有工具(不过滤 disabled_tools) - tools = await get_all_mcp_tools(name) + tools = await get_all_mcp_tools(name, auth_context=auth_context, db=db) tool_list = [] for tool in tools: @@ -335,6 +797,8 @@ async def get_mcp_server_tools( "data": tool_list, "total": len(tool_list), } + except ValueError as tool_error: + raise HTTPException(status_code=403, detail=f"获取工具失败: {str(tool_error)}") except Exception as tool_error: logger.error(f"Failed to get tools from MCP server '{name}': {tool_error}") raise HTTPException(status_code=500, detail=f"获取工具失败: {str(tool_error)}") @@ -354,10 +818,11 @@ async def refresh_mcp_server_tools( """刷新 MCP 服务器的工具列表(清除缓存重新获取)""" try: await get_server_or_404(db, name) + auth_context = _auth_context_from_user(current_user) try: # 获取所有工具(不过滤 disabled_tools) - tools = await get_all_mcp_tools(name) + tools = await get_all_mcp_tools(name, auth_context=auth_context, db=db, force_refresh=True) # 获取统计信息 stats = get_mcp_tools_stats(name) @@ -377,6 +842,8 @@ async def refresh_mcp_server_tools( "enabled_count": enabled_count, "disabled_count": disabled_count, } + except ValueError as tool_error: + raise HTTPException(status_code=403, detail=f"刷新失败: {str(tool_error)}") except Exception as tool_error: raise HTTPException(status_code=500, detail=f"刷新失败: {str(tool_error)}") except HTTPException: diff --git a/backend/server/utils/lifespan.py b/backend/server/utils/lifespan.py index 10380f252..ccc8d72d1 100644 --- a/backend/server/utils/lifespan.py +++ b/backend/server/utils/lifespan.py @@ -5,7 +5,7 @@ from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from yuxi.services.task_service import tasker -from yuxi.services.mcp_service import ensure_builtin_mcp_servers_in_db +from yuxi.services.mcp.server_service import ensure_builtin_mcp_servers_in_db from yuxi.services.model_provider_service import ensure_builtin_model_providers_in_db from yuxi.services.subagent_service import init_builtin_subagents from yuxi.services.run_queue_service import close_queue_clients, get_redis_client @@ -101,6 +101,14 @@ async def lifespan(app: FastAPI): """) logger.info("Yuxi backend startup complete") yield + + from yuxi.services.mcp.client_pool import mcp_client_pool + from yuxi.services.mcp_auth.proxy_service import close_shared_proxy_client + + logger.info("Shutting down MCP client pool and proxy clients...") + await mcp_client_pool.shutdown() + await close_shared_proxy_client() + await tasker.shutdown() shutdown_sandbox_provider() await close_queue_clients() diff --git a/backend/test/e2e/test_mcp_admin_flow_e2e.py b/backend/test/e2e/test_mcp_admin_flow_e2e.py new file mode 100644 index 000000000..c58403b0d --- /dev/null +++ b/backend/test/e2e/test_mcp_admin_flow_e2e.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import uuid + +import httpx +import pytest + +pytestmark = [pytest.mark.asyncio, pytest.mark.e2e, pytest.mark.slow] + + +def _build_server_name(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def _build_auth_config() -> dict: + return { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "binding", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${access_token}", + }, + { + "name": "X-Yuxi-User", + "value_template": "${context.user_id}", + }, + { + "name": "X-Yuxi-Department", + "value_template": "${context.department_id}", + }, + ], + }, + "refresh_policy": { + "pre_refresh_seconds": 300, + "retry_once_on_401": True, + }, + "token_request": { + "url": "http://internal-gateway.local/token", + "method": "POST", + "body_type": "json", + "headers": { + "Content-Type": "application/json", + }, + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + "user_id": "${context.user_id}", + "department_id": "${context.department_id}", + }, + "response_map": { + "access_token": "data.access_token", + "refresh_token": "data.refresh_token", + "expires_in": "data.expires_in", + }, + }, + } + + +async def _cleanup_server(client: httpx.AsyncClient, headers: dict[str, str], server_name: str) -> None: + list_response = await client.get(f"/api/system/mcp-servers/{server_name}/connections", headers=headers) + if list_response.status_code == 200: + for connection in list_response.json().get("data", []): + await client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{connection['id']}", + headers=headers, + ) + + await client.delete(f"/api/system/mcp-servers/{server_name}", headers=headers) + await client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=headers, + ) + + +async def test_mcp_admin_flow_e2e_supports_dynamic_auth_connections( + e2e_client: httpx.AsyncClient, + e2e_headers: dict[str, str], +): + invalid_server_name = _build_server_name("e2e-mcp-invalid-auth") + invalid_response = await e2e_client.post( + "/api/system/mcp-servers", + json={ + "name": invalid_server_name, + "transport": "streamable_http", + "url": "http://mcp-upstream.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${access_token}", + } + ], + }, + }, + }, + headers=e2e_headers, + ) + assert invalid_response.status_code == 400, invalid_response.text + assert "auth_config 配置无效" in invalid_response.json()["detail"] + + server_name = _build_server_name("e2e-mcp-auth") + create_server_response = await e2e_client.post( + "/api/system/mcp-servers", + json={ + "name": server_name, + "transport": "streamable_http", + "url": "http://mcp-upstream.local/mcp", + "description": "e2e mcp auth server", + "auth_config": _build_auth_config(), + }, + headers=e2e_headers, + ) + assert create_server_response.status_code == 200, create_server_response.text + + try: + server_payload = create_server_response.json()["data"] + assert server_payload["name"] == server_name + assert server_payload["auth_config"]["provider"] == "custom_http_token" + + create_connection_response = await e2e_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "system", + "scope_id": "ignored-by-normalization", + "display_name": "全局共享连接", + "external_subject": "gateway-service-account", + "credential": { + "secrets": { + "client_id": "cid-1", + "client_secret": "secret-1", + }, + "refresh_token": "refresh-1", + }, + "meta_json": {"tenant": "shared"}, + }, + headers=e2e_headers, + ) + assert create_connection_response.status_code == 200, create_connection_response.text + connection_payload = create_connection_response.json()["data"] + connection_id = connection_payload["id"] + assert connection_payload["scope_type"] == "system" + assert connection_payload["scope_id"] == "global" + assert connection_payload["status"] == "active" + assert connection_payload["has_credentials"] is True + + list_connections_response = await e2e_client.get( + f"/api/system/mcp-servers/{server_name}/connections", + headers=e2e_headers, + ) + assert list_connections_response.status_code == 200, list_connections_response.text + assert list_connections_response.json()["data"] == [connection_payload] + + retire_response = await e2e_client.delete( + f"/api/system/mcp-servers/{server_name}", + headers=e2e_headers, + ) + assert retire_response.status_code == 200, retire_response.text + + hard_delete_conflict_response = await e2e_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=e2e_headers, + ) + assert hard_delete_conflict_response.status_code == 409, hard_delete_conflict_response.text + dependency_payload = hard_delete_conflict_response.json()["detail"] + assert dependency_payload["has_references"] is True + assert dependency_payload["connections"] == [ + { + "scope_type": "system", + "scope_id": "global", + "status": "active", + } + ] + + delete_connection_response = await e2e_client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}", + headers=e2e_headers, + ) + assert delete_connection_response.status_code == 200, delete_connection_response.text + + hard_delete_response = await e2e_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=e2e_headers, + ) + assert hard_delete_response.status_code == 200, hard_delete_response.text + finally: + await _cleanup_server(e2e_client, e2e_headers, server_name) diff --git a/backend/test/integration/api/test_integration_mcp_router.py b/backend/test/integration/api/test_integration_mcp_router.py new file mode 100644 index 000000000..266c9dd4a --- /dev/null +++ b/backend/test/integration/api/test_integration_mcp_router.py @@ -0,0 +1,584 @@ +from __future__ import annotations + +import uuid + +import pytest + +pytestmark = [pytest.mark.asyncio, pytest.mark.integration] + + +def _build_server_name(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def _build_auth_config() -> dict: + return { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "binding", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${access_token}", + }, + { + "name": "X-Yuxi-User", + "value_template": "${context.user_id}", + }, + { + "name": "X-Yuxi-Department", + "value_template": "${context.department_id}", + }, + ], + }, + "refresh_policy": { + "pre_refresh_seconds": 300, + "retry_once_on_401": True, + }, + "token_request": { + "url": "http://internal-gateway.local/token", + "method": "POST", + "body_type": "json", + "headers": { + "Content-Type": "application/json", + }, + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + "user_id": "${context.user_id}", + "department_id": "${context.department_id}", + }, + "response_map": { + "access_token": "data.access_token", + "refresh_token": "data.refresh_token", + "expires_in": "data.expires_in", + }, + }, + } + + +async def _create_server(test_client, admin_headers: dict[str, str], name: str) -> None: + response = await test_client.post( + "/api/system/mcp-servers", + json={ + "name": name, + "transport": "streamable_http", + "url": "http://mcp-upstream.local/mcp", + "description": "pytest mcp auth server", + "auth_config": _build_auth_config(), + }, + headers=admin_headers, + ) + assert response.status_code == 200, response.text + + +async def _cleanup_server(test_client, admin_headers: dict[str, str], name: str) -> None: + list_response = await test_client.get(f"/api/system/mcp-servers/{name}/connections", headers=admin_headers) + if list_response.status_code == 200: + for connection in list_response.json().get("data", []): + await test_client.delete( + f"/api/system/mcp-servers/{name}/connections/{connection['id']}", + headers=admin_headers, + ) + + await test_client.delete(f"/api/system/mcp-servers/{name}", headers=admin_headers) + await test_client.delete( + f"/api/system/mcp-servers/{name}", + params={"hard": "true"}, + headers=admin_headers, + ) + + +async def test_admin_can_manage_mcp_server_connections_via_real_api(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-auth") + await _create_server(test_client, admin_headers, server_name) + + try: + get_response = await test_client.get(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + assert get_response.status_code == 200, get_response.text + get_payload = get_response.json()["data"] + assert get_payload["name"] == server_name + assert get_payload["auth_config"]["provider"] == "custom_http_token" + + update_response = await test_client.put( + f"/api/system/mcp-servers/{server_name}", + json={ + "description": "updated auth server", + "auth_config": { + **_build_auth_config(), + "refresh_policy": { + "pre_refresh_seconds": 120, + "retry_once_on_401": True, + }, + }, + }, + headers=admin_headers, + ) + assert update_response.status_code == 200, update_response.text + updated_payload = update_response.json()["data"] + assert updated_payload["name"] == server_name + assert updated_payload["description"] == "updated auth server" + assert updated_payload["auth_config"]["refresh_policy"]["pre_refresh_seconds"] == 120 + + create_connection_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "department", + "scope_id": "finance-dept", + "display_name": "财务共享连接", + "external_subject": "finance-bot", + "credential": { + "secrets": { + "client_id": "cid-1", + "client_secret": "secret-1", + }, + "refresh_token": "refresh-1", + }, + "meta_json": {"tenant": "finance"}, + }, + headers=admin_headers, + ) + assert create_connection_response.status_code == 200, create_connection_response.text + connection_payload = create_connection_response.json()["data"] + connection_id = connection_payload["id"] + assert connection_payload["scope_type"] == "department" + assert connection_payload["display_name"] == "财务共享连接" + assert connection_payload["has_credentials"] is True + assert "credential_blob" not in connection_payload + + list_connections_response = await test_client.get( + f"/api/system/mcp-servers/{server_name}/connections", + headers=admin_headers, + ) + assert list_connections_response.status_code == 200, list_connections_response.text + listed_connections = list_connections_response.json()["data"] + assert len(listed_connections) == 1 + assert listed_connections[0]["id"] == connection_id + assert listed_connections[0]["has_credentials"] is True + assert "credential_blob" not in listed_connections[0] + + update_connection_response = await test_client.put( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}", + json={ + "display_name": "财务共享连接-更新", + "meta_json": {"tenant": "finance", "stage": "updated"}, + }, + headers=admin_headers, + ) + assert update_connection_response.status_code == 200, update_connection_response.text + assert update_connection_response.json()["data"]["display_name"] == "财务共享连接-更新" + assert update_connection_response.json()["data"]["meta_json"]["stage"] == "updated" + + status_response = await test_client.put( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}/status", + json={"status": "reauth_required"}, + headers=admin_headers, + ) + assert status_response.status_code == 200, status_response.text + assert status_response.json()["data"]["status"] == "reauth_required" + + reauth_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}/reauth", + headers=admin_headers, + ) + assert reauth_response.status_code == 200, reauth_response.text + assert reauth_response.json()["data"]["status"] == "active" + + delete_connection_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}", + headers=admin_headers, + ) + assert delete_connection_response.status_code == 200, delete_connection_response.text + + retire_response = await test_client.delete(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + assert retire_response.status_code == 200, retire_response.text + + hard_delete_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=admin_headers, + ) + assert hard_delete_response.status_code == 200, hard_delete_response.text + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + +async def test_hard_delete_mcp_server_returns_dependency_summary_when_connections_exist( + test_client, admin_headers +): + server_name = _build_server_name("pytest-mcp-delete") + await _create_server(test_client, admin_headers, server_name) + + try: + create_connection_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "department", + "scope_id": "finance-dept", + "display_name": "财务共享连接", + "credential": { + "secrets": { + "client_id": "cid-1", + "client_secret": "secret-1", + } + }, + }, + headers=admin_headers, + ) + assert create_connection_response.status_code == 200, create_connection_response.text + connection_id = create_connection_response.json()["data"]["id"] + + retire_response = await test_client.delete(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + assert retire_response.status_code == 200, retire_response.text + + hard_delete_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=admin_headers, + ) + assert hard_delete_response.status_code == 409, hard_delete_response.text + detail = hard_delete_response.json()["detail"] + assert detail["has_references"] is True + assert detail["connections"] == [ + { + "scope_type": "department", + "scope_id": "finance-dept", + "status": "active", + } + ] + + delete_connection_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}", + headers=admin_headers, + ) + assert delete_connection_response.status_code == 200, delete_connection_response.text + + hard_delete_after_cleanup_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=admin_headers, + ) + assert hard_delete_after_cleanup_response.status_code == 200, hard_delete_after_cleanup_response.text + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + +async def test_bound_auth_server_test_endpoint_requires_connection_level_testing(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-bound-test") + await _create_server(test_client, admin_headers, server_name) + + try: + response = await test_client.post(f"/api/system/mcp-servers/{server_name}/test", headers=admin_headers) + assert response.status_code == 400, response.text + assert "需要绑定连接" in response.json()["detail"] + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + +async def test_bound_auth_server_test_endpoint_succeeds_with_connection(test_client, admin_headers): + me_response = await test_client.get("/api/auth/me", headers=admin_headers) + assert me_response.status_code == 200, me_response.text + me_data = me_response.json() + admin_dept_id = str(me_data["department_id"]) + + server_name = _build_server_name("pytest-mcp-bound-test-ok") + + response = await test_client.post( + "/api/system/mcp-servers", + json={ + "name": server_name, + "transport": "sse", + "url": "http://mcp-demo-server:8999/sse", + "description": "pytest mcp auth server ok", + "auth_config": { + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${secret.access_token}", + } + ], + }, + }, + }, + headers=admin_headers, + ) + assert response.status_code == 200, response.text + + try: + test_fail_response = await test_client.post(f"/api/system/mcp-servers/{server_name}/test", headers=admin_headers) + assert test_fail_response.status_code == 400, test_fail_response.text + assert "需要绑定连接" in test_fail_response.json()["detail"] + + conn_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "department", + "scope_id": admin_dept_id, + "display_name": "Dept Scope Test OK", + "credential": {"secrets": {"access_token": "dummy_dept_token"}}, + }, + headers=admin_headers, + ) + assert conn_response.status_code == 200, conn_response.text + conn_id = conn_response.json()["data"]["id"] + + test_ok_response = await test_client.post(f"/api/system/mcp-servers/{server_name}/test", headers=admin_headers) + assert test_ok_response.status_code == 200, test_ok_response.text + assert test_ok_response.json()["success"] is True + assert test_ok_response.json()["tool_count"] > 0 + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + + +async def test_bound_auth_tools_endpoint_requires_current_admin_connection(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-bound-tools") + await _create_server(test_client, admin_headers, server_name) + + try: + response = await test_client.get(f"/api/system/mcp-servers/{server_name}/tools", headers=admin_headers) + assert response.status_code == 403, response.text + assert "Active MCP connection not found" in response.json()["detail"] + + refresh_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/tools/refresh", + headers=admin_headers, + ) + assert refresh_response.status_code == 403, refresh_response.text + assert "Active MCP connection not found" in refresh_response.json()["detail"] + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + +async def test_create_mcp_server_rejects_invalid_auth_config_via_real_api(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-invalid-auth") + + response = await test_client.post( + "/api/system/mcp-servers", + json={ + "name": server_name, + "transport": "streamable_http", + "url": "http://mcp-upstream.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${access_token}", + } + ], + }, + }, + }, + headers=admin_headers, + ) + + assert response.status_code == 400, response.text + assert "auth_config 配置无效" in response.json()["detail"] + + +async def test_create_system_connection_defaults_scope_id_to_global_via_real_api(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-system-scope") + await _create_server(test_client, admin_headers, server_name) + + try: + response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "system", + "scope_id": "", + "display_name": "全局共享连接", + "credential": {"secrets": {"client_id": "cid-1", "client_secret": "secret-1"}}, + }, + headers=admin_headers, + ) + assert response.status_code == 200, response.text + payload = response.json()["data"] + assert payload["scope_type"] == "system" + assert payload["scope_id"] == "global" + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + +async def test_mcp_connections_all_scopes_e2e(test_client, admin_headers): + # 1. 获取当前管理员的用户信息以获取正确的用户 ID 和部门 ID + me_response = await test_client.get("/api/auth/me", headers=admin_headers) + assert me_response.status_code == 200, me_response.text + me_data = me_response.json() + admin_db_id = str(me_data["id"]) + admin_dept_id = str(me_data["department_id"]) + + server_name = _build_server_name("pytest-mcp-scopes") + + try: + # A. 测试个人 (User) 范围 + # 创建一个 binding_scope="user" 的服务器 + create_response = await test_client.post( + "/api/system/mcp-servers", + json={ + "name": server_name, + "transport": "sse", + "url": "http://mcp-demo-server:8999/sse", # 使用启动的 mock server sse 端口 + "description": "pytest scopes user test", + "auth_config": { + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + }, + headers=admin_headers, + ) + assert create_response.status_code == 200, create_response.text + + # 创建对应的个人连接,scope_id 必须与当前用户的 db_id (主键数字字符串) 一致 + conn_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "user", + "scope_id": admin_db_id, + "display_name": "User Scope Test", + "credential": {"secrets": {"access_token": "dummy_user_token"}}, + }, + headers=admin_headers, + ) + assert conn_response.status_code == 200, conn_response.text + conn_id = conn_response.json()["data"]["id"] + + # 测试该连接的可用性,测试时会根据 auth_context 自动解析并匹配 scope_id + test_conn_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections/{conn_id}/test", + headers=admin_headers, + ) + assert test_conn_response.status_code == 200, test_conn_response.text + assert test_conn_response.json()["tool_count"] > 0 + + # 清理该连接 + del_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{conn_id}", + headers=admin_headers, + ) + assert del_response.status_code == 200, del_response.text + + # 清理服务器 (软删除) + retire_response = await test_client.delete(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + assert retire_response.status_code == 200, retire_response.text + hard_del_response = await test_client.delete(f"/api/system/mcp-servers/{server_name}?hard=true", headers=admin_headers) + assert hard_del_response.status_code == 200, hard_del_response.text + + # B. 测试部门 (Department) 范围 + create_response = await test_client.post( + "/api/system/mcp-servers", + json={ + "name": server_name, + "transport": "sse", + "url": "http://mcp-demo-server:8999/sse", + "description": "pytest scopes dept test", + "auth_config": { + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + }, + headers=admin_headers, + ) + assert create_response.status_code == 200, create_response.text + + # 创建对应的部门连接 + conn_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "department", + "scope_id": admin_dept_id, + "display_name": "Dept Scope Test", + "credential": {"secrets": {"access_token": "dummy_dept_token"}}, + }, + headers=admin_headers, + ) + assert conn_response.status_code == 200, conn_response.text + conn_id = conn_response.json()["data"]["id"] + + # 测试连接 + test_conn_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections/{conn_id}/test", + headers=admin_headers, + ) + assert test_conn_response.status_code == 200, test_conn_response.text + assert test_conn_response.json()["tool_count"] > 0 + + # 清理 + await test_client.delete(f"/api/system/mcp-servers/{server_name}/connections/{conn_id}", headers=admin_headers) + await test_client.delete(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + await test_client.delete(f"/api/system/mcp-servers/{server_name}?hard=true", headers=admin_headers) + + # C. 测试系统 (System) 范围 + create_response = await test_client.post( + "/api/system/mcp-servers", + json={ + "name": server_name, + "transport": "sse", + "url": "http://mcp-demo-server:8999/sse", + "description": "pytest scopes system test", + "auth_config": { + "version": 1, + "provider": "bound_secret", + "binding_scope": "system", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + }, + headers=admin_headers, + ) + assert create_response.status_code == 200, create_response.text + + # 创建对应的全局连接 + conn_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "system", + "scope_id": "global", + "display_name": "Global Scope Test", + "credential": {"secrets": {"access_token": "dummy_global_token"}}, + }, + headers=admin_headers, + ) + assert conn_response.status_code == 200, conn_response.text + conn_id = conn_response.json()["data"]["id"] + + # 测试连接 + test_conn_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections/{conn_id}/test", + headers=admin_headers, + ) + assert test_conn_response.status_code == 200, test_conn_response.text + assert test_conn_response.json()["tool_count"] > 0 + + # 清理 + await test_client.delete(f"/api/system/mcp-servers/{server_name}/connections/{conn_id}", headers=admin_headers) + await test_client.delete(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + await test_client.delete(f"/api/system/mcp-servers/{server_name}?hard=true", headers=admin_headers) + + finally: + await _cleanup_server(test_client, admin_headers, server_name) + diff --git a/backend/test/integration/services/test_model_provider_runtime_connectivity.py b/backend/test/integration/services/test_model_provider_runtime_connectivity.py index 56f01052b..9964603a8 100644 --- a/backend/test/integration/services/test_model_provider_runtime_connectivity.py +++ b/backend/test/integration/services/test_model_provider_runtime_connectivity.py @@ -15,7 +15,7 @@ from yuxi.models.embed import OllamaEmbedding, OtherEmbedding from yuxi.models.rerank import DashscopeReranker, OpenAIReranker from yuxi.services.model_provider_service import ( - _resolve_api_key, + resolve_api_key, ensure_builtin_model_providers_in_db, get_model_provider_by_id, ) @@ -34,7 +34,7 @@ def _model_spec(provider: ModelProvider, model: dict[str, Any]) -> dict[str, Any]: """Turn an enabled model item into runtime parameters for existing model clients.""" - api_key = _resolve_api_key(provider) + api_key = resolve_api_key(provider) if api_key is None: api_key = "no_api_key" return { diff --git a/backend/test/mcp_demo_server.py b/backend/test/mcp_demo_server.py new file mode 100644 index 000000000..3211e2905 --- /dev/null +++ b/backend/test/mcp_demo_server.py @@ -0,0 +1,233 @@ +from __future__ import annotations +import argparse +import asyncio +import contextvars +import logging +import os +import sys +from typing import Any +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from mcp.server import Server, NotificationOptions +from mcp.server.models import InitializationOptions +import mcp.types as types +from mcp.server.sse import SseServerTransport +from mcp.server.stdio import stdio_server +from contextlib import asynccontextmanager +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.fastmcp.server import StreamableHTTPASGIApp + +# 简体中文注释与日志规范 (RULE[user_global]) +logger = logging.getLogger("mcp_demo_server") +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + +# 用于在不同传输协议下传递当前请求身份上下文的 ContextVar +current_request_headers_var = contextvars.ContextVar("current_request_headers", default=None) + +# 实例化 MCP 核心服务对象 +server = Server("yuxi-mcp-demo-server") + +@server.list_tools() +async def handle_list_tools() -> list[types.Tool]: + """根据身份或环境变量返回过滤后的三级权限工具列表""" + headers = current_request_headers_var.get() or {} + + # 优先级: HTTP Headers > 系统环境变量 (兼容 stdio 与 sse 两种环境的测试) + dept_id = headers.get("x-department-id") or os.environ.get("X_DEPARTMENT_ID") + user_id = headers.get("x-user-id") or os.environ.get("X_USER_ID") + auth_token = headers.get("authorization") or os.environ.get("AUTHORIZATION") + + logger.info(f"Listing tools - AuthToken: {auth_token}, DeptID: {dept_id}, UserID: {user_id}") + + # 基础路由工具 (全局可见) + tools = [ + types.Tool( + name="echo_global", + description="全局通用工具,无须任何权限即可访问", + inputSchema={ + "type": "object", + "properties": { + "message": {"type": "string", "description": "要回显的内容"} + }, + "required": ["message"] + }, + ) + ] + + # 部门级别权限工具 + if dept_id: + tools.append( + types.Tool( + name="echo_dept_data", + description=f"部门级别受限工具,当前已授权部门ID: {dept_id}", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "查询参数"} + }, + "required": ["query"] + }, + ) + ) + + # 用户个人级别权限工具 + if user_id: + tools.append( + types.Tool( + name="echo_user_profile", + description=f"个人受限工具,当前已授权用户ID: {user_id}", + inputSchema={ + "type": "object", + "properties": { + "dummy": {"type": "string", "description": "占位参数"} + } + }, + ) + ) + + return tools + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + """执行工具回显结果""" + logger.info(f"Calling tool: {name} with args: {arguments}") + args = arguments or {} + + if name == "echo_global": + msg = args.get("message", "") + return [types.TextContent(type="text", text=f"[Global Output] 回显内容: {msg}")] + + elif name == "echo_dept_data": + query = args.get("query", "") + return [types.TextContent(type="text", text=f"[Department Output] 数据查询回显: {query}")] + + elif name == "echo_user_profile": + return [types.TextContent(type="text", text="[User Output] 成功获取用户专有敏感配置与画像数据")] + + else: + raise ValueError(f"Unknown tool: {name}") + + +# ============================================================================= +# === SSE 与 Streamable HTTP 传输协议支持 (FastAPI) === +# ============================================================================= + +session_manager = StreamableHTTPSessionManager( + app=server, + stateless=True, +) + +@asynccontextmanager +async def lifespan(app: FastAPI): + async with session_manager.run(): + yield + +app = FastAPI(title="MCP Demo Server", lifespan=lifespan) +app.mount("/mcp", StreamableHTTPASGIApp(session_manager)) + +sse_transport = SseServerTransport("/messages") + +@app.post("/oauth/token") +async def oauth_token(request: Request): + """ + 模拟 OAuth2 认证端点。 + 返回一个 15 秒过期的 access_token,用以充分验证 Yuxi 后台的“短期 Token 自动失效与刷新”链路。 + """ + logger.info("Handling OAuth token request...") + return { + "access_token": "mock_access_token_123456", + "refresh_token": "mock_refresh_token_789000", + "expires_in": 15, # 15 秒过期,利于测试 + "token_type": "Bearer", + "scope": "read write" + } + +class SSEEndpoint: + async def __call__(self, scope, receive, send): + """建立 SSE 长连接通道,并将当前的 Headers 存入 ContextVar""" + headers_dict = {k.decode('utf-8'): v.decode('utf-8') for k, v in scope.get("headers", [])} + logger.info(f"New SSE connection attempt. Headers: {headers_dict}") + + # 注入 ContextVar,使得在该长连接处理循环下的所有 list/call_tool 能读取到 header + token = current_request_headers_var.set(headers_dict) + try: + async with sse_transport.connect_sse( + scope, receive, send + ) as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="yuxi-mcp-demo-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + finally: + current_request_headers_var.reset(token) + +app.add_route("/sse", SSEEndpoint(), methods=["GET"]) + +class MessagesEndpoint: + async def __call__(self, scope, receive, send): + """接收 SSE 通道发来的具体 JSON-RPC 请求""" + await sse_transport.handle_post_message(scope, receive, send) + +app.add_route("/messages", MessagesEndpoint(), methods=["POST"]) + + +# ============================================================================= +# === Stdio 传输协议支持 (本地子进程) === +# ============================================================================= + +async def run_stdio(): + """以 Stdio 形式在控制台管道中拉起""" + logger.info("Starting Stdio server transport...") + + # Stdio 模式下从系统环境变量读取 headers + current_request_headers_var.set(dict(os.environ)) + + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="yuxi-mcp-demo-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + +# ============================================================================= +# === 启动入口 === +# ============================================================================= + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mock MCP Demo Server") + parser.add_argument( + "--transport", + choices=["stdio", "sse", "streamable-http"], + default="sse", + help="传输协议类型 (默认为 sse)" + ) + parser.add_argument( + "--port", + type=int, + default=8999, + help="FastAPI SSE 服务的端口号" + ) + args = parser.parse_args() + + if args.transport == "stdio": + asyncio.run(run_stdio()) + else: + logger.info(f"Starting FastAPI Server (transport: {args.transport}) on port {args.port}...") + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/backend/test/unit/middlewares/test_runtime_config_middleware.py b/backend/test/unit/middlewares/test_runtime_config_middleware.py new file mode 100644 index 000000000..3f6c792cb --- /dev/null +++ b/backend/test/unit/middlewares/test_runtime_config_middleware.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +import logging +from types import SimpleNamespace + +import pytest +from langchain.tools.tool_node import ToolCallRequest +from langchain_core.messages import ToolMessage + +import yuxi.agents.middlewares.runtime_config_middleware as runtime_config_middleware +from yuxi.agents.middlewares.runtime_config_middleware import RuntimeConfigMiddleware + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_tools_from_context_passes_auth_context_to_mcp_loader( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +): + captured: list[tuple[str, str | None, str | None]] = [] + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + captured.append((server_name, auth_context.user_id, auth_context.department_id)) + return [] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = RuntimeConfigMiddleware() + context = SimpleNamespace( + tools=[], + mcps=["finance-gateway"], + user_id="user-1", + department_id="dept-9", + ) + + with caplog.at_level(logging.WARNING, logger="Yuxi"): + tools = await middleware.get_tools_from_context(context) + + assert tools == [] + assert captured == [("finance-gateway", "user-1", "dept-9")] + assert "mcp dependency unavailable" not in caplog.text + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_tools_from_context_uses_work_id_for_user_scoped_auth(monkeypatch: pytest.MonkeyPatch): + captured: list[tuple[str, str | None, str | None]] = [] + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + captured.append((server_name, auth_context.user_id, auth_context.department_id)) + return [] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = RuntimeConfigMiddleware() + context = SimpleNamespace( + tools=[], + mcps=["dts-mcp_server"], + user_id="2", + work_id="login-1001", + department_id="dept-9", + ) + + tools = await middleware.get_tools_from_context(context) + + assert tools == [] + assert captured == [("dts-mcp_server", "2", "dept-9")] + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_runtime_loaded_mcp_tool_can_be_executed_when_tool_node_did_not_pre_register_it( + monkeypatch: pytest.MonkeyPatch, +): + runtime_tool = SimpleNamespace(name="getTicket") + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del auth_context, db, http_client + assert server_name == "dts-mcp_server" + return [runtime_tool] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = RuntimeConfigMiddleware(enable_model_override=False, enable_system_prompt_override=False) + context = SimpleNamespace( + tools=[], + mcps=["dts-mcp_server"], + user_id="2", + work_id="login-1001", + department_id="dept-9", + model=None, + system_prompt="", + ) + + class DummyRequest: + def __init__(self, tools=None): + self.runtime = SimpleNamespace(context=context) + self.tools = tools or [] + self.system_message = None + + def override(self, **kwargs): + clone = DummyRequest(kwargs.get("tools", self.tools)) + clone.runtime = kwargs.get("runtime", self.runtime) + clone.system_message = kwargs.get("system_message", self.system_message) + return clone + + async def model_handler(next_request): + return next_request.tools + + await middleware.awrap_model_call(DummyRequest(), model_handler) + + tool_request = ToolCallRequest( + tool_call={"name": "getTicket", "args": {"arg0": "DTS2026012932159"}, "id": "call-1"}, + tool=None, + state={}, + runtime=SimpleNamespace(context=context), + ) + captured = {} + + async def tool_handler(next_request): + captured["tool"] = next_request.tool + return ToolMessage(content="ok", name=next_request.tool_call["name"], tool_call_id=next_request.tool_call["id"]) + + result = await middleware.awrap_tool_call(tool_request, tool_handler) + + assert result.content == "ok" + assert captured["tool"] is runtime_tool + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_awrap_model_call_appends_runtime_loaded_mcp_tools(monkeypatch: pytest.MonkeyPatch): + runtime_tool = SimpleNamespace(name="mcp__financeGateway__query", metadata={}) + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + assert server_name == "finance-gateway" + assert auth_context.user_id == "user-1" + assert auth_context.department_id == "dept-9" + return [runtime_tool] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + class DummyRequest: + def __init__(self): + self.runtime = SimpleNamespace( + context=SimpleNamespace( + tools=[], + mcps=["finance-gateway"], + user_id="user-1", + department_id="dept-9", + model=None, + system_prompt="", + ) + ) + self.tools = [] + self.system_message = None + + def override(self, **kwargs): + clone = DummyRequest() + clone.runtime = kwargs.get("runtime", self.runtime) + clone.tools = kwargs.get("tools", self.tools) + clone.system_message = kwargs.get("system_message", self.system_message) + return clone + + middleware = RuntimeConfigMiddleware(enable_model_override=False, enable_system_prompt_override=False) + request = DummyRequest() + + async def handler(next_request): + return next_request.tools + + tools = await middleware.awrap_model_call(request, handler) + + assert tools == [runtime_tool] + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_awrap_model_call_replaces_stale_managed_tool_with_fresh_runtime_tool(monkeypatch: pytest.MonkeyPatch): + stale_tool = SimpleNamespace(name="mcp__financeGateway__query", metadata={"version": "stale"}) + fresh_tool = SimpleNamespace(name="mcp__financeGateway__query", metadata={"version": "fresh"}) + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + assert server_name == "finance-gateway" + assert auth_context.user_id == "user-1" + assert auth_context.department_id == "dept-9" + return [fresh_tool] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + class DummyRequest: + def __init__(self, tools): + self.runtime = SimpleNamespace( + context=SimpleNamespace( + tools=[], + mcps=["finance-gateway"], + user_id="user-1", + department_id="dept-9", + model=None, + system_prompt="", + ) + ) + self.tools = tools + self.system_message = None + + def override(self, **kwargs): + clone = DummyRequest(kwargs.get("tools", self.tools)) + clone.runtime = kwargs.get("runtime", self.runtime) + clone.system_message = kwargs.get("system_message", self.system_message) + return clone + + middleware = RuntimeConfigMiddleware(enable_model_override=False, enable_system_prompt_override=False) + middleware.tools = [stale_tool] + request = DummyRequest([stale_tool]) + + async def handler(next_request): + return next_request.tools + + tools = await middleware.awrap_model_call(request, handler) + + assert tools == [fresh_tool] diff --git a/backend/test/unit/middlewares/test_skills_middleware.py b/backend/test/unit/middlewares/test_skills_middleware.py new file mode 100644 index 000000000..f3664e165 --- /dev/null +++ b/backend/test/unit/middlewares/test_skills_middleware.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import logging +from types import SimpleNamespace + +import pytest + +import yuxi.agents.middlewares.skills_middleware as skills_middleware +from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware, collect_context_mcp_names_for_preload + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_mcp_tools_from_context_passes_auth_context_to_mcp_loader( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +): + captured: list[tuple[str, str | None, str | None]] = [] + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + captured.append((server_name, auth_context.user_id, auth_context.department_id)) + return [] + + monkeypatch.setattr(skills_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = SkillsMiddleware() + context = SimpleNamespace( + mcps=["finance-gateway"], + user_id="user-1", + department_id="dept-9", + ) + + with caplog.at_level(logging.WARNING, logger="Yuxi"): + tools = await middleware._get_mcp_tools_from_context(context) + + assert tools == [] + assert captured == [("finance-gateway", "user-1", "dept-9")] + assert "mcp dependency unavailable" not in caplog.text + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_mcp_tools_from_context_uses_work_id_for_user_scoped_auth(monkeypatch: pytest.MonkeyPatch): + captured: list[tuple[str, str | None, str | None]] = [] + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + captured.append((server_name, auth_context.user_id, auth_context.department_id)) + return [] + + monkeypatch.setattr(skills_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = SkillsMiddleware() + context = SimpleNamespace( + mcps=["dts-mcp_server"], + user_id="2", + work_id="login-1001", + department_id="dept-9", + ) + + tools = await middleware._get_mcp_tools_from_context(context) + + assert tools == [] + assert captured == [("dts-mcp_server", "2", "dept-9")] + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_collect_context_mcp_names_for_preload_includes_configured_skill_dependencies( + monkeypatch: pytest.MonkeyPatch, +): + async def fake_get_dependency_map(db=None): + del db + return { + "reporter": {"tools": [], "mcps": ["charts"], "skills": ["common"]}, + "common": {"tools": [], "mcps": ["finance-gateway"], "skills": []}, + } + + monkeypatch.setattr(skills_middleware, "get_dependency_map", fake_get_dependency_map) + + context = SimpleNamespace(mcps=["direct", "charts"], skills=["reporter"]) + + names = await collect_context_mcp_names_for_preload(context) + + assert names == ["direct", "charts", "finance-gateway"] diff --git a/backend/test/unit/routers/test_mcp_internal_router.py b/backend/test/unit/routers/test_mcp_internal_router.py new file mode 100644 index 000000000..2da4e2a65 --- /dev/null +++ b/backend/test/unit/routers/test_mcp_internal_router.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import httpx +from fastapi import FastAPI, Response +from fastapi.testclient import TestClient + +from server.routers.mcp_internal_router import mcp_internal +from server.utils.auth_middleware import get_db + + +def _build_app() -> FastAPI: + app = FastAPI() + app.include_router(mcp_internal, prefix="/api") + + async def fake_db(): + return None + + app.dependency_overrides[get_db] = fake_db + return app + + +def test_internal_proxy_route_forwards_request(monkeypatch): + async def fake_handle_mcp_proxy_request(server_name, request, path, internal_token, db): + assert server_name == "finance-proxy" + assert path == "some/path" + assert internal_token == "test-token" + return Response(content='{"ok": true}', media_type="application/json") + + monkeypatch.setattr( + "server.routers.mcp_internal_router.handle_mcp_proxy_request", + fake_handle_mcp_proxy_request + ) + + client = TestClient(_build_app()) + resp = client.post( + "/api/internal/mcp-proxy/finance-proxy/some/path", + headers={"X-Yuxi-MCP-Proxy-Token": "test-token", "content-type": "application/json"}, + json={"jsonrpc": "2.0", "id": 1}, + ) + + assert resp.status_code == 200, resp.text + assert resp.json() == {"ok": True} + + +def test_internal_proxy_route_requires_internal_token(): + client = TestClient(_build_app()) + resp = client.post("/api/internal/mcp-proxy/finance-proxy", json={"jsonrpc": "2.0", "id": 1}) + assert resp.status_code == 401, resp.text # Missing header raises 401 Unauthorized diff --git a/backend/test/unit/routers/test_mcp_router.py b/backend/test/unit/routers/test_mcp_router.py index 4e1bc994c..541a83b94 100644 --- a/backend/test/unit/routers/test_mcp_router.py +++ b/backend/test/unit/routers/test_mcp_router.py @@ -25,6 +25,7 @@ async def fake_admin_user(): user_id="admin", password_hash="x", role="admin", + department_id=42, ) async def fake_required_user(): @@ -33,6 +34,7 @@ async def fake_required_user(): user_id="admin" if allow_admin else "user", password_hash="x", role="admin" if allow_admin else "user", + department_id=42, ) app.dependency_overrides[get_db] = fake_db @@ -41,6 +43,18 @@ async def fake_required_user(): return app +def _auth_config(binding_scope: str = "user") -> dict: + return { + "version": 1, + "provider": "bound_secret", + "binding_scope": binding_scope, + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + } + + def test_update_mcp_server_status(monkeypatch): captured = {} @@ -133,3 +147,876 @@ async def fake_get_all_mcp_servers(db): assert data_user["name"] == "test-mcp" assert data_user["description"] == "test mcp description" assert data_user["enabled"] is True + + +def test_get_mcp_server_normal_user_gets_public_detail(monkeypatch): + class DummyServer: + name = "personal-gateway" + description = "personal gateway" + transport = "streamable_http" + url = "http://gateway.local/mcp" + headers = {"Authorization": "Bearer secret"} + enabled = 1 + tags = ["finance"] + icon = "🔐" + auth_config_json = { + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + } + + def to_dict(self): + return {"name": self.name, "url": self.url, "headers": self.headers} + + async def fake_get_mcp_server(db, name): + del db + assert name == "personal-gateway" + return DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + + client = TestClient(_build_app(allow_admin=False)) + resp = client.get("/api/system/mcp-servers/personal-gateway") + + assert resp.status_code == 200, resp.text + data = resp.json()["data"] + assert data["name"] == "personal-gateway" + assert data["auth_config"] == { + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "manifest_scope": "server", + "secret_fields": ["access_token"], + } + assert "url" not in data + assert "headers" not in data + + +def test_get_mcp_server_normal_user_cannot_read_disabled_server(monkeypatch): + class DummyServer: + name = "disabled-gateway" + enabled = 0 + + async def fake_get_mcp_server(db, name): + del db, name + return DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + + client = TestClient(_build_app(allow_admin=False)) + resp = client.get("/api/system/mcp-servers/disabled-gateway") + + assert resp.status_code == 404, resp.text + + +def test_create_mcp_server_forwards_auth_config(monkeypatch): + captured = {} + + class DummyServer: + def to_dict(self): + return {"name": "gateway", "auth_config": {"provider": "custom_http_token"}} + + async def fake_create_mcp_server(db, **kwargs): + del db + captured.update(kwargs) + return DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.create_mcp_server", fake_create_mcp_server) + + client = TestClient(_build_app()) + resp = client.post( + "/api/system/mcp-servers", + json={ + "name": "gateway", + "transport": "streamable_http", + "url": "http://gateway.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + }, + }, + ) + assert resp.status_code == 200, resp.text + assert captured["auth_config"] == { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 0, "retry_once_on_401": False}, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + } + + +def test_update_mcp_server_forwards_auth_config(monkeypatch): + captured = {} + + class DummyServer: + def to_dict(self): + return {"name": "gateway", "auth_config": {"provider": "bound_secret"}} + + async def fake_update_mcp_server(db, name, **kwargs): + del db + captured["name"] = name + captured.update(kwargs) + return DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.update_mcp_server", fake_update_mcp_server) + + client = TestClient(_build_app()) + resp = client.put( + "/api/system/mcp-servers/gateway", + json={ + "description": "updated", + "auth_config": { + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + }, + ) + assert resp.status_code == 200, resp.text + assert captured["name"] == "gateway" + assert captured["auth_config"] == { + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 0, "retry_once_on_401": False}, + "token_request": None, + } + + +def test_create_mcp_server_rejects_invalid_auth_config(monkeypatch): + async def fake_create_mcp_server(db, **kwargs): + raise AssertionError("create_mcp_server should not be called when auth_config is invalid") + + monkeypatch.setattr("server.routers.mcp_router.create_mcp_server", fake_create_mcp_server) + + client = TestClient(_build_app()) + resp = client.post( + "/api/system/mcp-servers", + json={ + "name": "gateway", + "transport": "streamable_http", + "url": "http://gateway.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + }, + }, + ) + assert resp.status_code == 400, resp.text + assert "auth_config 配置无效" in resp.json()["detail"] + + +def test_list_mcp_connections(monkeypatch): + class DummyConnection: + def __init__(self, connection_id): + self.connection_id = connection_id + + def to_dict(self): + return {"id": self.connection_id, "scope_type": "department", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_list_mcp_connections(db, **kwargs): + del db, kwargs + return [DummyConnection(1), DummyConnection(2)] + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.list_mcp_connections", fake_list_mcp_connections) + + client = TestClient(_build_app()) + resp = client.get("/api/system/mcp-servers/gateway/connections") + assert resp.status_code == 200, resp.text + assert resp.json()["data"] == [ + {"id": 1, "scope_type": "department", "status": "active"}, + {"id": 2, "scope_type": "department", "status": "active"}, + ] + + +def test_list_mcp_connections_normal_user_only_lists_own_user_scope(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 9, "scope_type": "user", "scope_id": "user", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name, "enabled": 1, "auth_config_json": _auth_config()})() + + async def fake_list_mcp_connections(db, **kwargs): + del db + captured.update(kwargs) + return [DummyConnection()] + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.list_mcp_connections", fake_list_mcp_connections) + + client = TestClient(_build_app(allow_admin=False)) + resp = client.get("/api/system/mcp-servers/gateway/connections") + + assert resp.status_code == 200, resp.text + assert captured == {"server_name": "gateway", "scope_type": "user", "scope_id": "user"} + assert resp.json()["data"] == [ + {"id": 9, "scope_type": "user", "scope_id": "user", "status": "active"} + ] + + +def test_list_mcp_connections_admin_mine_filters_to_current_user(monkeypatch): + captured = {} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name, "enabled": 1, "auth_config_json": _auth_config()})() + + async def fake_list_mcp_connections(db, **kwargs): + del db + captured.update(kwargs) + return [] + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.list_mcp_connections", fake_list_mcp_connections) + + client = TestClient(_build_app()) + resp = client.get("/api/system/mcp-servers/gateway/connections?mine=true") + + assert resp.status_code == 200, resp.text + assert captured == {"server_name": "gateway", "scope_type": "user", "scope_id": "admin"} + + +def test_list_mcp_connections_paginated_returns_summary(monkeypatch): + captured = {} + count_filters = [] + + class DummyConnection: + def to_dict(self): + return {"id": 12, "scope_type": "user", "scope_id": "1", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type( + "DummyServer", + (), + { + "name": name, + "auth_config_json": _auth_config("user"), + }, + )() + + async def fake_list_mcp_connections_page(db, **kwargs): + del db + captured.update(kwargs) + return [DummyConnection()], 17 + + async def fake_count_mcp_connections(db, **kwargs): + del db + count_filters.append(kwargs.get("status_filter", "all")) + return {"all": 30, "active": 14, "attention": 3, "disabled": 5}.get( + kwargs.get("status_filter", "all"), 0 + ) + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr( + "server.routers.mcp_router.list_mcp_connections_page", + fake_list_mcp_connections_page, + ) + monkeypatch.setattr("server.routers.mcp_router.count_mcp_connections", fake_count_mcp_connections) + + client = TestClient(_build_app()) + resp = client.get( + "/api/system/mcp-servers/gateway/connections?paginated=true&status=attention&search=alice&page=2&page_size=5" + ) + + assert resp.status_code == 200, resp.text + assert captured["server_name"] == "gateway" + assert captured["status_filter"] == "attention" + assert captured["effective_scope_type"] == "user" + assert captured["credentials_required"] is True + assert captured["search"] == "alice" + assert captured["page"] == 2 + assert captured["page_size"] == 5 + payload = resp.json()["data"] + assert payload["items"] == [{"id": 12, "scope_type": "user", "scope_id": "1", "status": "active"}] + assert payload["total"] == 17 + assert payload["summary"] == {"total": 30, "active": 14, "attention": 3, "disabled": 5} + assert count_filters == ["all", "active", "attention", "disabled"] + + +def test_create_mcp_connection(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 7, "scope_type": "department", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_create_mcp_connection(db, **kwargs): + del db + captured.update(kwargs) + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.create_mcp_connection", fake_create_mcp_connection) + + client = TestClient(_build_app()) + resp = client.post( + "/api/system/mcp-servers/gateway/connections", + json={ + "scope_type": "department", + "scope_id": "42", + "display_name": "财务部共享连接", + "external_subject": "finance-user", + "credential": {"secrets": {"access_token": "token-1"}}, + "meta_json": {"tenant": "finance"}, + }, + ) + assert resp.status_code == 200, resp.text + assert captured["server_name"] == "gateway" + assert captured["scope_type"] == "department" + assert captured["scope_id"] == "42" + assert captured["credential_blob"] == '{"secrets": {"access_token": "token-1"}}' + assert captured["created_by"] == "admin" + + +def test_create_mcp_connection_normal_user_auto_binds_own_scope(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 11, "scope_type": "user", "scope_id": "user", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type( + "DummyServer", + (), + { + "name": name, + "enabled": 1, + "auth_config_json": { + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [ + {"name": "Authorization", "value_template": "Bearer ${secret.access_token}"} + ], + }, + }, + }, + )() + + async def fake_create_mcp_connection(db, **kwargs): + del db + captured.update(kwargs) + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.create_mcp_connection", fake_create_mcp_connection) + + client = TestClient(_build_app(allow_admin=False)) + resp = client.post( + "/api/system/mcp-servers/gateway/connections", + json={ + "scope_type": "user", + "display_name": "我的连接", + "credential": {"secrets": {"access_token": "token-1"}}, + }, + ) + + assert resp.status_code == 200, resp.text + assert captured["server_name"] == "gateway" + assert captured["scope_type"] == "user" + assert captured["scope_id"] == "user" + assert captured["created_by"] == "user" + + +def test_create_mcp_connection_normal_user_rejects_non_user_scope(monkeypatch): + async def fake_get_mcp_server(db, name): + del db + return type( + "DummyServer", + (), + { + "name": name, + "enabled": 1, + "auth_config_json": { + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [ + {"name": "Authorization", "value_template": "Bearer ${secret.access_token}"} + ], + }, + }, + }, + )() + + async def fake_create_mcp_connection(db, **kwargs): + raise AssertionError("ordinary users must not create shared MCP connections") + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.create_mcp_connection", fake_create_mcp_connection) + + client = TestClient(_build_app(allow_admin=False)) + resp = client.post( + "/api/system/mcp-servers/gateway/connections", + json={"scope_type": "department", "scope_id": "42"}, + ) + + assert resp.status_code == 403, resp.text + + +def test_update_mcp_connection_normal_user_rejects_non_user_binding(monkeypatch): + async def fake_get_mcp_server(db, name): + del db + return type( + "DummyServer", + (), + {"name": name, "enabled": 1, "auth_config_json": _auth_config("department")}, + )() + + async def fake_get_mcp_connection(db, connection_id): + del db, connection_id + raise AssertionError("ordinary users must not manage connections on shared-bound MCPs") + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + + client = TestClient(_build_app(allow_admin=False)) + resp = client.put( + "/api/system/mcp-servers/gateway/connections/7", + json={"display_name": "我的连接"}, + ) + + assert resp.status_code == 403, resp.text + + +def test_update_mcp_connection_status(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 7, "status": "reauth_required"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_set_mcp_connection_status(db, connection_id, **kwargs): + del db + captured["connection_id"] = connection_id + captured.update(kwargs) + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.set_mcp_connection_status", fake_set_mcp_connection_status) + + client = TestClient(_build_app()) + resp = client.put( + "/api/system/mcp-servers/gateway/connections/7/status", + json={"status": "reauth_required"}, + ) + assert resp.status_code == 200, resp.text + assert captured == { + "connection_id": 7, + "status": "reauth_required", + "updated_by": "admin", + } + + +def test_update_mcp_connection(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 7, "display_name": "新连接名", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_update_mcp_connection(db, connection_id, **kwargs): + del db + captured["connection_id"] = connection_id + captured.update(kwargs) + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.update_mcp_connection", fake_update_mcp_connection) + + client = TestClient(_build_app()) + resp = client.put( + "/api/system/mcp-servers/gateway/connections/7", + json={ + "display_name": "新连接名", + "credential": {"secrets": {"access_token": "token-2"}}, + }, + ) + assert resp.status_code == 200, resp.text + assert captured["connection_id"] == 7 + assert captured["display_name"] == "新连接名" + assert captured["credential_blob"] == '{"secrets": {"access_token": "token-2"}}' + assert captured["updated_by"] == "admin" + + +def test_delete_mcp_connection(monkeypatch): + captured = {} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_delete_mcp_connection(db, connection_id): + del db + captured["connection_id"] = connection_id + return True + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.delete_mcp_connection", fake_delete_mcp_connection) + + client = TestClient(_build_app()) + resp = client.delete("/api/system/mcp-servers/gateway/connections/7") + assert resp.status_code == 200, resp.text + assert captured == {"connection_id": 7} + + +def test_delete_mcp_connection_normal_user_cannot_delete_other_user_connection(monkeypatch): + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name, "enabled": 1, "auth_config_json": _auth_config()})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type( + "DummyConnectionRef", + (), + {"id": connection_id, "server_name": "gateway", "scope_type": "user", "scope_id": "other-user"}, + )() + + async def fake_delete_mcp_connection(db, connection_id): + raise AssertionError("should not delete another user's MCP connection") + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.delete_mcp_connection", fake_delete_mcp_connection) + + client = TestClient(_build_app(allow_admin=False)) + resp = client.delete("/api/system/mcp-servers/gateway/connections/7") + + assert resp.status_code == 404, resp.text + + +def test_test_mcp_connection_route(monkeypatch): + captured = {} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_test_mcp_connection(db, connection_id, *, updated_by=None): + del db + captured["connection_id"] = connection_id + captured["updated_by"] = updated_by + return {"tool_count": 3} + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.test_mcp_connection", fake_test_mcp_connection) + + client = TestClient(_build_app()) + resp = client.post("/api/system/mcp-servers/gateway/connections/7/test") + assert resp.status_code == 200, resp.text + assert resp.json()["tool_count"] == 3 + assert captured == {"connection_id": 7, "updated_by": "admin"} + + +def test_reauthorize_mcp_connection_route(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 7, "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_reauthorize_mcp_connection(db, connection_id, *, updated_by=None): + del db + captured["connection_id"] = connection_id + captured["updated_by"] = updated_by + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.reauthorize_mcp_connection", fake_reauthorize_mcp_connection) + + client = TestClient(_build_app()) + resp = client.post("/api/system/mcp-servers/gateway/connections/7/reauth") + assert resp.status_code == 200, resp.text + assert captured == {"connection_id": 7, "updated_by": "admin"} + + +def test_update_mcp_connection_status_rejects_connection_from_other_server(monkeypatch): + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "other-gateway"})() + + async def fake_set_mcp_connection_status(db, connection_id, **kwargs): + raise AssertionError("should not update a connection that belongs to another server") + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.set_mcp_connection_status", fake_set_mcp_connection_status) + + client = TestClient(_build_app()) + resp = client.put( + "/api/system/mcp-servers/gateway/connections/7/status", + json={"status": "reauth_required"}, + ) + assert resp.status_code == 404, resp.text + + +def test_test_mcp_server_requires_connection_level_test_for_bound_auth(monkeypatch): + class DummyServer: + auth_config_json = { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + } + + async def fake_get_server_or_404(db, name): + del db, name + return DummyServer() + + async def fake_get_all_mcp_tools(server_name, *, auth_context=None, db=None, http_client=None, force_refresh=False): + del server_name, auth_context, db, http_client, force_refresh + raise ValueError("Active MCP connection not found for server 'gateway' and scope department:42") + + monkeypatch.setattr("server.routers.mcp_router.get_server_or_404", fake_get_server_or_404) + monkeypatch.setattr("server.routers.mcp_router.get_all_mcp_tools", fake_get_all_mcp_tools) + + client = TestClient(_build_app()) + resp = client.post("/api/system/mcp-servers/gateway/test", json={}) + assert resp.status_code == 400, resp.text + + +def test_get_mcp_server_tools_uses_current_admin_auth_context(monkeypatch): + captured = {} + + class DummyServer: + disabled_tools = ["tool_b"] + + class DummyArgsSchema: + @staticmethod + def schema(): + return {"properties": {"city": {"type": "string"}}, "required": ["city"]} + + class DummyTool: + name = "tool_a" + description = "tool a" + metadata = {"id": "mcp__gateway__toolA"} + args_schema = DummyArgsSchema() + + async def fake_get_server_or_404(db, name): + del db + assert name == "gateway" + return DummyServer() + + async def fake_get_all_mcp_tools(server_name, *, auth_context=None, db=None, http_client=None, force_refresh=False): + del db, http_client, force_refresh + captured["server_name"] = server_name + captured["user_id"] = auth_context.user_id + captured["department_id"] = auth_context.department_id + return [DummyTool()] + + monkeypatch.setattr("server.routers.mcp_router.get_server_or_404", fake_get_server_or_404) + monkeypatch.setattr("server.routers.mcp_router.get_all_mcp_tools", fake_get_all_mcp_tools) + + client = TestClient(_build_app()) + resp = client.get("/api/system/mcp-servers/gateway/tools") + + assert resp.status_code == 200, resp.text + assert captured == { + "server_name": "gateway", + "user_id": "admin", + "department_id": "42", + } + payload = resp.json() + assert payload["total"] == 1 + assert payload["data"][0]["required"] == ["city"] + assert payload["data"][0]["enabled"] is True + + +def test_get_mcp_server_tools_returns_403_when_bound_connection_missing(monkeypatch): + class DummyServer: + disabled_tools = [] + + async def fake_get_server_or_404(db, name): + del db, name + return DummyServer() + + async def fake_get_all_mcp_tools(server_name, *, auth_context=None, db=None, http_client=None, force_refresh=False): + del server_name, auth_context, db, http_client, force_refresh + raise ValueError("Active MCP connection not found for server 'gateway' and scope department:42") + + monkeypatch.setattr("server.routers.mcp_router.get_server_or_404", fake_get_server_or_404) + monkeypatch.setattr("server.routers.mcp_router.get_all_mcp_tools", fake_get_all_mcp_tools) + + client = TestClient(_build_app()) + resp = client.get("/api/system/mcp-servers/gateway/tools") + + assert resp.status_code == 403, resp.text + + +def test_refresh_mcp_server_tools_returns_403_when_bound_connection_missing(monkeypatch): + async def fake_get_server_or_404(db, name): + del db, name + return type("DummyServer", (), {})() + + async def fake_get_all_mcp_tools(server_name, *, auth_context=None, db=None, http_client=None, force_refresh=False): + del server_name, auth_context, db, http_client, force_refresh + raise ValueError("Active MCP connection not found for server 'gateway' and scope department:42") + + monkeypatch.setattr("server.routers.mcp_router.get_server_or_404", fake_get_server_or_404) + monkeypatch.setattr("server.routers.mcp_router.get_all_mcp_tools", fake_get_all_mcp_tools) + + client = TestClient(_build_app()) + resp = client.post("/api/system/mcp-servers/gateway/tools/refresh") + + assert resp.status_code == 403, resp.text + + +def test_delete_mcp_server_defaults_to_retire(monkeypatch): + captured = {} + + class DummyServer: + created_by = "tester" + + def to_dict(self): + return {"name": "gateway", "enabled": False} + + async def fake_get_mcp_server(db, name): + del db + return DummyServer() + + async def fake_set_server_enabled(db, name, enabled, updated_by=None): + del db + captured["name"] = name + captured["enabled"] = enabled + captured["updated_by"] = updated_by + return False, DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.set_server_enabled", fake_set_server_enabled) + + client = TestClient(_build_app()) + resp = client.delete("/api/system/mcp-servers/gateway") + + assert resp.status_code == 200, resp.text + assert resp.json()["message"] == "服务器 'gateway' 已退役" + assert captured == { + "name": "gateway", + "enabled": False, + "updated_by": "admin", + } + + +def test_delete_mcp_server_hard_delete_returns_conflict(monkeypatch): + class DummyServer: + created_by = "tester" + enabled = 0 + + async def fake_get_mcp_server(db, name): + del db + return DummyServer() + + async def fake_get_dependency_summary(db, name): + del db, name + return { + "has_references": True, + "connections": [{"scope_type": "department", "scope_id": "42", "status": "active"}], + "skills": [{"slug": "finance-skill", "name": "Finance Skill"}], + "agent_configs": [], + } + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server_dependency_summary", fake_get_dependency_summary) + + client = TestClient(_build_app()) + resp = client.delete("/api/system/mcp-servers/gateway?hard=true") + + assert resp.status_code == 409, resp.text + assert resp.json()["detail"]["connections"] == [ + {"scope_type": "department", "scope_id": "42", "status": "active"} + ] diff --git a/backend/test/unit/services/test_chat_service_langfuse_stream.py b/backend/test/unit/services/test_chat_service_langfuse_stream.py index d9bbf83bb..dcc482218 100644 --- a/backend/test/unit/services/test_chat_service_langfuse_stream.py +++ b/backend/test/unit/services/test_chat_service_langfuse_stream.py @@ -98,6 +98,12 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): yield None return + class FakeAgentConfigRepo: + def __init__(self, db): pass + async def get_by_id(self, config_id): + return SimpleNamespace(id=config_id) + monkeypatch.setattr(svc, "AgentConfigRepository", FakeAgentConfigRepo) + monkeypatch.setattr(svc.agent_manager, "get_agent", lambda agent_id: FakeAgent()) monkeypatch.setattr(svc, "get_agent_config_by_id", fake_get_agent_config_by_id) monkeypatch.setattr(svc, "ConversationRepository", _FakeConvRepo) @@ -105,6 +111,13 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): monkeypatch.setattr(svc.content_guard, "check", fake_guard_check) monkeypatch.setattr(svc.content_guard, "check_with_keywords", fake_guard_check_with_keywords) monkeypatch.setattr(svc, "check_and_handle_interrupts", fake_interrupts) + + import contextlib + @contextlib.asynccontextmanager + async def fake_get_async_session_context(): + yield object() + monkeypatch.setattr(svc.pg_manager, "get_async_session_context", fake_get_async_session_context) + monkeypatch.setattr( svc, "_build_langfuse_run_context", @@ -132,12 +145,19 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): thread_id="thread-1", meta={"request_id": "req-1"}, image_content=None, - current_user=SimpleNamespace(id="user-1", department_id="dept-1"), + current_user=SimpleNamespace(id="user-1", user_id="login-user-1", department_id="dept-1"), db=object(), ): chunks.append(json.loads(chunk.decode("utf-8"))) - assert calls["stream_input_context"] == {"temperature": 0.1, "user_id": "user-1", "thread_id": "thread-1"} + assert calls["stream_input_context"] == { + "temperature": 0.1, + "user_id": "user-1", + "work_id": "login-user-1", + "thread_id": "thread-1", + "department_id": "dept-1", + "system_prompt": "工号: login-user-1", + } assert calls["stream_kwargs"] == { "callbacks": ["handler-1"], "metadata": {"langfuse_user_id": "user-1", "langfuse_session_id": "thread-1"}, @@ -188,6 +208,12 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): yield None return + class FakeAgentConfigRepo: + def __init__(self, db): pass + async def get_by_id(self, config_id): + return SimpleNamespace(id=config_id) + monkeypatch.setattr(svc, "AgentConfigRepository", FakeAgentConfigRepo) + monkeypatch.setattr(svc.agent_manager, "get_agent", lambda agent_id: FakeAgent()) monkeypatch.setattr(svc, "get_agent_config_by_id", fake_get_agent_config_by_id) monkeypatch.setattr(svc, "ConversationRepository", _FakeConvRepo) @@ -195,6 +221,13 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): monkeypatch.setattr(svc.content_guard, "check", fake_guard_check) monkeypatch.setattr(svc.content_guard, "check_with_keywords", fake_guard_check_with_keywords) monkeypatch.setattr(svc, "check_and_handle_interrupts", fake_interrupts) + + import contextlib + @contextlib.asynccontextmanager + async def fake_get_async_session_context(): + yield object() + monkeypatch.setattr(svc.pg_manager, "get_async_session_context", fake_get_async_session_context) + monkeypatch.setattr( svc, "_build_langfuse_run_context", @@ -210,7 +243,7 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): thread_id="thread-1", meta={"request_id": "req-1"}, image_content=None, - current_user=SimpleNamespace(id="user-1", department_id="dept-1"), + current_user=SimpleNamespace(id="user-1", user_id="login-user-1", department_id="dept-1"), db=object(), ): chunks.append(json.loads(chunk.decode("utf-8"))) diff --git a/backend/test/unit/services/test_chat_service_sync.py b/backend/test/unit/services/test_chat_service_sync.py index 16a147636..2c4f27781 100644 --- a/backend/test/unit/services/test_chat_service_sync.py +++ b/backend/test/unit/services/test_chat_service_sync.py @@ -142,7 +142,7 @@ def fake_get_trace_info(_run_context): thread_id="thread-1", meta={"request_id": "req-1"}, image_content=None, - current_user=SimpleNamespace(id="user-1", department_id="dept-1"), + current_user=SimpleNamespace(id="user-1", user_id="login-1001", department_id="dept-1"), db=object(), ) @@ -157,7 +157,14 @@ def fake_get_trace_info(_run_context): assert len(invoke_messages) == 1 assert isinstance(invoke_messages[0], HumanMessage) assert invoke_messages[0].content == "hello" - assert calls["invoke_input_context"] == {"temperature": 0.1, "user_id": "user-1", "thread_id": "thread-1"} + assert calls["invoke_input_context"] == { + "system_prompt": "工号: login-1001", + "temperature": 0.1, + "user_id": "user-1", + "work_id": "login-1001", + "thread_id": "thread-1", + "department_id": "dept-1", + } assert calls["invoke_kwargs"] == { "callbacks": ["handler-1"], "metadata": {"langfuse_user_id": "user-1", "langfuse_session_id": "thread-1"}, @@ -225,7 +232,7 @@ async def fake_guard_check(_content): thread_id="thread-2", meta={"request_id": "req-2"}, image_content=None, - current_user=SimpleNamespace(id="user-1", department_id="dept-1"), + current_user=SimpleNamespace(id="user-1", user_id="login-1001", department_id="dept-1"), db=object(), ) @@ -245,17 +252,38 @@ def fake_agents_prompt(_thread_id: str, _user_id: str) -> str: context = await svc._build_agent_input_context( {"system_prompt": "原始系统提示词", "temperature": 0.1}, thread_id="thread-1", - user_id="user-1", + current_user=SimpleNamespace(id="user-1", user_id="login-1001", department_id="dept-9"), ) - assert context["system_prompt"] == "原始系统提示词\n\n用户工作区 agents/AGENTS.md 内容:\n回答前先读取 AGENTS.md" + assert ( + context["system_prompt"] + == "原始系统提示词\n\n用户工作区 agents/AGENTS.md 内容:\n回答前先读取 AGENTS.md\n\n用户信息:\n工号: login-1001" + ) assert context["temperature"] == 0.1 assert context["thread_id"] == "thread-1" assert context["user_id"] == "user-1" + assert context["department_id"] == "dept-9" + + +@pytest.mark.asyncio +async def test_build_agent_input_context_derives_runtime_identity_from_current_user( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(svc, "_load_workspace_agents_prompt", _empty_agents_prompt) + + context = await svc._build_agent_input_context( + {}, + thread_id="thread-1", + current_user=SimpleNamespace(id=2, user_id="login-1001", department_id="dept-9"), + ) + + assert context["user_id"] == "2" + assert context["work_id"] == "login-1001" + assert context["department_id"] == "dept-9" @pytest.mark.asyncio -async def test_build_agent_input_context_keeps_prompt_when_workspace_agents_prompt_empty( +async def test_build_agent_input_context_appends_user_info_when_workspace_agents_prompt_empty( monkeypatch: pytest.MonkeyPatch, ): monkeypatch.setattr(svc, "_load_workspace_agents_prompt", _empty_agents_prompt) @@ -263,7 +291,8 @@ async def test_build_agent_input_context_keeps_prompt_when_workspace_agents_prom context = await svc._build_agent_input_context( {"system_prompt": "原始系统提示词"}, thread_id="thread-1", - user_id="user-1", + current_user=SimpleNamespace(id="user-1", user_id="login-1001", department_id="dept-9"), ) - assert context["system_prompt"] == "原始系统提示词" + assert context["system_prompt"] == "原始系统提示词\n\n用户信息:\n工号: login-1001" + assert context["department_id"] == "dept-9" diff --git a/backend/test/unit/services/test_chat_stream_attachment_materialize.py b/backend/test/unit/services/test_chat_stream_attachment_materialize.py index d8191192d..3dfc969ba 100644 --- a/backend/test/unit/services/test_chat_stream_attachment_materialize.py +++ b/backend/test/unit/services/test_chat_stream_attachment_materialize.py @@ -67,6 +67,7 @@ async def test_materialize_attachment_files_keeps_original_file_when_markdown_co result = await cs._materialize_attachment_files( thread_id="t-1", + user_id="u-1", upload=upload, file_name="demo.pdf", file_content=b"%PDF-test", @@ -103,6 +104,7 @@ async def _fake_convert(_upload): result = await cs._materialize_attachment_files( thread_id="t-1", + user_id="u-1", upload=upload, file_name="demo.txt", file_content=b"hello", diff --git a/backend/test/unit/services/test_mcp_auth_config_models.py b/backend/test/unit/services/test_mcp_auth_config_models.py new file mode 100644 index 000000000..61fd71fd4 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_config_models.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import os + +import pytest +from pydantic import ValidationError + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.config_models import MCPAuthConfig + + +def test_mcp_auth_config_applies_legacy_static_defaults(): + config = MCPAuthConfig.model_validate( + { + "version": 1, + "provider": "legacy_static", + "inject": { + "target": "headers", + "entries": [], + }, + } + ) + + assert config.binding_scope == "inline" + assert config.manifest_scope == "server" + assert config.refresh_policy.pre_refresh_seconds == 0 + assert config.refresh_policy.retry_once_on_401 is False + + +def test_mcp_auth_config_requires_token_request_for_dynamic_http_provider(): + with pytest.raises(ValidationError, match="token_request"): + MCPAuthConfig.model_validate( + { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + } + ) diff --git a/backend/test/unit/services/test_mcp_auth_crypto.py b/backend/test/unit/services/test_mcp_auth_crypto.py new file mode 100644 index 000000000..aa117588b --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_crypto.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import json +import os + +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.crypto import decrypt_credential_blob, encrypt_credential_blob + + +pytestmark = [pytest.mark.unit] + + +def test_encrypt_and_decrypt_credential_blob_round_trip(monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + plaintext = json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}, ensure_ascii=False) + + encrypted = encrypt_credential_blob(plaintext) + decrypted = decrypt_credential_blob(encrypted) + + assert encrypted != plaintext + payload = json.loads(encrypted) + assert payload["v"] == 2 + assert "salt" in payload + assert decrypted == plaintext + + +def test_decrypt_legacy_v1_envelope(monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + plaintext = "super-secret-legacy" + + import hashlib + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + from yuxi.services.mcp_auth.crypto import _b64encode + + key = hashlib.sha256(b"local-test-master-key").digest() + aesgcm = AESGCM(key) + nonce = os.urandom(12) + ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), b"yuxi:mcp_credentials:v1") + + v1_blob = json.dumps({ + "v": 1, + "kid": "local", + "nonce": _b64encode(nonce), + "ciphertext": _b64encode(ciphertext), + }) + + assert decrypt_credential_blob(v1_blob) == plaintext + + +def test_decrypt_credential_blob_keeps_legacy_plaintext_payload(monkeypatch): + monkeypatch.delenv("MCP_CREDENTIALS_MASTER_KEY", raising=False) + plaintext = '{"secrets":{"access_token":"legacy-token"}}' + + assert decrypt_credential_blob(plaintext) == plaintext + + +def test_encrypt_credential_blob_requires_master_key(monkeypatch): + monkeypatch.delenv("MCP_CREDENTIALS_MASTER_KEY", raising=False) + + with pytest.raises(ValueError, match="MCP_CREDENTIALS_MASTER_KEY"): + encrypt_credential_blob('{"secrets":{"access_token":"token"}}') diff --git a/backend/test/unit/services/test_mcp_auth_models.py b/backend/test/unit/services/test_mcp_auth_models.py new file mode 100644 index 000000000..c71285fd7 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_models.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import os + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +@pytest_asyncio.fixture +async def mcp_auth_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +async def test_mcp_server_to_dict_and_mcp_config_include_auth_config(mcp_auth_session): + server = MCPServer( + name="gateway", + description="internal gateway", + transport="streamable_http", + url="http://gateway.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 600, "retry_once_on_401": True}, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + }, + created_by="tester", + updated_by="tester", + ) + mcp_auth_session.add(server) + await mcp_auth_session.commit() + + payload = server.to_dict() + config = server.to_mcp_config() + + assert payload["auth_config"]["provider"] == "custom_http_token" + assert config["auth_config"]["binding_scope"] == "department" + + +async def test_mcp_connection_persists_scoped_binding_and_hides_credentials_by_default(mcp_auth_session): + server = MCPServer( + name="finance-gateway", + description="finance", + transport="streamable_http", + url="http://finance.local/mcp", + created_by="tester", + updated_by="tester", + ) + mcp_auth_session.add(server) + await mcp_auth_session.commit() + + connection = MCPConnection( + server_name="finance-gateway", + scope_type="department", + scope_id="42", + display_name="财务部共享凭据", + external_subject="finance-user", + status="active", + credential_blob="encrypted-secret", + meta_json={"last_success_at": "2026-06-02T10:00:00Z"}, + created_by="tester", + updated_by="tester", + ) + mcp_auth_session.add(connection) + await mcp_auth_session.commit() + + result = await mcp_auth_session.execute(select(MCPConnection).where(MCPConnection.server_name == "finance-gateway")) + saved = result.scalar_one() + + safe_payload = saved.to_dict() + internal_payload = saved.to_dict(include_credentials=True) + + assert safe_payload["scope_type"] == "department" + assert safe_payload["has_credentials"] is True + assert "credential_blob" not in safe_payload + assert internal_payload["credential_blob"] == "encrypted-secret" diff --git a/backend/test/unit/services/test_mcp_auth_orchestrator.py b/backend/test/unit/services/test_mcp_auth_orchestrator.py new file mode 100644 index 000000000..dc268d1ef --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_orchestrator.py @@ -0,0 +1,702 @@ +from __future__ import annotations + +import json +import os +from datetime import UTC, datetime, timedelta + +import httpx +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.crypto import encrypt_credential_blob +from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +class DummyTokenCache: + def __init__(self, token_payload: dict | None = None): + self.token_payload = token_payload + self.token_payloads = None + self.set_calls: list[tuple[int, dict]] = [] + self.deleted_connection_ids: list[int] = [] + self.acquire_calls: list[int] = [] + self.release_calls: list[int] = [] + self.acquire_result = True + + async def get_access_token(self, connection_id: int) -> dict | None: + del connection_id + if self.token_payloads is not None: + if self.token_payloads: + self.token_payload = self.token_payloads.pop(0) + else: + self.token_payload = None + return self.token_payload + + async def set_access_token(self, connection_id: int, token_payload: dict) -> None: + self.set_calls.append((connection_id, token_payload)) + self.token_payload = token_payload + + async def delete_access_token(self, connection_id: int) -> None: + self.deleted_connection_ids.append(connection_id) + self.token_payload = None + + async def acquire_refresh_lock(self, connection_id: int, *, ttl_seconds: int = 30) -> bool: + del ttl_seconds + self.acquire_calls.append(connection_id) + return self.acquire_result + + async def release_refresh_lock(self, connection_id: int) -> None: + self.release_calls.append(connection_id) + + +async def test_resolve_runtime_mcp_config_injects_bound_secret_header(): + os.environ["MCP_CREDENTIALS_MASTER_KEY"] = "local-test-master-key" + server = MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + server_name="finance-gateway", + scope_type="department", + scope_id="42", + credential_blob=encrypt_credential_blob(json.dumps({"secrets": {"access_token": "dept-token"}})), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="42"), + connection=connection, + ) + + assert resolved["transport"] == "streamable_http" + assert resolved["headers"] == { + "X-App": "yuxi", + "Authorization": "Bearer dept-token", + } + assert "auth_config" not in resolved + + +async def test_resolve_runtime_mcp_config_supports_raw_token_string_binding(): + os.environ["MCP_CREDENTIALS_MASTER_KEY"] = "local-test-master-key" + server = MCPServer( + name="raw-token-gateway", + transport="streamable_http", + url="http://raw.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "system", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + server_name="raw-token-gateway", + scope_type="system", + scope_id="global", + credential_blob=encrypt_credential_blob("raw-token-value"), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(), + connection=connection, + ) + + assert resolved["headers"] == {"Authorization": "Bearer raw-token-value"} + + +async def test_resolve_runtime_mcp_config_fetches_custom_http_token_with_user_context(): + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["headers"] = dict(request.headers) + captured["body"] = json.loads(request.content.decode("utf-8")) + return httpx.Response( + 200, + json={ + "data": { + "access_token": "fresh-token", + "refresh_token": "refresh-token", + "expires_in": 3600, + } + }, + ) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="corp-gateway", + transport="streamable_http", + url="http://corp.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 600, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "headers": { + "Content-Type": "application/json", + "X-Client-Id": "${secret.client_id}", + }, + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + "user_id": "${context.user_id}", + "department_id": "${context.department_id}", + }, + "response_map": { + "access_token": "data.access_token", + "refresh_token": "data.refresh_token", + "expires_in": "data.expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + server_name="corp-gateway", + scope_type="department", + scope_id="finance", + credential_blob=json.dumps({"secrets": {"client_id": "cid-1", "client_secret": "secret-1"}}), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="user-9", department_id="finance"), + connection=connection, + http_client=http_client, + token_cache=DummyTokenCache(), + ) + + await http_client.aclose() + + assert captured["url"] == "http://gateway.local/auth/token" + assert captured["body"] == { + "client_id": "cid-1", + "client_secret": "secret-1", + "user_id": "user-9", + "department_id": "finance", + } + assert resolved["headers"] == { + "X-App": "yuxi", + "Authorization": "Bearer fresh-token", + } + + +async def test_resolve_runtime_mcp_config_fetches_client_credentials_token(): + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["body"] = json.loads(request.content.decode("utf-8")) + return httpx.Response( + 200, + json={ + "access_token": "client-token", + "expires_in": 1800, + "token_type": "Bearer", + }, + ) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="client-credentials-mcp", + transport="streamable_http", + url="http://client.local/mcp", + auth_config_json={ + "version": 1, + "provider": "client_credentials", + "binding_scope": "system", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": { + "url": "http://gateway.local/oauth/token", + "method": "POST", + "body_type": "json", + "body_template": { + "grant_type": "client_credentials", + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + "response_map": { + "access_token": "access_token", + "expires_in": "expires_in", + "token_type": "token_type", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=11, + server_name="client-credentials-mcp", + scope_type="system", + scope_id="global", + credential_blob=json.dumps({"secrets": {"client_id": "cid-cc", "client_secret": "secret-cc"}}), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="d-1"), + connection=connection, + http_client=http_client, + token_cache=DummyTokenCache(), + ) + + await http_client.aclose() + + assert captured["url"] == "http://gateway.local/oauth/token" + assert captured["body"] == { + "grant_type": "client_credentials", + "client_id": "cid-cc", + "client_secret": "secret-cc", + } + assert resolved["headers"] == {"Authorization": "Bearer client-token"} + + +async def test_resolve_runtime_mcp_config_injects_stdio_env_from_secret_binding(): + server = MCPServer( + name="stdio-auth-mcp", + transport="stdio", + command="demo-server", + env={"LOG_LEVEL": "info"}, + auth_config_json={ + "version": 1, + "provider": "stdio_env", + "binding_scope": "user", + "manifest_scope": "binding", + "inject": { + "target": "env", + "entries": [ + {"name": "API_TOKEN", "value_template": "${secret.access_token}"}, + {"name": "YUXI_USER_ID", "value_template": "${context.user_id}"}, + ], + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + server_name="stdio-auth-mcp", + scope_type="user", + scope_id="user-1", + credential_blob=json.dumps({"secrets": {"access_token": "stdio-token"}}), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + connection=connection, + ) + + assert resolved["command"] == "demo-server" + assert resolved["env"] == { + "LOG_LEVEL": "info", + "API_TOKEN": "stdio-token", + "YUXI_USER_ID": "user-1", + } + + +async def test_resolve_runtime_mcp_config_uses_cached_custom_http_token_before_fetching(): + def handler(request: httpx.Request) -> httpx.Response: + raise AssertionError(f"unexpected token request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="corp-cache-mcp", + transport="streamable_http", + url="http://corp-cache.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 120, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "response_map": { + "access_token": "access_token", + "expires_in": "expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=21, + server_name="corp-cache-mcp", + scope_type="department", + scope_id="finance", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache( + { + "access_token": "cached-token", + "expires_at": (datetime.now(tz=UTC) + timedelta(minutes=10)).isoformat(), + "token_type": "Bearer", + } + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="finance"), + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + + await http_client.aclose() + + assert resolved["headers"] == {"Authorization": "Bearer cached-token"} + assert token_cache.set_calls == [] + + +async def test_resolve_runtime_mcp_config_refreshes_cached_token_when_expiring_soon(): + captured: list[tuple[str, dict]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + payload = json.loads(request.content.decode("utf-8")) + captured.append((str(request.url), payload)) + return httpx.Response( + 200, + json={ + "access_token": "refreshed-token", + "refresh_token": "refresh-next", + "expires_in": 3600, + "token_type": "Bearer", + }, + ) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="corp-refresh-mcp", + transport="streamable_http", + url="http://corp-refresh.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 300, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + "response_map": { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + "token_type": "token_type", + }, + "refresh": { + "url": "http://gateway.local/auth/refresh", + "method": "POST", + "body_type": "json", + "body_template": { + "refresh_token": "${token.refresh_token}", + }, + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=22, + server_name="corp-refresh-mcp", + scope_type="department", + scope_id="finance", + credential_blob=json.dumps( + { + "secrets": {"client_id": "cid", "client_secret": "secret"}, + "refresh_token": "refresh-old", + } + ), + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache( + { + "access_token": "stale-token", + "refresh_token": "refresh-old", + "expires_at": (datetime.now(tz=UTC) + timedelta(seconds=60)).isoformat(), + } + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="finance"), + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + + await http_client.aclose() + + assert captured == [ + ( + "http://gateway.local/auth/refresh", + { + "refresh_token": "refresh-old", + }, + ) + ] + assert resolved["headers"] == {"Authorization": "Bearer refreshed-token"} + assert token_cache.set_calls and token_cache.set_calls[0][0] == 22 + assert token_cache.set_calls[0][1]["access_token"] == "refreshed-token" + assert token_cache.acquire_calls == [22] + assert token_cache.release_calls == [22] + + +async def test_resolve_runtime_mcp_config_waits_for_refresh_lock_owner_to_publish_token(): + def handler(request: httpx.Request) -> httpx.Response: + raise AssertionError(f"unexpected token request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="corp-lock-mcp", + transport="streamable_http", + url="http://corp-lock.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 300, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "response_map": { + "access_token": "access_token", + "expires_in": "expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=24, + server_name="corp-lock-mcp", + scope_type="department", + scope_id="finance", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache() + token_cache.acquire_result = False + token_cache.token_payloads = [ + { + "access_token": "stale-token", + "expires_at": (datetime.now(tz=UTC) + timedelta(seconds=10)).isoformat(), + }, + { + "access_token": "fresh-from-other-worker", + "expires_at": (datetime.now(tz=UTC) + timedelta(minutes=30)).isoformat(), + }, + ] + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="finance"), + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + + await http_client.aclose() + + assert resolved["headers"] == {"Authorization": "Bearer fresh-from-other-worker"} + assert token_cache.acquire_calls == [24] + assert token_cache.release_calls == [] + assert token_cache.set_calls == [] + + +async def test_resolve_runtime_mcp_config_refreshes_authorization_code_token(): + captured: list[tuple[str, str]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append((request.method, str(request.url))) + if str(request.url) == "https://id.example.com/.well-known/openid-configuration": + return httpx.Response( + 200, + json={ + "token_endpoint": "https://id.example.com/oauth/token", + }, + ) + if str(request.url) == "https://id.example.com/oauth/token": + body_text = request.content.decode("utf-8") + assert "grant_type=refresh_token" in body_text + assert "refresh_token=refresh-old" in body_text + assert "client_id=oidc-client" in body_text + assert "client_secret=oidc-secret" in body_text + return httpx.Response( + 200, + json={ + "access_token": "oidc-access-token", + "refresh_token": "refresh-next", + "expires_in": 3600, + "token_type": "Bearer", + }, + ) + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="oidc-mcp", + transport="streamable_http", + url="http://oidc.local/mcp", + auth_config_json={ + "version": 1, + "provider": "authorization_code", + "binding_scope": "user", + "manifest_scope": "binding", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 120, "retry_once_on_401": True}, + "token_request": { + "issuer_url": "https://id.example.com", + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=23, + server_name="oidc-mcp", + scope_type="user", + scope_id="user-1", + credential_blob=json.dumps( + { + "secrets": { + "client_id": "oidc-client", + "client_secret": "oidc-secret", + }, + "refresh_token": "refresh-old", + } + ), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + connection=connection, + http_client=http_client, + token_cache=DummyTokenCache(), + ) + + await http_client.aclose() + + assert captured == [ + ("GET", "https://id.example.com/.well-known/openid-configuration"), + ("POST", "https://id.example.com/oauth/token"), + ] + assert resolved["headers"] == {"Authorization": "Bearer oidc-access-token"} + + +async def test_normalize_token_payload_naive_datetime(): + """测试 _normalize_token_payload 对 naive datetime 默认填充 UTC 时区""" + from yuxi.services.mcp_auth.orchestrator import _normalize_token_payload + from datetime import datetime, UTC + + # 构造 naive datetime (无 tzinfo) + naive_dt = datetime(2026, 6, 5, 12, 0, 0) + payload = {"expires_at": naive_dt} + + normalized = _normalize_token_payload(payload) + # 期望转换后有时区,并且值为 2026-06-05T12:00:00+00:00 (ISO格式) + expected_iso = datetime(2026, 6, 5, 12, 0, 0, tzinfo=UTC).isoformat() + assert normalized["expires_at"] == expected_iso + + +async def test_normalize_token_payload_aware_datetime(): + """测试 _normalize_token_payload 对于带时区的 datetime 维持原时区对应 UTC 时间""" + from yuxi.services.mcp_auth.orchestrator import _normalize_token_payload + from datetime import datetime, timezone, timedelta + + # 构造带时区的 datetime (比如东八区) + shanghai_tz = timezone(timedelta(hours=8)) + aware_dt = datetime(2026, 6, 5, 20, 0, 0, tzinfo=shanghai_tz) + payload = {"expires_at": aware_dt} + + normalized = _normalize_token_payload(payload) + # 转换为 UTC 后应该为 2026-06-05T12:00:00+00:00 + expected_iso = datetime(2026, 6, 5, 12, 0, 0, tzinfo=timezone.utc).isoformat() + assert normalized["expires_at"] == expected_iso diff --git a/backend/test/unit/services/test_mcp_auth_proxy_service.py b/backend/test/unit/services/test_mcp_auth_proxy_service.py new file mode 100644 index 000000000..b35c16a44 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_proxy_service.py @@ -0,0 +1,413 @@ +from __future__ import annotations + + +def make_mock_response(status_code, content): + import httpx + resp = httpx.Response(status_code, content=content) + async def fake_aiter_raw(): + yield content + resp.aiter_raw = fake_aiter_raw + return resp + +import json +import os +from datetime import UTC, datetime, timedelta + +import httpx +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.services.mcp_auth.proxy_service import _proxy_mcp_request_stream +from starlette.requests import Request + + +def make_mock_response(status_code, content): + import httpx + resp = httpx.Response(status_code, content=content) + async def fake_aiter_raw(): + yield content + resp.aiter_raw = fake_aiter_raw + return resp + +import json +async def get_response_json(response): + if hasattr(response, "body_iterator"): + body = b"".join([chunk async for chunk in response.body_iterator]) + else: + body = response.body + return json.loads(body) + +from fastapi import Response +from fastapi.responses import StreamingResponse +from yuxi.services.mcp import server_service +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer +from yuxi.services.mcp_auth import proxy_service +from yuxi.services.mcp_auth.proxy_service import create_proxy_access_token, handle_mcp_proxy_request + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +class DummyTokenCache: + def __init__(self, token_payload: dict | None = None): + self.token_payload = token_payload + self.deleted_connection_ids: list[int] = [] + self.set_calls: list[tuple[int, dict]] = [] + + async def get_access_token(self, connection_id: int) -> dict | None: + del connection_id + return self.token_payload + + async def delete_access_token(self, connection_id: int) -> None: + self.deleted_connection_ids.append(connection_id) + self.token_payload = None + + async def set_access_token(self, connection_id: int, token_payload: dict) -> None: + self.set_calls.append((connection_id, token_payload)) + self.token_payload = token_payload + + +async def test_proxy_mcp_request_retries_once_after_401_with_refreshed_token(): + observed_authorizations: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == "http://gateway.local/auth/token": + return httpx.Response( + 200, + json={ + "access_token": "fresh-token", + "refresh_token": "refresh-next", + "expires_in": 3600, + }, + ) + + if str(request.url) == "http://upstream.local/mcp": + observed_authorizations.append(request.headers.get("Authorization")) + if request.headers.get("Authorization") == "Bearer stale-token": + return make_mock_response(401, b'{"error": "expired"}') + if request.headers.get("Authorization") == "Bearer fresh-token": + resp = make_mock_response(200, b'{"result": "ok"}') + resp.is_stream_consumed = False + return resp + + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="proxy-retry", + transport="streamable_http", + url="http://upstream.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 60, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + "response_map": { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=41, + server_name="proxy-retry", + scope_type="department", + scope_id="dep-1", + status="active", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + meta_json={}, + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache( + { + "access_token": "stale-token", + "refresh_token": "refresh-old", + "expires_at": (datetime.now(tz=UTC) + timedelta(minutes=30)).isoformat(), + } + ) + + req = Request({"type": "http", "method": "POST", "headers": [(b"content-type", b"application/json")], "query_string": b""}) + + class DummyDB: + async def commit(self): pass + + response = await _proxy_mcp_request_stream( + server, + connection=connection, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + request=req, + body=b'{"jsonrpc":"2.0","id":1}', + db=DummyDB(), + _http_client=http_client, + _token_cache=token_cache, + ) + + await http_client.aclose() + + assert response.status_code == 200 + assert await get_response_json(response) == {"result": "ok"} + assert observed_authorizations == ["Bearer stale-token", "Bearer fresh-token"] + assert token_cache.deleted_connection_ids == [41] + assert token_cache.set_calls and token_cache.set_calls[0][0] == 41 + assert connection.status == "active" + + +async def test_proxy_mcp_request_marks_reauth_required_after_final_401(): + attempts = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal attempts + if str(request.url) == "http://gateway.local/auth/token": + return httpx.Response( + 200, + json={ + "access_token": f"fresh-token-{attempts}", + "refresh_token": "refresh-next", + "expires_in": 3600, + }, + ) + if str(request.url) == "http://upstream.local/mcp": + attempts += 1 + return make_mock_response(401, b'{"error": "expired"}') + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="proxy-fail-401", + transport="streamable_http", + url="http://upstream.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 60, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + "response_map": { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=42, + server_name="proxy-fail-401", + scope_type="department", + scope_id="dep-1", + status="active", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + meta_json={}, + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache( + { + "access_token": "stale-token", + "refresh_token": "refresh-old", + "expires_at": (datetime.now(tz=UTC) + timedelta(minutes=30)).isoformat(), + } + ) + + req = Request({"type": "http", "method": "POST", "headers": [(b"content-type", b"application/json")], "query_string": b""}) + + class DummyDB: + async def commit(self): pass + + response = await _proxy_mcp_request_stream( + server, + connection=connection, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + request=req, + body=b'{"jsonrpc":"2.0","id":1}', + db=DummyDB(), + _http_client=http_client, + _token_cache=token_cache, + ) + + await http_client.aclose() + + assert response.status_code == 424 + assert (await get_response_json(response))["error"] == "reauth_required" + assert connection.status == "reauth_required" + assert connection.meta_json["last_error"]["code"] == "unauthorized" + + +async def test_proxy_mcp_request_records_scope_error_on_403(): + def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == "http://upstream.local/mcp": + return make_mock_response(403, b'{"error": "forbidden"}') + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="proxy-403", + transport="streamable_http", + url="http://upstream.local/mcp", + headers={"Authorization": "Bearer static-token"}, + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=43, + server_name="proxy-403", + scope_type="department", + scope_id="dep-1", + status="active", + credential_blob=json.dumps({"secrets": {"access_token": "static-token"}}), + meta_json={}, + created_by="tester", + updated_by="tester", + ) + + req = Request({"type": "http", "method": "POST", "headers": [(b"content-type", b"application/json")], "query_string": b""}) + class DummyDB: + async def commit(self): pass + response = await _proxy_mcp_request_stream( + server, + connection=connection, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + request=req, + body=b'{"jsonrpc":"2.0","id":1}', + db=DummyDB(), + _http_client=http_client, + _token_cache=None, + ) + + await http_client.aclose() + + assert response.status_code == 403 + assert (await get_response_json(response))["error"] == "insufficient_scope" + assert connection.status == "active" + assert connection.meta_json["last_error"]["code"] == "insufficient_scope" + + +async def test_handle_mcp_proxy_request_allows_no_secret_dynamic_config_without_connection(monkeypatch): + monkeypatch.setenv("JWT_SECRET_KEY", "unit-test-jwt-secret-with-at-least-32-bytes") + monkeypatch.setenv("YUXI_INSTANCE_ID", "unit-test-instance") + + server = MCPServer( + name="proxy-no-secret", + transport="streamable_http", + url="http://upstream.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "user", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "body_template": {"work_id": "${context.work_id}"}, + "response_map": {"access_token": "access_token", "expires_in": "expires_in"}, + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + + async def fake_get_mcp_server(db, server_name): + del db + assert server_name == "proxy-no-secret" + return server + + observed = {} + + async def fake_proxy_mcp_request_stream(**kwargs): + assert kwargs["server"] is server + observed["connection"] = kwargs["connection"] + observed["auth_context"] = kwargs["auth_context"] + observed["body"] = kwargs["body"] + return Response(status_code=204) + + class EmptyResult: + def scalar_one_or_none(self): + return None + + class DummyDB: + async def execute(self, stmt): + del stmt + return EmptyResult() + + monkeypatch.setattr(server_service, "get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr(proxy_service, "_proxy_mcp_request_stream", fake_proxy_mcp_request_stream) + + token = create_proxy_access_token( + "proxy-no-secret", + AuthContext(user_id="user-1", department_id="dep-1", work_id="W001"), + ) + async def receive(): + return {"type": "http.request", "body": b'{"jsonrpc":"2.0","id":1}', "more_body": False} + + req = Request( + { + "type": "http", + "method": "POST", + "headers": [(b"content-type", b"application/json")], + "query_string": b"", + }, + receive, + ) + + response = await handle_mcp_proxy_request( + "proxy-no-secret", + request=req, + path="", + internal_token=token, + db=DummyDB(), + ) + + assert response.status_code == 204 + assert observed["connection"] is None + assert observed["auth_context"].work_id == "W001" + assert observed["body"] == b'{"jsonrpc":"2.0","id":1}' diff --git a/backend/test/unit/services/test_mcp_auth_runtime.py b/backend/test/unit/services/test_mcp_auth_runtime.py new file mode 100644 index 000000000..94f68a8bf --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_runtime.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +import json +import os + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp import connection_service, server_service, tool_registry_service +from yuxi.services.mcp.client_pool import mcp_client_pool +from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +@pytest_asyncio.fixture +async def runtime_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +async def test_get_runtime_mcp_server_config_resolves_department_connection(runtime_session): + server = MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + runtime_session.add(server) + runtime_session.add( + MCPConnection( + server_name="finance-gateway", + scope_type="department", + scope_id="42", + status="active", + credential_blob=json.dumps({"secrets": {"access_token": "dept-token"}}), + created_by="tester", + updated_by="tester", + ) + ) + await runtime_session.commit() + + config = await server_service.get_runtime_mcp_server_config( + "finance-gateway", + auth_context=AuthContext(user_id="u-1", department_id="42"), + db=runtime_session, + ) + + assert config is not None + assert config["headers"]["Authorization"] == "Bearer dept-token" + + +async def test_get_enabled_mcp_tools_does_not_reuse_user_connection_for_other_user(runtime_session, monkeypatch): + server = MCPServer( + name="personal-gateway", + transport="streamable_http", + url="http://personal.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + runtime_session.add(server) + runtime_session.add( + MCPConnection( + server_name="personal-gateway", + scope_type="user", + scope_id="user-1", + status="active", + credential_blob=json.dumps({"secrets": {"access_token": "user-1-token"}}), + created_by="tester", + updated_by="tester", + ) + ) + await runtime_session.commit() + + captured_configs: list[dict] = [] + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled_tools=None, **kwargs): + del disabled_tools, kwargs + assert server_name == "personal-gateway" + captured_configs.append(additional_servers[server_name]) + return ["private-tool"] + + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + user_1_tools = await tool_registry_service.get_enabled_mcp_tools( + "personal-gateway", + auth_context=AuthContext(user_id="user-1"), + db=runtime_session, + ) + + with pytest.raises(ValueError, match="Active MCP connection not found"): + await tool_registry_service.get_enabled_mcp_tools( + "personal-gateway", + auth_context=AuthContext(user_id="user-2"), + db=runtime_session, + ) + + assert user_1_tools == ["private-tool"] + assert len(captured_configs) == 1 + assert captured_configs[0]["headers"]["Authorization"] == "Bearer user-1-token" + + +async def test_get_enabled_mcp_tools_uses_runtime_mcp_config(monkeypatch): + captured: list[dict] = [] + + async def fake_get_runtime_mcp_server_config(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + assert server_name == "demo" + assert auth_context is not None + return { + "transport": "stdio", + "command": "demo-with-auth", + "disabled_tools": ["tool_b"], + } + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled_tools=None, **kwargs): + del kwargs + captured.append( + { + "server_name": server_name, + "additional_servers": additional_servers, + "disabled_tools": list(disabled_tools or []), + } + ) + return ["tool-a"] + + monkeypatch.setattr(server_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await tool_registry_service.get_enabled_mcp_tools( + "demo", + auth_context=AuthContext(user_id="u-100", department_id="d-9"), + ) + + assert tools == ["tool-a"] + assert captured == [ + { + "server_name": "demo", + "additional_servers": { + "demo": {"transport": "stdio", "command": "demo-with-auth", "disabled_tools": ["tool_b"]} + }, + "disabled_tools": ["tool_b"], + } + ] + + +async def test_get_all_mcp_tools_uses_runtime_mcp_config_when_auth_context_is_provided(monkeypatch): + captured: list[dict] = [] + + async def fake_get_runtime_mcp_server_config(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + assert server_name == "demo" + assert auth_context is not None + return { + "transport": "stdio", + "command": "demo-with-auth", + "disabled_tools": ["tool_b"], + } + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled_tools=None, **kwargs): + del kwargs + captured.append( + { + "server_name": server_name, + "additional_servers": additional_servers, + "disabled_tools": list(disabled_tools or []), + } + ) + return ["tool-a", "tool-b"] + + monkeypatch.setattr(server_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await tool_registry_service.get_all_mcp_tools( + "demo", + auth_context=AuthContext(user_id="u-100", department_id="d-9"), + ) + + assert tools == ["tool-a", "tool-b"] + assert captured == [ + { + "server_name": "demo", + "additional_servers": { + "demo": {"transport": "stdio", "command": "demo-with-auth", "disabled_tools": ["tool_b"]} + }, + "disabled_tools": [], + } + ] + + +async def test_get_runtime_mcp_server_config_returns_internal_proxy_for_dynamic_http_provider( + runtime_session, monkeypatch +): + monkeypatch.setenv("YUXI_INTERNAL_MCP_PROXY_BASE_URL", "http://internal-api:5050") + + server = MCPServer( + name="finance-proxy", + transport="streamable_http", + url="http://finance.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "response_map": { + "access_token": "access_token", + "expires_in": "expires_in", + }, + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + runtime_session.add(server) + runtime_session.add( + MCPConnection( + id=31, + server_name="finance-proxy", + scope_type="department", + scope_id="dep-88", + status="active", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + created_by="tester", + updated_by="tester", + ) + ) + await runtime_session.commit() + + config = await server_service.get_runtime_mcp_server_config( + "finance-proxy", + auth_context=AuthContext(user_id="user-1", department_id="dep-88"), + db=runtime_session, + ) + + assert config is not None + assert config["url"] == "http://internal-api:5050/api/internal/mcp-proxy/finance-proxy" + assert config["headers"]["X-App"] == "yuxi" + assert "X-Yuxi-MCP-Proxy-Token" in config["headers"] + assert "Authorization" not in config["headers"] + assert config["__yuxi_cache_partition"] == "connection:31" + assert config["__yuxi_allow_global_cache"] is False + assert config["__yuxi_disable_tool_object_cache"] is True + + +async def test_update_mcp_server_auth_config_clears_runtime_auth_cache(runtime_session, monkeypatch): + server = MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "system", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + runtime_session.add(server) + await runtime_session.commit() + + calls = {"runtime_auth_cache": 0, "tools_cache": 0} + + async def fake_clear_runtime_auth_cache(db, server_name): + assert db is runtime_session + assert server_name == "finance-gateway" + calls["runtime_auth_cache"] += 1 + + async def fake_invalidate_tools_cache(server_name): + assert server_name == "finance-gateway" + calls["tools_cache"] += 1 + + monkeypatch.setattr(tool_registry_service, "_clear_mcp_server_runtime_auth_cache", fake_clear_runtime_auth_cache) + monkeypatch.setattr(tool_registry_service, "invalidate_mcp_server_tools_cache", fake_invalidate_tools_cache) + + await server_service.update_mcp_server( + runtime_session, + "finance-gateway", + auth_config={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "system", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + }, + updated_by="tester", + ) + + assert calls == {"runtime_auth_cache": 1, "tools_cache": 1} diff --git a/backend/test/unit/services/test_mcp_auth_template_resolver.py b/backend/test/unit/services/test_mcp_auth_template_resolver.py new file mode 100644 index 000000000..9a14d75a6 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_template_resolver.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import os + +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.template_resolver import TemplateResolutionError, resolve_template_value + + +def test_resolve_template_value_supports_nested_structures(): + resolved = resolve_template_value( + { + "headers": { + "Authorization": "Bearer ${access_token}", + "X-User-Id": "${context.user_id}", + }, + "body": { + "client_id": "${secret.client_id}", + "tenant": "${secret.extra.tenant_code}", + "department_id": "${context.department_id}", + }, + "args": ["--user=${context.user_id}", "--refresh=${token.refresh_token}"], + }, + context={"user_id": "u-100", "department_id": "d-9"}, + secret={"client_id": "cid-1", "extra": {"tenant_code": "finance"}}, + token={"refresh_token": "refresh-1"}, + access_token="access-1", + ) + + assert resolved == { + "headers": { + "Authorization": "Bearer access-1", + "X-User-Id": "u-100", + }, + "body": { + "client_id": "cid-1", + "tenant": "finance", + "department_id": "d-9", + }, + "args": ["--user=u-100", "--refresh=refresh-1"], + } + + +def test_resolve_template_value_raises_for_unknown_placeholder(): + with pytest.raises(TemplateResolutionError, match="context.missing"): + resolve_template_value( + {"user": "${context.missing}"}, + context={"user_id": "u-100"}, + secret={}, + token={}, + access_token="access-1", + ) diff --git a/backend/test/unit/services/test_mcp_cache_policy.py b/backend/test/unit/services/test_mcp_cache_policy.py new file mode 100644 index 000000000..3973fe88b --- /dev/null +++ b/backend/test/unit/services/test_mcp_cache_policy.py @@ -0,0 +1,91 @@ +from __future__ import annotations +import pytest +from unittest.mock import MagicMock + +from yuxi.services.mcp.cache_policy import ( + CachePolicyFactory, + StaticCachePolicy, + TokenInjectedCachePolicy, + DynamicProxyCachePolicy, +) +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.storage.postgres.models_business import MCPConnection + + +def test_cache_policy_factory(): + """测试 CachePolicyFactory 的策略派发逻辑""" + # 静态/无鉴权 + assert isinstance(CachePolicyFactory.get_policy(None), StaticCachePolicy) + assert isinstance(CachePolicyFactory.get_policy("legacy_static"), StaticCachePolicy) + + # 注入型 + assert isinstance(CachePolicyFactory.get_policy("bound_secret"), TokenInjectedCachePolicy) + assert isinstance(CachePolicyFactory.get_policy("stdio_env"), TokenInjectedCachePolicy) + + # 动态 Token 型 + assert isinstance(CachePolicyFactory.get_policy("custom_http_token"), DynamicProxyCachePolicy) + assert isinstance(CachePolicyFactory.get_policy("client_credentials"), DynamicProxyCachePolicy) + assert isinstance(CachePolicyFactory.get_policy("authorization_code"), DynamicProxyCachePolicy) + + +def test_static_cache_policy(): + """测试 StaticCachePolicy""" + policy = StaticCachePolicy() + assert policy.should_cache_tool_object() is True + + auth_context = AuthContext() + partition, is_shared = policy.resolve_cache_partition(auth_context, None) + assert partition == "global" + assert is_shared is True + + +def test_token_injected_cache_policy(): + """测试 TokenInjectedCachePolicy""" + policy = TokenInjectedCachePolicy() + assert policy.should_cache_tool_object() is True + + auth_context = AuthContext(user_id="user_1", department_id="dept_A") + + # connection 为 None 退避 + partition, is_shared = policy.resolve_cache_partition(auth_context, None) + assert partition == "global" + assert is_shared is True + + # 系统连接,共享 + conn_sys = MCPConnection(id=10, scope_type="system", scope_id="global") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_sys) + assert partition == "connection:10" + assert is_shared is True + + # 部门连接,独占 + conn_dept = MCPConnection(id=11, scope_type="department", scope_id="dept_A") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_dept) + assert partition == "connection:11" + assert is_shared is False + + # 个人连接,独占 + conn_user = MCPConnection(id=12, scope_type="user", scope_id="user_1") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_user) + assert partition == "connection:12" + assert is_shared is False + + +def test_dynamic_proxy_cache_policy(): + """测试 DynamicProxyCachePolicy""" + policy = DynamicProxyCachePolicy() + # 动态鉴权必须禁止把带临时 Token 的 Tool 实例缓存在共享内存中 + assert policy.should_cache_tool_object() is False + + auth_context = AuthContext(user_id="user_1", department_id="dept_A") + + # 部门隔离连接,独占 + conn_dept = MCPConnection(id=20, scope_type="department", scope_id="dept_A") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_dept) + assert partition == "connection:20" + assert is_shared is False + + # 个人隔离连接,独占 + conn_user = MCPConnection(id=21, scope_type="user", scope_id="user_1") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_user) + assert partition == "connection:21" + assert is_shared is False diff --git a/backend/test/unit/services/test_mcp_client_pool.py b/backend/test/unit/services/test_mcp_client_pool.py new file mode 100644 index 000000000..1d085f0a4 --- /dev/null +++ b/backend/test/unit/services/test_mcp_client_pool.py @@ -0,0 +1,218 @@ +from __future__ import annotations +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from yuxi.services.mcp.client_pool import MCPClientPool, LongLivedSession + + +@pytest.mark.asyncio +async def test_long_lived_session_lifecycle(): + """测试 LongLivedSession 正常的启动与停止流程""" + mock_client = MagicMock() + mock_session = MagicMock() + + # 模拟 client.session() 返回一个 AsyncContextManager + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_client.session.return_value = mock_context + + ll_session = LongLivedSession(mock_client, "test_server") + + # 启动 + await ll_session.start() + assert ll_session._running is True + assert ll_session.session == mock_session + mock_client.session.assert_called_once_with("test_server") + + # 停止 + await ll_session.stop() + assert ll_session._running is False + assert ll_session.session is None + + +@pytest.mark.asyncio +async def test_client_pool_reuse_and_recreate(): + """测试 MCPClientPool 的复用逻辑与配置脏变重构逻辑""" + pool = MCPClientPool() + + config_1 = { + "transport": "stdio", + "command": "node", + "args": ["file1.js"], + "__yuxi_cache_partition": "p1", + } + + config_2 = { + "transport": "stdio", + "command": "node", + "args": ["file1.js"], + "__yuxi_cache_partition": "p1", + } + + config_changed = { + "transport": "stdio", + "command": "node", + "args": ["file2.js"], # 配置发生改变 + "__yuxi_cache_partition": "p1", + } + + mock_client_instance = MagicMock() + mock_session_instance = MagicMock() + + # Mock LongLivedSession 的 start/stop 以防真实建连 + with patch("yuxi.services.mcp.client_pool.MultiServerMCPClient", return_value=mock_client_instance), \ + patch("yuxi.services.mcp.client_pool.LongLivedSession") as MockLongLivedSession: + + mock_ll_instance = MagicMock() + mock_ll_instance.session = mock_session_instance + mock_ll_instance.start = AsyncMock() + mock_ll_instance.stop = AsyncMock() + MockLongLivedSession.return_value = mock_ll_instance + + # 1. 首次获取,创建新 Session + session_1 = await pool.get_session("test_server", "p1", config_1) + assert session_1 == mock_session_instance + assert MockLongLivedSession.call_count == 1 + mock_ll_instance.start.assert_called_once() + + # 2. 相同配置获取,直接复用 + session_2 = await pool.get_session("test_server", "p1", config_2) + assert session_2 == mock_session_instance + assert MockLongLivedSession.call_count == 1 # 没增加,说明复用了 + + # 3. 配置改变获取,销毁旧的,重新创建 + session_changed = await pool.get_session("test_server", "p1", config_changed) + assert session_changed == mock_session_instance + # 销毁被调用了 + mock_ll_instance.stop.assert_called_once() + # 创建计数增加 + assert MockLongLivedSession.call_count == 2 + + # 4. shutdown 清理 + await pool.shutdown() + assert mock_ll_instance.stop.call_count == 2 # 新增的那个也被 stop + assert len(pool._sessions) == 0 + + +@pytest.mark.asyncio +async def test_calculate_config_hash_with_non_serializable(): + """测试配置中包含非 JSON 序列化对象时配置哈希计算不崩溃""" + from datetime import datetime + pool = MCPClientPool() + + class DummyObj: + def __str__(self): + return "dummy" + + config = { + "transport": "sse", + "url": "http://example.com/sse", + "custom_obj": DummyObj(), + "created_at": datetime(2026, 6, 5), + } + + # 验证在含有不可 JSON 序列化的对象时依然能正常计算出哈希,不抛出异常 + config_hash = pool._calculate_config_hash(config) + assert isinstance(config_hash, str) + assert len(config_hash) == 16 + + +def test_calculate_config_hash_ignores_internal_proxy_token_header(): + """代理模式下短期 JWT 变化不应触发长连接重建""" + from yuxi.services.mcp_auth.proxy_service import INTERNAL_PROXY_TOKEN_HEADER + + pool = MCPClientPool() + config_a = { + "transport": "streamable_http", + "url": "http://api:5050/api/internal/mcp-proxy/demo", + "headers": { + "X-App": "yuxi", + "Authorization": "Bearer upstream-a", + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-a", + }, + } + config_b = { + "transport": "streamable_http", + "url": "http://api:5050/api/internal/mcp-proxy/demo", + "headers": { + "X-App": "yuxi", + "Authorization": "Bearer upstream-b", + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-b", + }, + } + + assert pool._calculate_config_hash(config_a) == pool._calculate_config_hash(config_b) + + +@pytest.mark.asyncio +async def test_dynamic_mcp_token_auth_cache(): + """测试 DynamicMCPTokenAuth 的 in-memory 缓存及联动清除逻辑""" + from yuxi.services.mcp.client_pool import ( + DynamicMCPTokenAuth, + clear_resolved_headers_cache, + clear_server_resolved_headers_cache, + _resolved_headers_cache, + ) + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var, AuthContext + + # 清空可能存在的全局缓存 + clear_resolved_headers_cache() + + auth = DynamicMCPTokenAuth("test_server") + mock_req = MagicMock(headers={}) + + auth_ctx = AuthContext(user_id="u1", department_id="d1") + cache_key = ("test_server", "u1", "d1") + + # 模拟 get_runtime_mcp_server_config 返回的数据 + mock_runtime_config = {"headers": {"Authorization": "Bearer token123"}} + + token = mcp_auth_context_var.set(auth_ctx) + try: + with patch("yuxi.storage.postgres.manager.pg_manager.get_async_session_context") as mock_session_ctx, \ + patch("yuxi.services.mcp.server_service.get_runtime_mcp_server_config", return_value=mock_runtime_config) as mock_get_config: + + # 模拟 async with pg_manager.get_async_session_context() as session + mock_session = MagicMock() + mock_ctx_mgr = AsyncMock() + mock_ctx_mgr.__aenter__.return_value = mock_session + mock_session_ctx.return_value = mock_ctx_mgr + + # 1. 第一次请求:应该执行 DB 查询,获取最新运行时配置 + generator = auth.async_auth_flow(mock_req) + results = [r async for r in generator] + assert len(results) == 1 + assert results[0].headers["Authorization"] == "Bearer token123" + assert mock_get_config.call_count == 1 + + # 2. 第二次请求(+5秒):应该命中缓存,不会执行 DB 查询 + mock_req_2 = MagicMock(headers={}) + generator_2 = auth.async_auth_flow(mock_req_2) + results_2 = [r async for r in generator_2] + assert len(results_2) == 1 + assert results_2[0].headers["Authorization"] == "Bearer token123" + # call_count 依然是 1,说明命中了缓存 + assert mock_get_config.call_count == 1 + + # 3. 细粒度清除缓存:清除指定 server_name + clear_server_resolved_headers_cache("test_server") + + # 4. 第三次请求(+10秒):清除缓存后,应该再次执行 DB 查询 + mock_req_3 = MagicMock(headers={}) + generator_3 = auth.async_auth_flow(mock_req_3) + results_3 = [r async for r in generator_3] + assert len(results_3) == 1 + assert mock_get_config.call_count == 2 + # 4. 测试 clear_mcp_cache / clear_mcp_server_tools_cache 联动清除所有 resolved_headers 缓存 + from yuxi.services.mcp.tool_registry_service import clear_mcp_cache, invalidate_mcp_server_tools_cache + # 确保当前有缓存项 + _resolved_headers_cache[cache_key] = {"Auth": "Bearer test"} + await invalidate_mcp_server_tools_cache("test_server") + assert len(_resolved_headers_cache) == 0 + + _resolved_headers_cache[cache_key] = {"Auth": "Bearer test"} + await clear_mcp_cache() + assert len(_resolved_headers_cache) == 0 + finally: + mcp_auth_context_var.reset(token) diff --git a/backend/test/unit/services/test_mcp_connection_service.py b/backend/test/unit/services/test_mcp_connection_service.py new file mode 100644 index 000000000..d1863c0e4 --- /dev/null +++ b/backend/test/unit/services/test_mcp_connection_service.py @@ -0,0 +1,1036 @@ +from __future__ import annotations + +import os + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp import connection_service, server_service, tool_registry_service +from yuxi.services.mcp_auth.crypto import decrypt_credential_blob +from yuxi.storage.postgres.models_business import AgentConfig, Department, MCPConnection, MCPServer, Skill, User + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +@pytest_asyncio.fixture +async def connection_service_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +@pytest_asyncio.fixture +async def delete_semantics_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Department.__table__.create) + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + await conn.run_sync(Skill.__table__.create) + await conn.run_sync(AgentConfig.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +@pytest_asyncio.fixture +async def connection_listing_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Department.__table__.create) + await conn.run_sync(User.__table__.create) + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +async def test_create_and_list_mcp_connections(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="finance-gateway", + scope_type="department", + scope_id="42", + display_name="财务部共享连接", + external_subject="finance-user", + credential_blob="encrypted-secret", + meta_json={"tenant": "finance"}, + created_by="tester", + ) + + listed = await connection_service.list_mcp_connections(connection_service_session, server_name="finance-gateway") + + assert created.server_name == "finance-gateway" + assert created.scope_type == "department" + assert [item.id for item in listed] == [created.id] + + +async def test_create_mcp_connection_normalizes_system_scope_to_global(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="global-gateway", + transport="streamable_http", + url="http://global.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="global-gateway", + scope_type="system", + scope_id="", + display_name="全局共享连接", + credential_blob="encrypted-secret", + created_by="tester", + ) + + assert created.scope_type == "system" + assert created.scope_id == "global" + + +async def test_list_mcp_connections_page_filters_health_and_searches_scope_target( + connection_listing_session, +): + connection_listing_session.add(Department(id=10, name="财务部")) + connection_listing_session.add( + User( + id=1, + username="Alice", + user_id="alice", + password_hash="x", + department_id=10, + ) + ) + connection_listing_session.add( + User( + id=2, + username="Bob", + user_id="bob", + password_hash="x", + department_id=10, + ) + ) + connection_listing_session.add( + MCPServer( + name="listing-gateway", + transport="streamable_http", + url="http://listing.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + connection_listing_session.add_all( + [ + MCPConnection( + server_name="listing-gateway", + scope_type="user", + scope_id="1", + display_name="Alice 连接", + status="active", + credential_blob="encrypted-secret", + ), + MCPConnection( + server_name="listing-gateway", + scope_type="user", + scope_id="2", + display_name="Bob 连接", + status="active", + credential_blob=None, + ), + MCPConnection( + server_name="listing-gateway", + scope_type="system", + scope_id="global", + display_name="历史全局连接", + status="active", + credential_blob="encrypted-secret", + ), + MCPConnection( + server_name="listing-gateway", + scope_type="user", + scope_id="3", + display_name="过期连接", + status="reauth_required", + credential_blob="encrypted-secret", + ), + MCPConnection( + server_name="listing-gateway", + scope_type="user", + scope_id="4", + display_name="停用连接", + status="disabled", + credential_blob="encrypted-secret", + ), + MCPConnection( + server_name="listing-gateway", + scope_type="department", + scope_id="10", + display_name="部门异常连接", + status="invalid", + credential_blob="encrypted-secret", + ), + ] + ) + await connection_listing_session.commit() + + active_connections, active_total = await connection_service.list_mcp_connections_page( + connection_listing_session, + server_name="listing-gateway", + status_filter="active", + effective_scope_type="user", + credentials_required=True, + page=1, + page_size=12, + ) + + assert active_total == 1 + assert [item.display_name for item in active_connections] == ["Alice 连接"] + + attention_count = await connection_service.count_mcp_connections( + connection_listing_session, + server_name="listing-gateway", + status_filter="attention", + effective_scope_type="user", + credentials_required=True, + ) + assert attention_count == 4 + + searched_connections, searched_total = await connection_service.list_mcp_connections_page( + connection_listing_session, + server_name="listing-gateway", + search="alice", + page=1, + page_size=12, + ) + + assert searched_total == 1 + assert searched_connections[0].scope_id == "1" + + +async def test_create_mcp_connection_duplicate_scope_uses_user_friendly_message( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="demo_mcp_server", + transport="streamable_http", + url="http://demo.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + await connection_service.create_mcp_connection( + connection_service_session, + server_name="demo_mcp_server", + scope_type="user", + scope_id="1", + display_name="个人连接", + credential_blob="encrypted-secret", + created_by="tester", + ) + + with pytest.raises(ValueError) as exc_info: + await connection_service.create_mcp_connection( + connection_service_session, + server_name="demo_mcp_server", + scope_type="user", + scope_id="1", + display_name="重复个人连接", + credential_blob="encrypted-secret", + created_by="tester", + ) + + message = str(exc_info.value) + assert message == 'MCP "demo_mcp_server" 的个人专用连接已存在,请直接编辑现有连接。' + assert "user:1" not in message + + +async def test_create_mcp_connection_rejects_scope_that_does_not_match_server_binding( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="personal-gateway", + transport="streamable_http", + url="http://personal.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [ + {"name": "Authorization", "value_template": "Bearer ${secret.access_token}"} + ], + }, + }, + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + with pytest.raises(ValueError) as exc_info: + await connection_service.create_mcp_connection( + connection_service_session, + server_name="personal-gateway", + scope_type="department", + scope_id="42", + display_name="部门连接", + credential_blob="encrypted-secret", + created_by="tester", + ) + + assert str(exc_info.value) == 'MCP "personal-gateway" 当前绑定类型是个人专用,只能使用个人专用连接。' + + +async def test_set_mcp_connection_status_updates_status(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="corp-gateway", + transport="streamable_http", + url="http://corp.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="corp-gateway", + scope_type="system", + scope_id="global", + display_name="全局共享连接", + credential_blob="encrypted-secret", + created_by="tester", + ) + + updated = await connection_service.set_mcp_connection_status( + connection_service_session, + created.id, + status="reauth_required", + updated_by="admin", + ) + + assert updated.status == "reauth_required" + assert updated.updated_by == "admin" + + +async def test_create_mcp_connection_rejects_invalid_scope_type(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="invalid-scope-gateway", + transport="streamable_http", + url="http://invalid-scope.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + with pytest.raises(ValueError, match="scope_type"): + await connection_service.create_mcp_connection( + connection_service_session, + server_name="invalid-scope-gateway", + scope_type="tenant", + scope_id="x", + created_by="tester", + ) + + +async def test_create_mcp_connection_rejects_missing_department_scope_id(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="missing-scope-id-gateway", + transport="streamable_http", + url="http://missing-scope-id.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + with pytest.raises(ValueError, match="scope_id"): + await connection_service.create_mcp_connection( + connection_service_session, + server_name="missing-scope-id-gateway", + scope_type="department", + scope_id="", + created_by="tester", + ) + + +async def test_set_mcp_connection_status_rejects_invalid_status(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="invalid-status-gateway", + transport="streamable_http", + url="http://invalid-status.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="invalid-status-gateway", + scope_type="system", + scope_id="global", + created_by="tester", + ) + + with pytest.raises(ValueError, match="status"): + await connection_service.set_mcp_connection_status( + connection_service_session, + created.id, + status="broken", + updated_by="admin", + ) + + +async def test_set_mcp_connection_status_rejects_reactivating_scope_mismatch( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="personal-status-gateway", + transport="streamable_http", + url="http://personal-status.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [ + {"name": "Authorization", "value_template": "Bearer ${secret.access_token}"} + ], + }, + }, + created_by="tester", + updated_by="tester", + ) + ) + connection_service_session.add( + MCPConnection( + server_name="personal-status-gateway", + scope_type="system", + scope_id="global", + display_name="历史全局连接", + status="disabled", + credential_blob="encrypted-secret", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + connection = ( + await connection_service.list_mcp_connections( + connection_service_session, + server_name="personal-status-gateway", + ) + )[0] + + with pytest.raises(ValueError) as exc_info: + await connection_service.set_mcp_connection_status( + connection_service_session, + connection.id, + status="active", + updated_by="admin", + ) + + assert str(exc_info.value) == 'MCP "personal-status-gateway" 当前绑定类型是个人专用,只能使用个人专用连接。' + + +async def test_create_mcp_connection_encrypts_credentials(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="secure-gateway", + transport="streamable_http", + url="http://secure.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + plaintext = '{"secrets":{"access_token":"secure-token"}}' + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="secure-gateway", + scope_type="system", + scope_id="global", + credential_blob=plaintext, + created_by="tester", + ) + + assert created.credential_blob != plaintext + assert decrypt_credential_blob(created.credential_blob) == plaintext + + +async def test_create_mcp_connection_rejects_plaintext_credentials_without_master_key( + connection_service_session, monkeypatch +): + monkeypatch.delenv("MCP_CREDENTIALS_MASTER_KEY", raising=False) + connection_service_session.add( + MCPServer( + name="insecure-gateway", + transport="streamable_http", + url="http://insecure.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + with pytest.raises(ValueError, match="MCP_CREDENTIALS_MASTER_KEY"): + await connection_service.create_mcp_connection( + connection_service_session, + server_name="insecure-gateway", + scope_type="system", + scope_id="global", + credential_blob='{"secrets":{"access_token":"token"}}', + created_by="tester", + ) + + +async def test_get_mcp_server_dependency_summary_reports_runtime_references(delete_semantics_session): + department = Department(name="研发部", description="dep") + delete_semantics_session.add(department) + delete_semantics_session.add( + MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + enabled=0, + created_by="tester", + updated_by="tester", + ) + ) + delete_semantics_session.add( + MCPConnection( + server_name="finance-gateway", + scope_type="department", + scope_id="42", + status="active", + created_by="tester", + updated_by="tester", + ) + ) + delete_semantics_session.add( + Skill( + slug="finance-skill", + name="Finance Skill", + description="desc", + tool_dependencies=[], + mcp_dependencies=["finance-gateway"], + skill_dependencies=[], + dir_path="skills/finance", + created_by="tester", + updated_by="tester", + ) + ) + await delete_semantics_session.flush() + delete_semantics_session.add( + AgentConfig( + department_id=department.id, + agent_id="agent-1", + name="Finance Agent", + description="desc", + config_json={"mcps": ["finance-gateway"]}, + pics=[], + examples=[], + created_by="tester", + updated_by="tester", + ) + ) + await delete_semantics_session.commit() + + summary = await server_service.get_mcp_server_dependency_summary(delete_semantics_session, "finance-gateway") + + assert summary["has_references"] is True + assert summary["connections"] == [{"scope_type": "department", "scope_id": "42", "status": "active"}] + assert summary["skills"] == [{"slug": "finance-skill", "name": "Finance Skill"}] + assert summary["agent_configs"] == [{"id": 1, "name": "Finance Agent", "agent_id": "agent-1"}] + + +async def test_update_mcp_connection_reencrypts_credentials(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="update-gateway", + transport="streamable_http", + url="http://update.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="update-gateway", + scope_type="system", + scope_id="global", + display_name="old", + credential_blob='{"secrets":{"access_token":"old-token"}}', + created_by="tester", + ) + + updated = await connection_service.update_mcp_connection( + connection_service_session, + created.id, + display_name="new", + credential_blob='{"secrets":{"access_token":"new-token"}}', + updated_by="admin", + ) + + assert updated.display_name == "new" + assert decrypt_credential_blob(updated.credential_blob) == '{"secrets":{"access_token":"new-token"}}' + assert updated.updated_by == "admin" + + +async def test_delete_mcp_connection_removes_record(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + cleared_connection_ids = [] + released_connection_ids = [] + + class DummyTokenCache: + async def delete_access_token(self, connection_id): + cleared_connection_ids.append(connection_id) + + async def release_refresh_lock(self, connection_id): + released_connection_ids.append(connection_id) + + monkeypatch.setattr("yuxi.services.mcp_auth.redis_token_cache.RedisTokenCache", lambda: DummyTokenCache()) + connection_service_session.add( + MCPServer( + name="delete-connection-gateway", + transport="streamable_http", + url="http://delete.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="delete-connection-gateway", + scope_type="system", + scope_id="global", + credential_blob='{"secrets":{"access_token":"token"}}', + created_by="tester", + ) + + deleted = await connection_service.delete_mcp_connection(connection_service_session, created.id) + + assert deleted is True + assert cleared_connection_ids == [created.id] + assert released_connection_ids == [created.id] + assert await connection_service.get_mcp_connection(connection_service_session, created.id) is None + + +async def test_reauthorize_mcp_connection_clears_runtime_error(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + cleared_connection_ids = [] + released_connection_ids = [] + + class DummyTokenCache: + async def delete_access_token(self, connection_id): + cleared_connection_ids.append(connection_id) + + async def release_refresh_lock(self, connection_id): + released_connection_ids.append(connection_id) + + monkeypatch.setattr("yuxi.services.mcp_auth.redis_token_cache.RedisTokenCache", lambda: DummyTokenCache()) + + connection_service_session.add( + MCPServer( + name="reauth-gateway", + transport="streamable_http", + url="http://reauth.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="reauth-gateway", + scope_type="system", + scope_id="global", + status="reauth_required", + credential_blob='{"secrets":{"access_token":"token"}}', + meta_json={"last_error": {"message": "expired"}}, + created_by="tester", + ) + + updated = await connection_service.reauthorize_mcp_connection( + connection_service_session, + created.id, + updated_by="admin", + ) + + assert cleared_connection_ids == [created.id] + assert released_connection_ids == [created.id] + assert updated.status == "active" + assert updated.meta_json == {} + assert updated.updated_by == "admin" + + +async def test_reauthorize_mcp_connection_rejects_scope_that_does_not_match_server_binding( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="personal-reauth-gateway", + transport="streamable_http", + url="http://personal-reauth.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [ + {"name": "Authorization", "value_template": "Bearer ${secret.access_token}"} + ], + }, + }, + created_by="tester", + updated_by="tester", + ) + ) + connection_service_session.add( + MCPConnection( + server_name="personal-reauth-gateway", + scope_type="system", + scope_id="global", + display_name="历史全局连接", + status="reauth_required", + credential_blob="encrypted-secret", + meta_json={"last_error": {"message": "expired"}}, + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + connection = ( + await connection_service.list_mcp_connections( + connection_service_session, + server_name="personal-reauth-gateway", + ) + )[0] + + with pytest.raises(ValueError) as exc_info: + await connection_service.reauthorize_mcp_connection( + connection_service_session, + connection.id, + updated_by="admin", + ) + + message = str(exc_info.value) + assert message == 'MCP "personal-reauth-gateway" 当前绑定类型是个人专用,只能使用个人专用连接。' + assert "user_id is required" not in message + + +async def test_update_mcp_connection_clears_runtime_auth_cache_on_credential_change( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + cleared_connection_ids = [] + released_connection_ids = [] + + class DummyTokenCache: + async def delete_access_token(self, connection_id): + cleared_connection_ids.append(connection_id) + + async def release_refresh_lock(self, connection_id): + released_connection_ids.append(connection_id) + + monkeypatch.setattr("yuxi.services.mcp_auth.redis_token_cache.RedisTokenCache", lambda: DummyTokenCache()) + connection_service_session.add( + MCPServer( + name="credential-update-gateway", + transport="streamable_http", + url="http://credential-update.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="credential-update-gateway", + scope_type="system", + scope_id="global", + credential_blob='{"secrets":{"access_token":"old-token"}}', + created_by="tester", + ) + + updated = await connection_service.update_mcp_connection( + connection_service_session, + created.id, + credential_blob='{"secrets":{"access_token":"new-token"}}', + updated_by="admin", + ) + + assert updated.updated_by == "admin" + assert cleared_connection_ids == [created.id] + assert released_connection_ids == [created.id] + + +async def test_set_server_enabled_clears_runtime_auth_cache_when_retiring( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + cleared_connection_ids = [] + released_connection_ids = [] + + class DummyTokenCache: + async def delete_access_token(self, connection_id): + cleared_connection_ids.append(connection_id) + + async def release_refresh_lock(self, connection_id): + released_connection_ids.append(connection_id) + + monkeypatch.setattr("yuxi.services.mcp_auth.redis_token_cache.RedisTokenCache", lambda: DummyTokenCache()) + connection_service_session.add( + MCPServer( + name="retire-gateway", + transport="streamable_http", + url="http://retire.local/mcp", + enabled=1, + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + first = await connection_service.create_mcp_connection( + connection_service_session, + server_name="retire-gateway", + scope_type="department", + scope_id="dep-1", + credential_blob='{"secrets":{"access_token":"token-1"}}', + created_by="tester", + ) + second = await connection_service.create_mcp_connection( + connection_service_session, + server_name="retire-gateway", + scope_type="department", + scope_id="dep-2", + credential_blob='{"secrets":{"access_token":"token-2"}}', + created_by="tester", + ) + + enabled, server = await server_service.set_server_enabled( + connection_service_session, + "retire-gateway", + False, + updated_by="admin", + ) + + assert enabled is False + assert bool(server.enabled) is False + assert cleared_connection_ids == [first.id, second.id] + assert released_connection_ids == [first.id, second.id] + + +async def test_test_mcp_connection_refreshes_success_metadata(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + + async def fake_get_runtime_mcp_server_config(server_name, *, auth_context=None, db=None, http_client=None): + del auth_context, db, http_client + return {"transport": "stdio", "command": f"{server_name}-cmd", "disabled_tools": []} + + async def fake_get_mcp_tools(server_name, additional_servers=None, disabled_tools=None, **kwargs): + del additional_servers, disabled_tools, kwargs + return [server_name, "tool-b"] + + monkeypatch.setattr(server_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + connection_service_session.add( + MCPServer( + name="test-gateway", + transport="streamable_http", + url="http://test.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="test-gateway", + scope_type="department", + scope_id="dep-9", + status="reauth_required", + credential_blob='{"secrets":{"access_token":"token"}}', + meta_json={"last_error": {"message": "old"}}, + created_by="tester", + ) + + result = await connection_service.test_mcp_connection( + connection_service_session, + created.id, + updated_by="admin", + ) + + assert result["tool_count"] == 2 + assert result["connection"].status == "active" + assert "last_success_at" in result["connection"].meta_json + assert "last_error" not in result["connection"].meta_json + + +async def test_test_mcp_connection_rejects_scope_that_does_not_match_server_binding( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="personal-runtime-gateway", + transport="streamable_http", + url="http://personal-runtime.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [ + {"name": "Authorization", "value_template": "Bearer ${secret.access_token}"} + ], + }, + }, + created_by="tester", + updated_by="tester", + ) + ) + connection_service_session.add( + MCPConnection( + server_name="personal-runtime-gateway", + scope_type="system", + scope_id="global", + display_name="历史全局连接", + status="active", + credential_blob="encrypted-secret", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + connection = ( + await connection_service.list_mcp_connections( + connection_service_session, + server_name="personal-runtime-gateway", + ) + )[0] + + with pytest.raises(ValueError) as exc_info: + await connection_service.test_mcp_connection( + connection_service_session, + connection.id, + updated_by="admin", + ) + + message = str(exc_info.value) + assert message == 'MCP "personal-runtime-gateway" 当前绑定类型是个人专用,只能使用个人专用连接。' + assert "user_id is required" not in message + + +async def test_test_mcp_connection_populates_work_id_for_user_scope(connection_service_session, monkeypatch): + observed_auth_contexts = [] + + async def fake_get_runtime_mcp_server_config(server_name, *, auth_context=None, db=None, http_client=None): + del db, http_client + observed_auth_contexts.append(auth_context) + return {"transport": "stdio", "command": f"{server_name}-cmd", "disabled_tools": []} + + async def fake_get_mcp_tools(server_name, additional_servers=None, disabled_tools=None, **kwargs): + del server_name, additional_servers, disabled_tools, kwargs + return ["tool-a"] + + monkeypatch.setattr(server_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + connection_service_session.add( + MCPServer( + name="user-work-id-gateway", + transport="streamable_http", + url="http://test.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await connection_service.create_mcp_connection( + connection_service_session, + server_name="user-work-id-gateway", + scope_type="user", + scope_id="U001", + created_by="tester", + ) + + result = await connection_service.test_mcp_connection( + connection_service_session, + created.id, + updated_by="admin", + ) + + assert result["tool_count"] == 1 + assert observed_auth_contexts[0].user_id == "U001" + assert observed_auth_contexts[0].work_id == "U001" diff --git a/backend/test/unit/services/test_mcp_service.py b/backend/test/unit/services/test_mcp_service.py deleted file mode 100644 index bd38ba833..000000000 --- a/backend/test/unit/services/test_mcp_service.py +++ /dev/null @@ -1,138 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace - -from yuxi.services import mcp_service - - -class _FakeClient: - def __init__(self, tools): - self._tools = tools - - async def get_tools(self): - return self._tools - - -async def test_get_enabled_mcp_tools_loads_latest_config_from_db(monkeypatch): - captured: list[dict] = [] - - async def fake_get_enabled_mcp_server_config(server_name: str, db=None): - del db - assert server_name == "demo" - return {"transport": "stdio", "command": "demo", "disabled_tools": ["tool_b"]} - - async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled_tools=None, **kwargs): - del kwargs - captured.append( - { - "server_name": server_name, - "additional_servers": additional_servers, - "disabled_tools": list(disabled_tools or []), - } - ) - return ["tool-a"] - - monkeypatch.setattr(mcp_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) - - tools = await mcp_service.get_enabled_mcp_tools("demo") - - assert tools == ["tool-a"] - assert captured == [ - { - "server_name": "demo", - "additional_servers": { - "demo": {"transport": "stdio", "command": "demo", "disabled_tools": ["tool_b"]} - }, - "disabled_tools": ["tool_b"], - } - ] - - -async def test_get_mcp_tools_rebuilds_cache_when_config_hash_changes(monkeypatch): - mcp_service.clear_mcp_cache() - - configs = [ - {"transport": "stdio", "command": "demo-v1", "disabled_tools": []}, - {"transport": "stdio", "command": "demo-v2", "disabled_tools": []}, - ] - build_calls: list[str] = [] - - async def fake_get_enabled_mcp_server_config(server_name: str, db=None): - del db - assert server_name == "demo" - return configs[0] - - async def fake_get_mcp_client(server_configs): - config = server_configs["demo"] - build_calls.append(config["command"]) - tool = SimpleNamespace(name=f"tool_for_{config['command']}", metadata={}) - return _FakeClient([tool]) - - monkeypatch.setattr(mcp_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) - - tools_v1_first = await mcp_service.get_mcp_tools("demo") - tools_v1_second = await mcp_service.get_mcp_tools("demo") - - configs[0] = configs[1] - tools_v2 = await mcp_service.get_mcp_tools("demo") - - assert [tool.name for tool in tools_v1_first] == ["tool_for_demo-v1"] - assert [tool.name for tool in tools_v1_second] == ["tool_for_demo-v1"] - assert [tool.name for tool in tools_v2] == ["tool_for_demo-v2"] - assert build_calls == ["demo-v1", "demo-v2"] - - mcp_service.clear_mcp_cache() - - -async def test_get_tools_from_all_servers_loads_names_from_db_once(monkeypatch): - server_configs = { - "alpha": {"transport": "stdio", "command": "cmd-a", "disabled_tools": []}, - "beta": {"transport": "stdio", "command": "cmd-b", "disabled_tools": []}, - } - calls: list[tuple[str, dict[str, dict]]] = [] - - async def fake_load_enabled_mcp_server_configs(*, names=None, db=None): - del names, db - return server_configs - - async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs): - del kwargs - calls.append((server_name, additional_servers or {})) - return [server_name] - - monkeypatch.setattr(mcp_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) - - tools = await mcp_service.get_tools_from_all_servers() - - assert tools == ["alpha", "beta"] - assert calls == [ - ("alpha", server_configs), - ("beta", server_configs), - ] - - -async def test_get_mcp_tools_sets_handle_tool_error(monkeypatch): - mcp_service.clear_mcp_cache() - - config = {"transport": "stdio", "command": "demo-tool", "disabled_tools": []} - - async def fake_get_enabled_mcp_server_config(server_name: str, db=None): - del db - return config - - async def fake_get_mcp_client(server_configs): - tool = SimpleNamespace(name="demo_tool", metadata={}) - return _FakeClient([tool]) - - monkeypatch.setattr(mcp_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) - - tools = await mcp_service.get_mcp_tools("demo") - assert len(tools) == 1 - assert tools[0].handle_tool_error is True - - mcp_service.clear_mcp_cache() - diff --git a/backend/test/unit/services/test_mcp_tool_cache.py b/backend/test/unit/services/test_mcp_tool_cache.py new file mode 100644 index 000000000..54c8ae23a --- /dev/null +++ b/backend/test/unit/services/test_mcp_tool_cache.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import pytest + +from yuxi.services.mcp_tool_cache import RedisMcpToolCache + + +class _FakeRedis: + def __init__(self): + self.data: dict[str, str] = {} + self.expire_calls: dict[str, int] = {} + + async def get(self, key: str) -> str | None: + return self.data.get(key) + + async def set(self, key: str, value: str, ex: int | None = None) -> None: + self.data[key] = value + if ex is not None: + self.expire_calls[key] = ex + + async def incr(self, key: str) -> int: + next_value = int(self.data.get(key) or "0") + 1 + self.data[key] = str(next_value) + return next_value + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_redis_mcp_tool_cache_revision_and_manifest_roundtrip(): + fake_redis = _FakeRedis() + + async def fake_redis_factory(): + return fake_redis + + cache = RedisMcpToolCache(redis_client_factory=fake_redis_factory) + + assert await cache.get_server_revision("demo") == 0 + assert await cache.get_partition_revision("demo", "connection:7") == 0 + + assert await cache.bump_server_revision("demo") == 1 + assert await cache.bump_partition_revision("demo", "connection:7") == 1 + + assert await cache.get_server_revision("demo") == 1 + assert await cache.get_partition_revision("demo", "connection:7") == 1 + + manifest = { + "server_name": "demo", + "cache_partition": "connection:7", + "cache_key": "demo:connection:7:s1:p1:abc123", + "tools": [ + { + "name": "alpha_tool", + "id": "mcp__demo__alphaTool", + "description": "alpha", + "parameters": {"city": {"type": "string"}}, + "required": ["city"], + } + ], + } + await cache.set_manifest("demo:connection:7:s1:p1:abc123", manifest) + + assert await cache.get_manifest("demo:connection:7:s1:p1:abc123") == manifest diff --git a/backend/test/unit/services/test_mcp_tool_registry_service.py b/backend/test/unit/services/test_mcp_tool_registry_service.py new file mode 100644 index 000000000..dbd40a610 --- /dev/null +++ b/backend/test/unit/services/test_mcp_tool_registry_service.py @@ -0,0 +1,444 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from yuxi.services.mcp import server_service, tool_registry_service +from yuxi.services.mcp.client_pool import mcp_client_pool +from yuxi.services.mcp_tool_cache import RedisMcpToolCache +from yuxi.services.mcp_auth.proxy_service import INTERNAL_PROXY_TOKEN_HEADER + + +class _FakeClient: + def __init__(self, tools): + self._tools = tools + + async def get_tools(self): + return self._tools + + +class _FakeRedis: + def __init__(self): + self.data: dict[str, str] = {} + self.expire_calls: dict[str, int] = {} + + async def get(self, key: str) -> str | None: + return self.data.get(key) + + async def set(self, key: str, value: str, ex: int | None = None) -> None: + self.data[key] = value + if ex is not None: + self.expire_calls[key] = ex + + async def incr(self, key: str) -> int: + next_value = int(self.data.get(key) or "0") + 1 + self.data[key] = str(next_value) + return next_value + + +async def test_get_enabled_mcp_tools_loads_latest_config_from_db(monkeypatch): + captured: list[dict] = [] + + async def fake_get_enabled_mcp_server_config(server_name: str, db=None): + del db + assert server_name == "demo" + return {"transport": "stdio", "command": "demo", "disabled_tools": ["tool_b"]} + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled_tools=None, **kwargs): + del kwargs + captured.append( + { + "server_name": server_name, + "additional_servers": additional_servers, + "disabled_tools": list(disabled_tools or []), + } + ) + return ["tool-a"] + + monkeypatch.setattr(server_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await tool_registry_service.get_enabled_mcp_tools("demo") + + assert tools == ["tool-a"] + assert captured == [ + { + "server_name": "demo", + "additional_servers": {"demo": {"transport": "stdio", "command": "demo", "disabled_tools": ["tool_b"]}}, + "disabled_tools": ["tool_b"], + } + ] + + +async def test_get_mcp_tools_rebuilds_cache_when_config_hash_changes(monkeypatch): + await tool_registry_service.clear_mcp_cache() + + configs = [ + {"transport": "stdio", "command": "demo-v1", "disabled_tools": []}, + {"transport": "stdio", "command": "demo-v2", "disabled_tools": []}, + ] + build_calls: list[str] = [] + + async def fake_get_enabled_mcp_server_config(server_name: str, db=None): + del db + assert server_name == "demo" + return configs[0] + + async def fake_get_mcp_client(server_configs): + config = server_configs["demo"] + build_calls.append(config["command"]) + tool = SimpleNamespace(name=f"tool_for_{config['command']}", metadata={}) + return _FakeClient([tool]) + + monkeypatch.setattr(server_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) + + tools_v1_first = await tool_registry_service.get_mcp_tools("demo") + tools_v1_second = await tool_registry_service.get_mcp_tools("demo") + + configs[0] = configs[1] + tools_v2 = await tool_registry_service.get_mcp_tools("demo") + + assert [tool.name for tool in tools_v1_first] == ["tool_for_demo-v1"] + assert [tool.name for tool in tools_v1_second] == ["tool_for_demo-v1"] + assert [tool.name for tool in tools_v2] == ["tool_for_demo-v2"] + assert build_calls == ["demo-v1", "demo-v2"] + + await tool_registry_service.clear_mcp_cache() + + +async def test_get_tools_from_all_servers_loads_names_from_db_once(monkeypatch): + server_configs = { + "alpha": {"transport": "stdio", "command": "cmd-a", "disabled_tools": []}, + "beta": {"transport": "stdio", "command": "cmd-b", "disabled_tools": []}, + } + calls: list[tuple[str, dict[str, dict]]] = [] + + async def fake_load_enabled_mcp_server_configs(*, names=None, db=None): + del names, db + return server_configs + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs): + del kwargs + calls.append((server_name, additional_servers or {})) + return [server_name] + + monkeypatch.setattr(server_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await tool_registry_service.get_tools_from_all_servers() + + assert tools == ["alpha", "beta"] + assert calls == [ + ("alpha", {"alpha": server_configs["alpha"]}), + ("beta", {"beta": server_configs["beta"]}), + ] + + +async def test_get_tools_from_all_servers_limits_preload_to_selected_names(monkeypatch): + server_configs = { + "alpha": {"transport": "stdio", "command": "cmd-a", "disabled_tools": []}, + "beta": {"transport": "stdio", "command": "cmd-b", "disabled_tools": []}, + } + loaded_names: list[list[str] | None] = [] + calls: list[str] = [] + + async def fake_load_enabled_mcp_server_configs(*, names=None, db=None): + del db + loaded_names.append(names) + if not names: + return server_configs + return {name: server_configs[name] for name in names if name in server_configs} + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs): + del additional_servers, kwargs + calls.append(server_name) + return [server_name] + + monkeypatch.setattr(server_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await tool_registry_service.get_tools_from_all_servers(["alpha", "alpha", "missing"]) + empty_tools = await tool_registry_service.get_tools_from_all_servers([]) + + assert tools == ["alpha"] + assert empty_tools == [] + assert loaded_names == [["alpha", "missing"]] + assert calls == ["alpha"] + + +async def test_get_mcp_tools_sets_handle_tool_error(monkeypatch): + await tool_registry_service.clear_mcp_cache() + + config = {"transport": "stdio", "command": "demo-tool", "disabled_tools": []} + + async def fake_get_enabled_mcp_server_config(server_name: str, db=None): + del db + return config + + async def fake_get_mcp_client(server_configs): + tool = SimpleNamespace(name="demo_tool", metadata={}) + return _FakeClient([tool]) + + monkeypatch.setattr(server_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) + + tools = await tool_registry_service.get_mcp_tools("demo") + assert len(tools) == 1 + assert tools[0].handle_tool_error is True + + await tool_registry_service.clear_mcp_cache() + + +async def test_get_mcp_tools_suppresses_retries_during_failure_cooldown(monkeypatch): + await tool_registry_service.clear_mcp_cache() + + config = {"transport": "stdio", "command": "offline-demo", "disabled_tools": []} + build_calls: list[dict] = [] + + async def fail_get_mcp_client(server_configs): + build_calls.append(server_configs) + raise ConnectionError("mcp service offline") + + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fail_get_mcp_client) + + tools_first = await tool_registry_service.get_mcp_tools("offline", additional_servers={"offline": config}) + tools_second = await tool_registry_service.get_mcp_tools("offline", additional_servers={"offline": config}) + tools_forced = await tool_registry_service.get_mcp_tools( + "offline", + additional_servers={"offline": config}, + force_refresh=True, + ) + + assert tools_first == [] + assert tools_second == [] + assert tools_forced == [] + assert len(build_calls) == 2 + + await tool_registry_service.clear_mcp_cache() + + +async def test_get_mcp_tools_keeps_connection_partitions_separate(monkeypatch): + await tool_registry_service.clear_mcp_cache() + + configs = [ + { + "transport": "streamable_http", + "url": "http://internal-api:5050/api/internal/mcp-proxy/demo", + "headers": { + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-user-a", + }, + "__yuxi_cache_partition": "connection:101", + "__yuxi_allow_global_cache": False, + }, + { + "transport": "streamable_http", + "url": "http://internal-api:5050/api/internal/mcp-proxy/demo", + "headers": { + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-user-b", + }, + "__yuxi_cache_partition": "connection:202", + "__yuxi_allow_global_cache": False, + }, + ] + build_calls: list[str] = [] + + async def fake_get_mcp_client(server_configs): + token = server_configs["demo"]["headers"][INTERNAL_PROXY_TOKEN_HEADER] + build_calls.append(token) + tool = SimpleNamespace(name=f"tool_for_{token}", metadata={}) + return _FakeClient([tool]) + + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) + + tools_a = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) + tools_b = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) + + assert [tool.name for tool in tools_a] == ["tool_for_proxy-token-user-a"] + assert [tool.name for tool in tools_b] == ["tool_for_proxy-token-user-b"] + assert build_calls == ["proxy-token-user-a", "proxy-token-user-b"] + + await tool_registry_service.clear_mcp_cache() + + +async def test_get_mcp_tools_does_not_cache_internal_proxy_tool_objects(monkeypatch): + await tool_registry_service.clear_mcp_cache() + + configs = [ + { + "transport": "streamable_http", + "url": "http://internal-api:5050/api/internal/mcp-proxy/demo", + "headers": { + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-v1", + }, + "__yuxi_cache_partition": "connection:101", + "__yuxi_allow_global_cache": False, + "__yuxi_disable_tool_object_cache": True, + }, + { + "transport": "streamable_http", + "url": "http://internal-api:5050/api/internal/mcp-proxy/demo", + "headers": { + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-v2", + }, + "__yuxi_cache_partition": "connection:101", + "__yuxi_allow_global_cache": False, + "__yuxi_disable_tool_object_cache": True, + }, + ] + build_calls: list[str] = [] + tool_load_count = 0 + + class RefreshingFakeClient: + async def get_tools(self): + nonlocal tool_load_count + tool_load_count += 1 + tool = SimpleNamespace(name=f"tool_for_load_{tool_load_count}", metadata={}) + return [tool] + + async def fake_get_mcp_client(server_configs): + token = server_configs["demo"]["headers"][INTERNAL_PROXY_TOKEN_HEADER] + build_calls.append(token) + return RefreshingFakeClient() + + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) + + tools_first = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) + tools_second = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) + + assert [tool.name for tool in tools_first] == ["tool_for_load_1"] + assert [tool.name for tool in tools_second] == ["tool_for_load_2"] + assert build_calls == ["proxy-token-v1"] + + await tool_registry_service.clear_mcp_cache() + + +async def test_get_tools_from_all_servers_skips_runtime_auth_servers_without_context(monkeypatch): + server_configs = { + "shared": {"transport": "stdio", "command": "cmd-shared", "disabled_tools": []}, + "bound": { + "transport": "streamable_http", + "url": "http://bound.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": { + "url": "http://bound.local/token", + "method": "POST", + "response_map": {"access_token": "access_token"}, + }, + }, + "disabled_tools": [], + }, + } + calls: list[tuple[str, dict[str, dict]]] = [] + + async def fake_load_enabled_mcp_server_configs(*, names=None, db=None): + del names, db + return server_configs + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs): + del kwargs + calls.append((server_name, additional_servers or {})) + return [server_name] + + monkeypatch.setattr(server_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await tool_registry_service.get_tools_from_all_servers() + + assert tools == ["shared"] + assert calls == [ + ("shared", {"shared": server_configs["shared"]}), + ] + + +async def test_get_mcp_tools_rebuilds_when_redis_server_revision_changes(monkeypatch): + await tool_registry_service.clear_mcp_cache() + + fake_redis = _FakeRedis() + + async def fake_redis_factory(): + return fake_redis + + monkeypatch.setattr(tool_registry_service, "_mcp_tool_cache_store", + RedisMcpToolCache(redis_client_factory=fake_redis_factory), + ) + + config = {"transport": "stdio", "command": "demo-tool", "disabled_tools": []} + build_calls: list[str] = [] + + async def fake_get_mcp_client(server_configs): + build_calls.append(server_configs["demo"]["command"]) + tool = SimpleNamespace(name=f"tool_{len(build_calls)}", metadata={}) + return _FakeClient([tool]) + + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) + + tools_first = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": config}) + tools_second = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": config}) + await tool_registry_service._mcp_tool_cache_store.bump_server_revision("demo") + tools_third = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": config}) + + assert [tool.name for tool in tools_first] == ["tool_1"] + assert [tool.name for tool in tools_second] == ["tool_1"] + assert [tool.name for tool in tools_third] == ["tool_2"] + assert build_calls == ["demo-tool", "demo-tool"] + + await tool_registry_service.clear_mcp_cache() + + +async def test_get_all_mcp_tools_uses_redis_manifest_when_local_cache_is_empty(monkeypatch): + await tool_registry_service.clear_mcp_cache() + + fake_redis = _FakeRedis() + + async def fake_redis_factory(): + return fake_redis + + monkeypatch.setattr(tool_registry_service, "_mcp_tool_cache_store", + RedisMcpToolCache(redis_client_factory=fake_redis_factory), + ) + + config = {"transport": "stdio", "command": "demo-tool", "disabled_tools": []} + + async def fake_get_mcp_client(server_configs): + del server_configs + tool = SimpleNamespace( + name="alpha_tool", + description="alpha", + metadata={}, + args_schema=SimpleNamespace( + schema=lambda: { + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + ), + ) + return _FakeClient([tool]) + + async def fake_get_enabled_mcp_server_config(server_name: str, db=None): + del server_name, db + return config + + monkeypatch.setattr(server_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) + + tools_first = await tool_registry_service.get_all_mcp_tools("demo") + assert [tool.name for tool in tools_first] == ["alpha_tool"] + + await tool_registry_service.clear_mcp_cache() + + async def fail_get_mcp_client(server_configs): + raise AssertionError(f"should not fetch live tools when redis manifest is available: {server_configs}") + + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fail_get_mcp_client) + + tools_second = await tool_registry_service.get_all_mcp_tools("demo") + + assert [tool.name for tool in tools_second] == ["alpha_tool"] + assert tools_second[0].metadata["id"] == "mcp__demo__alphaTool" diff --git a/backend/test/unit/services/test_model_provider_service.py b/backend/test/unit/services/test_model_provider_service.py index dbbb8148f..3303aca42 100644 --- a/backend/test/unit/services/test_model_provider_service.py +++ b/backend/test/unit/services/test_model_provider_service.py @@ -97,6 +97,12 @@ async def fake_fetch(client, provider, headers, endpoint, model_type): return [{"id": f"{model_type}-model", "type": model_type}] monkeypatch.setattr("yuxi.services.model_provider_service._fetch_models_from_endpoint", fake_fetch) + + from unittest.mock import AsyncMock, MagicMock + mock_client_instance = MagicMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + monkeypatch.setattr("yuxi.services.model_provider_service.httpx.AsyncClient", lambda **kwargs: mock_client_instance) class Provider: base_url = "https://example.com/v1" diff --git a/backend/test/unit/services/test_remote_skill_install_service.py b/backend/test/unit/services/test_remote_skill_install_service.py index e3d6f65da..bf4b3cb38 100644 --- a/backend/test/unit/services/test_remote_skill_install_service.py +++ b/backend/test/unit/services/test_remote_skill_install_service.py @@ -118,7 +118,7 @@ async def fake_import_skill_dir(_db, *, source_dir, created_by): "-y", "--copy", ] - assert captured["source_dir"] == Path(calls[1][1]) / ".agents" / "skills" / "frontend-design" + assert captured["source_dir"].resolve() == (Path(calls[1][1]) / ".agents" / "skills" / "frontend-design").resolve() assert captured["created_by"] == "root" diff --git a/backend/test/unit/storage/test_postgres_manager_business_schema.py b/backend/test/unit/storage/test_postgres_manager_business_schema.py new file mode 100644 index 000000000..dad29f5f0 --- /dev/null +++ b/backend/test/unit/storage/test_postgres_manager_business_schema.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import os +from types import SimpleNamespace + +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.storage.postgres.manager import PostgresManager + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +class _FakeBeginContext: + def __init__(self, statements: list[str]): + self._statements = statements + + async def __aenter__(self): + async def execute(stmt): + self._statements.append(str(stmt)) + + return SimpleNamespace(execute=execute) + + async def __aexit__(self, exc_type, exc, tb): + return False + + +async def test_ensure_business_schema_includes_mcp_auth_tables_and_columns(): + statements: list[str] = [] + manager = object.__new__(PostgresManager) + PostgresManager.__init__(manager) + manager._initialized = True + manager.async_engine = SimpleNamespace(begin=lambda: _FakeBeginContext(statements)) + + await manager.ensure_business_schema() + + assert any( + "ALTER TABLE IF EXISTS mcp_servers ADD COLUMN IF NOT EXISTS auth_config_json JSONB" in stmt + for stmt in statements + ) + assert any("CREATE TABLE IF NOT EXISTS mcp_connections" in stmt for stmt in statements) diff --git a/backend/test/unit/test_base_context.py b/backend/test/unit/test_base_context.py new file mode 100644 index 000000000..f2892dffa --- /dev/null +++ b/backend/test/unit/test_base_context.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from yuxi.agents.context import BaseContext + + +def test_base_context_accepts_internal_identity_fields_without_exposing_them_as_configurable(): + context = BaseContext() + + context.update({"department_id": "dept-9", "work_id": "login-1001"}) + + assert context.department_id == "dept-9" + assert context.work_id == "login-1001" + configurable_items = BaseContext.get_configurable_items() + assert "department_id" not in configurable_items + assert "work_id" not in configurable_items diff --git a/backend/uv.lock b/backend/uv.lock index 7fdfad461..41305dd5e 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -5650,6 +5650,7 @@ dependencies = [ { name = "argon2-cffi" }, { name = "asyncpg" }, { name = "beautifulsoup4" }, + { name = "cachetools" }, { name = "chardet" }, { name = "colorlog" }, { name = "dashscope" }, @@ -5729,6 +5730,7 @@ requires-dist = [ { name = "argon2-cffi", specifier = ">=25.1.0" }, { name = "asyncpg", specifier = ">=0.30.0" }, { name = "beautifulsoup4", specifier = ">=4.12.0" }, + { name = "cachetools", specifier = ">=5.3.0" }, { name = "chardet", specifier = ">=5.0.0" }, { name = "colorlog", specifier = ">=6.9.0" }, { name = "dashscope", specifier = ">=1.23.2" }, diff --git a/docker-compose.yml b/docker-compose.yml index 910af1211..97aa56ee8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,6 +6,7 @@ x-api-worker-env: &api-worker-env # DBs and other services POSTGRES_URL: ${POSTGRES_URL:-postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@postgres:5432/${POSTGRES_DB:-yuxi_know}} REDIS_URL: ${REDIS_URL:-redis://redis:6379/0} + YUXI_INTERNAL_MCP_PROXY_BASE_URL: ${YUXI_INTERNAL_MCP_PROXY_BASE_URL:-http://api:5050} NEO4J_URI: ${NEO4J_URI:-bolt://graph:7687} NEO4J_USERNAME: ${NEO4J_USERNAME:-neo4j} NEO4J_PASSWORD: ${NEO4J_PASSWORD:-0123456789} @@ -29,6 +30,8 @@ x-api-worker-env: &api-worker-env # Agent run RUN_CANCEL_KEY_TTL_SECONDS: ${RUN_CANCEL_KEY_TTL_SECONDS:-1800} RUN_EVENTS_STREAM_TTL_SECONDS: ${RUN_EVENTS_STREAM_TTL_SECONDS:-7200} + # MCP auth + MCP_CREDENTIALS_MASTER_KEY: ${MCP_CREDENTIALS_MASTER_KEY:-} # 其他环境变量 NO_PROXY: localhost,127.0.0.1,milvus,graph,minio,milvus-etcd-dev,etcd,mineru,paddlex,sandbox-provisioner,api.siliconflow.cn no_proxy: localhost,127.0.0.1,milvus,graph,minio,milvus-etcd-dev,etcd,mineru,paddlex,sandbox-provisioner,api.siliconflow.cn @@ -406,6 +409,21 @@ services: - app-network restart: unless-stopped + mcp-demo-server: + image: yuxi-api:${YUXI_VERSION:-0.6.2} + container_name: mcp-demo-server + working_dir: /app + volumes: + - ./backend/package:/app/package + - ./backend/test:/app/test + - ./.env:/app/.env + ports: + - "8999:8999" + networks: + - app-network + command: uv run --no-sync --no-dev test/mcp_demo_server.py --transport sse --port 8999 + restart: unless-stopped + volumes: nltk_data: diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index c7dc69d73..23e3fe7fd 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -37,6 +37,11 @@ ### 0.6.3 开发记录 - 修复 DeepAgent 未绑定 `DeepContext`,导致深度分析专用系统提示词和子智能体默认模型配置未生效的问题;同时避免运行时重复注入默认提示词。 +- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期 and 重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载 of managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题;统一 Agent 运行态与连接管理页的个人 MCP scope 语义,避免运行态使用数据库主键查找 `mcp_connections.scope_id` 导致个人连接不可用;补齐运行时鉴权 MCP 工具的执行阶段映射,避免模型已绑定 `getTicket` 等动态工具但 ToolNode 静态注册表无法执行的问题;审计并修复该链路隐患:通过 DynamicMCPTokenAuth 引入 15 秒 TTL 在内存缓存(含联动清除机制)解决 httpx 请求对 DB 的高频重复查询问题;修复 `_normalize_token_payload` 处理 naive datetime 的时区偏差问题以消除 token 无限自动刷新的 Bug;改进 `_calculate_config_hash` 哈希计算逻辑,对 json.dumps 增加 default=str 降级保护防止无法序列化而崩溃的问题;优化免密钥连接测试,在 binding_scope 非 inline 且配置未引用 `\${secret.xxx}` 变量时免去 connection 的强检验,允许直接进行测试和工具加载;在 `client_pool` 中实现长连接失效的断线清理机制,防止 anyio.ClosedResourceError 报错固化在缓存中;修复 mock MCP demo server 在 FastAPI 路由下返回值的 ASGI 响应冲突,将其重构为原生 ASGI App 路由,并在 Docker 中容器化部署;前端增加 `${context.work_id}` 快注按键并补齐后端 context.work_id 工号识别支持;修复未配置认证时前端发送空字典 `{}` 导致 Pydantic 400 校验错误的问题。 + - 本次补充:明确 `YUXI_INTERNAL_MCP_PROXY_BASE_URL` 是动态 HTTP MCP 的内部鉴权网关地址;统一 runtime config 与代理入口的 active connection 强制规则,允许未引用 `${secret.xxx}` 的动态 MCP 无绑定连接运行;连接测试补齐 user scope 的 `work_id`,连接池 hash 忽略短期 `X-Yuxi-MCP-Proxy-Token`;补充 MCP 动态鉴权使用说明和开发手册。 + - 本次补充:将个人级 MCP 连接配置收敛到用户设置弹框,普通用户仅可查看脱敏 MCP 信息并维护自己的 `user` scope 连接;管理员仍在扩展管理中维护 MCP 服务、共享连接与工具开关。 + - 本次补充:优化 MCP 连接管理体验,管理页连接区支持健康筛选、绑定对象搜索和分页;连接卡片统一展示生效范围、绑定对象与单一问题主动作,设置页沿用同一卡片语言并在详情头部展示生效范围。 + - 本次补充:为 MCP 工具加载失败增加短期冷却与日志降噪,服务端离线时 Agent 运行态会跳过不可用 MCP,避免每轮运行重复建连并输出大量 error traceback;图构建阶段只预加载当前 agent 配置与已配置 skill 依赖的 MCP,手动刷新或配置变更仍会重新探测。 --- diff --git a/web/src/apis/mcp_api.js b/web/src/apis/mcp_api.js index 66fbdf69a..5e3e92fa1 100644 --- a/web/src/apis/mcp_api.js +++ b/web/src/apis/mcp_api.js @@ -1,4 +1,13 @@ -import { apiGet, apiAdminGet, apiAdminPost, apiAdminPut, apiAdminDelete } from './base' +import { + apiGet, + apiPost, + apiPut, + apiDelete, + apiAdminGet, + apiAdminPost, + apiAdminPut, + apiAdminDelete +} from './base' /** * MCP 服务器管理 API 模块 @@ -25,7 +34,7 @@ export const getMcpServers = async () => { * @returns {Promise} - 服务器配置 */ export const getMcpServer = async (name) => { - return apiAdminGet(`${BASE_URL}/${encodeURIComponent(name)}`) + return apiGet(`${BASE_URL}/${encodeURIComponent(name)}`) } /** @@ -79,6 +88,48 @@ export const updateMcpServerStatus = async (name, enabled) => { return apiAdminPut(`${BASE_URL}/${encodeURIComponent(name)}/status`, { enabled }) } +// ============================================================================= +// === MCP 连接管理 === +// ============================================================================= + +export const getMcpServerConnections = async (name, options = {}) => { + const params = new URLSearchParams() + if (options.mine) params.set('mine', 'true') + if (options.paginated) params.set('paginated', 'true') + if (options.status) params.set('status', options.status) + if (options.search) params.set('search', options.search) + if (options.page) params.set('page', String(options.page)) + if (options.page_size) params.set('page_size', String(options.page_size)) + const query = params.toString() + return apiGet(`${BASE_URL}/${encodeURIComponent(name)}/connections${query ? `?${query}` : ''}`) +} + +export const createMcpServerConnection = async (name, data) => { + return apiPost(`${BASE_URL}/${encodeURIComponent(name)}/connections`, data) +} + +export const updateMcpServerConnection = async (name, connectionId, data) => { + return apiPut(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}`, data) +} + +export const updateMcpConnectionStatus = async (name, connectionId, status) => { + return apiPut(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}/status`, { + status + }) +} + +export const deleteMcpServerConnection = async (name, connectionId) => { + return apiDelete(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}`) +} + +export const testMcpConnection = async (name, connectionId) => { + return apiPost(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}/test`, {}) +} + +export const reauthorizeMcpConnection = async (name, connectionId) => { + return apiPost(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}/reauth`, {}) +} + // ============================================================================= // === MCP 工具管理 === // ============================================================================= @@ -126,6 +177,13 @@ export const mcpApi = { deleteMcpServer, testMcpServer, updateMcpServerStatus, + getMcpServerConnections, + createMcpServerConnection, + updateMcpServerConnection, + updateMcpConnectionStatus, + deleteMcpServerConnection, + testMcpConnection, + reauthorizeMcpConnection, getMcpServerTools, refreshMcpServerTools, toggleMcpServerTool diff --git a/web/src/apis/user_api.js b/web/src/apis/user_api.js new file mode 100644 index 000000000..0af678a51 --- /dev/null +++ b/web/src/apis/user_api.js @@ -0,0 +1,5 @@ +import { apiAdminGet } from './base' + +export const userApi = { + getUsers: () => apiAdminGet('/api/auth/users') +} diff --git a/web/src/components/McpPersonalConnectionsSection.vue b/web/src/components/McpPersonalConnectionsSection.vue new file mode 100644 index 000000000..6ec154dea --- /dev/null +++ b/web/src/components/McpPersonalConnectionsSection.vue @@ -0,0 +1,1185 @@ + + + + + diff --git a/web/src/components/SettingsModal.vue b/web/src/components/SettingsModal.vue index 4f0c10703..b2d34e8f9 100644 --- a/web/src/components/SettingsModal.vue +++ b/web/src/components/SettingsModal.vue @@ -64,6 +64,15 @@ API Key +
+ + MCP 连接 +
@@ -142,6 +151,14 @@ > API Key
+ @@ -166,6 +183,10 @@
+ +
+ +
@@ -178,6 +199,7 @@ import { useUserStore } from '@/stores/user' import { ExternalLink, Key as KeyIcon, + Plug, Settings, SquareCode, Star, @@ -190,11 +212,16 @@ import ModelProvidersComponent from '@/components/ModelProvidersComponent.vue' import UserManagementComponent from '@/components/UserManagementComponent.vue' import DepartmentManagementComponent from '@/components/DepartmentManagementComponent.vue' import ApiKeyManagementComponent from '@/components/ApiKeyManagementComponent.vue' +import McpPersonalConnectionsSection from '@/components/McpPersonalConnectionsSection.vue' const props = defineProps({ visible: { type: Boolean, default: false + }, + initialTab: { + type: String, + default: null } }) @@ -221,6 +248,20 @@ const dismissStarCard = () => { localStorage.setItem(STAR_CARD_STORAGE_KEY, 'true') } +const canOpenTab = (tab) => { + if (tab === 'base' || tab === 'user') return userStore.isAdmin + if (tab === 'model' || tab === 'department') return userStore.isSuperAdmin + if (tab === 'apikey' || tab === 'mcp') return userStore.isLoggedIn + return false +} + +const getDefaultTab = () => { + if (props.initialTab && canOpenTab(props.initialTab)) return props.initialTab + if (userStore.isAdmin) return 'base' + if (userStore.isLoggedIn) return 'mcp' + return 'base' +} + onMounted(() => { showStarCard.value = localStorage.getItem(STAR_CARD_STORAGE_KEY) !== 'true' }) @@ -230,11 +271,7 @@ watch( () => props.visible, (newVal) => { if (newVal) { - if (userStore.isAdmin) { - activeTab.value = 'base' - } else if (userStore.isLogin) { - activeTab.value = 'apikey' - } + activeTab.value = getDefaultTab() } } ) @@ -450,7 +487,8 @@ watch( .model-providers-section, .user-management, .department-management, - .apikey-management { + .apikey-management, + .mcp-personal-settings { min-height: auto; } diff --git a/web/src/components/extensions/McpAuthConfigBuilder.vue b/web/src/components/extensions/McpAuthConfigBuilder.vue new file mode 100644 index 000000000..2fff0b7db --- /dev/null +++ b/web/src/components/extensions/McpAuthConfigBuilder.vue @@ -0,0 +1,1025 @@ + + + + + diff --git a/web/src/components/extensions/McpDetailView.vue b/web/src/components/extensions/McpDetailView.vue index 2de759bef..e823fefdc 100644 --- a/web/src/components/extensions/McpDetailView.vue +++ b/web/src/components/extensions/McpDetailView.vue @@ -228,6 +228,10 @@ + @@ -340,6 +344,18 @@ {{ server.created_by }} +
+ + +
@@ -433,6 +449,448 @@ + + + +
+
+
+

连接管理

+

+ {{ + hasAuthConfig + ? '按全局、部门或用户维护长期凭据,运行时自动换取和刷新 token。' + : '当前 MCP 未配置动态鉴权,通常不需要维护连接。' + }} +

+
+ + + + + 新建连接 + + + + + 刷新 + + +
+ +
+
+ 认证方式 + {{ + providerLabelMap[server.auth_config?.provider] || + server.auth_config?.provider || + '未配置' + }} +
+
+ 默认绑定 + {{ authBindingScopeLabel }} +
+
+ 可用连接 + {{ activeConnectionCount }} +
+
+ 需处理 + {{ attentionConnectionCount }} +
+
+ +
+
+ + + + +
+
+ 共 {{ connectionTotal }} 条 + +
+
+ + +
+ +
+
+ + + + 新建连接 + + + 清除筛选 + +
+
+
+
+
+
+ +
+

{{ getConnectionTitle(connection) }}

+ + {{ connection.external_subject }} + +
+
+ + + + + {{ getConnectionScopeLabel(connection.scope_type) }} + +
+ +
+
+
+ 问题:{{ getConnectionIssue(connection).label }} + {{ getConnectionIssue(connection).description }} +
+ + {{ getConnectionIssue(connection).actionLabel }} + +
+
+ 绑定对象: + {{ + getConnectionScopeTargetLabel(connection) + }} +
+
+ 最近记录: + {{ getConnectionLastInfo(connection) }} +
+
+ + +
+
+
+ +
+
+
+ + + +
+
+ 绑定范围 + 决定运行时为哪些请求使用这组凭据。 +
+
+ +
+ + + + + +
+ +
+
+ 展示信息 + 名称用于列表识别,不参与鉴权计算。 +
+ + + +
+ +
+
+ 凭据 + {{ credentialHint }} +
+
+ + + +
+ + + +
+ + + + + + + + + + + + + + + + +
+
+
+
@@ -461,11 +919,20 @@ import { Save, X, Rows3, - Braces + Braces, + KeyRound, + Globe2, + Building2, + UserRound, + Search } from 'lucide-vue-next' import { mcpApi } from '@/apis/mcp_api' import { formatFullDateTime } from '@/utils/time' +import { extractSecretFieldNames } from '@/utils/mcpAuthConfigBuilder' +import McpAuthConfigBuilder from '@/components/extensions/McpAuthConfigBuilder.vue' import McpEnvEditor from '@/components/McpEnvEditor.vue' +import { departmentApi } from '@/apis/department_api' +import { userApi } from '@/apis/user_api' const route = useRoute() const router = useRouter() @@ -476,11 +943,34 @@ const server = ref(null) const detailTab = ref('general') const testLoading = ref(null) +const userList = ref([]) +const departmentList = ref([]) +const isFetchingScopeOptions = ref(false) + const tools = ref([]) const toolsLoading = ref(false) const toolsError = ref(null) const toolSearchText = ref('') const toggleToolLoading = ref(null) +const connections = ref([]) +const connectionsLoading = ref(false) +const connectionsError = ref(null) +const connectionFilter = ref('all') +const connectionSearchText = ref('') +const connectionPage = ref(1) +const connectionPageSize = ref(12) +const connectionTotal = ref(0) +const connectionSummary = reactive({ + total: 0, + active: 0, + attention: 0, + disabled: 0 +}) +let connectionSearchTimer = null +const showConnectionForm = ref(false) +const connectionSubmitting = ref(false) +const connectionActionLoading = ref(null) +const editingConnectionId = ref(null) const isEditing = ref(false) const editLoading = ref(false) @@ -496,17 +986,121 @@ const editForm = reactive({ args: [], env: null, headersText: '', + authConfigText: '', timeout: null, sse_read_timeout: null, tags: [], icon: '' }) +const connectionForm = reactive({ + scopeType: 'department', + scopeId: '', + displayName: '', + externalSubject: '', + credentialText: '', + secretValues: {}, + metaText: '' +}) + +const connectionScopeOptions = [ + { + value: 'system', + label: '全局共享', + description: '所有用户共用', + icon: Globe2 + }, + { + value: 'department', + label: '部门共享', + description: '按部门隔离', + icon: Building2 + }, + { + value: 'user', + label: '个人专用', + description: '按用户隔离', + icon: UserRound + } +] + +const scopeLabelMap = { + inline: '内联', + system: '全局共享', + department: '部门共享', + user: '个人专用' +} + +const statusLabelMap = { + active: '启用', + disabled: '停用', + reauth_required: '需要重连', + invalid: '无效' +} + +const providerLabelMap = { + none: '不启用', + bound_secret: '绑定长期密钥', + custom_http_token: '接口换 Token', + stdio_env: 'StdIO 环境变量', + client_credentials: 'OAuth2 客户端凭证' +} + +const connectionFilterOptions = [ + { label: '全部', value: 'all' }, + { label: '生效中', value: 'active' }, + { label: '需处理', value: 'attention' }, + { label: '未启用', value: 'disabled' } +] + const actionLabel = computed(() => { - if (server.value?.enabled === false) return '添加' - return server.value?.created_by === 'system' ? '移除' : '删除' + if (server.value?.enabled === false) return '恢复' + return server.value?.created_by === 'system' ? '移除' : '退役' +}) + +const isEditingConnection = computed(() => editingConnectionId.value !== null) + +const hasAuthConfig = computed( + () => !!server.value?.auth_config && Object.keys(server.value.auth_config).length > 0 +) + +const validConnectionScopeTypes = ['system', 'department', 'user'] + +const effectiveConnectionScopeType = computed(() => { + const bindingScope = server.value?.auth_config?.binding_scope + return validConnectionScopeTypes.includes(bindingScope) ? bindingScope : '' +}) + +const availableConnectionScopeOptions = computed(() => { + if (!effectiveConnectionScopeType.value) return connectionScopeOptions + return connectionScopeOptions.filter( + (option) => option.value === effectiveConnectionScopeType.value + ) +}) + +const authBindingScopeLabel = computed(() => { + const bindingScope = server.value?.auth_config?.binding_scope + return scopeLabelMap[bindingScope] || '未限定' +}) + +const connectionTotalCount = computed(() => connectionSummary.total || connectionTotal.value) + +const activeConnectionCount = computed(() => connectionSummary.active || 0) + +const attentionConnectionCount = computed(() => connectionSummary.attention || 0) + +const hasConnectionListFilter = computed( + () => connectionFilter.value !== 'all' || Boolean(connectionSearchText.value.trim()) +) + +const connectionEmptyDescription = computed(() => { + if (hasConnectionListFilter.value) return '没有匹配的连接。' + if (hasAuthConfig.value) return '暂无连接。创建连接后,运行时会按绑定范围自动选择凭据。' + return '当前 MCP 没有启用动态鉴权连接。' }) +const connectionDrawerTitle = computed(() => (isEditingConnection.value ? '编辑连接' : '新建连接')) + const filteredTools = computed(() => { if (!toolSearchText.value) return tools.value const search = toolSearchText.value.toLowerCase() @@ -524,6 +1118,36 @@ const isStdioTransport = computed( .toLowerCase() === 'stdio' ) +const credentialSecretFields = computed(() => + extractSecretFieldNames(server.value?.auth_config || {}) +) + +const connectionCredentialsRequired = computed(() => credentialSecretFields.value.length > 0) + +const showScopeIdField = computed(() => connectionForm.scopeType !== 'system') + +const scopeIdLabel = computed(() => { + if (connectionForm.scopeType === 'department') return '部门' + if (connectionForm.scopeType === 'user') return '用户' + return '范围标识' +}) + +const scopeIdPlaceholder = computed(() => { + if (connectionForm.scopeType === 'department') return '请选择部门' + if (connectionForm.scopeType === 'user') return '请选择用户' + return '留空默认 global' +}) + +const credentialHint = computed(() => { + if (isEditingConnection.value) { + return '为安全起见不回显已有凭据;留空表示保持原值。' + } + if (credentialSecretFields.value.length > 0) { + return '系统已根据认证配置推导出需要录入的密钥字段。' + } + return '当前认证配置没有声明密钥字段,可直接粘贴长期 token。' +}) + const goBack = () => { router.push({ path: '/extensions', query: { tab: 'mcp' } }) } @@ -535,6 +1159,164 @@ const getTransportColor = (transport) => { return colors[transport] || 'blue' } +const createEmptySecretValues = () => + Object.fromEntries(credentialSecretFields.value.map((fieldName) => [fieldName, ''])) + +const setNestedSecretValue = (target, path, value) => { + const segments = String(path || '') + .split('.') + .filter(Boolean) + let current = target + segments.forEach((segment, index) => { + if (index === segments.length - 1) { + current[segment] = value + return + } + current[segment] = current[segment] || {} + current = current[segment] + }) +} + +const getConnectionTitle = (connection) => + connection.display_name || + `${getConnectionScopeLabel(connection.scope_type)} ${getConnectionScopeTargetLabel(connection)}` + +const getConnectionScopeLabel = (scopeType) => scopeLabelMap[scopeType] || scopeType || '未知范围' + +const getConnectionStatusLabel = (status) => statusLabelMap[status] || status || '未知状态' + +const isConnectionScopeMatched = (connection) => + !effectiveConnectionScopeType.value || + connection?.scope_type === effectiveConnectionScopeType.value + +const isConnectionCredentialMissing = (connection) => + connectionCredentialsRequired.value && !connection?.has_credentials + +const getConnectionScopeTargetLabel = (connection) => { + const scopeId = String(connection?.scope_id || '') + if (connection?.scope_type === 'system') { + return '全部用户' + } + if (connection?.scope_type === 'department') { + const department = departmentList.value.find((item) => String(item.id) === scopeId) + return department ? `${department.name} (#${department.id})` : `部门 #${scopeId || '-'}` + } + if (connection?.scope_type === 'user') { + const user = userList.value.find( + (item) => String(item.id) === scopeId || String(item.user_id) === scopeId + ) + if (!user) return `用户 #${scopeId || '-'}` + return user.username === user.user_id ? user.username : `${user.username} (${user.user_id})` + } + return scopeId || '未指定' +} + +const canToggleConnectionStatus = (connection) => { + if (!['active', 'disabled'].includes(connection?.status)) return false + if (connection.status === 'disabled') { + return isConnectionScopeMatched(connection) && !isConnectionCredentialMissing(connection) + } + return true +} + +const getConnectionStatusSwitchLabel = (connection) => { + if (connection?.status === 'active') return '已启用' + return getConnectionStatusLabel(connection?.status) +} + +const getConnectionStatusToggleTooltip = (connection) => { + if (connection?.status === 'active') return '停用连接' + if (connection?.status === 'disabled') { + if (!isConnectionScopeMatched(connection)) { + return `该连接未生效,不能启用;当前 MCP 使用${authBindingScopeLabel.value}` + } + if (isConnectionCredentialMissing(connection)) { + return '请先补充凭据' + } + return '启用连接' + } + return '请先重连或编辑凭据' +} + +const canTestConnection = (connection) => + isConnectionScopeMatched(connection) && !isConnectionCredentialMissing(connection) + +const getConnectionTestTooltip = (connection) => { + if (canTestConnection(connection)) return '测试连接' + if (isConnectionCredentialMissing(connection)) return '请先补充凭据' + return `该连接未生效,当前 MCP 使用${authBindingScopeLabel.value}` +} + +const canReauthorizeConnection = (connection) => + isConnectionScopeMatched(connection) && !isConnectionCredentialMissing(connection) + +const getConnectionReauthorizeTooltip = (connection) => { + if (canReauthorizeConnection(connection)) return '重置授权并重新激活' + if (isConnectionCredentialMissing(connection)) return '请先补充凭据' + return `该连接未生效,不能重连;当前 MCP 使用${authBindingScopeLabel.value}` +} + +const getConnectionIssue = (connection) => { + if (!isConnectionScopeMatched(connection)) { + return { + key: 'scope_mismatch', + label: '范围不匹配', + description: `当前 MCP 使用${authBindingScopeLabel.value},这组连接不会在运行时生效。`, + actionLabel: '新建匹配连接', + tone: 'warning' + } + } + if (isConnectionCredentialMissing(connection)) { + return { + key: 'missing_credentials', + label: '缺少凭据', + description: '缺少长期凭据,运行时无法换取或注入 token。', + actionLabel: '补充凭据', + tone: 'error' + } + } + if (connection?.status === 'reauth_required') { + return { + key: 'reauth_required', + label: '授权失效', + description: '缓存 token 已失效,需要重新授权后才能继续使用。', + actionLabel: '重连', + tone: 'warning' + } + } + if (connection?.status === 'invalid' || connection?.meta_json?.last_error?.message) { + return { + key: 'test_failed', + label: '测试失败', + description: connection?.meta_json?.last_error?.message || '最近一次连接检测失败。', + actionLabel: '编辑凭据', + tone: 'error' + } + } + return null +} + +const getConnectionLastInfo = (connection) => { + if (connection.meta_json?.last_success_at) { + return `最近成功 ${formatTime(connection.meta_json.last_success_at)}` + } + if (connection.updated_at) { + return `更新于 ${formatTime(connection.updated_at)}` + } + return '暂无记录' +} + +const getSecretFieldLabel = (fieldName) => { + const labelMap = { + client_id: 'Client ID', + client_secret: 'Client Secret', + access_token: 'Access Token', + refresh_token: 'Refresh Token', + issuer_url: 'Issuer URL' + } + return labelMap[fieldName] || fieldName +} + const resetEditForm = (data) => { Object.assign(editForm, { name: data?.name || '', @@ -545,6 +1327,7 @@ const resetEditForm = (data) => { args: data?.args || [], env: data?.env || null, headersText: data?.headers ? JSON.stringify(data.headers, null, 2) : '', + authConfigText: data?.auth_config ? JSON.stringify(data.auth_config, null, 2) : '', timeout: data?.timeout, sse_read_timeout: data?.sse_read_timeout, tags: data?.tags || [], @@ -586,6 +1369,20 @@ const parseJsonToForm = () => { } } +const parseJsonText = (text, label, { allowRawString = false } = {}) => { + const trimmed = String(text || '').trim() + if (!trimmed) return null + try { + return JSON.parse(trimmed) + } catch { + if (allowRawString) { + return trimmed + } + message.error(`${label} JSON 格式错误`) + return undefined + } +} + const buildEditPayload = () => { if (formMode.value === 'json') { try { @@ -606,6 +1403,11 @@ const buildEditPayload = () => { } } + const authConfig = parseJsonText(editForm.authConfigText, '认证配置') + if (authConfig === undefined) { + return null + } + return { name: editForm.name, description: editForm.description || null, @@ -615,6 +1417,7 @@ const buildEditPayload = () => { args: editForm.args.length > 0 ? editForm.args : null, env: editForm.env, headers, + auth_config: authConfig, timeout: editForm.timeout || null, sse_read_timeout: editForm.sse_read_timeout || null, tags: editForm.tags.length > 0 ? editForm.tags : null, @@ -622,6 +1425,46 @@ const buildEditPayload = () => { } } +const getDefaultConnectionScopeType = () => + availableConnectionScopeOptions.value[0]?.value || 'department' + +const resetConnectionForm = () => { + editingConnectionId.value = null + Object.assign(connectionForm, { + scopeType: getDefaultConnectionScopeType(), + scopeId: '', + displayName: '', + externalSubject: '', + credentialText: '', + secretValues: createEmptySecretValues(), + metaText: '' + }) +} + +const openCreateConnectionDrawer = () => { + resetConnectionForm() + showConnectionForm.value = true +} + +const closeConnectionForm = () => { + showConnectionForm.value = false + resetConnectionForm() +} + +const startEditConnection = (connection) => { + editingConnectionId.value = connection.id + showConnectionForm.value = true + Object.assign(connectionForm, { + scopeType: connection.scope_type || 'department', + scopeId: connection.scope_id || '', + displayName: connection.display_name || '', + externalSubject: connection.external_subject || '', + credentialText: '', + secretValues: createEmptySecretValues(), + metaText: connection.meta_json ? JSON.stringify(connection.meta_json, null, 2) : '' + }) +} + const validateEditPayload = (data) => { if (!data.name?.trim()) { message.error('MCP 名称不能为空') @@ -700,6 +1543,68 @@ const fetchTools = async () => { } } +const fetchConnections = async () => { + if (!server.value) return + try { + connectionsLoading.value = true + connectionsError.value = null + const result = await mcpApi.getMcpServerConnections(server.value.name, { + paginated: true, + status: connectionFilter.value, + search: connectionSearchText.value.trim(), + page: connectionPage.value, + page_size: connectionPageSize.value + }) + if (result.success) { + if (Array.isArray(result.data)) { + connections.value = result.data + connectionTotal.value = result.data.length + Object.assign(connectionSummary, { + total: result.data.length, + active: result.data.filter( + (connection) => + connection.status === 'active' && + isConnectionScopeMatched(connection) && + !isConnectionCredentialMissing(connection) + ).length, + attention: result.data.filter((connection) => Boolean(getConnectionIssue(connection))) + .length, + disabled: result.data.filter((connection) => connection.status === 'disabled').length + }) + return + } + const pageData = result.data || {} + const nextConnections = pageData.items || [] + const nextTotal = pageData.total || 0 + const nextPageSize = pageData.page_size || connectionPageSize.value + if (connectionPage.value > 1 && nextConnections.length === 0 && nextTotal > 0) { + connectionPage.value = Math.ceil(nextTotal / nextPageSize) + await fetchConnections() + return + } + connections.value = nextConnections + connectionTotal.value = nextTotal + connectionPageSize.value = nextPageSize + Object.assign(connectionSummary, { + total: pageData.summary?.total || 0, + active: pageData.summary?.active || 0, + attention: pageData.summary?.attention || 0, + disabled: pageData.summary?.disabled || 0 + }) + } else { + connectionsError.value = result.message || '获取连接列表失败' + connections.value = [] + connectionTotal.value = 0 + } + } catch (err) { + connectionsError.value = err.message || '获取连接列表失败' + connections.value = [] + connectionTotal.value = 0 + } finally { + connectionsLoading.value = false + } +} + const handleToggleTool = async (tool) => { if (!server.value) return try { @@ -745,22 +1650,226 @@ const handleTestServer = async () => { } } -const handleDangerAction = async () => { - if (!server.value) return - if (server.value.enabled === false) { - await handleSetServerEnabled(server.value, true) - return - } - if (server.value.created_by === 'system') { - await handleSetServerEnabled(server.value, false) - return +const buildConnectionCredential = () => { + const rawCredential = parseJsonText(connectionForm.credentialText, '长期凭据', { + allowRawString: true + }) + if (rawCredential === undefined) return undefined + if (rawCredential !== null) return rawCredential + + const secrets = {} + Object.entries(connectionForm.secretValues).forEach(([key, value]) => { + const trimmedValue = String(value || '').trim() + if (trimmedValue) { + setNestedSecretValue(secrets, key, trimmedValue) + } + }) + + if (Object.keys(secrets).length === 0) { + return null } - confirmDeleteServer(server.value) + return { secrets } } -const handleSetServerEnabled = async (srv, enabled) => { - try { - const result = await mcpApi.updateMcpServerStatus(srv.name, enabled) +const validateConnectionCredential = () => { + if (isEditingConnection.value || credentialSecretFields.value.length === 0) { + return true + } + + const missingFields = credentialSecretFields.value.filter( + (fieldName) => !String(connectionForm.secretValues[fieldName] || '').trim() + ) + if (missingFields.length === 0 || connectionForm.credentialText.trim()) { + return true + } + + message.error(`请填写凭据字段:${missingFields.join('、')}`) + return false +} + +const handleSubmitConnection = async () => { + if (!server.value) return + + const scopeId = connectionForm.scopeType === 'system' ? 'global' : connectionForm.scopeId.trim() + if (!scopeId) { + message.error(`${scopeIdLabel.value}不能为空`) + return + } + if (!validateConnectionCredential()) return + + const metaJson = parseJsonText(connectionForm.metaText, '连接元数据') + if (metaJson === undefined) return + const credential = buildConnectionCredential() + if (credential === undefined) return + + try { + connectionSubmitting.value = true + const payload = { + display_name: connectionForm.displayName || null, + external_subject: connectionForm.externalSubject || null, + meta_json: metaJson + } + if (credential !== null) { + payload.credential = credential + } + + const result = isEditingConnection.value + ? await mcpApi.updateMcpServerConnection( + server.value.name, + editingConnectionId.value, + payload + ) + : await mcpApi.createMcpServerConnection(server.value.name, { + scope_type: connectionForm.scopeType, + scope_id: scopeId, + status: 'active', + ...payload + }) + if (result.success) { + message.success(isEditingConnection.value ? '连接更新成功' : '连接创建成功') + showConnectionForm.value = false + resetConnectionForm() + await fetchConnections() + } else { + message.error(result.message || (isEditingConnection.value ? '连接更新失败' : '连接创建失败')) + } + } catch (err) { + message.error(err.message || (isEditingConnection.value ? '连接更新失败' : '连接创建失败')) + } finally { + connectionSubmitting.value = false + } +} + +const handleToggleConnectionStatus = async (connection, checked) => { + if (!server.value || !canToggleConnectionStatus(connection)) return + const nextStatus = checked ? 'active' : 'disabled' + if (connection.status === nextStatus) return + const loadingKey = `${connection.id}:status` + try { + connectionActionLoading.value = loadingKey + const result = await mcpApi.updateMcpConnectionStatus( + server.value.name, + connection.id, + nextStatus + ) + if (result.success) { + message.success(result.message || (checked ? '连接已启用' : '连接已停用')) + await fetchConnections() + } else { + message.error(result.message || '状态更新失败') + await fetchConnections() + } + } catch (err) { + message.error(err.message || '状态更新失败') + await fetchConnections() + } finally { + connectionActionLoading.value = null + } +} + +const handleTestConnection = async (connection) => { + if (!server.value || !canTestConnection(connection)) return + const loadingKey = `${connection.id}:test` + try { + connectionActionLoading.value = loadingKey + const result = await mcpApi.testMcpConnection(server.value.name, connection.id) + if (result.success) { + message.success(result.message || '连接测试成功') + await fetchConnections() + } else { + message.error(result.message || '连接测试失败') + } + } catch (err) { + message.error(err.message || '连接测试失败') + } finally { + connectionActionLoading.value = null + } +} + +const handleConnectionIssueAction = (connection) => { + const issue = getConnectionIssue(connection) + if (!issue) return + if (issue.key === 'scope_mismatch') { + openCreateConnectionDrawer() + return + } + if (issue.key === 'missing_credentials' || issue.key === 'test_failed') { + startEditConnection(connection) + return + } + if (issue.key === 'reauth_required') { + handleReauthorizeConnection(connection) + } +} + +const resetConnectionFilters = () => { + connectionFilter.value = 'all' + connectionSearchText.value = '' +} + +const handleReauthorizeConnection = async (connection) => { + if (!server.value || !canReauthorizeConnection(connection)) return + const loadingKey = `${connection.id}:reauth` + try { + connectionActionLoading.value = loadingKey + const result = await mcpApi.reauthorizeMcpConnection(server.value.name, connection.id) + if (result.success) { + message.success(result.message || '连接已重置') + await fetchConnections() + } else { + message.error(result.message || '连接重置失败') + } + } catch (err) { + message.error(err.message || '连接重置失败') + } finally { + connectionActionLoading.value = null + } +} + +const handleDeleteConnection = (connection) => { + if (!server.value) return + Modal.confirm({ + title: '确认删除连接', + content: `确定要删除连接 "${getConnectionTitle(connection)}" 吗?`, + okText: '删除', + okType: 'danger', + cancelText: '取消', + async onOk() { + try { + const result = await mcpApi.deleteMcpServerConnection(server.value.name, connection.id) + if (result.success) { + message.success(result.message || '连接已删除') + if (editingConnectionId.value === connection.id) { + showConnectionForm.value = false + resetConnectionForm() + } + await fetchConnections() + } else { + message.error(result.message || '连接删除失败') + } + } catch (err) { + message.error(err.message || '连接删除失败') + } + } + }) +} + +const handleDangerAction = async () => { + if (!server.value) return + if (server.value.enabled === false) { + await handleSetServerEnabled(server.value, true) + return + } + if (server.value.created_by === 'system') { + await handleSetServerEnabled(server.value, false) + return + } + confirmDeleteServer(server.value) +} + +const handleSetServerEnabled = async (srv, enabled) => { + try { + const result = await mcpApi.updateMcpServerStatus(srv.name, enabled) if (result.success) { message.success(result.message || `MCP 已${enabled ? '添加' : '移除'}`) await fetchServer() @@ -774,17 +1883,17 @@ const handleSetServerEnabled = async (srv, enabled) => { const confirmDeleteServer = (srv) => { Modal.confirm({ - title: '确认删除 MCP', - content: `确定要删除 MCP "${srv.name}" 吗?此操作不可撤销。`, - okText: '删除', - okType: 'danger', + title: '确认退役 MCP', + content: `确定要退役 MCP "${srv.name}" 吗?退役后不会再被新运行加载,但配置和连接会保留。`, + okText: '退役', + okType: 'primary', cancelText: '取消', async onOk() { try { const result = await mcpApi.deleteMcpServer(srv.name) if (result.success) { - message.success('MCP 删除成功') - router.push({ path: '/extensions', query: { tab: 'mcp' } }) + message.success(result.message || 'MCP 已退役') + await fetchServer() } else { message.error(result.message || '删除失败') } @@ -799,10 +1908,58 @@ watch(detailTab, (tab) => { if (tab === 'tools' && server.value) { fetchTools() } + if (tab === 'connections' && server.value) { + fetchConnections() + } +}) + +watch(connectionFilter, () => { + if (detailTab.value !== 'connections') return + if (connectionPage.value === 1) { + fetchConnections() + } else { + connectionPage.value = 1 + } }) +watch(connectionSearchText, () => { + if (connectionSearchTimer) { + clearTimeout(connectionSearchTimer) + } + connectionSearchTimer = setTimeout(() => { + if (detailTab.value !== 'connections') return + if (connectionPage.value === 1) { + fetchConnections() + } else { + connectionPage.value = 1 + } + }, 300) +}) + +watch([connectionPage, connectionPageSize], () => { + if (detailTab.value !== 'connections') return + fetchConnections() +}) + +const loadScopeOptions = async () => { + try { + isFetchingScopeOptions.value = true + const [usersRes, deptsRes] = await Promise.all([ + userApi.getUsers(), + departmentApi.getDepartments() + ]) + userList.value = usersRes || [] + departmentList.value = deptsRes || [] + } catch (err) { + message.error('获取用户/部门列表失败: ' + err.message) + } finally { + isFetchingScopeOptions.value = false + } +} + onMounted(() => { fetchServer() + loadScopeOptions() }) @@ -1131,6 +2288,398 @@ onMounted(() => { } .mcp-detail { + .connections-tab { + display: flex; + flex-direction: column; + gap: 14px; + } + + .connection-command-bar { + display: flex; + justify-content: space-between; + gap: 16px; + align-items: center; + padding: 16px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-0); + } + + .connection-command-copy { + min-width: 0; + + h3 { + margin: 0 0 4px; + color: var(--gray-900); + font-size: 16px; + font-weight: 600; + } + + p { + margin: 0; + color: var(--gray-500); + font-size: 13px; + line-height: 1.5; + } + } + + .connection-summary-strip { + display: grid; + grid-template-columns: repeat(4, minmax(0, 1fr)); + border: 1px solid var(--gray-150); + border-radius: 8px; + overflow: hidden; + background: var(--gray-0); + } + + .connection-summary-item { + display: flex; + flex-direction: column; + gap: 4px; + padding: 12px 14px; + + & + .connection-summary-item { + border-left: 1px solid var(--gray-100); + } + + .summary-label { + color: var(--gray-500); + font-size: 12px; + } + + strong { + min-width: 0; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + color: var(--gray-900); + font-size: 15px; + font-weight: 600; + } + } + + .connection-list-toolbar { + position: sticky; + top: 0; + z-index: 5; + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; + padding: 12px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-0); + } + + .connection-filter-group, + .connection-page-controls { + min-width: 0; + display: flex; + align-items: center; + gap: 12px; + } + + .connection-filter-group { + flex: 1; + } + + .connection-page-controls { + flex-shrink: 0; + justify-content: flex-end; + } + + .connection-result-count { + color: var(--gray-500); + font-size: 12px; + white-space: nowrap; + } + + .connection-search-input { + width: min(320px, 100%); + + :deep(.ant-input-prefix) { + color: var(--gray-500); + } + } + + .connection-empty-state { + display: flex; + flex-direction: column; + align-items: center; + gap: 12px; + padding: 42px 16px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-0); + } + + .connection-list-body { + display: flex; + flex-direction: column; + gap: 12px; + } + + .connection-cards-grid { + display: grid; + grid-template-columns: repeat(auto-fill, minmax(340px, 1fr)); + gap: 12px; + } + + .connection-card { + padding: 12px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-0); + transition: + border-color 0.2s, + box-shadow 0.2s; + + &:hover { + border-color: var(--gray-300); + } + } + + .connection-card-header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; + margin-bottom: 10px; + } + + .connection-key-info { + min-width: 0; + flex: 1; + display: flex; + align-items: center; + gap: 10px; + } + + .connection-key-icon { + color: var(--main-600); + flex-shrink: 0; + } + + .connection-key-copy { + min-width: 0; + + h4 { + margin: 0; + overflow: hidden; + color: var(--gray-900); + font-size: 14px; + font-weight: 600; + line-height: 1.4; + text-overflow: ellipsis; + white-space: nowrap; + } + + span { + display: block; + margin-top: 2px; + overflow: hidden; + color: var(--gray-500); + font-size: 12px; + line-height: 1.4; + text-overflow: ellipsis; + white-space: nowrap; + } + } + + .connection-scope-badge { + display: inline-flex; + align-items: center; + justify-content: center; + height: 28px; + flex-shrink: 0; + gap: 5px; + padding: 0 10px; + border: 1px solid var(--gray-150); + border-radius: 7px; + background: var(--gray-25); + color: var(--gray-700); + font-size: 12px; + font-weight: 500; + line-height: 1; + white-space: nowrap; + + &.scope-system { + border-color: var(--color-success-100); + background: var(--color-success-50); + color: var(--color-success-700); + } + + &.scope-department { + border-color: var(--color-accent-100); + background: var(--color-accent-50); + color: var(--color-accent-700); + } + + &.scope-user { + border-color: var(--color-info-100); + background: var(--color-info-50); + color: var(--color-info-700); + } + + &.is-mismatch { + border-color: var(--color-warning-100); + background: var(--color-warning-50); + color: var(--color-warning-900); + } + } + + .connection-card-content { + margin-bottom: 10px; + } + + .connection-issue { + display: flex; + align-items: center; + justify-content: space-between; + gap: 10px; + margin-bottom: 10px; + padding: 9px 10px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-25); + + &.issue-warning { + border-color: var(--color-warning-100); + background: var(--color-warning-10); + + .issue-copy span { + color: var(--color-warning-900); + } + } + + &.issue-error { + border-color: var(--color-error-100); + background: var(--color-error-10); + + .issue-copy span { + color: var(--color-error-700); + } + } + } + + .issue-copy { + min-width: 0; + display: flex; + flex-direction: column; + gap: 2px; + + span { + font-size: 13px; + font-weight: 600; + line-height: 1.4; + } + + small { + overflow: hidden; + color: var(--gray-600); + font-size: 12px; + line-height: 1.4; + text-overflow: ellipsis; + white-space: nowrap; + } + } + + .issue-action { + flex-shrink: 0; + padding: 0; + font-size: 12px; + font-weight: 500; + } + + .connection-info-item { + display: flex; + align-items: flex-start; + gap: 6px; + margin-bottom: 6px; + color: var(--gray-900); + font-size: 13px; + + &:last-child { + margin-bottom: 0; + } + + .info-label { + color: var(--gray-600); + flex-shrink: 0; + } + + .info-value { + min-width: 0; + color: var(--gray-900); + word-break: break-all; + } + } + + .connection-card-footer { + display: flex; + align-items: center; + justify-content: space-between; + gap: 10px; + padding-top: 8px; + border-top: 1px solid var(--gray-100); + } + + .footer-left { + display: flex; + align-items: center; + gap: 8px; + flex-shrink: 0; + } + + .switch-label { + color: var(--gray-600); + font-size: 12px; + } + + .status-switch-wrap { + display: inline-flex; + align-items: center; + } + + .connection-row-actions { + display: flex; + flex-wrap: wrap; + justify-content: flex-end; + gap: 4px; + } + + .connection-action-wrap { + display: inline-flex; + } + + .connection-action-btn { + display: inline-flex; + align-items: center; + gap: 4px; + color: var(--gray-700); + font-size: 12px; + + &:hover { + color: var(--main-600); + } + } + + .danger-action-btn { + color: var(--color-error-700); + + &:hover { + background: var(--color-error-50); + color: var(--color-error-900); + } + } + + .connection-pagination { + display: flex; + justify-content: flex-end; + padding-top: 2px; + } + .detail-content-wrapper { flex: 1; min-height: 0; @@ -1139,9 +2688,216 @@ onMounted(() => { } .detail-content-inner { - max-width: 900px; + max-width: 1120px; margin: 0 auto; padding: 16px var(--page-padding); } } + +.connection-drawer-form { + display: flex; + min-height: 100%; + flex-direction: column; + padding: 18px 20px 0; + + :deep(.ant-form-item) { + margin-bottom: 0; + } + + :deep(.ant-form-item-label > label) { + color: var(--gray-700); + font-size: 13px; + font-weight: 500; + } +} + +.drawer-section { + display: flex; + flex-direction: column; + gap: 14px; + padding-bottom: 18px; + + & + .drawer-section { + padding-top: 18px; + border-top: 1px solid var(--gray-100); + } +} + +.drawer-section-title { + display: flex; + flex-direction: column; + gap: 3px; + + span { + color: var(--gray-900); + font-size: 14px; + font-weight: 600; + } + + small { + color: var(--gray-500); + font-size: 12px; + line-height: 1.5; + } +} + +.scope-option-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(136px, 1fr)); + gap: 8px; +} + +.scope-option { + display: flex; + min-height: 78px; + flex-direction: column; + align-items: flex-start; + justify-content: center; + gap: 4px; + padding: 10px 12px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-0); + color: var(--gray-700); + cursor: pointer; + text-align: left; + transition: + border-color 0.15s ease, + background-color 0.15s ease, + color 0.15s ease; + + span { + color: var(--gray-900); + font-size: 13px; + font-weight: 600; + } + + small { + color: var(--gray-500); + font-size: 12px; + } + + &:hover:not(:disabled) { + border-color: var(--main-300); + background: var(--main-10); + color: var(--main-color); + } + + &.active { + border-color: var(--main-color); + background: var(--main-30); + color: var(--main-color); + } + + &:disabled { + cursor: not-allowed; + opacity: 0.7; + } +} + +.secret-field-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + gap: 12px; +} + +.connection-advanced-collapse { + margin: 0 -4px 12px; + + :deep(.ant-collapse-header) { + padding: 10px 4px; + color: var(--gray-600); + font-size: 13px; + } + + :deep(.ant-collapse-content-box) { + display: flex; + flex-direction: column; + gap: 14px; + padding: 4px 4px 12px; + } +} + +.connection-drawer-footer { + position: sticky; + bottom: 0; + display: flex; + justify-content: flex-end; + gap: 8px; + margin: auto -20px 0; + padding: 14px 20px; + border-top: 1px solid var(--gray-100); + background: var(--gray-0); +} + +@media (max-width: 980px) { + .mcp-detail { + .connection-list-toolbar { + align-items: flex-start; + flex-direction: column; + } + + .connection-filter-group, + .connection-page-controls { + width: 100%; + justify-content: space-between; + } + + .connection-filter-group { + align-items: flex-start; + flex-direction: column; + } + + .connection-search-input { + width: 100%; + } + + .connection-summary-strip { + grid-template-columns: repeat(2, minmax(0, 1fr)); + } + + .connection-summary-item:nth-child(3) { + border-left: 0; + border-top: 1px solid var(--gray-100); + } + + .connection-summary-item:nth-child(4) { + border-top: 1px solid var(--gray-100); + } + + .connection-row-actions { + justify-content: flex-start; + } + + .connection-card-footer { + align-items: flex-start; + flex-direction: column; + } + } +} + +@media (max-width: 640px) { + .mcp-detail { + .connection-command-bar { + align-items: flex-start; + flex-direction: column; + } + + .connection-issue { + align-items: flex-start; + flex-direction: column; + } + + .connection-pagination { + justify-content: flex-start; + } + + .connection-summary-strip, + .connection-cards-grid, + .scope-option-grid, + .secret-field-grid { + grid-template-columns: 1fr; + } + } +} diff --git a/web/src/components/extensions/McpFormModal.vue b/web/src/components/extensions/McpFormModal.vue index 5206170de..4f96eaf7a 100644 --- a/web/src/components/extensions/McpFormModal.vue +++ b/web/src/components/extensions/McpFormModal.vue @@ -6,7 +6,7 @@ :confirmLoading="formLoading" @cancel="visible = false" :maskClosable="false" - width="560px" + width="min(780px, calc(100vw - 32px))" class="server-modal" >
@@ -93,6 +93,7 @@ + { args: obj.args || [], env: obj.env || null, headersText: obj.headers ? JSON.stringify(obj.headers, null, 2) : '', + authConfigText: obj.auth_config ? JSON.stringify(obj.auth_config, null, 2) : '', timeout: obj.timeout || null, sse_read_timeout: obj.sse_read_timeout || null, tags: obj.tags || [], @@ -259,6 +267,16 @@ const handleFormSubmit = async () => { return } } + + let authConfig = null + if (form.authConfigText.trim()) { + try { + authConfig = JSON.parse(form.authConfigText) + } catch { + message.error('认证配置 JSON 格式错误') + return + } + } data = { name: form.name, description: form.description || null, @@ -268,6 +286,7 @@ const handleFormSubmit = async () => { args: form.args.length > 0 ? form.args : null, env: form.env, headers, + auth_config: authConfig, timeout: form.timeout || null, sse_read_timeout: form.sse_read_timeout || null, tags: form.tags.length > 0 ? form.tags : null, diff --git a/web/src/layouts/AppLayout.vue b/web/src/layouts/AppLayout.vue index ff00aeb48..3a19393d7 100644 --- a/web/src/layouts/AppLayout.vue +++ b/web/src/layouts/AppLayout.vue @@ -55,14 +55,21 @@ const showDebugModal = ref(false) // Add state for settings modal const showSettingsModal = ref(false) +const settingsInitialTab = ref(null) const { sidebarCollapsed } = storeToRefs(chatUIStore) // Provide settings modal methods to child components -const openSettingsModal = () => { +const openSettingsModal = (initialTab = null) => { + settingsInitialTab.value = initialTab showSettingsModal.value = true } +const closeSettingsModal = () => { + showSettingsModal.value = false + settingsInitialTab.value = null +} + // Handle debug modal close const handleDebugModalClose = () => { showDebugModal.value = false @@ -387,7 +394,11 @@ provide('settingsModal', { - +
diff --git a/web/src/utils/__tests__/mcpAuthConfigBuilder.test.js b/web/src/utils/__tests__/mcpAuthConfigBuilder.test.js new file mode 100644 index 000000000..f1f83df9e --- /dev/null +++ b/web/src/utils/__tests__/mcpAuthConfigBuilder.test.js @@ -0,0 +1,84 @@ +import assert from 'node:assert/strict' + +import { + authConfigToBuilderForm, + buildAuthConfigFromBuilderForm, + createDefaultAuthBuilderForm, + extractSecretFieldNames +} from '../mcpAuthConfigBuilder.js' + +const run = () => { + { + const form = createDefaultAuthBuilderForm() + assert.equal(buildAuthConfigFromBuilderForm(form), null) + } + + { + const form = createDefaultAuthBuilderForm('custom_http_token') + form.bindingScope = 'department' + form.injectEntries = [ + { name: 'Authorization', value_template: 'Bearer ${access_token}' }, + { name: 'X-Yuxi-User', value_template: '${context.user_id}' } + ] + form.tokenUrl = 'http://internal-gateway/token' + form.tokenHeaders = [{ key: 'Content-Type', value: 'application/json' }] + form.tokenBodyTemplate = [ + { key: 'client_id', value: '${secret.client_id}' }, + { key: 'client_secret', value: '${secret.client_secret}' }, + { key: 'user_id', value: '${context.user_id}' } + ] + form.tokenResponseMap = [ + { key: 'access_token', value: 'data.access_token' }, + { key: 'expires_in', value: 'data.expires_in' } + ] + + const config = buildAuthConfigFromBuilderForm(form) + assert.equal(config.provider, 'custom_http_token') + assert.equal(config.binding_scope, 'department') + assert.equal(config.manifest_scope, 'binding') + assert.deepEqual(config.inject.entries, form.injectEntries) + assert.deepEqual(config.token_request, { + url: 'http://internal-gateway/token', + method: 'POST', + body_type: 'json', + headers: { 'Content-Type': 'application/json' }, + body_template: { + client_id: '${secret.client_id}', + client_secret: '${secret.client_secret}', + user_id: '${context.user_id}' + }, + response_map: { + access_token: 'data.access_token', + expires_in: 'data.expires_in' + } + }) + assert.deepEqual(extractSecretFieldNames(config), ['client_id', 'client_secret']) + } + + { + const config = { + version: 1, + provider: 'bound_secret', + binding_scope: 'user', + manifest_scope: 'binding', + inject: { + target: 'headers', + entries: [{ name: 'X-Api-Key', value_template: '${secret.api_key}' }] + }, + refresh_policy: { + pre_refresh_seconds: 120, + retry_once_on_401: false + } + } + + const form = authConfigToBuilderForm(config) + assert.equal(form.provider, 'bound_secret') + assert.equal(form.bindingScope, 'user') + assert.deepEqual(form.injectEntries, [{ name: 'X-Api-Key', value_template: '${secret.api_key}' }]) + assert.deepEqual(buildAuthConfigFromBuilderForm(form), config) + } + + console.log('mcpAuthConfigBuilder: all assertions passed') +} + +run() diff --git a/web/src/utils/mcpAuthConfigBuilder.js b/web/src/utils/mcpAuthConfigBuilder.js new file mode 100644 index 000000000..eeafc252e --- /dev/null +++ b/web/src/utils/mcpAuthConfigBuilder.js @@ -0,0 +1,206 @@ +export const FORM_AUTH_PROVIDERS = new Set([ + 'bound_secret', + 'custom_http_token', + 'client_credentials', + 'stdio_env' +]) + +const TOKEN_PROVIDERS = new Set(['custom_http_token', 'client_credentials']) + +const DEFAULT_RESPONSE_MAP = [ + { key: 'access_token', value: 'data.access_token' }, + { key: 'refresh_token', value: 'data.refresh_token' }, + { key: 'expires_in', value: 'data.expires_in' } +] + +const DEFAULT_JSON_HEADERS = [{ key: 'Content-Type', value: 'application/json' }] + +const DEFAULT_FORM_HEADERS = [ + { key: 'Content-Type', value: 'application/x-www-form-urlencoded' } +] + +const DEFAULT_GATEWAY_BODY = [ + { key: 'client_id', value: '${secret.client_id}' }, + { key: 'client_secret', value: '${secret.client_secret}' }, + { key: 'user_id', value: '${context.user_id}' }, + { key: 'department_id', value: '${context.department_id}' } +] + +const DEFAULT_CLIENT_CREDENTIALS_BODY = [ + { key: 'grant_type', value: 'client_credentials' }, + { key: 'client_id', value: '${secret.client_id}' }, + { key: 'client_secret', value: '${secret.client_secret}' } +] + +const normalizeText = (value) => String(value ?? '').trim() + +export const objectToKeyValueRows = (value, fallbackRows = [{ key: '', value: '' }]) => { + if (!value || typeof value !== 'object' || Array.isArray(value)) { + return fallbackRows.map((row) => ({ ...row })) + } + + const rows = Object.entries(value).map(([key, rowValue]) => ({ + key, + value: rowValue == null ? '' : String(rowValue) + })) + return rows.length > 0 ? rows : fallbackRows.map((row) => ({ ...row })) +} + +export const keyValueRowsToObject = (rows) => { + const entries = (Array.isArray(rows) ? rows : []) + .map((row) => ({ + key: normalizeText(row?.key), + value: row?.value == null ? '' : String(row.value) + })) + .filter((row) => row.key) + + if (entries.length === 0) { + return {} + } + return Object.fromEntries(entries.map((row) => [row.key, row.value])) +} + +const createDefaultInjectEntries = (provider) => { + if (provider === 'stdio_env') { + return [{ name: 'MCP_ACCESS_TOKEN', value_template: '${secret.access_token}' }] + } + if (provider === 'bound_secret') { + return [{ name: 'Authorization', value_template: 'Bearer ${secret.access_token}' }] + } + return [{ name: 'Authorization', value_template: 'Bearer ${access_token}' }] +} + +export const createDefaultAuthBuilderForm = (provider = 'none') => { + const normalizedProvider = provider === 'none' ? 'none' : provider + const isClientCredentials = normalizedProvider === 'client_credentials' + const isEnvProvider = normalizedProvider === 'stdio_env' + + return { + provider: normalizedProvider, + bindingScope: 'department', + manifestScope: 'binding', + injectTarget: isEnvProvider ? 'env' : 'headers', + injectEntries: + normalizedProvider === 'none' ? [] : createDefaultInjectEntries(normalizedProvider), + preRefreshSeconds: TOKEN_PROVIDERS.has(normalizedProvider) ? 300 : 0, + retryOnceOn401: TOKEN_PROVIDERS.has(normalizedProvider), + tokenUrl: '', + tokenMethod: 'POST', + tokenBodyType: isClientCredentials ? 'form' : 'json', + tokenHeaders: isClientCredentials ? DEFAULT_FORM_HEADERS.map((row) => ({ ...row })) : DEFAULT_JSON_HEADERS.map((row) => ({ ...row })), + tokenBodyTemplate: isClientCredentials + ? DEFAULT_CLIENT_CREDENTIALS_BODY.map((row) => ({ ...row })) + : DEFAULT_GATEWAY_BODY.map((row) => ({ ...row })), + tokenResponseMap: DEFAULT_RESPONSE_MAP.map((row) => ({ ...row })) + } +} + +export const isAuthConfigSupportedByBuilder = (config) => { + if (!config || Object.keys(config).length === 0) { + return true + } + return FORM_AUTH_PROVIDERS.has(config.provider) +} + +export const authConfigToBuilderForm = (config) => { + if (!config || Object.keys(config).length === 0) { + return createDefaultAuthBuilderForm() + } + + const provider = FORM_AUTH_PROVIDERS.has(config.provider) ? config.provider : 'custom_http_token' + const form = createDefaultAuthBuilderForm(provider) + const tokenRequest = config.token_request || {} + + form.bindingScope = config.binding_scope || form.bindingScope + form.manifestScope = config.manifest_scope || form.manifestScope + form.injectTarget = config.inject?.target || form.injectTarget + form.injectEntries = + Array.isArray(config.inject?.entries) && config.inject.entries.length > 0 + ? config.inject.entries.map((entry) => ({ + name: entry.name || '', + value_template: entry.value_template || '' + })) + : form.injectEntries + form.preRefreshSeconds = + Number.isFinite(Number(config.refresh_policy?.pre_refresh_seconds)) + ? Number(config.refresh_policy.pre_refresh_seconds) + : form.preRefreshSeconds + form.retryOnceOn401 = + typeof config.refresh_policy?.retry_once_on_401 === 'boolean' + ? config.refresh_policy.retry_once_on_401 + : form.retryOnceOn401 + form.tokenUrl = tokenRequest.url || '' + form.tokenMethod = tokenRequest.method || form.tokenMethod + form.tokenBodyType = tokenRequest.body_type || form.tokenBodyType + form.tokenHeaders = objectToKeyValueRows(tokenRequest.headers, form.tokenHeaders) + form.tokenBodyTemplate = objectToKeyValueRows(tokenRequest.body_template, form.tokenBodyTemplate) + form.tokenResponseMap = objectToKeyValueRows(tokenRequest.response_map, form.tokenResponseMap) + + return form +} + +const normalizeInjectEntries = (entries) => + (Array.isArray(entries) ? entries : []) + .map((entry) => ({ + name: normalizeText(entry?.name), + value_template: String(entry?.value_template ?? '').trim() + })) + .filter((entry) => entry.name) + +export const buildAuthConfigFromBuilderForm = (form) => { + if (!form || form.provider === 'none') { + return null + } + + const provider = FORM_AUTH_PROVIDERS.has(form.provider) ? form.provider : 'custom_http_token' + const config = { + version: 1, + provider, + binding_scope: form.bindingScope || 'department', + manifest_scope: form.manifestScope || 'binding', + inject: { + target: form.injectTarget || 'headers', + entries: normalizeInjectEntries(form.injectEntries) + }, + refresh_policy: { + pre_refresh_seconds: Number(form.preRefreshSeconds) || 0, + retry_once_on_401: Boolean(form.retryOnceOn401) + } + } + + if (TOKEN_PROVIDERS.has(provider)) { + config.token_request = { + url: normalizeText(form.tokenUrl), + method: normalizeText(form.tokenMethod || 'POST').toUpperCase(), + body_type: form.tokenBodyType || 'json', + headers: keyValueRowsToObject(form.tokenHeaders), + body_template: keyValueRowsToObject(form.tokenBodyTemplate), + response_map: keyValueRowsToObject(form.tokenResponseMap) + } + } + + return config +} + +export const extractSecretFieldNames = (value, fields = new Set()) => { + if (typeof value === 'string') { + const pattern = /\$\{secret\.([^}]+)\}/g + let match = pattern.exec(value) + while (match) { + fields.add(match[1]) + match = pattern.exec(value) + } + return [...fields] + } + + if (Array.isArray(value)) { + value.forEach((item) => extractSecretFieldNames(item, fields)) + return [...fields] + } + + if (value && typeof value === 'object') { + Object.values(value).forEach((item) => extractSecretFieldNames(item, fields)) + } + + return [...fields] +}