From c949f66a144f951a4dfd0120cd7b92773648655e Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Tue, 2 Jun 2026 22:35:33 +0800 Subject: [PATCH 01/36] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81MCP=E5=8A=A8?= =?UTF-8?q?=E6=80=81=E9=89=B4=E6=9D=83=E8=BF=9E=E6=8E=A5=E4=B8=8E=E5=8F=AF?= =?UTF-8?q?=E8=A7=86=E5=8C=96=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.template | 3 +- .gitignore | 5 +- backend/package/yuxi/agents/context.py | 5 + .../middlewares/runtime_config_middleware.py | 9 +- .../agents/middlewares/skills_middleware.py | 9 +- backend/package/yuxi/services/chat_service.py | 39 +- .../yuxi/services/mcp_auth/__init__.py | 14 + .../yuxi/services/mcp_auth/config_models.py | 52 + .../package/yuxi/services/mcp_auth/crypto.py | 94 ++ .../yuxi/services/mcp_auth/orchestrator.py | 533 ++++++++ .../yuxi/services/mcp_auth/proxy_service.py | 211 +++ .../services/mcp_auth/redis_token_cache.py | 78 ++ .../services/mcp_auth/template_resolver.py | 116 ++ backend/package/yuxi/services/mcp_service.py | 459 ++++++- .../package/yuxi/storage/postgres/manager.py | 25 + .../yuxi/storage/postgres/models_business.py | 57 + backend/server/routers/__init__.py | 2 + backend/server/routers/mcp_internal_router.py | 102 ++ backend/server/routers/mcp_router.py | 283 +++- backend/test/e2e/test_mcp_admin_flow_e2e.py | 199 +++ .../test/integration/api/test_mcp_router.py | 332 +++++ .../test_runtime_config_middleware.py | 36 + .../middlewares/test_skills_middleware.py | 33 + .../unit/routers/test_mcp_internal_router.py | 84 ++ backend/test/unit/routers/test_mcp_router.py | 466 +++++++ .../unit/services/test_chat_service_sync.py | 11 +- .../services/test_mcp_auth_config_models.py | 43 + .../unit/services/test_mcp_auth_crypto.py | 39 + .../unit/services/test_mcp_auth_models.py | 100 ++ .../services/test_mcp_auth_orchestrator.py | 668 ++++++++++ .../services/test_mcp_auth_proxy_service.py | 281 ++++ .../test_mcp_auth_template_resolver.py | 54 + .../services/test_mcp_connection_service.py | 607 +++++++++ .../services/test_mcp_service_auth_runtime.py | 179 +++ .../test_postgres_manager_business_schema.py | 43 + backend/test/unit/test_base_context.py | 12 + docker-compose.yml | 3 + docs/develop-guides/roadmap.md | 1 + web/src/apis/mcp_api.js | 41 + .../extensions/McpAuthConfigBuilder.vue | 952 ++++++++++++++ .../components/extensions/McpDetailView.vue | 1141 ++++++++++++++++- .../components/extensions/McpFormModal.vue | 21 +- .../__tests__/mcpAuthConfigBuilder.test.js | 84 ++ web/src/utils/mcpAuthConfigBuilder.js | 206 +++ 44 files changed, 7701 insertions(+), 31 deletions(-) create mode 100644 backend/package/yuxi/services/mcp_auth/__init__.py create mode 100644 backend/package/yuxi/services/mcp_auth/config_models.py create mode 100644 backend/package/yuxi/services/mcp_auth/crypto.py create mode 100644 backend/package/yuxi/services/mcp_auth/orchestrator.py create mode 100644 backend/package/yuxi/services/mcp_auth/proxy_service.py create mode 100644 backend/package/yuxi/services/mcp_auth/redis_token_cache.py create mode 100644 backend/package/yuxi/services/mcp_auth/template_resolver.py create mode 100644 backend/server/routers/mcp_internal_router.py create mode 100644 backend/test/e2e/test_mcp_admin_flow_e2e.py create mode 100644 backend/test/integration/api/test_mcp_router.py create mode 100644 backend/test/unit/middlewares/test_runtime_config_middleware.py create mode 100644 backend/test/unit/middlewares/test_skills_middleware.py create mode 100644 backend/test/unit/routers/test_mcp_internal_router.py create mode 100644 backend/test/unit/services/test_mcp_auth_config_models.py create mode 100644 backend/test/unit/services/test_mcp_auth_crypto.py create mode 100644 backend/test/unit/services/test_mcp_auth_models.py create mode 100644 backend/test/unit/services/test_mcp_auth_orchestrator.py create mode 100644 backend/test/unit/services/test_mcp_auth_proxy_service.py create mode 100644 backend/test/unit/services/test_mcp_auth_template_resolver.py create mode 100644 backend/test/unit/services/test_mcp_connection_service.py create mode 100644 backend/test/unit/services/test_mcp_service_auth_runtime.py create mode 100644 backend/test/unit/storage/test_postgres_manager_business_schema.py create mode 100644 backend/test/unit/test_base_context.py create mode 100644 web/src/components/extensions/McpAuthConfigBuilder.vue create mode 100644 web/src/utils/__tests__/mcpAuthConfigBuilder.test.js create mode 100644 web/src/utils/mcpAuthConfigBuilder.js diff --git a/.env.template b/.env.template index e56b640ea..58f889295 100644 --- a/.env.template +++ b/.env.template @@ -29,6 +29,7 @@ YUXI_INSTANCE_ID= # # Servies # YUXI_SUPER_ADMIN_NAME= # YUXI_SUPER_ADMIN_PASSWORD= +# MCP_CREDENTIALS_MASTER_KEY= # # URL Whitelist (comma-separated domains/IPs, empty to disable URL parsing) # YUXI_URL_WHITELIST=github.com,docs.example.com,gitlab.example.com,127.0.0.1 @@ -73,4 +74,4 @@ YUXI_INSTANCE_ID= # SANDBOX_NODE_HOST=host.docker.internal # KUBECONFIG_PATH=/root/.kube/config # THREAD_PVC=yuxi-thread -# SKILLS_PVC=yuxi-skills # 当前代码会读取,但 Pod 挂载实际仍只使用 THREAD_PVC \ No newline at end of file +# SKILLS_PVC=yuxi-skills # 当前代码会读取,但 Pod 挂载实际仍只使用 THREAD_PVC diff --git a/.gitignore b/.gitignore index ec4c35cac..aa5dc6012 100644 --- a/.gitignore +++ b/.gitignore @@ -79,4 +79,7 @@ docs/vibe /models -.taskr/ \ No newline at end of file +.taskr/ + +.workbuddy +.worktrees/ diff --git a/backend/package/yuxi/agents/context.py b/backend/package/yuxi/agents/context.py index 851780b69..a77bb59ef 100644 --- a/backend/package/yuxi/agents/context.py +++ b/backend/package/yuxi/agents/context.py @@ -33,6 +33,11 @@ def update(self, data: dict): metadata={"name": "用户ID", "configurable": False, "description": "用来唯一标识一个用户"}, ) + department_id: str | None = field( + default=None, + metadata={"name": "部门ID", "configurable": False, "description": "用来标识当前用户所属部门"}, + ) + system_prompt: Annotated[str, {"__template_metadata__": {"kind": "prompt"}}] = field( default="You are a helpful assistant.", metadata={"name": "系统提示词", "description": "用来描述智能体的角色和行为"}, diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index a1b6e98f2..d891225e3 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -8,6 +8,7 @@ from yuxi.agents import load_chat_model from yuxi.agents.toolkits import get_all_tool_instances +from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.mcp_service import get_enabled_mcp_tools from yuxi.utils.datetime_utils import shanghai_now from yuxi.utils.logging_config import logger @@ -151,7 +152,13 @@ async def get_tools_from_context(self, context) -> list: continue selected_mcp_servers.add(server_name) try: - mcp_tools = await get_enabled_mcp_tools(server_name) + mcp_tools = await get_enabled_mcp_tools( + server_name, + auth_context=AuthContext( + user_id=getattr(context, "user_id", None), + department_id=getattr(context, "department_id", None), + ), + ) if not mcp_tools: logger.warning(f"RuntimeConfigMiddleware: mcp dependency unavailable, skip: {server_name}") selected_tools.extend(mcp_tools) diff --git a/backend/package/yuxi/agents/middlewares/skills_middleware.py b/backend/package/yuxi/agents/middlewares/skills_middleware.py index 21a1d7543..269581e86 100644 --- a/backend/package/yuxi/agents/middlewares/skills_middleware.py +++ b/backend/package/yuxi/agents/middlewares/skills_middleware.py @@ -15,6 +15,7 @@ from yuxi.agents.toolkits import get_all_tool_instances from yuxi.repositories.skill_repository import SkillRepository +from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.mcp_service import get_enabled_mcp_tools from yuxi.services.skill_service import _normalize_string_list, is_valid_skill_slug from yuxi.storage.postgres.manager import pg_manager @@ -342,7 +343,13 @@ async def _get_mcp_tools_from_context( async def load_mcp_tools(server_name: str) -> list: """加载单个 MCP 服务器的工具""" try: - mcp_tools = await get_enabled_mcp_tools(server_name) + mcp_tools = await get_enabled_mcp_tools( + server_name, + auth_context=AuthContext( + user_id=getattr(context, "user_id", None), + department_id=getattr(context, "department_id", None), + ), + ) if not mcp_tools: logger.warning(f"SkillsMiddleware: mcp dependency unavailable, skip: {server_name}") return mcp_tools diff --git a/backend/package/yuxi/services/chat_service.py b/backend/package/yuxi/services/chat_service.py index 3430a3180..cfbf8c897 100644 --- a/backend/package/yuxi/services/chat_service.py +++ b/backend/package/yuxi/services/chat_service.py @@ -56,7 +56,13 @@ def _load_workspace_agents_prompt(thread_id: str, user_id: str) -> str: return prompt -async def _build_agent_input_context(agent_config: dict, *, thread_id: str, user_id: str) -> dict: +async def _build_agent_input_context( + agent_config: dict, + *, + thread_id: str, + user_id: str, + department_id: str | int | None = None, +) -> dict: input_context = dict(agent_config or {}) agents_prompt = await asyncio.to_thread(_load_workspace_agents_prompt, thread_id, user_id) @@ -65,7 +71,13 @@ async def _build_agent_input_context(agent_config: dict, *, thread_id: str, user base_prompt = str(input_context.get("system_prompt") or "").rstrip() input_context["system_prompt"] = f"{base_prompt}\n\n{agents_section}" if base_prompt else agents_section - input_context.update({"user_id": user_id, "thread_id": thread_id}) + input_context.update( + { + "user_id": user_id, + "thread_id": thread_id, + "department_id": str(department_id) if department_id is not None else None, + } + ) return input_context @@ -598,7 +610,12 @@ async def agent_chat( thread_id = str(uuid.uuid4()) logger.warning(f"No thread_id provided, generated new thread_id: {thread_id}") - input_context = await _build_agent_input_context(agent_config, thread_id=thread_id, user_id=user_id) + input_context = await _build_agent_input_context( + agent_config, + thread_id=thread_id, + user_id=user_id, + department_id=getattr(current_user, "department_id", None), + ) langfuse_run = _build_langfuse_run_context( current_user=current_user, thread_id=thread_id, @@ -814,7 +831,12 @@ def make_chunk(content=None, **kwargs): thread_id = str(uuid.uuid4()) logger.warning(f"No thread_id provided, generated new thread_id: {thread_id}") - input_context = await _build_agent_input_context(agent_config, thread_id=thread_id, user_id=user_id) + input_context = await _build_agent_input_context( + agent_config, + thread_id=thread_id, + user_id=user_id, + department_id=getattr(current_user, "department_id", None), + ) langfuse_run = _build_langfuse_run_context( current_user=current_user, thread_id=thread_id, @@ -1049,7 +1071,14 @@ def make_resume_chunk(content=None, **kwargs): return context = agent.context_schema() - context.update(await _build_agent_input_context(agent_config or {}, thread_id=thread_id, user_id=user_id)) + context.update( + await _build_agent_input_context( + agent_config or {}, + thread_id=thread_id, + user_id=user_id, + department_id=getattr(current_user, "department_id", None), + ) + ) graph = await agent.get_graph(context=context) langfuse_run = _build_langfuse_run_context( current_user=current_user, diff --git a/backend/package/yuxi/services/mcp_auth/__init__.py b/backend/package/yuxi/services/mcp_auth/__init__.py new file mode 100644 index 000000000..ef382794e --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/__init__.py @@ -0,0 +1,14 @@ +"""MCP auth helpers.""" + +from .config_models import MCPAuthConfig +from .crypto import decrypt_credential_blob, encrypt_credential_blob, is_encrypted_credential_blob +from .template_resolver import TemplateResolutionError, resolve_template_value + +__all__ = [ + "MCPAuthConfig", + "TemplateResolutionError", + "decrypt_credential_blob", + "encrypt_credential_blob", + "is_encrypted_credential_blob", + "resolve_template_value", +] diff --git a/backend/package/yuxi/services/mcp_auth/config_models.py b/backend/package/yuxi/services/mcp_auth/config_models.py new file mode 100644 index 000000000..eccb991f8 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/config_models.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class InjectEntry(BaseModel): + name: str + value_template: str + + +class InjectConfig(BaseModel): + target: Literal["headers", "env"] + entries: list[InjectEntry] = Field(default_factory=list) + + +class RefreshPolicy(BaseModel): + pre_refresh_seconds: int = 0 + retry_once_on_401: bool = False + + +class MCPAuthConfig(BaseModel): + model_config = ConfigDict(extra="allow") + + version: int = 1 + provider: Literal[ + "legacy_static", + "bound_secret", + "client_credentials", + "custom_http_token", + "authorization_code", + "stdio_env", + ] + binding_scope: Literal["inline", "system", "department", "user"] | None = None + manifest_scope: Literal["server", "binding"] | None = None + inject: InjectConfig + refresh_policy: RefreshPolicy = Field(default_factory=RefreshPolicy) + token_request: dict[str, Any] | None = None + + @model_validator(mode="after") + def apply_defaults_and_validate(self) -> MCPAuthConfig: + if self.binding_scope is None: + self.binding_scope = "inline" if self.provider == "legacy_static" else "system" + if self.manifest_scope is None: + self.manifest_scope = "server" + if ( + self.provider in {"custom_http_token", "client_credentials", "authorization_code"} + and not self.token_request + ): + raise ValueError("token_request is required for dynamic auth providers") + return self diff --git a/backend/package/yuxi/services/mcp_auth/crypto.py b/backend/package/yuxi/services/mcp_auth/crypto.py new file mode 100644 index 000000000..f4b5ccac4 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/crypto.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import base64 +import hashlib +import json +import os +from typing import Any + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +MASTER_KEY_ENV = "MCP_CREDENTIALS_MASTER_KEY" +ENVELOPE_VERSION = 1 +ENVELOPE_KEY_ID = "local" +_AAD = b"yuxi:mcp_credentials:v1" + + +def _get_master_key() -> str: + value = os.getenv(MASTER_KEY_ENV, "").strip() + if not value: + raise ValueError(f"{MASTER_KEY_ENV} is required when storing encrypted MCP credentials") + return value + + +def _derive_aes_key(master_key: str) -> bytes: + return hashlib.sha256(master_key.encode("utf-8")).digest() + + +def _b64encode(value: bytes) -> str: + return base64.urlsafe_b64encode(value).decode("ascii") + + +def _b64decode(value: str) -> bytes: + return base64.urlsafe_b64decode(value.encode("ascii")) + + +def _parse_envelope(blob: str) -> dict[str, Any] | None: + try: + payload = json.loads(blob) + except (TypeError, json.JSONDecodeError): + return None + if not isinstance(payload, dict): + return None + required_keys = {"v", "kid", "nonce", "ciphertext"} + if not required_keys.issubset(payload.keys()): + return None + if payload.get("v") != ENVELOPE_VERSION: + return None + return payload + + +def is_encrypted_credential_blob(blob: str | None) -> bool: + if not blob or not isinstance(blob, str): + return False + return _parse_envelope(blob) is not None + + +def encrypt_credential_blob(plaintext: str) -> str: + if not plaintext: + return plaintext + if is_encrypted_credential_blob(plaintext): + return plaintext + + master_key = _get_master_key() + aesgcm = AESGCM(_derive_aes_key(master_key)) + nonce = os.urandom(12) + ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), _AAD) + return json.dumps( + { + "v": ENVELOPE_VERSION, + "kid": ENVELOPE_KEY_ID, + "nonce": _b64encode(nonce), + "ciphertext": _b64encode(ciphertext), + }, + ensure_ascii=True, + separators=(",", ":"), + ) + + +def decrypt_credential_blob(blob: str | None) -> str | None: + if blob is None or not isinstance(blob, str): + return blob + + payload = _parse_envelope(blob) + if payload is None: + return blob + + master_key = _get_master_key() + aesgcm = AESGCM(_derive_aes_key(master_key)) + plaintext = aesgcm.decrypt( + _b64decode(payload["nonce"]), + _b64decode(payload["ciphertext"]), + _AAD, + ) + return plaintext.decode("utf-8") diff --git a/backend/package/yuxi/services/mcp_auth/orchestrator.py b/backend/package/yuxi/services/mcp_auth/orchestrator.py new file mode 100644 index 000000000..18133fc7e --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/orchestrator.py @@ -0,0 +1,533 @@ +from __future__ import annotations +import asyncio +import json +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import Any + +import httpx + +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.crypto import decrypt_credential_blob +from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache +from yuxi.services.mcp_auth.template_resolver import resolve_template_value +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer +from yuxi.utils import logger + + +@dataclass(slots=True) +class AuthContext: + user_id: str | None = None + department_id: str | None = None + + +_DEFAULT_TOKEN_RESPONSE_MAP = { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + "expires_at": "expires_at", + "scope": "scope", + "token_type": "token_type", +} +_REFRESH_LOCK_WAIT_SECONDS = 1.0 +_REFRESH_LOCK_POLL_INTERVAL_SECONDS = 0.05 + + +def _parse_credential_blob(connection: MCPConnection | None) -> dict[str, Any]: + if connection is None or not connection.credential_blob: + return {} + if isinstance(connection.credential_blob, dict): + return dict(connection.credential_blob) + decrypted = decrypt_credential_blob(connection.credential_blob) + if not decrypted: + return {} + try: + return json.loads(decrypted) + except json.JSONDecodeError: + return { + "access_token": decrypted, + "secrets": {"access_token": decrypted}, + } + + +def _extract_path(payload: dict[str, Any], path: str) -> Any: + current: Any = payload + for segment in path.split("."): + if isinstance(current, dict): + current = current[segment] + continue + raise KeyError(path) + return current + + +def _base_server_config(server: MCPServer) -> dict[str, Any]: + config = server.to_mcp_config() + config.pop("auth_config", None) + return config + + +def _context_payload(context: AuthContext) -> dict[str, Any]: + return { + "user_id": context.user_id, + "department_id": context.department_id, + } + + +def _parse_datetime(value: str | None) -> datetime | None: + if not value or not isinstance(value, str): + return None + try: + parsed = datetime.fromisoformat(value) + except ValueError: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=UTC) + return parsed.astimezone(UTC) + + +def _normalize_token_payload(token_values: dict[str, Any]) -> dict[str, Any]: + normalized = dict(token_values) + expires_at = normalized.get("expires_at") + if isinstance(expires_at, datetime): + normalized["expires_at"] = expires_at.astimezone(UTC).isoformat() + return normalized + if isinstance(expires_at, str): + parsed = _parse_datetime(expires_at) + if parsed is not None: + normalized["expires_at"] = parsed.isoformat() + return normalized + expires_in = normalized.get("expires_in") + if isinstance(expires_in, str) and expires_in.isdigit(): + expires_in = int(expires_in) + normalized["expires_in"] = expires_in + if isinstance(expires_in, (int, float)): + normalized["expires_at"] = (datetime.now(tz=UTC) + timedelta(seconds=int(expires_in))).isoformat() + return normalized + + +def _is_token_expiring_soon(token_values: dict[str, Any], *, pre_refresh_seconds: int) -> bool: + expires_at = _parse_datetime(token_values.get("expires_at")) + if expires_at is None: + return False + return expires_at <= datetime.now(tz=UTC) + timedelta(seconds=max(pre_refresh_seconds, 0)) + + +def _merge_injected_entries( + config: dict[str, Any], + *, + inject_target: str, + inject_entries: list[dict[str, str]], + context: AuthContext, + secret_values: dict[str, Any], + token_values: dict[str, Any], + access_token: str | None, +) -> dict[str, Any]: + target_values = dict(config.get(inject_target) or {}) + for entry in inject_entries: + target_values[entry["name"]] = resolve_template_value( + entry["value_template"], + context={ + "user_id": context.user_id, + "department_id": context.department_id, + }, + secret=secret_values, + token=token_values, + access_token=access_token, + ) + config[inject_target] = target_values + return config + + +async def _fetch_custom_http_token( + request_config: dict[str, Any], + *, + response_map: dict[str, str] | None, + context: AuthContext, + secret_values: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, +) -> dict[str, Any]: + response_map = response_map or dict(_DEFAULT_TOKEN_RESPONSE_MAP) + if http_client is None: + http_client = httpx.AsyncClient() + should_close = True + else: + should_close = False + + try: + headers = resolve_template_value( + request_config.get("headers") or {}, + context=_context_payload(context), + secret=secret_values, + token=token_values, + access_token=token_values.get("access_token"), + ) + body = resolve_template_value( + request_config.get("body_template") or {}, + context=_context_payload(context), + secret=secret_values, + token=token_values, + access_token=token_values.get("access_token"), + ) + body_type = request_config.get("body_type", "json") + request_kwargs: dict[str, Any] = { + "method": (request_config.get("method") or "POST").upper(), + "url": request_config["url"], + "headers": headers, + } + if body_type == "json": + request_kwargs["json"] = body + else: + request_kwargs["data"] = body + + response = await http_client.request(**request_kwargs) + response.raise_for_status() + payload = response.json() + resolved = {} + for field_name, path in response_map.items(): + try: + resolved[field_name] = _extract_path(payload, path) + except KeyError: + continue + return _normalize_token_payload(resolved) + finally: + if should_close: + await http_client.aclose() + + +async def _resolve_authorization_code_token_request( + *, + token_request: dict[str, Any], + secret_values: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, +) -> tuple[dict[str, Any], dict[str, str]]: + if http_client is None: + http_client = httpx.AsyncClient() + should_close = True + else: + should_close = False + + try: + issuer_url = ( + token_request.get("issuer_url") + or secret_values.get("issuer_url") + or token_values.get("issuer_url") + ) + if not issuer_url: + raise ValueError("authorization_code provider requires token_request.issuer_url") + discovery_url = f"{str(issuer_url).rstrip('/')}/.well-known/openid-configuration" + response = await http_client.get(discovery_url) + response.raise_for_status() + payload = response.json() + token_endpoint = payload.get("token_endpoint") + if not token_endpoint: + raise ValueError("authorization_code provider discovery missing token_endpoint") + return { + "url": token_endpoint, + "method": "POST", + "body_type": "form", + "headers": { + "Content-Type": "application/x-www-form-urlencoded", + }, + "body_template": { + "grant_type": "refresh_token", + "refresh_token": "${token.refresh_token}", + "client_id": token_request.get("client_id", "${secret.client_id}"), + "client_secret": token_request.get("client_secret", "${secret.client_secret}"), + }, + }, dict(_DEFAULT_TOKEN_RESPONSE_MAP) + finally: + if should_close: + await http_client.aclose() + + +async def _load_cached_token( + *, + token_cache: Any | None, + connection_id: int | None, +) -> dict[str, Any] | None: + if token_cache is None or connection_id is None: + return None + try: + cached = await token_cache.get_access_token(connection_id) + except Exception as exc: + logger.warning(f"Failed to load MCP access token cache for connection {connection_id}: {exc}") + return None + if not cached: + return None + return _normalize_token_payload(cached) + + +async def _store_cached_token( + *, + token_cache: Any | None, + connection_id: int | None, + token_payload: dict[str, Any], +) -> None: + if token_cache is None or connection_id is None: + return + try: + await token_cache.set_access_token(connection_id, token_payload) + except Exception as exc: + logger.warning(f"Failed to persist MCP access token cache for connection {connection_id}: {exc}") + + +async def _acquire_refresh_lock( + *, + token_cache: Any | None, + connection_id: int | None, +) -> bool: + if token_cache is None or connection_id is None: + return True + acquire_method = getattr(token_cache, "acquire_refresh_lock", None) + if acquire_method is None: + return True + try: + return bool(await acquire_method(connection_id)) + except Exception as exc: + logger.warning(f"Failed to acquire MCP refresh lock for connection {connection_id}: {exc}") + return True + + +async def _release_refresh_lock( + *, + token_cache: Any | None, + connection_id: int | None, + acquired: bool, +) -> None: + if not acquired or token_cache is None or connection_id is None: + return + release_method = getattr(token_cache, "release_refresh_lock", None) + if release_method is None: + return + try: + await release_method(connection_id) + except Exception as exc: + logger.warning(f"Failed to release MCP refresh lock for connection {connection_id}: {exc}") + + +async def _wait_for_refreshed_token( + *, + token_cache: Any | None, + connection_id: int | None, + pre_refresh_seconds: int, +) -> dict[str, Any] | None: + if token_cache is None or connection_id is None: + return None + + remaining = _REFRESH_LOCK_WAIT_SECONDS + while remaining > 0: + await asyncio.sleep(_REFRESH_LOCK_POLL_INTERVAL_SECONDS) + cached_token = await _load_cached_token(token_cache=token_cache, connection_id=connection_id) + if cached_token and not _is_token_expiring_soon( + cached_token, + pre_refresh_seconds=pre_refresh_seconds, + ): + return cached_token + remaining -= _REFRESH_LOCK_POLL_INTERVAL_SECONDS + return None + + +async def _request_dynamic_token_values( + auth_config: MCPAuthConfig, + *, + context: AuthContext, + connection: MCPConnection | None, + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + http_client: httpx.AsyncClient | None, + token_cache: Any | None, + token_values: dict[str, Any], +) -> dict[str, Any]: + token_request = auth_config.token_request or {} + refresh_request = token_request.get("refresh") + if ( + token_values + and refresh_request + and (token_values.get("refresh_token") or credential_payload.get("refresh_token")) + ): + refresh_token_values = dict(token_values) + if not refresh_token_values.get("refresh_token") and credential_payload.get("refresh_token"): + refresh_token_values["refresh_token"] = credential_payload["refresh_token"] + refreshed = await _fetch_custom_http_token( + refresh_request, + response_map=(refresh_request.get("response_map") or token_request.get("response_map")), + context=context, + secret_values=secret_values, + token_values=refresh_token_values, + http_client=http_client, + ) + if not refreshed.get("refresh_token") and refresh_token_values.get("refresh_token"): + refreshed["refresh_token"] = refresh_token_values["refresh_token"] + await _store_cached_token( + token_cache=token_cache, + connection_id=getattr(connection, "id", None), + token_payload=refreshed, + ) + return refreshed + + if auth_config.provider == "authorization_code": + authorization_request, response_map = await _resolve_authorization_code_token_request( + token_request=token_request, + secret_values=secret_values, + token_values=token_values or credential_payload, + http_client=http_client, + ) + authorization_token_values = dict(token_values or credential_payload) + if not authorization_token_values.get("refresh_token") and credential_payload.get("refresh_token"): + authorization_token_values["refresh_token"] = credential_payload["refresh_token"] + resolved = await _fetch_custom_http_token( + authorization_request, + response_map=response_map, + context=context, + secret_values=secret_values, + token_values=authorization_token_values, + http_client=http_client, + ) + if not resolved.get("refresh_token") and authorization_token_values.get("refresh_token"): + resolved["refresh_token"] = authorization_token_values["refresh_token"] + await _store_cached_token( + token_cache=token_cache, + connection_id=getattr(connection, "id", None), + token_payload=resolved, + ) + return resolved + + resolved = await _fetch_custom_http_token( + token_request, + response_map=token_request.get("response_map"), + context=context, + secret_values=secret_values, + token_values=token_values, + http_client=http_client, + ) + if not resolved.get("refresh_token") and credential_payload.get("refresh_token"): + resolved["refresh_token"] = credential_payload["refresh_token"] + await _store_cached_token( + token_cache=token_cache, + connection_id=getattr(connection, "id", None), + token_payload=resolved, + ) + return resolved + + +async def _resolve_dynamic_token_values( + auth_config: MCPAuthConfig, + *, + context: AuthContext, + connection: MCPConnection | None, + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + http_client: httpx.AsyncClient | None, + token_cache: Any | None, +) -> dict[str, Any]: + if token_cache is None and connection is not None: + token_cache = RedisTokenCache() + + cached_token = await _load_cached_token( + token_cache=token_cache, + connection_id=getattr(connection, "id", None), + ) + pre_refresh_seconds = auth_config.refresh_policy.pre_refresh_seconds + if cached_token and not _is_token_expiring_soon(cached_token, pre_refresh_seconds=pre_refresh_seconds): + return cached_token + + token_values = dict(cached_token or {}) + if not token_values: + token_values.update( + { + key: value + for key, value in credential_payload.items() + if key in {"access_token", "refresh_token", "expires_in", "expires_at", "scope", "token_type"} + } + ) + token_values = _normalize_token_payload(token_values) + if token_values.get("access_token") and not _is_token_expiring_soon( + token_values, + pre_refresh_seconds=pre_refresh_seconds, + ): + return token_values + connection_id = getattr(connection, "id", None) + lock_acquired = await _acquire_refresh_lock(token_cache=token_cache, connection_id=connection_id) + if not lock_acquired: + refreshed_from_cache = await _wait_for_refreshed_token( + token_cache=token_cache, + connection_id=connection_id, + pre_refresh_seconds=pre_refresh_seconds, + ) + if refreshed_from_cache: + return refreshed_from_cache + + try: + return await _request_dynamic_token_values( + auth_config, + context=context, + connection=connection, + secret_values=secret_values, + credential_payload=credential_payload, + http_client=http_client, + token_cache=token_cache, + token_values=token_values, + ) + finally: + await _release_refresh_lock( + token_cache=token_cache, + connection_id=connection_id, + acquired=lock_acquired, + ) + + +async def resolve_runtime_mcp_config( + server: MCPServer, + *, + auth_context: AuthContext, + connection: MCPConnection | None = None, + http_client: httpx.AsyncClient | None = None, + token_cache: Any | None = None, +) -> dict[str, Any]: + config = _base_server_config(server) + auth_payload = server.auth_config_json or {} + if not auth_payload: + return config + + auth_config = MCPAuthConfig.model_validate(auth_payload) + inject_entries = [entry.model_dump() for entry in auth_config.inject.entries] + credential_payload = _parse_credential_blob(connection) + secret_values = credential_payload.get("secrets") or {} + + if auth_config.provider == "legacy_static": + return config + + if auth_config.provider in {"bound_secret", "stdio_env"}: + return _merge_injected_entries( + config, + inject_target=auth_config.inject.target, + inject_entries=inject_entries, + context=auth_context, + secret_values=secret_values, + token_values=credential_payload, + access_token=None, + ) + + if auth_config.provider in {"custom_http_token", "client_credentials", "authorization_code"}: + token_values = await _resolve_dynamic_token_values( + auth_config, + context=auth_context, + connection=connection, + secret_values=secret_values, + credential_payload=credential_payload, + http_client=http_client, + token_cache=token_cache, + ) + return _merge_injected_entries( + config, + inject_target=auth_config.inject.target, + inject_entries=inject_entries, + context=auth_context, + secret_values=secret_values, + token_values=token_values, + access_token=token_values.get("access_token"), + ) + + raise ValueError(f"Unsupported MCP auth provider: {auth_config.provider}") diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py new file mode 100644 index 000000000..fb9cff682 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +from datetime import timedelta +from typing import Any +from urllib.parse import urlencode + +import httpx + +from server.utils.auth_utils import AuthUtils +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + +INTERNAL_PROXY_TOKEN_HEADER = "X-Yuxi-MCP-Proxy-Token" +_PROXY_TOKEN_TYPE = "mcp_proxy" +_DYNAMIC_HTTP_PROVIDERS = {"custom_http_token", "client_credentials", "authorization_code"} +_HTTP_TRANSPORTS = {"streamable_http", "sse"} +_HOP_BY_HOP_HEADERS = { + "connection", + "content-length", + "host", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailer", + "transfer-encoding", + "upgrade", +} + + +def should_use_internal_proxy(server: MCPServer, auth_config: MCPAuthConfig, proxy_base_url: str | None) -> bool: + return bool( + proxy_base_url + and server.transport in _HTTP_TRANSPORTS + and auth_config.provider in _DYNAMIC_HTTP_PROVIDERS + ) + + +def create_proxy_access_token(server_name: str, auth_context: AuthContext) -> str: + return AuthUtils.create_access_token( + { + "sub": f"mcp-proxy:{server_name}", + "token_type": _PROXY_TOKEN_TYPE, + "server_name": server_name, + "user_id": auth_context.user_id, + "department_id": auth_context.department_id, + }, + expires_delta=timedelta(minutes=15), + ) + + +def decode_proxy_access_token(token: str, *, server_name: str) -> AuthContext: + payload = AuthUtils.decode_token(token) + if not payload: + raise ValueError("invalid internal proxy token") + if payload.get("token_type") != _PROXY_TOKEN_TYPE: + raise ValueError("invalid internal proxy token type") + if payload.get("server_name") != server_name: + raise ValueError("internal proxy token server mismatch") + return AuthContext( + user_id=payload.get("user_id"), + department_id=payload.get("department_id"), + ) + + +def build_internal_proxy_url(proxy_base_url: str, server_name: str) -> str: + return f"{proxy_base_url.rstrip('/')}/api/internal/mcp-proxy/{server_name}" + + +def build_proxy_runtime_config( + server: MCPServer, + *, + auth_config: MCPAuthConfig, + auth_context: AuthContext, + proxy_base_url: str, +) -> dict[str, Any]: + config = server.to_mcp_config() + config.pop("auth_config", None) + headers = dict(config.get("headers") or {}) + headers[INTERNAL_PROXY_TOKEN_HEADER] = create_proxy_access_token(server.name, auth_context) + config["headers"] = headers + config["url"] = build_internal_proxy_url(proxy_base_url, server.name) + if auth_config.manifest_scope == "binding": + if auth_config.binding_scope == "department": + partition = f"department:{auth_context.department_id or 'unknown'}" + elif auth_config.binding_scope == "user": + partition = f"user:{auth_context.user_id or 'unknown'}" + else: + partition = f"{auth_config.binding_scope}:global" + config["__yuxi_cache_partition"] = partition + config["__yuxi_allow_global_cache"] = False + else: + config["__yuxi_cache_partition"] = "server" + config["__yuxi_allow_global_cache"] = True + return config + + +def _merge_upstream_headers( + base_headers: dict[str, Any], + request_headers: dict[str, str] | None, +) -> dict[str, Any]: + merged = dict(base_headers or {}) + for key, value in (request_headers or {}).items(): + if key.lower() in _HOP_BY_HOP_HEADERS or key.lower() == INTERNAL_PROXY_TOKEN_HEADER.lower(): + continue + merged[key] = value + return merged + + +def _build_target_url(base_url: str, path: str = "", query_params: dict[str, Any] | None = None) -> str: + if not path: + target = base_url + else: + target = f"{base_url.rstrip('/')}/{path.lstrip('/')}" + if query_params: + return f"{target}?{urlencode(query_params, doseq=True)}" + return target + + +def _mark_reauth_required(connection: MCPConnection | None, message: str) -> None: + if connection is None: + return + connection.status = "reauth_required" + meta_json = dict(connection.meta_json or {}) + meta_json["last_error"] = { + "code": "unauthorized", + "message": message, + } + connection.meta_json = meta_json + + +def _record_scope_error(connection: MCPConnection | None, message: str) -> None: + if connection is None: + return + meta_json = dict(connection.meta_json or {}) + meta_json["last_error"] = { + "code": "insufficient_scope", + "message": message, + } + connection.meta_json = meta_json + + +async def proxy_mcp_request( + server: MCPServer, + *, + connection: MCPConnection | None, + auth_context: AuthContext, + method: str, + headers: dict[str, str] | None, + query_params: dict[str, Any] | None, + body: bytes, + path: str = "", + http_client: httpx.AsyncClient | None = None, + token_cache: Any | None = None, +) -> httpx.Response: + auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) + if server.transport not in _HTTP_TRANSPORTS: + raise ValueError(f"Internal proxy only supports HTTP MCP transports, got: {server.transport}") + + if http_client is None: + http_client = httpx.AsyncClient() + should_close = True + else: + should_close = False + + try: + max_attempts = 2 if auth_config.refresh_policy.retry_once_on_401 else 1 + for attempt in range(max_attempts): + runtime_config = await resolve_runtime_mcp_config( + server, + auth_context=auth_context, + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + target_url = _build_target_url(runtime_config["url"], path=path, query_params=query_params) + upstream_headers = _merge_upstream_headers(runtime_config.get("headers") or {}, headers) + response = await http_client.request( + method=method.upper(), + url=target_url, + headers=upstream_headers, + content=body, + ) + if response.status_code == 403: + _record_scope_error(connection, "MCP upstream rejected request due to insufficient scope") + return httpx.Response( + 403, + json={ + "error": "insufficient_scope", + "message": "当前授权范围不足,请联系管理员或重新授权", + }, + ) + if response.status_code != 401: + return response + if attempt + 1 >= max_attempts: + break + if token_cache is not None and connection is not None and getattr(connection, "id", None) is not None: + await token_cache.delete_access_token(connection.id) + + _mark_reauth_required(connection, "MCP upstream returned 401 after retry") + return httpx.Response( + 424, + json={ + "error": "reauth_required", + "message": "连接失效,请重新连接", + }, + ) + finally: + if should_close: + await http_client.aclose() diff --git a/backend/package/yuxi/services/mcp_auth/redis_token_cache.py b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py new file mode 100644 index 000000000..59177c06b --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import json +from collections.abc import Awaitable, Callable +from datetime import UTC, datetime +from typing import Any + +from yuxi.services.run_queue_service import get_redis_client +from yuxi.utils import logger + +ACCESS_TOKEN_KEY_PREFIX = "yuxi:mcp:access_token:v1" +REFRESH_LOCK_KEY_PREFIX = "yuxi:mcp:refresh_lock:v1" +DEFAULT_TOKEN_TTL_SECONDS = 300 +DEFAULT_LOCK_TTL_SECONDS = 30 + + +def _access_token_key(connection_id: int) -> str: + return f"{ACCESS_TOKEN_KEY_PREFIX}:{connection_id}" + + +def _refresh_lock_key(connection_id: int) -> str: + return f"{REFRESH_LOCK_KEY_PREFIX}:{connection_id}" + + +def _compute_token_ttl_seconds(token_payload: dict[str, Any]) -> int: + expires_at = token_payload.get("expires_at") + if isinstance(expires_at, str): + try: + expires_at_dt = datetime.fromisoformat(expires_at) + if expires_at_dt.tzinfo is None: + expires_at_dt = expires_at_dt.replace(tzinfo=UTC) + ttl = int((expires_at_dt - datetime.now(tz=UTC)).total_seconds()) + return max(ttl, 1) + except ValueError: + logger.warning(f"Invalid expires_at in MCP token payload: {expires_at}") + expires_in = token_payload.get("expires_in") + if isinstance(expires_in, (int, float)) and int(expires_in) > 0: + return int(expires_in) + return DEFAULT_TOKEN_TTL_SECONDS + + +class RedisTokenCache: + def __init__(self, redis_client_factory: Callable[[], Awaitable[Any]] | None = None): + self._redis_client_factory = redis_client_factory or get_redis_client + + async def _get_redis(self): + return await self._redis_client_factory() + + async def get_access_token(self, connection_id: int) -> dict[str, Any] | None: + redis = await self._get_redis() + raw = await redis.get(_access_token_key(connection_id)) + if not raw: + return None + if isinstance(raw, dict): + return raw + return json.loads(raw) + + async def set_access_token(self, connection_id: int, token_payload: dict[str, Any]) -> None: + redis = await self._get_redis() + ttl_seconds = _compute_token_ttl_seconds(token_payload) + await redis.set( + _access_token_key(connection_id), + json.dumps(token_payload, ensure_ascii=False, separators=(",", ":")), + ex=ttl_seconds, + ) + + async def delete_access_token(self, connection_id: int) -> None: + redis = await self._get_redis() + await redis.delete(_access_token_key(connection_id)) + + async def acquire_refresh_lock(self, connection_id: int, *, ttl_seconds: int = DEFAULT_LOCK_TTL_SECONDS) -> bool: + redis = await self._get_redis() + acquired = await redis.set(_refresh_lock_key(connection_id), "1", ex=ttl_seconds, nx=True) + return bool(acquired) + + async def release_refresh_lock(self, connection_id: int) -> None: + redis = await self._get_redis() + await redis.delete(_refresh_lock_key(connection_id)) diff --git a/backend/package/yuxi/services/mcp_auth/template_resolver.py b/backend/package/yuxi/services/mcp_auth/template_resolver.py new file mode 100644 index 000000000..446dbab16 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/template_resolver.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import re +from collections.abc import Mapping +from typing import Any + +_PLACEHOLDER_PATTERN = re.compile(r"\$\{([^}]+)\}") + + +class TemplateResolutionError(ValueError): + """Raised when a template placeholder cannot be resolved.""" + + +def _lookup_path(root: Any, path: str, *, full_expression: str) -> Any: + current = root + for segment in path.split("."): + if isinstance(current, Mapping) and segment in current: + current = current[segment] + continue + raise TemplateResolutionError(f"Unknown template placeholder: {full_expression}") + return current + + +def _resolve_placeholder( + expression: str, + *, + context: Mapping[str, Any], + secret: Mapping[str, Any], + token: Mapping[str, Any], + access_token: str | None, +) -> Any: + if expression == "access_token": + if access_token is None: + raise TemplateResolutionError("Unknown template placeholder: access_token") + return access_token + + if "." not in expression: + raise TemplateResolutionError(f"Unknown template placeholder: {expression}") + + root_name, path = expression.split(".", 1) + roots = { + "context": context, + "secret": secret, + "token": token, + } + if root_name not in roots: + raise TemplateResolutionError(f"Unknown template placeholder: {expression}") + return _lookup_path(roots[root_name], path, full_expression=expression) + + +def resolve_template_value( + value: Any, + *, + context: Mapping[str, Any], + secret: Mapping[str, Any], + token: Mapping[str, Any], + access_token: str | None, +) -> Any: + if isinstance(value, Mapping): + return { + key: resolve_template_value( + item, + context=context, + secret=secret, + token=token, + access_token=access_token, + ) + for key, item in value.items() + } + + if isinstance(value, list): + return [ + resolve_template_value( + item, + context=context, + secret=secret, + token=token, + access_token=access_token, + ) + for item in value + ] + + if not isinstance(value, str): + return value + + matches = list(_PLACEHOLDER_PATTERN.finditer(value)) + if not matches: + return value + + if len(matches) == 1 and matches[0].span() == (0, len(value)): + return _resolve_placeholder( + matches[0].group(1), + context=context, + secret=secret, + token=token, + access_token=access_token, + ) + + parts: list[str] = [] + cursor = 0 + for match in matches: + start, end = match.span() + if start > cursor: + parts.append(value[cursor:start]) + resolved = _resolve_placeholder( + match.group(1), + context=context, + secret=secret, + token=token, + access_token=access_token, + ) + parts.append(str(resolved)) + cursor = end + if cursor < len(value): + parts.append(value[cursor:]) + return "".join(parts) diff --git a/backend/package/yuxi/services/mcp_service.py b/backend/package/yuxi/services/mcp_service.py index ebec9f9db..45d724745 100644 --- a/backend/package/yuxi/services/mcp_service.py +++ b/backend/package/yuxi/services/mcp_service.py @@ -9,16 +9,28 @@ import asyncio import hashlib +import httpx import json +import os import re import traceback from collections.abc import Callable +from datetime import UTC, datetime from typing import Any, cast from langchain_mcp_adapters.client import MultiServerMCPClient from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession -from yuxi.storage.postgres.models_business import MCPServer +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.crypto import encrypt_credential_blob +from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config +from yuxi.services.mcp_auth.proxy_service import ( + INTERNAL_PROXY_TOKEN_HEADER, + build_proxy_runtime_config, + should_use_internal_proxy, +) +from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache +from yuxi.storage.postgres.models_business import AgentConfig, MCPConnection, MCPServer, Skill from yuxi.utils import logger # ============================================================================= @@ -35,6 +47,8 @@ # MCP tools statistics (for reporting enabled/disabled counts) _mcp_tools_stats: dict[str, dict[str, int]] = {} _UNSET = object() +_VALID_MCP_CONNECTION_SCOPE_TYPES = {"system", "department", "user"} +_VALID_MCP_CONNECTION_STATUSES = {"active", "disabled", "reauth_required", "invalid"} # Default MCP Server configurations (Imported to DB on first run) _DEFAULT_MCP_SERVERS = { @@ -69,6 +83,8 @@ "icon", ) +_MCP_PROXY_BASE_URL_ENV = "YUXI_INTERNAL_MCP_PROXY_BASE_URL" + # ============================================================================= # === Core Logic (Moved from agents/common/mcp.py) === # ============================================================================= @@ -207,6 +223,138 @@ async def get_enabled_mcp_server_config(server_name: str, *, db: AsyncSession | return configs.get(server_name) +def _get_internal_mcp_proxy_base_url() -> str | None: + value = os.getenv(_MCP_PROXY_BASE_URL_ENV, "").strip() + return value or None + + +def _extract_cache_identity(server_config: dict[str, Any]) -> tuple[dict[str, Any], str, bool]: + cache_partition = str(server_config.get("__yuxi_cache_partition") or "server") + allow_global_cache = bool(server_config.get("__yuxi_allow_global_cache", True)) + cache_identity = { + key: value + for key, value in server_config.items() + if key not in {"__yuxi_cache_partition", "__yuxi_allow_global_cache", "disabled_tools"} + } + headers = dict(cache_identity.get("headers") or {}) + headers.pop(INTERNAL_PROXY_TOKEN_HEADER, None) + if headers: + cache_identity["headers"] = headers + elif "headers" in cache_identity: + cache_identity["headers"] = {} + return cache_identity, cache_partition, allow_global_cache + + +def _resolve_scope_id(binding_scope: str, auth_context: AuthContext | None) -> str | None: + if binding_scope == "inline": + return None + if binding_scope == "system": + return "global" + if auth_context is None: + raise ValueError(f"auth_context is required for MCP binding scope '{binding_scope}'") + if binding_scope == "department": + if not auth_context.department_id: + raise ValueError("department_id is required for department-scoped MCP auth") + return str(auth_context.department_id) + if binding_scope == "user": + if not auth_context.user_id: + raise ValueError("user_id is required for user-scoped MCP auth") + return str(auth_context.user_id) + raise ValueError(f"Unsupported MCP binding scope: {binding_scope}") + + +def _normalize_mcp_connection_scope(scope_type: str, scope_id: str | None) -> tuple[str, str]: + normalized_scope_type = str(scope_type or "").strip().lower() + if normalized_scope_type not in _VALID_MCP_CONNECTION_SCOPE_TYPES: + raise ValueError("scope_type must be one of: system, department, user") + + normalized_scope_id = str(scope_id or "").strip() + if normalized_scope_type == "system": + return normalized_scope_type, "global" + if not normalized_scope_id: + raise ValueError(f"scope_id is required for {normalized_scope_type}-scoped MCP connections") + return normalized_scope_type, normalized_scope_id + + +def _normalize_mcp_connection_status(status: str) -> str: + normalized_status = str(status or "").strip().lower() + if normalized_status not in _VALID_MCP_CONNECTION_STATUSES: + raise ValueError("status must be one of: active, disabled, reauth_required, invalid") + return normalized_status + + +async def _get_enabled_mcp_server_record(server_name: str, *, db: AsyncSession) -> MCPServer | None: + result = await db.execute( + select(MCPServer).where( + MCPServer.enabled == 1, + MCPServer.name == server_name, + ) + ) + return result.scalar_one_or_none() + + +async def get_runtime_mcp_server_config( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, +) -> dict[str, Any] | None: + if db is None and auth_context is None: + return await get_enabled_mcp_server_config(server_name) + + if db is not None: + server = await _get_enabled_mcp_server_record(server_name, db=db) + if server is None: + return None + if not server.auth_config_json: + return server.to_mcp_config() + + auth_config = MCPAuthConfig.model_validate(server.auth_config_json) + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) + if scope_id is None: + return server.to_mcp_config() + + result = await db.execute( + select(MCPConnection).where( + MCPConnection.server_name == server_name, + MCPConnection.scope_type == auth_config.binding_scope, + MCPConnection.scope_id == scope_id, + MCPConnection.status == "active", + ) + ) + connection = result.scalar_one_or_none() + if connection is None: + raise ValueError( + f"Active MCP connection not found for server '{server_name}' and scope " + f"{auth_config.binding_scope}:{scope_id}" + ) + proxy_base_url = _get_internal_mcp_proxy_base_url() + if should_use_internal_proxy(server, auth_config, proxy_base_url): + return build_proxy_runtime_config( + server, + auth_config=auth_config, + auth_context=auth_context or AuthContext(), + proxy_base_url=proxy_base_url or "", + ) + return await resolve_runtime_mcp_config( + server, + auth_context=auth_context or AuthContext(), + connection=connection, + http_client=http_client, + ) + + from yuxi.storage.postgres.manager import pg_manager + + async with pg_manager.get_async_session_context() as session: + return await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=session, + http_client=http_client, + ) + + async def get_enabled_mcp_server_names(*, db: AsyncSession | None = None) -> list[str]: """Get enabled MCP server names from the database.""" configs = await _load_enabled_mcp_server_configs(db=db) @@ -245,9 +393,13 @@ async def get_mcp_tools( # 配置 hash 直接基于完整配置生成。只要数据库中的配置发生变化, # 本地工具缓存 key 就会变化,从而自然触发重建。 - config_payload = json.dumps(server_config, sort_keys=True, ensure_ascii=True, separators=(",", ":")) + cache_identity, cache_partition, allow_global_cache = _extract_cache_identity(server_config) + config_payload = json.dumps(cache_identity, sort_keys=True, ensure_ascii=True, separators=(",", ":")) config_hash = hashlib.sha256(config_payload.encode("utf-8")).hexdigest()[:16] - cache_key = f"{server_name}:{config_hash}" + if allow_global_cache: + cache_key = f"{server_name}:{config_hash}" + else: + cache_key = f"{server_name}:{cache_partition}:{config_hash}" all_processed_tools: list[Callable[..., Any]] = [] @@ -258,7 +410,11 @@ async def get_mcp_tools( if not all_processed_tools: try: # disabled_tools 只影响返回值过滤,不参与 MCP client 建连参数。 - client_config = {k: v for k, v in server_config.items() if k not in ("disabled_tools",)} + client_config = { + k: v + for k, v in server_config.items() + if k not in ("disabled_tools", "__yuxi_cache_partition", "__yuxi_allow_global_cache") + } client = await get_mcp_client({server_name: client_config}) if client is None: @@ -347,6 +503,27 @@ def clear_mcp_server_tools_cache(server_name: str) -> None: logger.info(f"Cleared tools cache for MCP server '{server_name}'") +async def _clear_mcp_connection_runtime_auth_cache(connection_id: int | None) -> None: + if connection_id is None: + return + + cache = RedisTokenCache() + try: + await cache.delete_access_token(connection_id) + except Exception as exc: + logger.warning(f"Failed to clear MCP token cache for connection {connection_id}: {exc}") + try: + await cache.release_refresh_lock(connection_id) + except Exception as exc: + logger.warning(f"Failed to clear MCP refresh lock for connection {connection_id}: {exc}") + + +async def _clear_mcp_server_runtime_auth_cache(db: AsyncSession, server_name: str) -> None: + connections = await list_mcp_connections(db, server_name=server_name) + for connection in connections: + await _clear_mcp_connection_runtime_auth_cache(getattr(connection, "id", None)) + + def get_mcp_tools_stats(server_name: str) -> dict[str, int] | None: """Get tools statistics for a MCP server. @@ -373,6 +550,224 @@ async def get_all_mcp_servers(db: AsyncSession) -> list[MCPServer]: return list(result.scalars().all()) +async def get_mcp_connection(db: AsyncSession, connection_id: int) -> MCPConnection | None: + result = await db.execute(select(MCPConnection).where(MCPConnection.id == connection_id)) + return result.scalar_one_or_none() + + +def _auth_context_from_connection(connection: MCPConnection) -> AuthContext: + if connection.scope_type == "department": + return AuthContext(department_id=connection.scope_id) + if connection.scope_type == "user": + return AuthContext(user_id=connection.scope_id) + return AuthContext() + + +async def list_mcp_connections( + db: AsyncSession, + *, + server_name: str | None = None, + scope_type: str | None = None, + scope_id: str | None = None, +) -> list[MCPConnection]: + stmt = select(MCPConnection) + if server_name is not None: + stmt = stmt.where(MCPConnection.server_name == server_name) + if scope_type is not None: + stmt = stmt.where(MCPConnection.scope_type == scope_type) + if scope_id is not None: + stmt = stmt.where(MCPConnection.scope_id == scope_id) + stmt = stmt.order_by(MCPConnection.id.asc()) + result = await db.execute(stmt) + return list(result.scalars().all()) + + +async def create_mcp_connection( + db: AsyncSession, + *, + server_name: str, + scope_type: str, + scope_id: str, + display_name: str | None = None, + external_subject: str | None = None, + status: str = "active", + credential_blob: str | None = None, + meta_json: dict[str, Any] | None = None, + created_by: str | None = None, +) -> MCPConnection: + server = await get_mcp_server(db, server_name) + if server is None: + raise ValueError(f"Server '{server_name}' does not exist") + normalized_scope_type, normalized_scope_id = _normalize_mcp_connection_scope(scope_type, scope_id) + normalized_status = _normalize_mcp_connection_status(status) + + encrypted_credential_blob = ( + encrypt_credential_blob(credential_blob) + if isinstance(credential_blob, str) and credential_blob.strip() + else credential_blob + ) + + connection = MCPConnection( + server_name=server_name, + scope_type=normalized_scope_type, + scope_id=normalized_scope_id, + display_name=display_name, + external_subject=external_subject, + status=normalized_status, + credential_blob=encrypted_credential_blob, + meta_json=meta_json or {}, + created_by=created_by, + updated_by=created_by, + ) + db.add(connection) + await db.commit() + await db.refresh(connection) + return connection + + +async def update_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + display_name: str | None = None, + external_subject: str | None = None, + credential_blob: Any = _UNSET, + meta_json: dict[str, Any] | None = None, + status: str | None = None, + updated_by: str | None = None, +) -> MCPConnection: + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + should_clear_runtime_auth_cache = False + if display_name is not None: + connection.display_name = display_name + if external_subject is not None: + connection.external_subject = external_subject + if credential_blob is not _UNSET: + if isinstance(credential_blob, str) and credential_blob.strip(): + connection.credential_blob = encrypt_credential_blob(credential_blob) + else: + connection.credential_blob = credential_blob + should_clear_runtime_auth_cache = True + if meta_json is not None: + connection.meta_json = meta_json + if status is not None: + connection.status = _normalize_mcp_connection_status(status) + should_clear_runtime_auth_cache = True + if updated_by is not None: + connection.updated_by = updated_by + + await db.commit() + await db.refresh(connection) + if should_clear_runtime_auth_cache: + await _clear_mcp_connection_runtime_auth_cache(connection.id) + return connection + + +async def delete_mcp_connection(db: AsyncSession, connection_id: int) -> bool: + connection = await get_mcp_connection(db, connection_id) + if connection is None: + return False + deleted_connection_id = connection.id + await db.delete(connection) + await db.commit() + await _clear_mcp_connection_runtime_auth_cache(deleted_connection_id) + return True + + +async def set_mcp_connection_status( + db: AsyncSession, + connection_id: int, + *, + status: str, + updated_by: str | None = None, +) -> MCPConnection: + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + connection.status = _normalize_mcp_connection_status(status) + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + await _clear_mcp_connection_runtime_auth_cache(connection.id) + return connection + + +async def reauthorize_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + updated_by: str | None = None, +) -> MCPConnection: + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + cache = RedisTokenCache() + if getattr(connection, "id", None) is not None: + try: + await cache.delete_access_token(connection.id) + except Exception as exc: + logger.warning(f"Failed to clear MCP token cache for connection {connection.id}: {exc}") + try: + await cache.release_refresh_lock(connection.id) + except Exception as exc: + logger.warning(f"Failed to clear MCP refresh lock for connection {connection.id}: {exc}") + + connection.status = "active" + meta_json = dict(connection.meta_json or {}) + meta_json.pop("last_error", None) + connection.meta_json = meta_json + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + return connection + + +async def test_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + updated_by: str | None = None, +) -> dict[str, Any]: + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + server = await get_mcp_server(db, connection.server_name) + if server is None: + raise ValueError(f"Server '{connection.server_name}' does not exist") + + auth_context = _auth_context_from_connection(connection) + config = await get_runtime_mcp_server_config(server.name, auth_context=auth_context, db=db) + if config is None: + raise ValueError(f"MCP server '{server.name}' runtime config unavailable") + + tools = await get_mcp_tools( + server.name, + additional_servers={server.name: config}, + disabled_tools=[], + cache=False, + force_refresh=True, + ) + + meta_json = dict(connection.meta_json or {}) + meta_json["last_success_at"] = datetime.now(tz=UTC).isoformat() + meta_json.pop("last_error", None) + connection.meta_json = meta_json + connection.status = "active" + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + return {"tool_count": len(tools), "connection": connection} + + async def create_mcp_server( db: AsyncSession, name: str, @@ -387,6 +782,7 @@ async def create_mcp_server( sse_read_timeout: int = None, tags: list = None, icon: str = None, + auth_config: dict | None = None, created_by: str = None, ) -> MCPServer: """Create server.""" @@ -404,6 +800,7 @@ async def create_mcp_server( args=args, env=env, headers=headers, + auth_config_json=auth_config, timeout=timeout, sse_read_timeout=sse_read_timeout, tags=tags, @@ -416,6 +813,7 @@ async def create_mcp_server( await db.commit() await db.refresh(server) + await _clear_mcp_server_runtime_auth_cache(db, name) clear_mcp_server_tools_cache(name) logger.info(f"Created MCP server '{name}'") @@ -436,6 +834,7 @@ async def update_mcp_server( sse_read_timeout: int = None, tags: list = None, icon: str = None, + auth_config: Any = _UNSET, updated_by: str = None, ) -> MCPServer: """Update server configuration.""" @@ -457,6 +856,8 @@ async def update_mcp_server( server.env = env if headers is not None: server.headers = headers + if auth_config is not _UNSET: + server.auth_config_json = auth_config if timeout is not None: server.timeout = timeout if sse_read_timeout is not None: @@ -483,15 +884,48 @@ async def delete_mcp_server(db: AsyncSession, name: str) -> bool: if not server: return False + connection_ids = [item.id for item in await list_mcp_connections(db, server_name=name)] await db.delete(server) await db.commit() + for connection_id in connection_ids: + await _clear_mcp_connection_runtime_auth_cache(connection_id) clear_mcp_server_tools_cache(name) logger.info(f"Deleted MCP server '{name}'") return True +async def get_mcp_server_dependency_summary(db: AsyncSession, name: str) -> dict[str, Any]: + connections = await list_mcp_connections(db, server_name=name) + + skill_rows = (await db.execute(select(Skill))).scalars().all() + matched_skills = [ + {"slug": item.slug, "name": item.name} + for item in skill_rows + if name in (item.mcp_dependencies or []) + ] + + agent_config_rows = (await db.execute(select(AgentConfig))).scalars().all() + matched_agent_configs = [] + for item in agent_config_rows: + config_json = item.config_json or {} + if name in (config_json.get("mcps") or []): + matched_agent_configs.append({"id": item.id, "name": item.name, "agent_id": item.agent_id}) + + connection_refs = [ + {"scope_type": item.scope_type, "scope_id": item.scope_id, "status": item.status} + for item in connections + ] + + return { + "has_references": bool(connection_refs or matched_skills or matched_agent_configs), + "connections": connection_refs, + "skills": matched_skills, + "agent_configs": matched_agent_configs, + } + + # ============================================================================= # === Tool Management === # ============================================================================= @@ -511,6 +945,8 @@ async def set_server_enabled( await db.commit() is_enabled = bool(server.enabled) + if not is_enabled: + await _clear_mcp_server_runtime_auth_cache(db, name) clear_mcp_server_tools_cache(name) logger.info(f"Set MCP server '{name}' enabled={is_enabled}") @@ -564,7 +1000,13 @@ async def toggle_tool_enabled( # ============================================================================= -async def get_enabled_mcp_tools(server_name: str) -> list: +async def get_enabled_mcp_tools( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, +) -> list: """Get MCP server tools (auto-filtering disabled_tools). Unified entry point for Agents, automatically: @@ -578,7 +1020,12 @@ async def get_enabled_mcp_tools(server_name: str) -> list: Returns: List of enabled tools """ - config = await get_enabled_mcp_server_config(server_name) + config = await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=db, + http_client=http_client, + ) if config is None: logger.warning(f"MCP server '{server_name}' not found in database or disabled") return [] diff --git a/backend/package/yuxi/storage/postgres/manager.py b/backend/package/yuxi/storage/postgres/manager.py index c4a6fde3b..64e220ed1 100644 --- a/backend/package/yuxi/storage/postgres/manager.py +++ b/backend/package/yuxi/storage/postgres/manager.py @@ -196,6 +196,29 @@ async def ensure_business_schema(self): "ALTER TABLE IF EXISTS subagents ADD COLUMN IF NOT EXISTS enabled BOOLEAN NOT NULL DEFAULT TRUE", "ALTER TABLE IF EXISTS conversations ADD COLUMN IF NOT EXISTS is_pinned BOOLEAN NOT NULL DEFAULT FALSE", "ALTER TABLE IF EXISTS mcp_servers ADD COLUMN IF NOT EXISTS env JSONB", + "ALTER TABLE IF EXISTS mcp_servers ADD COLUMN IF NOT EXISTS auth_config_json JSONB", + """ + CREATE TABLE IF NOT EXISTS mcp_connections ( + id SERIAL PRIMARY KEY, + server_name VARCHAR(100) NOT NULL REFERENCES mcp_servers(name) ON DELETE CASCADE, + scope_type VARCHAR(16) NOT NULL, + scope_id VARCHAR(64) NOT NULL, + display_name VARCHAR(128), + external_subject VARCHAR(255), + status VARCHAR(32) NOT NULL DEFAULT 'active', + credential_blob TEXT, + meta_json JSONB NOT NULL DEFAULT '{}'::jsonb, + created_by VARCHAR(64), + updated_by VARCHAR(64), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT ck_mcp_connections_scope_type CHECK (scope_type IN ('system', 'department', 'user')), + CONSTRAINT ck_mcp_connections_status CHECK ( + status IN ('active', 'disabled', 'reauth_required', 'invalid') + ), + CONSTRAINT uq_mcp_connections_server_scope UNIQUE (server_name, scope_type, scope_id) + ) + """, """ CREATE TABLE IF NOT EXISTS model_providers ( id SERIAL PRIMARY KEY, @@ -244,6 +267,8 @@ async def ensure_business_schema(self): "CREATE INDEX IF NOT EXISTS idx_agent_runs_thread_created ON agent_runs(thread_id, created_at DESC)", "CREATE INDEX IF NOT EXISTS idx_agent_runs_status_updated ON agent_runs(status, updated_at)", "CREATE INDEX IF NOT EXISTS ix_conversations_is_pinned ON conversations(is_pinned)", + "CREATE INDEX IF NOT EXISTS idx_mcp_connections_status ON mcp_connections(status)", + "CREATE INDEX IF NOT EXISTS idx_mcp_connections_subject ON mcp_connections(external_subject)", "CREATE UNIQUE INDEX IF NOT EXISTS ix_model_providers_provider_id ON model_providers(provider_id)", "CREATE INDEX IF NOT EXISTS ix_model_providers_is_enabled ON model_providers(is_enabled)", ] diff --git a/backend/package/yuxi/storage/postgres/models_business.py b/backend/package/yuxi/storage/postgres/models_business.py index ef3299653..86c4fd2d8 100644 --- a/backend/package/yuxi/storage/postgres/models_business.py +++ b/backend/package/yuxi/storage/postgres/models_business.py @@ -440,6 +440,7 @@ class MCPServer(Base): args = Column(JSON, nullable=True, comment="命令参数数组(stdio)") env = Column(JSON, nullable=True, comment="环境变量(stdio)") headers = Column(JSON, nullable=True, comment="HTTP 请求头") + auth_config_json = Column(JSON, nullable=True, comment="MCP 认证配置") timeout = Column(Integer, nullable=True, comment="HTTP 超时时间(秒)") sse_read_timeout = Column(Integer, nullable=True, comment="SSE 读取超时(秒)") @@ -469,6 +470,7 @@ def to_dict(self) -> dict[str, Any]: "args": self.args or [], "env": self.env or {}, "headers": self.headers or {}, + "auth_config": self.auth_config_json or {}, "timeout": self.timeout, "sse_read_timeout": self.sse_read_timeout, "tags": self.tags or [], @@ -516,6 +518,14 @@ def to_mcp_config(self) -> dict[str, Any]: config["headers"] = json.loads(self.headers) except json.JSONDecodeError: pass + if self.auth_config_json: + if isinstance(self.auth_config_json, dict): + config["auth_config"] = self.auth_config_json + elif isinstance(self.auth_config_json, str): + try: + config["auth_config"] = json.loads(self.auth_config_json) + except json.JSONDecodeError: + pass if self.timeout is not None: config["timeout"] = self.timeout if self.sse_read_timeout is not None: @@ -525,6 +535,53 @@ def to_mcp_config(self) -> dict[str, Any]: return config +class MCPConnection(Base): + """MCP 长期连接与凭据绑定模型""" + + __tablename__ = "mcp_connections" + __table_args__ = ( + UniqueConstraint("server_name", "scope_type", "scope_id", name="uq_mcp_connections_server_scope"), + Index("idx_mcp_connections_status", "status"), + Index("idx_mcp_connections_subject", "external_subject"), + ) + + id = Column(Integer, primary_key=True, autoincrement=True) + server_name = Column(String(100), ForeignKey("mcp_servers.name", ondelete="CASCADE"), nullable=False) + scope_type = Column(String(16), nullable=False, comment="system/department/user") + scope_id = Column(String(64), nullable=False, comment="绑定范围标识") + display_name = Column(String(128), nullable=True, comment="展示名称") + external_subject = Column(String(255), nullable=True, comment="外部系统主体标识") + status = Column(String(32), nullable=False, default="active", comment="连接状态") + credential_blob = Column(Text, nullable=True, comment="加密后的长期敏感凭据") + meta_json = Column(JSON, nullable=False, default=dict, comment="非敏感元数据") + created_by = Column(String(64), nullable=True) + updated_by = Column(String(64), nullable=True) + created_at = Column(DateTime, default=utc_now_naive, comment="创建时间") + updated_at = Column(DateTime, default=utc_now_naive, onupdate=utc_now_naive, comment="更新时间") + + server = relationship("MCPServer") + + def to_dict(self, *, include_credentials: bool = False) -> dict[str, Any]: + payload = { + "id": self.id, + "server_name": self.server_name, + "scope_type": self.scope_type, + "scope_id": self.scope_id, + "display_name": self.display_name, + "external_subject": self.external_subject, + "status": self.status, + "meta_json": self.meta_json or {}, + "has_credentials": bool(self.credential_blob), + "created_by": self.created_by, + "updated_by": self.updated_by, + "created_at": format_utc_datetime(self.created_at), + "updated_at": format_utc_datetime(self.updated_at), + } + if include_credentials: + payload["credential_blob"] = self.credential_blob + return payload + + class ModelProvider(Base): """模型供应商配置,存储 provider 基础信息、模型端点和可用模型。""" diff --git a/backend/server/routers/__init__.py b/backend/server/routers/__init__.py index bf998c296..b8528f826 100644 --- a/backend/server/routers/__init__.py +++ b/backend/server/routers/__init__.py @@ -6,6 +6,7 @@ from server.routers.chat_router import chat from server.routers.dashboard_router import dashboard from server.routers.auth_dept_router import department +from server.routers.mcp_internal_router import mcp_internal from server.routers.mcp_router import mcp from server.routers.model_provider_router import model_providers from server.routers.skill_router import skills @@ -32,6 +33,7 @@ router.include_router(department) # /api/departments/* 部门与权限相关数据 router.include_router(tasks) # /api/tasks/* 后台任务查询与管理 router.include_router(mcp) # /api/system/mcp-servers/* MCP 服务管理 +router.include_router(mcp_internal) # /api/internal/mcp-proxy/* 动态 MCP 内部代理 router.include_router(model_providers) # /api/system/model-providers/* 独立模型配置 router.include_router(skills) # /api/system/skills/* Skills 管理 router.include_router(subagents_router) # /api/system/subagents/* 子智能体管理 diff --git a/backend/server/routers/mcp_internal_router.py b/backend/server/routers/mcp_internal_router.py new file mode 100644 index 000000000..b0dc3d827 --- /dev/null +++ b/backend/server/routers/mcp_internal_router.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from server.utils.auth_middleware import get_db +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.proxy_service import ( + INTERNAL_PROXY_TOKEN_HEADER, + decode_proxy_access_token, + proxy_mcp_request, +) +from yuxi.services.mcp_service import _resolve_scope_id, get_mcp_server +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer +from yuxi.utils import logger + +mcp_internal = APIRouter(prefix="/internal/mcp-proxy", tags=["mcp-internal"]) + + +async def _load_active_connection( + db: AsyncSession, + *, + server: MCPServer, + auth_context, +) -> MCPConnection | None: + auth_payload = server.auth_config_json or {} + if not auth_payload: + return None + + auth_config = MCPAuthConfig.model_validate(auth_payload) + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) + if scope_id is None: + return None + + result = await db.execute( + select(MCPConnection).where( + MCPConnection.server_name == server.name, + MCPConnection.scope_type == auth_config.binding_scope, + MCPConnection.scope_id == scope_id, + MCPConnection.status == "active", + ) + ) + return result.scalar_one_or_none() + + +@mcp_internal.api_route( + "/{server_name}", + methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], +) +@mcp_internal.api_route( + "/{server_name}/{path:path}", + methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], +) +async def proxy_mcp_server_request( + server_name: str, + request: Request, + path: str = "", + internal_token: str | None = Header(None, alias=INTERNAL_PROXY_TOKEN_HEADER), + db: AsyncSession = Depends(get_db), +): + if not internal_token: + raise HTTPException(status_code=401, detail="missing internal proxy token") + + try: + auth_context = decode_proxy_access_token(internal_token, server_name=server_name) + except ValueError as exc: + raise HTTPException(status_code=401, detail=str(exc)) from exc + + server = await get_mcp_server(db, server_name) + if server is None: + raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在") + + try: + connection = await _load_active_connection(db, server=server, auth_context=auth_context) + body = await request.body() + upstream_response = await proxy_mcp_request( + server, + connection=connection, + auth_context=auth_context, + method=request.method, + headers=dict(request.headers), + query_params=dict(request.query_params), + body=body, + path=path, + ) + if connection is not None and hasattr(db, "commit"): + await db.commit() + response_headers = {} + content_type = upstream_response.headers.get("content-type") + if content_type: + response_headers["content-type"] = content_type + return Response( + content=upstream_response.content, + status_code=upstream_response.status_code, + headers=response_headers, + ) + except HTTPException: + raise + except Exception as exc: + logger.error(f"Failed to proxy MCP server '{server_name}': {exc}") + raise HTTPException(status_code=500, detail=str(exc)) from exc diff --git a/backend/server/routers/mcp_router.py b/backend/server/routers/mcp_router.py index 76fdbd96a..68a68602e 100644 --- a/backend/server/routers/mcp_router.py +++ b/backend/server/routers/mcp_router.py @@ -1,18 +1,30 @@ """MCP 服务器管理路由""" +import json + from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession +from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_service import ( + create_mcp_connection, create_mcp_server, + delete_mcp_connection, get_mcp_tools_stats, delete_mcp_server, + get_mcp_connection, + get_mcp_server_dependency_summary, get_all_mcp_servers, get_all_mcp_tools, get_mcp_server, + list_mcp_connections, + reauthorize_mcp_connection, + set_mcp_connection_status, set_server_enabled, + test_mcp_connection, toggle_tool_enabled, + update_mcp_connection, update_mcp_server, ) from yuxi.storage.postgres.models_business import User @@ -40,6 +52,7 @@ class CreateMcpServerRequest(BaseModel): sse_read_timeout: int | None = Field(None, description="SSE 读取超时(秒)") tags: list | None = Field(None, description="标签数组") icon: str | None = Field(None, description="图标(emoji)") + auth_config: dict | None = Field(None, description="MCP 鉴权配置") class UpdateMcpServerRequest(BaseModel): @@ -54,12 +67,35 @@ class UpdateMcpServerRequest(BaseModel): sse_read_timeout: int | None = Field(None, description="SSE 读取超时(秒)") tags: list | None = Field(None, description="标签数组") icon: str | None = Field(None, description="图标(emoji)") + auth_config: dict | None = Field(None, description="MCP 鉴权配置") class UpdateMcpServerStatusRequest(BaseModel): enabled: bool = Field(..., description="是否启用") +class CreateMcpConnectionRequest(BaseModel): + scope_type: str = Field(..., description="连接范围:system/department/user") + scope_id: str = Field(..., description="范围标识") + display_name: str | None = Field(None, description="展示名称") + external_subject: str | None = Field(None, description="外部系统主体标识") + credential: dict | str | None = Field(None, description="长期凭据") + meta_json: dict | None = Field(None, description="非敏感元数据") + status: str = Field("active", description="连接状态") + + +class UpdateMcpConnectionStatusRequest(BaseModel): + status: str = Field(..., description="连接状态") + + +class UpdateMcpConnectionRequest(BaseModel): + display_name: str | None = Field(None, description="展示名称") + external_subject: str | None = Field(None, description="外部系统主体标识") + credential: dict | str | None = Field(None, description="长期凭据") + meta_json: dict | None = Field(None, description="非敏感元数据") + status: str | None = Field(None, description="连接状态") + + # ============================================================================= # === Helpers === # ============================================================================= @@ -73,6 +109,30 @@ async def get_server_or_404(db: AsyncSession, name: str): return server +async def get_connection_for_server_or_404(db: AsyncSession, server_name: str, connection_id: int): + connection = await get_mcp_connection(db, connection_id) + if connection is None or connection.server_name != server_name: + raise HTTPException(status_code=404, detail=f"连接 '{connection_id}' 不存在") + return connection + + +def _normalize_credential_blob(value: dict | str | None) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return json.dumps(value, ensure_ascii=False) + + +def _validate_auth_config_or_400(payload: dict | None) -> dict | None: + if payload is None: + return None + try: + return MCPAuthConfig.model_validate(payload).model_dump(mode="json") + except Exception as exc: + raise HTTPException(status_code=400, detail=f"auth_config 配置无效: {exc}") from exc + + # ============================================================================= # === MCP 服务器 CRUD === # ============================================================================= @@ -127,6 +187,7 @@ async def create_mcp_server_route( raise HTTPException(status_code=400, detail="传输类型为 stdio 时,command 必填") try: + auth_config = _validate_auth_config_or_400(request.auth_config) server = await create_mcp_server( db, name=request.name, @@ -141,9 +202,12 @@ async def create_mcp_server_route( sse_read_timeout=request.sse_read_timeout, tags=request.tags, icon=request.icon, + auth_config=auth_config, created_by=current_user.username, ) return {"success": True, "data": server.to_dict()} + except HTTPException: + raise except ValueError as ve: raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: @@ -182,10 +246,12 @@ async def update_mcp_server_route( raise HTTPException(status_code=400, detail=f"传输类型必须是 {', '.join(valid_transports)} 之一") try: - fields_set = getattr(request, "model_fields_set", getattr(request, "__fields_set__", set())) + fields_set = request.model_fields_set update_kwargs = {} if "env" in fields_set: update_kwargs["env"] = request.env + if "auth_config" in fields_set: + update_kwargs["auth_config"] = _validate_auth_config_or_400(request.auth_config) server = await update_mcp_server( db, @@ -204,6 +270,8 @@ async def update_mcp_server_route( **update_kwargs, ) return {"success": True, "data": server.to_dict()} + except HTTPException: + raise except ValueError as ve: raise HTTPException(status_code=404, detail=str(ve)) except Exception as e: @@ -214,6 +282,7 @@ async def update_mcp_server_route( @mcp.delete("/{name}") async def delete_mcp_server_route( name: str, + hard: bool = False, current_user: User = Depends(get_admin_user), db: AsyncSession = Depends(get_db), ): @@ -221,13 +290,26 @@ async def delete_mcp_server_route( try: # 检查是否为系统内置服务器 server = await get_mcp_server(db, name) - if server and server.created_by == "system": + if not server: + raise HTTPException(status_code=404, detail=f"服务器 '{name}' 不存在") + if server.created_by == "system": raise HTTPException(status_code=403, detail="系统内置的 MCP 服务器无法删除") + if not hard: + await set_server_enabled(db, name, False, current_user.username) + return {"success": True, "message": f"服务器 '{name}' 已退役"} + + if bool(server.enabled): + raise HTTPException(status_code=409, detail="请先退役服务器,再执行硬删除") + + dependency_summary = await get_mcp_server_dependency_summary(db, name) + if dependency_summary["has_references"]: + raise HTTPException(status_code=409, detail=dependency_summary) + deleted = await delete_mcp_server(db, name) if not deleted: raise HTTPException(status_code=404, detail=f"服务器 '{name}' 不存在") - return {"success": True, "message": f"服务器 '{name}' 已删除"} + return {"success": True, "message": f"服务器 '{name}' 已彻底删除"} except HTTPException: raise except Exception as e: @@ -248,7 +330,11 @@ async def test_mcp_server( ): """测试 MCP 服务器连接""" try: - await get_server_or_404(db, name) + server = await get_server_or_404(db, name) + if server.auth_config_json: + auth_config = MCPAuthConfig.model_validate(server.auth_config_json) + if auth_config.binding_scope != "inline": + raise HTTPException(status_code=400, detail="该 MCP 需要绑定连接,请在连接页测试具体连接") try: tools = await get_all_mcp_tools(name) @@ -289,6 +375,195 @@ async def update_mcp_server_status_route( raise HTTPException(status_code=500, detail=str(e)) +# ============================================================================= +# === MCP 连接管理 === +# ============================================================================= + + +@mcp.get("/{name}/connections") +async def get_mcp_connections( + name: str, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + del current_user + try: + await get_server_or_404(db, name) + connections = await list_mcp_connections(db, server_name=name) + return {"success": True, "data": [item.to_dict() for item in connections]} + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to list MCP connections: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.post("/{name}/connections") +async def create_mcp_connection_route( + name: str, + request: CreateMcpConnectionRequest, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + try: + await get_server_or_404(db, name) + connection = await create_mcp_connection( + db, + server_name=name, + scope_type=request.scope_type, + scope_id=request.scope_id, + display_name=request.display_name, + external_subject=request.external_subject, + status=request.status, + credential_blob=_normalize_credential_blob(request.credential), + meta_json=request.meta_json, + created_by=current_user.username, + ) + return {"success": True, "data": connection.to_dict()} + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to create MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.put("/{name}/connections/{connection_id}/status") +async def update_mcp_connection_status_route( + name: str, + connection_id: int, + request: UpdateMcpConnectionStatusRequest, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + try: + await get_server_or_404(db, name) + await get_connection_for_server_or_404(db, name, connection_id) + connection = await set_mcp_connection_status( + db, + connection_id, + status=request.status, + updated_by=current_user.username, + ) + return {"success": True, "data": connection.to_dict(), "message": "连接状态已更新"} + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to update MCP connection status: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.put("/{name}/connections/{connection_id}") +async def update_mcp_connection_route( + name: str, + connection_id: int, + request: UpdateMcpConnectionRequest, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + try: + await get_server_or_404(db, name) + await get_connection_for_server_or_404(db, name, connection_id) + fields_set = request.model_fields_set + update_kwargs = {} + if "credential" in fields_set: + update_kwargs["credential_blob"] = _normalize_credential_blob(request.credential) + if "display_name" in fields_set: + update_kwargs["display_name"] = request.display_name + if "external_subject" in fields_set: + update_kwargs["external_subject"] = request.external_subject + if "meta_json" in fields_set: + update_kwargs["meta_json"] = request.meta_json + if "status" in fields_set: + update_kwargs["status"] = request.status + + connection = await update_mcp_connection( + db, + connection_id, + updated_by=current_user.username, + **update_kwargs, + ) + return {"success": True, "data": connection.to_dict(), "message": "连接已更新"} + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to update MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.delete("/{name}/connections/{connection_id}") +async def delete_mcp_connection_route( + name: str, + connection_id: int, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + del current_user + try: + await get_server_or_404(db, name) + await get_connection_for_server_or_404(db, name, connection_id) + deleted = await delete_mcp_connection(db, connection_id) + if not deleted: + raise HTTPException(status_code=404, detail=f"连接 '{connection_id}' 不存在") + return {"success": True, "message": "连接已删除"} + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.post("/{name}/connections/{connection_id}/test") +async def test_mcp_connection_route( + name: str, + connection_id: int, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + try: + await get_server_or_404(db, name) + await get_connection_for_server_or_404(db, name, connection_id) + result = await test_mcp_connection(db, connection_id, updated_by=current_user.username) + return { + "success": True, + "tool_count": result["tool_count"], + "message": f"连接成功,共发现 {result['tool_count']} 个工具", + } + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to test MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@mcp.post("/{name}/connections/{connection_id}/reauth") +async def reauthorize_mcp_connection_route( + name: str, + connection_id: int, + current_user: User = Depends(get_admin_user), + db: AsyncSession = Depends(get_db), +): + try: + await get_server_or_404(db, name) + await get_connection_for_server_or_404(db, name, connection_id) + connection = await reauthorize_mcp_connection(db, connection_id, updated_by=current_user.username) + return {"success": True, "data": connection.to_dict(), "message": "连接已重置并重新激活"} + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to reauthorize MCP connection: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # ============================================================================= # === MCP 工具管理 === # ============================================================================= diff --git a/backend/test/e2e/test_mcp_admin_flow_e2e.py b/backend/test/e2e/test_mcp_admin_flow_e2e.py new file mode 100644 index 000000000..c58403b0d --- /dev/null +++ b/backend/test/e2e/test_mcp_admin_flow_e2e.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import uuid + +import httpx +import pytest + +pytestmark = [pytest.mark.asyncio, pytest.mark.e2e, pytest.mark.slow] + + +def _build_server_name(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def _build_auth_config() -> dict: + return { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "binding", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${access_token}", + }, + { + "name": "X-Yuxi-User", + "value_template": "${context.user_id}", + }, + { + "name": "X-Yuxi-Department", + "value_template": "${context.department_id}", + }, + ], + }, + "refresh_policy": { + "pre_refresh_seconds": 300, + "retry_once_on_401": True, + }, + "token_request": { + "url": "http://internal-gateway.local/token", + "method": "POST", + "body_type": "json", + "headers": { + "Content-Type": "application/json", + }, + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + "user_id": "${context.user_id}", + "department_id": "${context.department_id}", + }, + "response_map": { + "access_token": "data.access_token", + "refresh_token": "data.refresh_token", + "expires_in": "data.expires_in", + }, + }, + } + + +async def _cleanup_server(client: httpx.AsyncClient, headers: dict[str, str], server_name: str) -> None: + list_response = await client.get(f"/api/system/mcp-servers/{server_name}/connections", headers=headers) + if list_response.status_code == 200: + for connection in list_response.json().get("data", []): + await client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{connection['id']}", + headers=headers, + ) + + await client.delete(f"/api/system/mcp-servers/{server_name}", headers=headers) + await client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=headers, + ) + + +async def test_mcp_admin_flow_e2e_supports_dynamic_auth_connections( + e2e_client: httpx.AsyncClient, + e2e_headers: dict[str, str], +): + invalid_server_name = _build_server_name("e2e-mcp-invalid-auth") + invalid_response = await e2e_client.post( + "/api/system/mcp-servers", + json={ + "name": invalid_server_name, + "transport": "streamable_http", + "url": "http://mcp-upstream.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${access_token}", + } + ], + }, + }, + }, + headers=e2e_headers, + ) + assert invalid_response.status_code == 400, invalid_response.text + assert "auth_config 配置无效" in invalid_response.json()["detail"] + + server_name = _build_server_name("e2e-mcp-auth") + create_server_response = await e2e_client.post( + "/api/system/mcp-servers", + json={ + "name": server_name, + "transport": "streamable_http", + "url": "http://mcp-upstream.local/mcp", + "description": "e2e mcp auth server", + "auth_config": _build_auth_config(), + }, + headers=e2e_headers, + ) + assert create_server_response.status_code == 200, create_server_response.text + + try: + server_payload = create_server_response.json()["data"] + assert server_payload["name"] == server_name + assert server_payload["auth_config"]["provider"] == "custom_http_token" + + create_connection_response = await e2e_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "system", + "scope_id": "ignored-by-normalization", + "display_name": "全局共享连接", + "external_subject": "gateway-service-account", + "credential": { + "secrets": { + "client_id": "cid-1", + "client_secret": "secret-1", + }, + "refresh_token": "refresh-1", + }, + "meta_json": {"tenant": "shared"}, + }, + headers=e2e_headers, + ) + assert create_connection_response.status_code == 200, create_connection_response.text + connection_payload = create_connection_response.json()["data"] + connection_id = connection_payload["id"] + assert connection_payload["scope_type"] == "system" + assert connection_payload["scope_id"] == "global" + assert connection_payload["status"] == "active" + assert connection_payload["has_credentials"] is True + + list_connections_response = await e2e_client.get( + f"/api/system/mcp-servers/{server_name}/connections", + headers=e2e_headers, + ) + assert list_connections_response.status_code == 200, list_connections_response.text + assert list_connections_response.json()["data"] == [connection_payload] + + retire_response = await e2e_client.delete( + f"/api/system/mcp-servers/{server_name}", + headers=e2e_headers, + ) + assert retire_response.status_code == 200, retire_response.text + + hard_delete_conflict_response = await e2e_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=e2e_headers, + ) + assert hard_delete_conflict_response.status_code == 409, hard_delete_conflict_response.text + dependency_payload = hard_delete_conflict_response.json()["detail"] + assert dependency_payload["has_references"] is True + assert dependency_payload["connections"] == [ + { + "scope_type": "system", + "scope_id": "global", + "status": "active", + } + ] + + delete_connection_response = await e2e_client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}", + headers=e2e_headers, + ) + assert delete_connection_response.status_code == 200, delete_connection_response.text + + hard_delete_response = await e2e_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=e2e_headers, + ) + assert hard_delete_response.status_code == 200, hard_delete_response.text + finally: + await _cleanup_server(e2e_client, e2e_headers, server_name) diff --git a/backend/test/integration/api/test_mcp_router.py b/backend/test/integration/api/test_mcp_router.py new file mode 100644 index 000000000..b4c3ffdfb --- /dev/null +++ b/backend/test/integration/api/test_mcp_router.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +import uuid + +import pytest + +pytestmark = [pytest.mark.asyncio, pytest.mark.integration] + + +def _build_server_name(prefix: str) -> str: + return f"{prefix}-{uuid.uuid4().hex[:8]}" + + +def _build_auth_config() -> dict: + return { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "binding", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${access_token}", + }, + { + "name": "X-Yuxi-User", + "value_template": "${context.user_id}", + }, + { + "name": "X-Yuxi-Department", + "value_template": "${context.department_id}", + }, + ], + }, + "refresh_policy": { + "pre_refresh_seconds": 300, + "retry_once_on_401": True, + }, + "token_request": { + "url": "http://internal-gateway.local/token", + "method": "POST", + "body_type": "json", + "headers": { + "Content-Type": "application/json", + }, + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + "user_id": "${context.user_id}", + "department_id": "${context.department_id}", + }, + "response_map": { + "access_token": "data.access_token", + "refresh_token": "data.refresh_token", + "expires_in": "data.expires_in", + }, + }, + } + + +async def _create_server(test_client, admin_headers: dict[str, str], name: str) -> None: + response = await test_client.post( + "/api/system/mcp-servers", + json={ + "name": name, + "transport": "streamable_http", + "url": "http://mcp-upstream.local/mcp", + "description": "pytest mcp auth server", + "auth_config": _build_auth_config(), + }, + headers=admin_headers, + ) + assert response.status_code == 200, response.text + + +async def _cleanup_server(test_client, admin_headers: dict[str, str], name: str) -> None: + list_response = await test_client.get(f"/api/system/mcp-servers/{name}/connections", headers=admin_headers) + if list_response.status_code == 200: + for connection in list_response.json().get("data", []): + await test_client.delete( + f"/api/system/mcp-servers/{name}/connections/{connection['id']}", + headers=admin_headers, + ) + + await test_client.delete(f"/api/system/mcp-servers/{name}", headers=admin_headers) + await test_client.delete( + f"/api/system/mcp-servers/{name}", + params={"hard": "true"}, + headers=admin_headers, + ) + + +async def test_admin_can_manage_mcp_server_connections_via_real_api(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-auth") + await _create_server(test_client, admin_headers, server_name) + + try: + get_response = await test_client.get(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + assert get_response.status_code == 200, get_response.text + get_payload = get_response.json()["data"] + assert get_payload["name"] == server_name + assert get_payload["auth_config"]["provider"] == "custom_http_token" + + update_response = await test_client.put( + f"/api/system/mcp-servers/{server_name}", + json={ + "description": "updated auth server", + "auth_config": { + **_build_auth_config(), + "refresh_policy": { + "pre_refresh_seconds": 120, + "retry_once_on_401": True, + }, + }, + }, + headers=admin_headers, + ) + assert update_response.status_code == 200, update_response.text + updated_payload = update_response.json()["data"] + assert updated_payload["name"] == server_name + assert updated_payload["description"] == "updated auth server" + assert updated_payload["auth_config"]["refresh_policy"]["pre_refresh_seconds"] == 120 + + create_connection_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "department", + "scope_id": "finance-dept", + "display_name": "财务共享连接", + "external_subject": "finance-bot", + "credential": { + "secrets": { + "client_id": "cid-1", + "client_secret": "secret-1", + }, + "refresh_token": "refresh-1", + }, + "meta_json": {"tenant": "finance"}, + }, + headers=admin_headers, + ) + assert create_connection_response.status_code == 200, create_connection_response.text + connection_payload = create_connection_response.json()["data"] + connection_id = connection_payload["id"] + assert connection_payload["scope_type"] == "department" + assert connection_payload["display_name"] == "财务共享连接" + assert connection_payload["has_credentials"] is True + assert "credential_blob" not in connection_payload + + list_connections_response = await test_client.get( + f"/api/system/mcp-servers/{server_name}/connections", + headers=admin_headers, + ) + assert list_connections_response.status_code == 200, list_connections_response.text + listed_connections = list_connections_response.json()["data"] + assert len(listed_connections) == 1 + assert listed_connections[0]["id"] == connection_id + assert listed_connections[0]["has_credentials"] is True + assert "credential_blob" not in listed_connections[0] + + update_connection_response = await test_client.put( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}", + json={ + "display_name": "财务共享连接-更新", + "meta_json": {"tenant": "finance", "stage": "updated"}, + }, + headers=admin_headers, + ) + assert update_connection_response.status_code == 200, update_connection_response.text + assert update_connection_response.json()["data"]["display_name"] == "财务共享连接-更新" + assert update_connection_response.json()["data"]["meta_json"]["stage"] == "updated" + + status_response = await test_client.put( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}/status", + json={"status": "reauth_required"}, + headers=admin_headers, + ) + assert status_response.status_code == 200, status_response.text + assert status_response.json()["data"]["status"] == "reauth_required" + + reauth_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}/reauth", + headers=admin_headers, + ) + assert reauth_response.status_code == 200, reauth_response.text + assert reauth_response.json()["data"]["status"] == "active" + + delete_connection_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}", + headers=admin_headers, + ) + assert delete_connection_response.status_code == 200, delete_connection_response.text + + retire_response = await test_client.delete(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + assert retire_response.status_code == 200, retire_response.text + + hard_delete_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=admin_headers, + ) + assert hard_delete_response.status_code == 200, hard_delete_response.text + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + +async def test_hard_delete_mcp_server_returns_dependency_summary_when_connections_exist( + test_client, admin_headers +): + server_name = _build_server_name("pytest-mcp-delete") + await _create_server(test_client, admin_headers, server_name) + + try: + create_connection_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "department", + "scope_id": "finance-dept", + "display_name": "财务共享连接", + "credential": { + "secrets": { + "client_id": "cid-1", + "client_secret": "secret-1", + } + }, + }, + headers=admin_headers, + ) + assert create_connection_response.status_code == 200, create_connection_response.text + connection_id = create_connection_response.json()["data"]["id"] + + retire_response = await test_client.delete(f"/api/system/mcp-servers/{server_name}", headers=admin_headers) + assert retire_response.status_code == 200, retire_response.text + + hard_delete_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=admin_headers, + ) + assert hard_delete_response.status_code == 409, hard_delete_response.text + detail = hard_delete_response.json()["detail"] + assert detail["has_references"] is True + assert detail["connections"] == [ + { + "scope_type": "department", + "scope_id": "finance-dept", + "status": "active", + } + ] + + delete_connection_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}/connections/{connection_id}", + headers=admin_headers, + ) + assert delete_connection_response.status_code == 200, delete_connection_response.text + + hard_delete_after_cleanup_response = await test_client.delete( + f"/api/system/mcp-servers/{server_name}", + params={"hard": "true"}, + headers=admin_headers, + ) + assert hard_delete_after_cleanup_response.status_code == 200, hard_delete_after_cleanup_response.text + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + +async def test_bound_auth_server_test_endpoint_requires_connection_level_testing(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-bound-test") + await _create_server(test_client, admin_headers, server_name) + + try: + response = await test_client.post(f"/api/system/mcp-servers/{server_name}/test", headers=admin_headers) + assert response.status_code == 400, response.text + assert "需要绑定连接" in response.json()["detail"] + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + +async def test_create_mcp_server_rejects_invalid_auth_config_via_real_api(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-invalid-auth") + + response = await test_client.post( + "/api/system/mcp-servers", + json={ + "name": server_name, + "transport": "streamable_http", + "url": "http://mcp-upstream.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [ + { + "name": "Authorization", + "value_template": "Bearer ${access_token}", + } + ], + }, + }, + }, + headers=admin_headers, + ) + + assert response.status_code == 400, response.text + assert "auth_config 配置无效" in response.json()["detail"] + + +async def test_create_system_connection_defaults_scope_id_to_global_via_real_api(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-system-scope") + await _create_server(test_client, admin_headers, server_name) + + try: + response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/connections", + json={ + "scope_type": "system", + "scope_id": "", + "display_name": "全局共享连接", + "credential": {"secrets": {"client_id": "cid-1", "client_secret": "secret-1"}}, + }, + headers=admin_headers, + ) + assert response.status_code == 200, response.text + payload = response.json()["data"] + assert payload["scope_type"] == "system" + assert payload["scope_id"] == "global" + finally: + await _cleanup_server(test_client, admin_headers, server_name) diff --git a/backend/test/unit/middlewares/test_runtime_config_middleware.py b/backend/test/unit/middlewares/test_runtime_config_middleware.py new file mode 100644 index 000000000..24010b98c --- /dev/null +++ b/backend/test/unit/middlewares/test_runtime_config_middleware.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +import yuxi.agents.middlewares.runtime_config_middleware as runtime_config_middleware +from yuxi.agents.middlewares.runtime_config_middleware import RuntimeConfigMiddleware + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_tools_from_context_passes_auth_context_to_mcp_loader(monkeypatch: pytest.MonkeyPatch): + captured: list[tuple[str, str | None, str | None]] = [] + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + captured.append((server_name, auth_context.user_id, auth_context.department_id)) + return [] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = RuntimeConfigMiddleware() + context = SimpleNamespace( + tools=[], + mcps=["finance-gateway"], + user_id="user-1", + department_id="dept-9", + ) + + tools = await middleware.get_tools_from_context(context) + + assert tools == [] + assert captured == [("finance-gateway", "user-1", "dept-9")] diff --git a/backend/test/unit/middlewares/test_skills_middleware.py b/backend/test/unit/middlewares/test_skills_middleware.py new file mode 100644 index 000000000..1380c15fd --- /dev/null +++ b/backend/test/unit/middlewares/test_skills_middleware.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +import yuxi.agents.middlewares.skills_middleware as skills_middleware +from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_mcp_tools_from_context_passes_auth_context_to_mcp_loader(monkeypatch: pytest.MonkeyPatch): + captured: list[tuple[str, str | None, str | None]] = [] + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + captured.append((server_name, auth_context.user_id, auth_context.department_id)) + return [] + + monkeypatch.setattr(skills_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = SkillsMiddleware() + context = SimpleNamespace( + mcps=["finance-gateway"], + user_id="user-1", + department_id="dept-9", + ) + + tools = await middleware._get_mcp_tools_from_context(context) + + assert tools == [] + assert captured == [("finance-gateway", "user-1", "dept-9")] diff --git a/backend/test/unit/routers/test_mcp_internal_router.py b/backend/test/unit/routers/test_mcp_internal_router.py new file mode 100644 index 000000000..c7db1b71d --- /dev/null +++ b/backend/test/unit/routers/test_mcp_internal_router.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import httpx +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from server.routers.mcp_internal_router import mcp_internal +from server.utils.auth_middleware import get_db +from yuxi.services.mcp_auth.orchestrator import AuthContext + + +def _build_app() -> FastAPI: + app = FastAPI() + app.include_router(mcp_internal, prefix="/api") + + async def fake_db(): + return None + + app.dependency_overrides[get_db] = fake_db + return app + + +def test_internal_proxy_route_forwards_request(monkeypatch): + class DummyServer: + name = "finance-proxy" + transport = "streamable_http" + auth_config_json = { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + } + + class DummyConnection: + status = "active" + meta_json = {} + + async def fake_get_mcp_server(db, name): + del db + assert name == "finance-proxy" + return DummyServer() + + async def fake_load_connection(db, *, server, auth_context): + del db + assert server.name == "finance-proxy" + assert auth_context.department_id == "dep-1" + return DummyConnection() + + async def fake_proxy_mcp_request(server, **kwargs): + del kwargs + assert server.name == "finance-proxy" + return httpx.Response( + 200, + json={"ok": True}, + headers={"content-type": "application/json"}, + ) + + monkeypatch.setattr( + "server.routers.mcp_internal_router.decode_proxy_access_token", + lambda token, server_name: AuthContext(user_id="user-1", department_id="dep-1"), + ) + monkeypatch.setattr("server.routers.mcp_internal_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_internal_router._load_active_connection", fake_load_connection) + monkeypatch.setattr("server.routers.mcp_internal_router.proxy_mcp_request", fake_proxy_mcp_request) + + client = TestClient(_build_app()) + resp = client.post( + "/api/internal/mcp-proxy/finance-proxy", + headers={"X-Yuxi-MCP-Proxy-Token": "test-token", "content-type": "application/json"}, + json={"jsonrpc": "2.0", "id": 1}, + ) + + assert resp.status_code == 200, resp.text + assert resp.json() == {"ok": True} + + +def test_internal_proxy_route_requires_internal_token(): + client = TestClient(_build_app()) + resp = client.post("/api/internal/mcp-proxy/finance-proxy", json={"jsonrpc": "2.0", "id": 1}) + assert resp.status_code == 401, resp.text diff --git a/backend/test/unit/routers/test_mcp_router.py b/backend/test/unit/routers/test_mcp_router.py index 4e1bc994c..1d3d969e9 100644 --- a/backend/test/unit/routers/test_mcp_router.py +++ b/backend/test/unit/routers/test_mcp_router.py @@ -133,3 +133,469 @@ async def fake_get_all_mcp_servers(db): assert data_user["name"] == "test-mcp" assert data_user["description"] == "test mcp description" assert data_user["enabled"] is True + + +def test_create_mcp_server_forwards_auth_config(monkeypatch): + captured = {} + + class DummyServer: + def to_dict(self): + return {"name": "gateway", "auth_config": {"provider": "custom_http_token"}} + + async def fake_create_mcp_server(db, **kwargs): + del db + captured.update(kwargs) + return DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.create_mcp_server", fake_create_mcp_server) + + client = TestClient(_build_app()) + resp = client.post( + "/api/system/mcp-servers", + json={ + "name": "gateway", + "transport": "streamable_http", + "url": "http://gateway.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + }, + }, + ) + assert resp.status_code == 200, resp.text + assert captured["auth_config"] == { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 0, "retry_once_on_401": False}, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + } + + +def test_update_mcp_server_forwards_auth_config(monkeypatch): + captured = {} + + class DummyServer: + def to_dict(self): + return {"name": "gateway", "auth_config": {"provider": "bound_secret"}} + + async def fake_update_mcp_server(db, name, **kwargs): + del db + captured["name"] = name + captured.update(kwargs) + return DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.update_mcp_server", fake_update_mcp_server) + + client = TestClient(_build_app()) + resp = client.put( + "/api/system/mcp-servers/gateway", + json={ + "description": "updated", + "auth_config": { + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + }, + ) + assert resp.status_code == 200, resp.text + assert captured["name"] == "gateway" + assert captured["auth_config"] == { + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 0, "retry_once_on_401": False}, + "token_request": None, + } + + +def test_create_mcp_server_rejects_invalid_auth_config(monkeypatch): + async def fake_create_mcp_server(db, **kwargs): + raise AssertionError("create_mcp_server should not be called when auth_config is invalid") + + monkeypatch.setattr("server.routers.mcp_router.create_mcp_server", fake_create_mcp_server) + + client = TestClient(_build_app()) + resp = client.post( + "/api/system/mcp-servers", + json={ + "name": "gateway", + "transport": "streamable_http", + "url": "http://gateway.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + }, + }, + ) + assert resp.status_code == 400, resp.text + assert "auth_config 配置无效" in resp.json()["detail"] + + +def test_list_mcp_connections(monkeypatch): + class DummyConnection: + def __init__(self, connection_id): + self.connection_id = connection_id + + def to_dict(self): + return {"id": self.connection_id, "scope_type": "department", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_list_mcp_connections(db, **kwargs): + del db, kwargs + return [DummyConnection(1), DummyConnection(2)] + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.list_mcp_connections", fake_list_mcp_connections) + + client = TestClient(_build_app()) + resp = client.get("/api/system/mcp-servers/gateway/connections") + assert resp.status_code == 200, resp.text + assert resp.json()["data"] == [ + {"id": 1, "scope_type": "department", "status": "active"}, + {"id": 2, "scope_type": "department", "status": "active"}, + ] + + +def test_create_mcp_connection(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 7, "scope_type": "department", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_create_mcp_connection(db, **kwargs): + del db + captured.update(kwargs) + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.create_mcp_connection", fake_create_mcp_connection) + + client = TestClient(_build_app()) + resp = client.post( + "/api/system/mcp-servers/gateway/connections", + json={ + "scope_type": "department", + "scope_id": "42", + "display_name": "财务部共享连接", + "external_subject": "finance-user", + "credential": {"secrets": {"access_token": "token-1"}}, + "meta_json": {"tenant": "finance"}, + }, + ) + assert resp.status_code == 200, resp.text + assert captured["server_name"] == "gateway" + assert captured["scope_type"] == "department" + assert captured["scope_id"] == "42" + assert captured["credential_blob"] == '{"secrets": {"access_token": "token-1"}}' + assert captured["created_by"] == "admin" + + +def test_update_mcp_connection_status(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 7, "status": "reauth_required"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_set_mcp_connection_status(db, connection_id, **kwargs): + del db + captured["connection_id"] = connection_id + captured.update(kwargs) + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.set_mcp_connection_status", fake_set_mcp_connection_status) + + client = TestClient(_build_app()) + resp = client.put( + "/api/system/mcp-servers/gateway/connections/7/status", + json={"status": "reauth_required"}, + ) + assert resp.status_code == 200, resp.text + assert captured == { + "connection_id": 7, + "status": "reauth_required", + "updated_by": "admin", + } + + +def test_update_mcp_connection(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 7, "display_name": "新连接名", "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_update_mcp_connection(db, connection_id, **kwargs): + del db + captured["connection_id"] = connection_id + captured.update(kwargs) + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.update_mcp_connection", fake_update_mcp_connection) + + client = TestClient(_build_app()) + resp = client.put( + "/api/system/mcp-servers/gateway/connections/7", + json={ + "display_name": "新连接名", + "credential": {"secrets": {"access_token": "token-2"}}, + }, + ) + assert resp.status_code == 200, resp.text + assert captured["connection_id"] == 7 + assert captured["display_name"] == "新连接名" + assert captured["credential_blob"] == '{"secrets": {"access_token": "token-2"}}' + assert captured["updated_by"] == "admin" + + +def test_delete_mcp_connection(monkeypatch): + captured = {} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_delete_mcp_connection(db, connection_id): + del db + captured["connection_id"] = connection_id + return True + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.delete_mcp_connection", fake_delete_mcp_connection) + + client = TestClient(_build_app()) + resp = client.delete("/api/system/mcp-servers/gateway/connections/7") + assert resp.status_code == 200, resp.text + assert captured == {"connection_id": 7} + + +def test_test_mcp_connection_route(monkeypatch): + captured = {} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_test_mcp_connection(db, connection_id, *, updated_by=None): + del db + captured["connection_id"] = connection_id + captured["updated_by"] = updated_by + return {"tool_count": 3} + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.test_mcp_connection", fake_test_mcp_connection) + + client = TestClient(_build_app()) + resp = client.post("/api/system/mcp-servers/gateway/connections/7/test") + assert resp.status_code == 200, resp.text + assert resp.json()["tool_count"] == 3 + assert captured == {"connection_id": 7, "updated_by": "admin"} + + +def test_reauthorize_mcp_connection_route(monkeypatch): + captured = {} + + class DummyConnection: + def to_dict(self): + return {"id": 7, "status": "active"} + + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "gateway"})() + + async def fake_reauthorize_mcp_connection(db, connection_id, *, updated_by=None): + del db + captured["connection_id"] = connection_id + captured["updated_by"] = updated_by + return DummyConnection() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.reauthorize_mcp_connection", fake_reauthorize_mcp_connection) + + client = TestClient(_build_app()) + resp = client.post("/api/system/mcp-servers/gateway/connections/7/reauth") + assert resp.status_code == 200, resp.text + assert captured == {"connection_id": 7, "updated_by": "admin"} + + +def test_update_mcp_connection_status_rejects_connection_from_other_server(monkeypatch): + async def fake_get_mcp_server(db, name): + del db + return type("DummyServer", (), {"name": name})() + + async def fake_get_mcp_connection(db, connection_id): + del db + return type("DummyConnectionRef", (), {"id": connection_id, "server_name": "other-gateway"})() + + async def fake_set_mcp_connection_status(db, connection_id, **kwargs): + raise AssertionError("should not update a connection that belongs to another server") + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_connection", fake_get_mcp_connection) + monkeypatch.setattr("server.routers.mcp_router.set_mcp_connection_status", fake_set_mcp_connection_status) + + client = TestClient(_build_app()) + resp = client.put( + "/api/system/mcp-servers/gateway/connections/7/status", + json={"status": "reauth_required"}, + ) + assert resp.status_code == 404, resp.text + + +def test_test_mcp_server_requires_connection_level_test_for_bound_auth(monkeypatch): + class DummyServer: + auth_config_json = { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + } + + async def fake_get_server_or_404(db, name): + del db, name + return DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.get_server_or_404", fake_get_server_or_404) + + client = TestClient(_build_app()) + resp = client.post("/api/system/mcp-servers/gateway/test", json={}) + assert resp.status_code == 400, resp.text + + +def test_delete_mcp_server_defaults_to_retire(monkeypatch): + captured = {} + + class DummyServer: + created_by = "tester" + + def to_dict(self): + return {"name": "gateway", "enabled": False} + + async def fake_get_mcp_server(db, name): + del db + return DummyServer() + + async def fake_set_server_enabled(db, name, enabled, updated_by=None): + del db + captured["name"] = name + captured["enabled"] = enabled + captured["updated_by"] = updated_by + return False, DummyServer() + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.set_server_enabled", fake_set_server_enabled) + + client = TestClient(_build_app()) + resp = client.delete("/api/system/mcp-servers/gateway") + + assert resp.status_code == 200, resp.text + assert resp.json()["message"] == "服务器 'gateway' 已退役" + assert captured == { + "name": "gateway", + "enabled": False, + "updated_by": "admin", + } + + +def test_delete_mcp_server_hard_delete_returns_conflict(monkeypatch): + class DummyServer: + created_by = "tester" + enabled = 0 + + async def fake_get_mcp_server(db, name): + del db + return DummyServer() + + async def fake_get_dependency_summary(db, name): + del db, name + return { + "has_references": True, + "connections": [{"scope_type": "department", "scope_id": "42", "status": "active"}], + "skills": [{"slug": "finance-skill", "name": "Finance Skill"}], + "agent_configs": [], + } + + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_router.get_mcp_server_dependency_summary", fake_get_dependency_summary) + + client = TestClient(_build_app()) + resp = client.delete("/api/system/mcp-servers/gateway?hard=true") + + assert resp.status_code == 409, resp.text + assert resp.json()["detail"]["connections"] == [ + {"scope_type": "department", "scope_id": "42", "status": "active"} + ] diff --git a/backend/test/unit/services/test_chat_service_sync.py b/backend/test/unit/services/test_chat_service_sync.py index 16a147636..58422ba91 100644 --- a/backend/test/unit/services/test_chat_service_sync.py +++ b/backend/test/unit/services/test_chat_service_sync.py @@ -157,7 +157,12 @@ def fake_get_trace_info(_run_context): assert len(invoke_messages) == 1 assert isinstance(invoke_messages[0], HumanMessage) assert invoke_messages[0].content == "hello" - assert calls["invoke_input_context"] == {"temperature": 0.1, "user_id": "user-1", "thread_id": "thread-1"} + assert calls["invoke_input_context"] == { + "temperature": 0.1, + "user_id": "user-1", + "thread_id": "thread-1", + "department_id": "dept-1", + } assert calls["invoke_kwargs"] == { "callbacks": ["handler-1"], "metadata": {"langfuse_user_id": "user-1", "langfuse_session_id": "thread-1"}, @@ -246,12 +251,14 @@ def fake_agents_prompt(_thread_id: str, _user_id: str) -> str: {"system_prompt": "原始系统提示词", "temperature": 0.1}, thread_id="thread-1", user_id="user-1", + department_id="dept-9", ) assert context["system_prompt"] == "原始系统提示词\n\n用户工作区 agents/AGENTS.md 内容:\n回答前先读取 AGENTS.md" assert context["temperature"] == 0.1 assert context["thread_id"] == "thread-1" assert context["user_id"] == "user-1" + assert context["department_id"] == "dept-9" @pytest.mark.asyncio @@ -264,6 +271,8 @@ async def test_build_agent_input_context_keeps_prompt_when_workspace_agents_prom {"system_prompt": "原始系统提示词"}, thread_id="thread-1", user_id="user-1", + department_id="dept-9", ) assert context["system_prompt"] == "原始系统提示词" + assert context["department_id"] == "dept-9" diff --git a/backend/test/unit/services/test_mcp_auth_config_models.py b/backend/test/unit/services/test_mcp_auth_config_models.py new file mode 100644 index 000000000..61fd71fd4 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_config_models.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import os + +import pytest +from pydantic import ValidationError + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.config_models import MCPAuthConfig + + +def test_mcp_auth_config_applies_legacy_static_defaults(): + config = MCPAuthConfig.model_validate( + { + "version": 1, + "provider": "legacy_static", + "inject": { + "target": "headers", + "entries": [], + }, + } + ) + + assert config.binding_scope == "inline" + assert config.manifest_scope == "server" + assert config.refresh_policy.pre_refresh_seconds == 0 + assert config.refresh_policy.retry_once_on_401 is False + + +def test_mcp_auth_config_requires_token_request_for_dynamic_http_provider(): + with pytest.raises(ValidationError, match="token_request"): + MCPAuthConfig.model_validate( + { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + } + ) diff --git a/backend/test/unit/services/test_mcp_auth_crypto.py b/backend/test/unit/services/test_mcp_auth_crypto.py new file mode 100644 index 000000000..feee81e76 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_crypto.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import json +import os + +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.crypto import decrypt_credential_blob, encrypt_credential_blob + + +pytestmark = [pytest.mark.unit] + + +def test_encrypt_and_decrypt_credential_blob_round_trip(monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + plaintext = json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}, ensure_ascii=False) + + encrypted = encrypt_credential_blob(plaintext) + decrypted = decrypt_credential_blob(encrypted) + + assert encrypted != plaintext + assert json.loads(encrypted)["v"] == 1 + assert decrypted == plaintext + + +def test_decrypt_credential_blob_keeps_legacy_plaintext_payload(monkeypatch): + monkeypatch.delenv("MCP_CREDENTIALS_MASTER_KEY", raising=False) + plaintext = '{"secrets":{"access_token":"legacy-token"}}' + + assert decrypt_credential_blob(plaintext) == plaintext + + +def test_encrypt_credential_blob_requires_master_key(monkeypatch): + monkeypatch.delenv("MCP_CREDENTIALS_MASTER_KEY", raising=False) + + with pytest.raises(ValueError, match="MCP_CREDENTIALS_MASTER_KEY"): + encrypt_credential_blob('{"secrets":{"access_token":"token"}}') diff --git a/backend/test/unit/services/test_mcp_auth_models.py b/backend/test/unit/services/test_mcp_auth_models.py new file mode 100644 index 000000000..c71285fd7 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_models.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import os + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +@pytest_asyncio.fixture +async def mcp_auth_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +async def test_mcp_server_to_dict_and_mcp_config_include_auth_config(mcp_auth_session): + server = MCPServer( + name="gateway", + description="internal gateway", + transport="streamable_http", + url="http://gateway.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 600, "retry_once_on_401": True}, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + }, + created_by="tester", + updated_by="tester", + ) + mcp_auth_session.add(server) + await mcp_auth_session.commit() + + payload = server.to_dict() + config = server.to_mcp_config() + + assert payload["auth_config"]["provider"] == "custom_http_token" + assert config["auth_config"]["binding_scope"] == "department" + + +async def test_mcp_connection_persists_scoped_binding_and_hides_credentials_by_default(mcp_auth_session): + server = MCPServer( + name="finance-gateway", + description="finance", + transport="streamable_http", + url="http://finance.local/mcp", + created_by="tester", + updated_by="tester", + ) + mcp_auth_session.add(server) + await mcp_auth_session.commit() + + connection = MCPConnection( + server_name="finance-gateway", + scope_type="department", + scope_id="42", + display_name="财务部共享凭据", + external_subject="finance-user", + status="active", + credential_blob="encrypted-secret", + meta_json={"last_success_at": "2026-06-02T10:00:00Z"}, + created_by="tester", + updated_by="tester", + ) + mcp_auth_session.add(connection) + await mcp_auth_session.commit() + + result = await mcp_auth_session.execute(select(MCPConnection).where(MCPConnection.server_name == "finance-gateway")) + saved = result.scalar_one() + + safe_payload = saved.to_dict() + internal_payload = saved.to_dict(include_credentials=True) + + assert safe_payload["scope_type"] == "department" + assert safe_payload["has_credentials"] is True + assert "credential_blob" not in safe_payload + assert internal_payload["credential_blob"] == "encrypted-secret" diff --git a/backend/test/unit/services/test_mcp_auth_orchestrator.py b/backend/test/unit/services/test_mcp_auth_orchestrator.py new file mode 100644 index 000000000..0c83578d7 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_orchestrator.py @@ -0,0 +1,668 @@ +from __future__ import annotations + +import json +import os +from datetime import UTC, datetime, timedelta + +import httpx +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.crypto import encrypt_credential_blob +from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +class DummyTokenCache: + def __init__(self, token_payload: dict | None = None): + self.token_payload = token_payload + self.token_payloads = None + self.set_calls: list[tuple[int, dict]] = [] + self.deleted_connection_ids: list[int] = [] + self.acquire_calls: list[int] = [] + self.release_calls: list[int] = [] + self.acquire_result = True + + async def get_access_token(self, connection_id: int) -> dict | None: + del connection_id + if self.token_payloads is not None: + if self.token_payloads: + self.token_payload = self.token_payloads.pop(0) + else: + self.token_payload = None + return self.token_payload + + async def set_access_token(self, connection_id: int, token_payload: dict) -> None: + self.set_calls.append((connection_id, token_payload)) + self.token_payload = token_payload + + async def delete_access_token(self, connection_id: int) -> None: + self.deleted_connection_ids.append(connection_id) + self.token_payload = None + + async def acquire_refresh_lock(self, connection_id: int, *, ttl_seconds: int = 30) -> bool: + del ttl_seconds + self.acquire_calls.append(connection_id) + return self.acquire_result + + async def release_refresh_lock(self, connection_id: int) -> None: + self.release_calls.append(connection_id) + + +async def test_resolve_runtime_mcp_config_injects_bound_secret_header(): + os.environ["MCP_CREDENTIALS_MASTER_KEY"] = "local-test-master-key" + server = MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + server_name="finance-gateway", + scope_type="department", + scope_id="42", + credential_blob=encrypt_credential_blob(json.dumps({"secrets": {"access_token": "dept-token"}})), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="42"), + connection=connection, + ) + + assert resolved["transport"] == "streamable_http" + assert resolved["headers"] == { + "X-App": "yuxi", + "Authorization": "Bearer dept-token", + } + assert "auth_config" not in resolved + + +async def test_resolve_runtime_mcp_config_supports_raw_token_string_binding(): + os.environ["MCP_CREDENTIALS_MASTER_KEY"] = "local-test-master-key" + server = MCPServer( + name="raw-token-gateway", + transport="streamable_http", + url="http://raw.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "system", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + server_name="raw-token-gateway", + scope_type="system", + scope_id="global", + credential_blob=encrypt_credential_blob("raw-token-value"), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(), + connection=connection, + ) + + assert resolved["headers"] == {"Authorization": "Bearer raw-token-value"} + + +async def test_resolve_runtime_mcp_config_fetches_custom_http_token_with_user_context(): + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["headers"] = dict(request.headers) + captured["body"] = json.loads(request.content.decode("utf-8")) + return httpx.Response( + 200, + json={ + "data": { + "access_token": "fresh-token", + "refresh_token": "refresh-token", + "expires_in": 3600, + } + }, + ) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="corp-gateway", + transport="streamable_http", + url="http://corp.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 600, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "headers": { + "Content-Type": "application/json", + "X-Client-Id": "${secret.client_id}", + }, + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + "user_id": "${context.user_id}", + "department_id": "${context.department_id}", + }, + "response_map": { + "access_token": "data.access_token", + "refresh_token": "data.refresh_token", + "expires_in": "data.expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + server_name="corp-gateway", + scope_type="department", + scope_id="finance", + credential_blob=json.dumps({"secrets": {"client_id": "cid-1", "client_secret": "secret-1"}}), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="user-9", department_id="finance"), + connection=connection, + http_client=http_client, + ) + + await http_client.aclose() + + assert captured["url"] == "http://gateway.local/auth/token" + assert captured["body"] == { + "client_id": "cid-1", + "client_secret": "secret-1", + "user_id": "user-9", + "department_id": "finance", + } + assert resolved["headers"] == { + "X-App": "yuxi", + "Authorization": "Bearer fresh-token", + } + + +async def test_resolve_runtime_mcp_config_fetches_client_credentials_token(): + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["body"] = json.loads(request.content.decode("utf-8")) + return httpx.Response( + 200, + json={ + "access_token": "client-token", + "expires_in": 1800, + "token_type": "Bearer", + }, + ) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="client-credentials-mcp", + transport="streamable_http", + url="http://client.local/mcp", + auth_config_json={ + "version": 1, + "provider": "client_credentials", + "binding_scope": "system", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": { + "url": "http://gateway.local/oauth/token", + "method": "POST", + "body_type": "json", + "body_template": { + "grant_type": "client_credentials", + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + "response_map": { + "access_token": "access_token", + "expires_in": "expires_in", + "token_type": "token_type", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=11, + server_name="client-credentials-mcp", + scope_type="system", + scope_id="global", + credential_blob=json.dumps({"secrets": {"client_id": "cid-cc", "client_secret": "secret-cc"}}), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="d-1"), + connection=connection, + http_client=http_client, + ) + + await http_client.aclose() + + assert captured["url"] == "http://gateway.local/oauth/token" + assert captured["body"] == { + "grant_type": "client_credentials", + "client_id": "cid-cc", + "client_secret": "secret-cc", + } + assert resolved["headers"] == {"Authorization": "Bearer client-token"} + + +async def test_resolve_runtime_mcp_config_injects_stdio_env_from_secret_binding(): + server = MCPServer( + name="stdio-auth-mcp", + transport="stdio", + command="demo-server", + env={"LOG_LEVEL": "info"}, + auth_config_json={ + "version": 1, + "provider": "stdio_env", + "binding_scope": "user", + "manifest_scope": "binding", + "inject": { + "target": "env", + "entries": [ + {"name": "API_TOKEN", "value_template": "${secret.access_token}"}, + {"name": "YUXI_USER_ID", "value_template": "${context.user_id}"}, + ], + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + server_name="stdio-auth-mcp", + scope_type="user", + scope_id="user-1", + credential_blob=json.dumps({"secrets": {"access_token": "stdio-token"}}), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + connection=connection, + ) + + assert resolved["command"] == "demo-server" + assert resolved["env"] == { + "LOG_LEVEL": "info", + "API_TOKEN": "stdio-token", + "YUXI_USER_ID": "user-1", + } + + +async def test_resolve_runtime_mcp_config_uses_cached_custom_http_token_before_fetching(): + def handler(request: httpx.Request) -> httpx.Response: + raise AssertionError(f"unexpected token request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="corp-cache-mcp", + transport="streamable_http", + url="http://corp-cache.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 120, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "response_map": { + "access_token": "access_token", + "expires_in": "expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=21, + server_name="corp-cache-mcp", + scope_type="department", + scope_id="finance", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache( + { + "access_token": "cached-token", + "expires_at": (datetime.now(tz=UTC) + timedelta(minutes=10)).isoformat(), + "token_type": "Bearer", + } + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="finance"), + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + + await http_client.aclose() + + assert resolved["headers"] == {"Authorization": "Bearer cached-token"} + assert token_cache.set_calls == [] + + +async def test_resolve_runtime_mcp_config_refreshes_cached_token_when_expiring_soon(): + captured: list[tuple[str, dict]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + payload = json.loads(request.content.decode("utf-8")) + captured.append((str(request.url), payload)) + return httpx.Response( + 200, + json={ + "access_token": "refreshed-token", + "refresh_token": "refresh-next", + "expires_in": 3600, + "token_type": "Bearer", + }, + ) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="corp-refresh-mcp", + transport="streamable_http", + url="http://corp-refresh.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 300, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + "response_map": { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + "token_type": "token_type", + }, + "refresh": { + "url": "http://gateway.local/auth/refresh", + "method": "POST", + "body_type": "json", + "body_template": { + "refresh_token": "${token.refresh_token}", + }, + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=22, + server_name="corp-refresh-mcp", + scope_type="department", + scope_id="finance", + credential_blob=json.dumps( + { + "secrets": {"client_id": "cid", "client_secret": "secret"}, + "refresh_token": "refresh-old", + } + ), + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache( + { + "access_token": "stale-token", + "refresh_token": "refresh-old", + "expires_at": (datetime.now(tz=UTC) + timedelta(seconds=60)).isoformat(), + } + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="finance"), + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + + await http_client.aclose() + + assert captured == [ + ( + "http://gateway.local/auth/refresh", + { + "refresh_token": "refresh-old", + }, + ) + ] + assert resolved["headers"] == {"Authorization": "Bearer refreshed-token"} + assert token_cache.set_calls and token_cache.set_calls[0][0] == 22 + assert token_cache.set_calls[0][1]["access_token"] == "refreshed-token" + assert token_cache.acquire_calls == [22] + assert token_cache.release_calls == [22] + + +async def test_resolve_runtime_mcp_config_waits_for_refresh_lock_owner_to_publish_token(): + def handler(request: httpx.Request) -> httpx.Response: + raise AssertionError(f"unexpected token request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="corp-lock-mcp", + transport="streamable_http", + url="http://corp-lock.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 300, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "response_map": { + "access_token": "access_token", + "expires_in": "expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=24, + server_name="corp-lock-mcp", + scope_type="department", + scope_id="finance", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache() + token_cache.acquire_result = False + token_cache.token_payloads = [ + { + "access_token": "stale-token", + "expires_at": (datetime.now(tz=UTC) + timedelta(seconds=10)).isoformat(), + }, + { + "access_token": "fresh-from-other-worker", + "expires_at": (datetime.now(tz=UTC) + timedelta(minutes=30)).isoformat(), + }, + ] + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="u-1", department_id="finance"), + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + + await http_client.aclose() + + assert resolved["headers"] == {"Authorization": "Bearer fresh-from-other-worker"} + assert token_cache.acquire_calls == [24] + assert token_cache.release_calls == [] + assert token_cache.set_calls == [] + + +async def test_resolve_runtime_mcp_config_refreshes_authorization_code_token(): + captured: list[tuple[str, str]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append((request.method, str(request.url))) + if str(request.url) == "https://id.example.com/.well-known/openid-configuration": + return httpx.Response( + 200, + json={ + "token_endpoint": "https://id.example.com/oauth/token", + }, + ) + if str(request.url) == "https://id.example.com/oauth/token": + body_text = request.content.decode("utf-8") + assert "grant_type=refresh_token" in body_text + assert "refresh_token=refresh-old" in body_text + assert "client_id=oidc-client" in body_text + assert "client_secret=oidc-secret" in body_text + return httpx.Response( + 200, + json={ + "access_token": "oidc-access-token", + "refresh_token": "refresh-next", + "expires_in": 3600, + "token_type": "Bearer", + }, + ) + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="oidc-mcp", + transport="streamable_http", + url="http://oidc.local/mcp", + auth_config_json={ + "version": 1, + "provider": "authorization_code", + "binding_scope": "user", + "manifest_scope": "binding", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 120, "retry_once_on_401": True}, + "token_request": { + "issuer_url": "https://id.example.com", + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=23, + server_name="oidc-mcp", + scope_type="user", + scope_id="user-1", + credential_blob=json.dumps( + { + "secrets": { + "client_id": "oidc-client", + "client_secret": "oidc-secret", + }, + "refresh_token": "refresh-old", + } + ), + created_by="tester", + updated_by="tester", + ) + + resolved = await resolve_runtime_mcp_config( + server, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + connection=connection, + http_client=http_client, + ) + + await http_client.aclose() + + assert captured == [ + ("GET", "https://id.example.com/.well-known/openid-configuration"), + ("POST", "https://id.example.com/oauth/token"), + ] + assert resolved["headers"] == {"Authorization": "Bearer oidc-access-token"} diff --git a/backend/test/unit/services/test_mcp_auth_proxy_service.py b/backend/test/unit/services/test_mcp_auth_proxy_service.py new file mode 100644 index 000000000..4a6d8a16a --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_proxy_service.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import json +import os +from datetime import UTC, datetime, timedelta + +import httpx +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.services.mcp_auth.proxy_service import proxy_mcp_request +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +class DummyTokenCache: + def __init__(self, token_payload: dict | None = None): + self.token_payload = token_payload + self.deleted_connection_ids: list[int] = [] + self.set_calls: list[tuple[int, dict]] = [] + + async def get_access_token(self, connection_id: int) -> dict | None: + del connection_id + return self.token_payload + + async def delete_access_token(self, connection_id: int) -> None: + self.deleted_connection_ids.append(connection_id) + self.token_payload = None + + async def set_access_token(self, connection_id: int, token_payload: dict) -> None: + self.set_calls.append((connection_id, token_payload)) + self.token_payload = token_payload + + +async def test_proxy_mcp_request_retries_once_after_401_with_refreshed_token(): + observed_authorizations: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == "http://gateway.local/auth/token": + return httpx.Response( + 200, + json={ + "access_token": "fresh-token", + "refresh_token": "refresh-next", + "expires_in": 3600, + }, + ) + + if str(request.url) == "http://upstream.local/mcp": + observed_authorizations.append(request.headers.get("Authorization")) + if request.headers.get("Authorization") == "Bearer stale-token": + return httpx.Response(401, json={"error": "expired"}) + if request.headers.get("Authorization") == "Bearer fresh-token": + return httpx.Response(200, json={"result": "ok"}) + + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="proxy-retry", + transport="streamable_http", + url="http://upstream.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 60, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + "response_map": { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=41, + server_name="proxy-retry", + scope_type="department", + scope_id="dep-1", + status="active", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + meta_json={}, + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache( + { + "access_token": "stale-token", + "refresh_token": "refresh-old", + "expires_at": (datetime.now(tz=UTC) + timedelta(minutes=30)).isoformat(), + } + ) + + response = await proxy_mcp_request( + server, + connection=connection, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + method="POST", + headers={"content-type": "application/json"}, + query_params={}, + body=b'{"jsonrpc":"2.0","id":1}', + http_client=http_client, + token_cache=token_cache, + ) + + await http_client.aclose() + + assert response.status_code == 200 + assert response.json() == {"result": "ok"} + assert observed_authorizations == ["Bearer stale-token", "Bearer fresh-token"] + assert token_cache.deleted_connection_ids == [41] + assert token_cache.set_calls and token_cache.set_calls[0][0] == 41 + assert connection.status == "active" + + +async def test_proxy_mcp_request_marks_reauth_required_after_final_401(): + attempts = 0 + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal attempts + if str(request.url) == "http://gateway.local/auth/token": + return httpx.Response( + 200, + json={ + "access_token": f"fresh-token-{attempts}", + "refresh_token": "refresh-next", + "expires_in": 3600, + }, + ) + if str(request.url) == "http://upstream.local/mcp": + attempts += 1 + return httpx.Response(401, json={"error": "expired"}) + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="proxy-fail-401", + transport="streamable_http", + url="http://upstream.local/mcp", + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "refresh_policy": {"pre_refresh_seconds": 60, "retry_once_on_401": True}, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "body_type": "json", + "body_template": { + "client_id": "${secret.client_id}", + "client_secret": "${secret.client_secret}", + }, + "response_map": { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + }, + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=42, + server_name="proxy-fail-401", + scope_type="department", + scope_id="dep-1", + status="active", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + meta_json={}, + created_by="tester", + updated_by="tester", + ) + token_cache = DummyTokenCache( + { + "access_token": "stale-token", + "refresh_token": "refresh-old", + "expires_at": (datetime.now(tz=UTC) + timedelta(minutes=30)).isoformat(), + } + ) + + response = await proxy_mcp_request( + server, + connection=connection, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + method="POST", + headers={"content-type": "application/json"}, + query_params={}, + body=b'{"jsonrpc":"2.0","id":1}', + http_client=http_client, + token_cache=token_cache, + ) + + await http_client.aclose() + + assert response.status_code == 424 + assert response.json()["error"] == "reauth_required" + assert connection.status == "reauth_required" + assert connection.meta_json["last_error"]["code"] == "unauthorized" + + +async def test_proxy_mcp_request_records_scope_error_on_403(): + def handler(request: httpx.Request) -> httpx.Response: + if str(request.url) == "http://upstream.local/mcp": + return httpx.Response(403, json={"error": "forbidden"}) + raise AssertionError(f"unexpected request: {request.method} {request.url}") + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + server = MCPServer( + name="proxy-403", + transport="streamable_http", + url="http://upstream.local/mcp", + headers={"Authorization": "Bearer static-token"}, + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + created_by="tester", + updated_by="tester", + ) + connection = MCPConnection( + id=43, + server_name="proxy-403", + scope_type="department", + scope_id="dep-1", + status="active", + credential_blob=json.dumps({"secrets": {"access_token": "static-token"}}), + meta_json={}, + created_by="tester", + updated_by="tester", + ) + + response = await proxy_mcp_request( + server, + connection=connection, + auth_context=AuthContext(user_id="user-1", department_id="dep-1"), + method="POST", + headers={"content-type": "application/json"}, + query_params={}, + body=b'{"jsonrpc":"2.0","id":1}', + http_client=http_client, + token_cache=None, + ) + + await http_client.aclose() + + assert response.status_code == 403 + assert response.json()["error"] == "insufficient_scope" + assert connection.status == "active" + assert connection.meta_json["last_error"]["code"] == "insufficient_scope" diff --git a/backend/test/unit/services/test_mcp_auth_template_resolver.py b/backend/test/unit/services/test_mcp_auth_template_resolver.py new file mode 100644 index 000000000..9a14d75a6 --- /dev/null +++ b/backend/test/unit/services/test_mcp_auth_template_resolver.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import os + +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services.mcp_auth.template_resolver import TemplateResolutionError, resolve_template_value + + +def test_resolve_template_value_supports_nested_structures(): + resolved = resolve_template_value( + { + "headers": { + "Authorization": "Bearer ${access_token}", + "X-User-Id": "${context.user_id}", + }, + "body": { + "client_id": "${secret.client_id}", + "tenant": "${secret.extra.tenant_code}", + "department_id": "${context.department_id}", + }, + "args": ["--user=${context.user_id}", "--refresh=${token.refresh_token}"], + }, + context={"user_id": "u-100", "department_id": "d-9"}, + secret={"client_id": "cid-1", "extra": {"tenant_code": "finance"}}, + token={"refresh_token": "refresh-1"}, + access_token="access-1", + ) + + assert resolved == { + "headers": { + "Authorization": "Bearer access-1", + "X-User-Id": "u-100", + }, + "body": { + "client_id": "cid-1", + "tenant": "finance", + "department_id": "d-9", + }, + "args": ["--user=u-100", "--refresh=refresh-1"], + } + + +def test_resolve_template_value_raises_for_unknown_placeholder(): + with pytest.raises(TemplateResolutionError, match="context.missing"): + resolve_template_value( + {"user": "${context.missing}"}, + context={"user_id": "u-100"}, + secret={}, + token={}, + access_token="access-1", + ) diff --git a/backend/test/unit/services/test_mcp_connection_service.py b/backend/test/unit/services/test_mcp_connection_service.py new file mode 100644 index 000000000..348e7b598 --- /dev/null +++ b/backend/test/unit/services/test_mcp_connection_service.py @@ -0,0 +1,607 @@ +from __future__ import annotations + +import os + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services import mcp_service +from yuxi.services.mcp_auth.crypto import decrypt_credential_blob +from yuxi.storage.postgres.models_business import AgentConfig, Department, MCPConnection, MCPServer, Skill + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +@pytest_asyncio.fixture +async def connection_service_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +@pytest_asyncio.fixture +async def delete_semantics_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Department.__table__.create) + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + await conn.run_sync(Skill.__table__.create) + await conn.run_sync(AgentConfig.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +async def test_create_and_list_mcp_connections(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="finance-gateway", + scope_type="department", + scope_id="42", + display_name="财务部共享连接", + external_subject="finance-user", + credential_blob="encrypted-secret", + meta_json={"tenant": "finance"}, + created_by="tester", + ) + + listed = await mcp_service.list_mcp_connections(connection_service_session, server_name="finance-gateway") + + assert created.server_name == "finance-gateway" + assert created.scope_type == "department" + assert [item.id for item in listed] == [created.id] + + +async def test_create_mcp_connection_normalizes_system_scope_to_global(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="global-gateway", + transport="streamable_http", + url="http://global.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="global-gateway", + scope_type="system", + scope_id="", + display_name="全局共享连接", + credential_blob="encrypted-secret", + created_by="tester", + ) + + assert created.scope_type == "system" + assert created.scope_id == "global" + + +async def test_set_mcp_connection_status_updates_status(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="corp-gateway", + transport="streamable_http", + url="http://corp.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="corp-gateway", + scope_type="system", + scope_id="global", + display_name="全局共享连接", + credential_blob="encrypted-secret", + created_by="tester", + ) + + updated = await mcp_service.set_mcp_connection_status( + connection_service_session, + created.id, + status="reauth_required", + updated_by="admin", + ) + + assert updated.status == "reauth_required" + assert updated.updated_by == "admin" + + +async def test_create_mcp_connection_rejects_invalid_scope_type(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="invalid-scope-gateway", + transport="streamable_http", + url="http://invalid-scope.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + with pytest.raises(ValueError, match="scope_type"): + await mcp_service.create_mcp_connection( + connection_service_session, + server_name="invalid-scope-gateway", + scope_type="tenant", + scope_id="x", + created_by="tester", + ) + + +async def test_create_mcp_connection_rejects_missing_department_scope_id(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="missing-scope-id-gateway", + transport="streamable_http", + url="http://missing-scope-id.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + with pytest.raises(ValueError, match="scope_id"): + await mcp_service.create_mcp_connection( + connection_service_session, + server_name="missing-scope-id-gateway", + scope_type="department", + scope_id="", + created_by="tester", + ) + + +async def test_set_mcp_connection_status_rejects_invalid_status(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="invalid-status-gateway", + transport="streamable_http", + url="http://invalid-status.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="invalid-status-gateway", + scope_type="system", + scope_id="global", + created_by="tester", + ) + + with pytest.raises(ValueError, match="status"): + await mcp_service.set_mcp_connection_status( + connection_service_session, + created.id, + status="broken", + updated_by="admin", + ) + + +async def test_create_mcp_connection_encrypts_credentials(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="secure-gateway", + transport="streamable_http", + url="http://secure.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + plaintext = '{"secrets":{"access_token":"secure-token"}}' + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="secure-gateway", + scope_type="system", + scope_id="global", + credential_blob=plaintext, + created_by="tester", + ) + + assert created.credential_blob != plaintext + assert decrypt_credential_blob(created.credential_blob) == plaintext + + +async def test_create_mcp_connection_rejects_plaintext_credentials_without_master_key( + connection_service_session, monkeypatch +): + monkeypatch.delenv("MCP_CREDENTIALS_MASTER_KEY", raising=False) + connection_service_session.add( + MCPServer( + name="insecure-gateway", + transport="streamable_http", + url="http://insecure.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + with pytest.raises(ValueError, match="MCP_CREDENTIALS_MASTER_KEY"): + await mcp_service.create_mcp_connection( + connection_service_session, + server_name="insecure-gateway", + scope_type="system", + scope_id="global", + credential_blob='{"secrets":{"access_token":"token"}}', + created_by="tester", + ) + + +async def test_get_mcp_server_dependency_summary_reports_runtime_references(delete_semantics_session): + department = Department(name="研发部", description="dep") + delete_semantics_session.add(department) + delete_semantics_session.add( + MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + enabled=0, + created_by="tester", + updated_by="tester", + ) + ) + delete_semantics_session.add( + MCPConnection( + server_name="finance-gateway", + scope_type="department", + scope_id="42", + status="active", + created_by="tester", + updated_by="tester", + ) + ) + delete_semantics_session.add( + Skill( + slug="finance-skill", + name="Finance Skill", + description="desc", + tool_dependencies=[], + mcp_dependencies=["finance-gateway"], + skill_dependencies=[], + dir_path="skills/finance", + created_by="tester", + updated_by="tester", + ) + ) + await delete_semantics_session.flush() + delete_semantics_session.add( + AgentConfig( + department_id=department.id, + agent_id="agent-1", + name="Finance Agent", + description="desc", + config_json={"mcps": ["finance-gateway"]}, + pics=[], + examples=[], + created_by="tester", + updated_by="tester", + ) + ) + await delete_semantics_session.commit() + + summary = await mcp_service.get_mcp_server_dependency_summary(delete_semantics_session, "finance-gateway") + + assert summary["has_references"] is True + assert summary["connections"] == [{"scope_type": "department", "scope_id": "42", "status": "active"}] + assert summary["skills"] == [{"slug": "finance-skill", "name": "Finance Skill"}] + assert summary["agent_configs"] == [{"id": 1, "name": "Finance Agent", "agent_id": "agent-1"}] + + +async def test_update_mcp_connection_reencrypts_credentials(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + connection_service_session.add( + MCPServer( + name="update-gateway", + transport="streamable_http", + url="http://update.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="update-gateway", + scope_type="system", + scope_id="global", + display_name="old", + credential_blob='{"secrets":{"access_token":"old-token"}}', + created_by="tester", + ) + + updated = await mcp_service.update_mcp_connection( + connection_service_session, + created.id, + display_name="new", + credential_blob='{"secrets":{"access_token":"new-token"}}', + updated_by="admin", + ) + + assert updated.display_name == "new" + assert decrypt_credential_blob(updated.credential_blob) == '{"secrets":{"access_token":"new-token"}}' + assert updated.updated_by == "admin" + + +async def test_delete_mcp_connection_removes_record(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + cleared_connection_ids = [] + released_connection_ids = [] + + class DummyTokenCache: + async def delete_access_token(self, connection_id): + cleared_connection_ids.append(connection_id) + + async def release_refresh_lock(self, connection_id): + released_connection_ids.append(connection_id) + + monkeypatch.setattr(mcp_service, "RedisTokenCache", lambda: DummyTokenCache()) + connection_service_session.add( + MCPServer( + name="delete-connection-gateway", + transport="streamable_http", + url="http://delete.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="delete-connection-gateway", + scope_type="system", + scope_id="global", + credential_blob='{"secrets":{"access_token":"token"}}', + created_by="tester", + ) + + deleted = await mcp_service.delete_mcp_connection(connection_service_session, created.id) + + assert deleted is True + assert cleared_connection_ids == [created.id] + assert released_connection_ids == [created.id] + assert await mcp_service.get_mcp_connection(connection_service_session, created.id) is None + + +async def test_reauthorize_mcp_connection_clears_runtime_error(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + cleared_connection_ids = [] + released_connection_ids = [] + + class DummyTokenCache: + async def delete_access_token(self, connection_id): + cleared_connection_ids.append(connection_id) + + async def release_refresh_lock(self, connection_id): + released_connection_ids.append(connection_id) + + monkeypatch.setattr(mcp_service, "RedisTokenCache", lambda: DummyTokenCache()) + + connection_service_session.add( + MCPServer( + name="reauth-gateway", + transport="streamable_http", + url="http://reauth.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="reauth-gateway", + scope_type="system", + scope_id="global", + status="reauth_required", + credential_blob='{"secrets":{"access_token":"token"}}', + meta_json={"last_error": {"message": "expired"}}, + created_by="tester", + ) + + updated = await mcp_service.reauthorize_mcp_connection( + connection_service_session, + created.id, + updated_by="admin", + ) + + assert cleared_connection_ids == [created.id] + assert released_connection_ids == [created.id] + assert updated.status == "active" + assert updated.meta_json == {} + assert updated.updated_by == "admin" + + +async def test_update_mcp_connection_clears_runtime_auth_cache_on_credential_change( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + cleared_connection_ids = [] + released_connection_ids = [] + + class DummyTokenCache: + async def delete_access_token(self, connection_id): + cleared_connection_ids.append(connection_id) + + async def release_refresh_lock(self, connection_id): + released_connection_ids.append(connection_id) + + monkeypatch.setattr(mcp_service, "RedisTokenCache", lambda: DummyTokenCache()) + connection_service_session.add( + MCPServer( + name="credential-update-gateway", + transport="streamable_http", + url="http://credential-update.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="credential-update-gateway", + scope_type="system", + scope_id="global", + credential_blob='{"secrets":{"access_token":"old-token"}}', + created_by="tester", + ) + + updated = await mcp_service.update_mcp_connection( + connection_service_session, + created.id, + credential_blob='{"secrets":{"access_token":"new-token"}}', + updated_by="admin", + ) + + assert updated.updated_by == "admin" + assert cleared_connection_ids == [created.id] + assert released_connection_ids == [created.id] + + +async def test_set_server_enabled_clears_runtime_auth_cache_when_retiring( + connection_service_session, monkeypatch +): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + cleared_connection_ids = [] + released_connection_ids = [] + + class DummyTokenCache: + async def delete_access_token(self, connection_id): + cleared_connection_ids.append(connection_id) + + async def release_refresh_lock(self, connection_id): + released_connection_ids.append(connection_id) + + monkeypatch.setattr(mcp_service, "RedisTokenCache", lambda: DummyTokenCache()) + connection_service_session.add( + MCPServer( + name="retire-gateway", + transport="streamable_http", + url="http://retire.local/mcp", + enabled=1, + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + first = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="retire-gateway", + scope_type="department", + scope_id="dep-1", + credential_blob='{"secrets":{"access_token":"token-1"}}', + created_by="tester", + ) + second = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="retire-gateway", + scope_type="department", + scope_id="dep-2", + credential_blob='{"secrets":{"access_token":"token-2"}}', + created_by="tester", + ) + + enabled, server = await mcp_service.set_server_enabled( + connection_service_session, + "retire-gateway", + False, + updated_by="admin", + ) + + assert enabled is False + assert bool(server.enabled) is False + assert cleared_connection_ids == [first.id, second.id] + assert released_connection_ids == [first.id, second.id] + + +async def test_test_mcp_connection_refreshes_success_metadata(connection_service_session, monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + + async def fake_get_runtime_mcp_server_config(server_name, *, auth_context=None, db=None, http_client=None): + del auth_context, db, http_client + return {"transport": "stdio", "command": f"{server_name}-cmd", "disabled_tools": []} + + async def fake_get_mcp_tools(server_name, additional_servers=None, disabled_tools=None, **kwargs): + del additional_servers, disabled_tools, kwargs + return [server_name, "tool-b"] + + monkeypatch.setattr(mcp_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + + connection_service_session.add( + MCPServer( + name="test-gateway", + transport="streamable_http", + url="http://test.local/mcp", + created_by="tester", + updated_by="tester", + ) + ) + await connection_service_session.commit() + + created = await mcp_service.create_mcp_connection( + connection_service_session, + server_name="test-gateway", + scope_type="department", + scope_id="dep-9", + status="reauth_required", + credential_blob='{"secrets":{"access_token":"token"}}', + meta_json={"last_error": {"message": "old"}}, + created_by="tester", + ) + + result = await mcp_service.test_mcp_connection( + connection_service_session, + created.id, + updated_by="admin", + ) + + assert result["tool_count"] == 2 + assert result["connection"].status == "active" + assert "last_success_at" in result["connection"].meta_json + assert "last_error" not in result["connection"].meta_json diff --git a/backend/test/unit/services/test_mcp_service_auth_runtime.py b/backend/test/unit/services/test_mcp_service_auth_runtime.py new file mode 100644 index 000000000..3a5273ad6 --- /dev/null +++ b/backend/test/unit/services/test_mcp_service_auth_runtime.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import json +import os + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.services import mcp_service +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +@pytest_asyncio.fixture +async def runtime_session(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(MCPServer.__table__.create) + await conn.run_sync(MCPConnection.__table__.create) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +async def test_get_runtime_mcp_server_config_resolves_department_connection(runtime_session): + server = MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + runtime_session.add(server) + runtime_session.add( + MCPConnection( + server_name="finance-gateway", + scope_type="department", + scope_id="42", + status="active", + credential_blob=json.dumps({"secrets": {"access_token": "dept-token"}}), + created_by="tester", + updated_by="tester", + ) + ) + await runtime_session.commit() + + config = await mcp_service.get_runtime_mcp_server_config( + "finance-gateway", + auth_context=AuthContext(user_id="u-1", department_id="42"), + db=runtime_session, + ) + + assert config is not None + assert config["headers"]["Authorization"] == "Bearer dept-token" + + +async def test_get_enabled_mcp_tools_uses_runtime_mcp_config(monkeypatch): + captured: list[dict] = [] + + async def fake_get_runtime_mcp_server_config(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + assert server_name == "demo" + assert auth_context is not None + return { + "transport": "stdio", + "command": "demo-with-auth", + "disabled_tools": ["tool_b"], + } + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled_tools=None, **kwargs): + del kwargs + captured.append( + { + "server_name": server_name, + "additional_servers": additional_servers, + "disabled_tools": list(disabled_tools or []), + } + ) + return ["tool-a"] + + monkeypatch.setattr(mcp_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await mcp_service.get_enabled_mcp_tools( + "demo", + auth_context=AuthContext(user_id="u-100", department_id="d-9"), + ) + + assert tools == ["tool-a"] + assert captured == [ + { + "server_name": "demo", + "additional_servers": { + "demo": {"transport": "stdio", "command": "demo-with-auth", "disabled_tools": ["tool_b"]} + }, + "disabled_tools": ["tool_b"], + } + ] + + +async def test_get_runtime_mcp_server_config_returns_internal_proxy_for_dynamic_http_provider( + runtime_session, monkeypatch +): + monkeypatch.setenv("YUXI_INTERNAL_MCP_PROXY_BASE_URL", "http://internal-api:5050") + + server = MCPServer( + name="finance-proxy", + transport="streamable_http", + url="http://finance.local/mcp", + headers={"X-App": "yuxi"}, + auth_config_json={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": { + "url": "http://gateway.local/auth/token", + "method": "POST", + "response_map": { + "access_token": "access_token", + "expires_in": "expires_in", + }, + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + runtime_session.add(server) + runtime_session.add( + MCPConnection( + id=31, + server_name="finance-proxy", + scope_type="department", + scope_id="dep-88", + status="active", + credential_blob=json.dumps({"secrets": {"client_id": "cid", "client_secret": "secret"}}), + created_by="tester", + updated_by="tester", + ) + ) + await runtime_session.commit() + + config = await mcp_service.get_runtime_mcp_server_config( + "finance-proxy", + auth_context=AuthContext(user_id="user-1", department_id="dep-88"), + db=runtime_session, + ) + + assert config is not None + assert config["url"] == "http://internal-api:5050/api/internal/mcp-proxy/finance-proxy" + assert config["headers"]["X-App"] == "yuxi" + assert "X-Yuxi-MCP-Proxy-Token" in config["headers"] + assert "Authorization" not in config["headers"] diff --git a/backend/test/unit/storage/test_postgres_manager_business_schema.py b/backend/test/unit/storage/test_postgres_manager_business_schema.py new file mode 100644 index 000000000..dad29f5f0 --- /dev/null +++ b/backend/test/unit/storage/test_postgres_manager_business_schema.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import os +from types import SimpleNamespace + +import pytest + +os.environ.setdefault("OPENAI_API_KEY", "test-key") + +from yuxi.storage.postgres.manager import PostgresManager + + +pytestmark = [pytest.mark.asyncio, pytest.mark.unit] + + +class _FakeBeginContext: + def __init__(self, statements: list[str]): + self._statements = statements + + async def __aenter__(self): + async def execute(stmt): + self._statements.append(str(stmt)) + + return SimpleNamespace(execute=execute) + + async def __aexit__(self, exc_type, exc, tb): + return False + + +async def test_ensure_business_schema_includes_mcp_auth_tables_and_columns(): + statements: list[str] = [] + manager = object.__new__(PostgresManager) + PostgresManager.__init__(manager) + manager._initialized = True + manager.async_engine = SimpleNamespace(begin=lambda: _FakeBeginContext(statements)) + + await manager.ensure_business_schema() + + assert any( + "ALTER TABLE IF EXISTS mcp_servers ADD COLUMN IF NOT EXISTS auth_config_json JSONB" in stmt + for stmt in statements + ) + assert any("CREATE TABLE IF NOT EXISTS mcp_connections" in stmt for stmt in statements) diff --git a/backend/test/unit/test_base_context.py b/backend/test/unit/test_base_context.py new file mode 100644 index 000000000..87fffcda0 --- /dev/null +++ b/backend/test/unit/test_base_context.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from yuxi.agents.context import BaseContext + + +def test_base_context_accepts_department_id_without_exposing_it_as_configurable(): + context = BaseContext() + + context.update({"department_id": "dept-9"}) + + assert context.department_id == "dept-9" + assert "department_id" not in BaseContext.get_configurable_items() diff --git a/docker-compose.yml b/docker-compose.yml index 910af1211..32da99673 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,6 +6,7 @@ x-api-worker-env: &api-worker-env # DBs and other services POSTGRES_URL: ${POSTGRES_URL:-postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@postgres:5432/${POSTGRES_DB:-yuxi_know}} REDIS_URL: ${REDIS_URL:-redis://redis:6379/0} + YUXI_INTERNAL_MCP_PROXY_BASE_URL: ${YUXI_INTERNAL_MCP_PROXY_BASE_URL:-http://api:5050} NEO4J_URI: ${NEO4J_URI:-bolt://graph:7687} NEO4J_USERNAME: ${NEO4J_USERNAME:-neo4j} NEO4J_PASSWORD: ${NEO4J_PASSWORD:-0123456789} @@ -29,6 +30,8 @@ x-api-worker-env: &api-worker-env # Agent run RUN_CANCEL_KEY_TTL_SECONDS: ${RUN_CANCEL_KEY_TTL_SECONDS:-1800} RUN_EVENTS_STREAM_TTL_SECONDS: ${RUN_EVENTS_STREAM_TTL_SECONDS:-7200} + # MCP auth + MCP_CREDENTIALS_MASTER_KEY: ${MCP_CREDENTIALS_MASTER_KEY:-} # 其他环境变量 NO_PROXY: localhost,127.0.0.1,milvus,graph,minio,milvus-etcd-dev,etcd,mineru,paddlex,sandbox-provisioner,api.siliconflow.cn no_proxy: localhost,127.0.0.1,milvus,graph,minio,milvus-etcd-dev,etcd,mineru,paddlex,sandbox-provisioner,api.siliconflow.cn diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index c7dc69d73..09ae14ca8 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -37,6 +37,7 @@ ### 0.6.3 开发记录 - 修复 DeepAgent 未绑定 `DeepContext`,导致深度分析专用系统提示词和子智能体默认模型配置未生效的问题;同时避免运行时重复注入默认提示词。 +- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端。 --- diff --git a/web/src/apis/mcp_api.js b/web/src/apis/mcp_api.js index 66fbdf69a..0e8f0e16c 100644 --- a/web/src/apis/mcp_api.js +++ b/web/src/apis/mcp_api.js @@ -79,6 +79,40 @@ export const updateMcpServerStatus = async (name, enabled) => { return apiAdminPut(`${BASE_URL}/${encodeURIComponent(name)}/status`, { enabled }) } +// ============================================================================= +// === MCP 连接管理 === +// ============================================================================= + +export const getMcpServerConnections = async (name) => { + return apiAdminGet(`${BASE_URL}/${encodeURIComponent(name)}/connections`) +} + +export const createMcpServerConnection = async (name, data) => { + return apiAdminPost(`${BASE_URL}/${encodeURIComponent(name)}/connections`, data) +} + +export const updateMcpServerConnection = async (name, connectionId, data) => { + return apiAdminPut(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}`, data) +} + +export const updateMcpConnectionStatus = async (name, connectionId, status) => { + return apiAdminPut(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}/status`, { + status + }) +} + +export const deleteMcpServerConnection = async (name, connectionId) => { + return apiAdminDelete(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}`) +} + +export const testMcpConnection = async (name, connectionId) => { + return apiAdminPost(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}/test`, {}) +} + +export const reauthorizeMcpConnection = async (name, connectionId) => { + return apiAdminPost(`${BASE_URL}/${encodeURIComponent(name)}/connections/${connectionId}/reauth`, {}) +} + // ============================================================================= // === MCP 工具管理 === // ============================================================================= @@ -126,6 +160,13 @@ export const mcpApi = { deleteMcpServer, testMcpServer, updateMcpServerStatus, + getMcpServerConnections, + createMcpServerConnection, + updateMcpServerConnection, + updateMcpConnectionStatus, + deleteMcpServerConnection, + testMcpConnection, + reauthorizeMcpConnection, getMcpServerTools, refreshMcpServerTools, toggleMcpServerTool diff --git a/web/src/components/extensions/McpAuthConfigBuilder.vue b/web/src/components/extensions/McpAuthConfigBuilder.vue new file mode 100644 index 000000000..f909f6a68 --- /dev/null +++ b/web/src/components/extensions/McpAuthConfigBuilder.vue @@ -0,0 +1,952 @@ + + + + + diff --git a/web/src/components/extensions/McpDetailView.vue b/web/src/components/extensions/McpDetailView.vue index 2de759bef..81d353205 100644 --- a/web/src/components/extensions/McpDetailView.vue +++ b/web/src/components/extensions/McpDetailView.vue @@ -228,6 +228,10 @@ + @@ -340,6 +344,13 @@ {{ server.created_by }} +
+ +
{{ JSON.stringify(server.auth_config, null, 2) }}
+
@@ -433,6 +444,322 @@ + + + +
+
+
+

连接管理

+

+ {{ + hasAuthConfig + ? '按全局、部门或用户维护长期凭据,运行时自动换取和刷新 token。' + : '当前 MCP 未配置动态鉴权,通常不需要维护连接。' + }} +

+
+ + + + + 新建连接 + + + + + 刷新 + + +
+ +
+
+ 认证方式 + {{ server.auth_config?.provider || '未配置' }} +
+
+ 默认绑定 + {{ authBindingScopeLabel }} +
+
+ 可用连接 + {{ activeConnectionCount }} +
+
+ 需处理 + {{ attentionConnectionCount }} +
+
+ + +
+ +
+
+ + + + 新建连接 + +
+
+
+ 连接 + 范围 + 状态 + 凭据 + 最近信息 + 操作 +
+
+
+ {{ getConnectionTitle(connection) }} + + {{ connection.external_subject }} + +
+
+ + + + + {{ getConnectionScopeLabel(connection.scope_type) }} + + {{ connection.scope_id }} +
+
+ + {{ getConnectionStatusLabel(connection.status) }} + +
+
+ {{ connection.has_credentials ? '已配置' : '未配置' }} +
+
+ {{ getConnectionLastInfo(connection) }} +
+
+ + 编辑 + + + 测试 + + + 重连 + + + 删除 + +
+
+
+
+ + + +
+
+ 绑定范围 + 决定运行时为哪些请求使用这组凭据。 +
+
+ +
+ + + +
+ +
+
+ 展示信息 + 名称用于列表识别,不参与鉴权计算。 +
+ + + + + + + {{ option.label }} + + + +
+ +
+
+ 凭据 + {{ credentialHint }} +
+
+ + + +
+ + + +
+ + + + + + + + + + + + + + + + +
+
+
+
@@ -461,10 +788,16 @@ import { Save, X, Rows3, - Braces + Braces, + KeyRound, + Globe2, + Building2, + UserRound } from 'lucide-vue-next' import { mcpApi } from '@/apis/mcp_api' import { formatFullDateTime } from '@/utils/time' +import { extractSecretFieldNames } from '@/utils/mcpAuthConfigBuilder' +import McpAuthConfigBuilder from '@/components/extensions/McpAuthConfigBuilder.vue' import McpEnvEditor from '@/components/McpEnvEditor.vue' const route = useRoute() @@ -481,6 +814,13 @@ const toolsLoading = ref(false) const toolsError = ref(null) const toolSearchText = ref('') const toggleToolLoading = ref(null) +const connections = ref([]) +const connectionsLoading = ref(false) +const connectionsError = ref(null) +const showConnectionForm = ref(false) +const connectionSubmitting = ref(false) +const connectionActionLoading = ref(null) +const editingConnectionId = ref(null) const isEditing = ref(false) const editLoading = ref(false) @@ -496,17 +836,97 @@ const editForm = reactive({ args: [], env: null, headersText: '', + authConfigText: '', timeout: null, sse_read_timeout: null, tags: [], icon: '' }) +const connectionForm = reactive({ + scopeType: 'department', + scopeId: '', + displayName: '', + externalSubject: '', + credentialText: '', + secretValues: {}, + metaText: '', + status: 'active' +}) + +const connectionScopeOptions = [ + { + value: 'system', + label: '全局共享', + description: '所有用户共用', + icon: Globe2 + }, + { + value: 'department', + label: '部门共享', + description: '按部门隔离', + icon: Building2 + }, + { + value: 'user', + label: '个人专用', + description: '按用户隔离', + icon: UserRound + } +] + +const connectionStatusOptions = [ + { value: 'active', label: '启用' }, + { value: 'disabled', label: '停用' }, + { value: 'reauth_required', label: '需要重连' }, + { value: 'invalid', label: '无效' } +] + +const scopeLabelMap = { + inline: '内联', + system: '全局共享', + department: '部门共享', + user: '个人专用' +} + +const statusLabelMap = { + active: '启用', + disabled: '停用', + reauth_required: '需要重连', + invalid: '无效' +} + const actionLabel = computed(() => { - if (server.value?.enabled === false) return '添加' - return server.value?.created_by === 'system' ? '移除' : '删除' + if (server.value?.enabled === false) return '恢复' + return server.value?.created_by === 'system' ? '移除' : '退役' +}) + +const isEditingConnection = computed(() => editingConnectionId.value !== null) + +const hasAuthConfig = computed( + () => !!server.value?.auth_config && Object.keys(server.value.auth_config).length > 0 +) + +const authBindingScopeLabel = computed(() => { + const bindingScope = server.value?.auth_config?.binding_scope + return scopeLabelMap[bindingScope] || '未限定' }) +const activeConnectionCount = computed( + () => connections.value.filter((connection) => connection.status === 'active').length +) + +const attentionConnectionCount = computed( + () => + connections.value.filter((connection) => + ['reauth_required', 'invalid'].includes(connection.status) + ).length +) + +const connectionDrawerTitle = computed(() => + isEditingConnection.value ? '编辑连接' : '新建连接' +) + const filteredTools = computed(() => { if (!toolSearchText.value) return tools.value const search = toolSearchText.value.toLowerCase() @@ -524,6 +944,34 @@ const isStdioTransport = computed( .toLowerCase() === 'stdio' ) +const credentialSecretFields = computed(() => + extractSecretFieldNames(server.value?.auth_config || {}) +) + +const showScopeIdField = computed(() => connectionForm.scopeType !== 'system') + +const scopeIdLabel = computed(() => { + if (connectionForm.scopeType === 'department') return '部门 ID' + if (connectionForm.scopeType === 'user') return '用户 ID' + return '范围标识' +}) + +const scopeIdPlaceholder = computed(() => { + if (connectionForm.scopeType === 'department') return '请输入部门 ID' + if (connectionForm.scopeType === 'user') return '请输入用户 ID' + return '留空默认 global' +}) + +const credentialHint = computed(() => { + if (isEditingConnection.value) { + return '为安全起见不回显已有凭据;留空表示保持原值。' + } + if (credentialSecretFields.value.length > 0) { + return '系统已根据认证配置推导出需要录入的密钥字段。' + } + return '当前认证配置没有声明密钥字段,可直接粘贴长期 token。' +}) + const goBack = () => { router.push({ path: '/extensions', query: { tab: 'mcp' } }) } @@ -535,6 +983,62 @@ const getTransportColor = (transport) => { return colors[transport] || 'blue' } +const createEmptySecretValues = () => + Object.fromEntries(credentialSecretFields.value.map((fieldName) => [fieldName, ''])) + +const setNestedSecretValue = (target, path, value) => { + const segments = String(path || '') + .split('.') + .filter(Boolean) + let current = target + segments.forEach((segment, index) => { + if (index === segments.length - 1) { + current[segment] = value + return + } + current[segment] = current[segment] || {} + current = current[segment] + }) +} + +const getConnectionTitle = (connection) => + connection.display_name || `${getConnectionScopeLabel(connection.scope_type)} ${connection.scope_id}` + +const getConnectionScopeLabel = (scopeType) => scopeLabelMap[scopeType] || scopeType || '未知范围' + +const getConnectionStatusLabel = (status) => statusLabelMap[status] || status || '未知状态' + +const getConnectionStatusClass = (status) => { + if (status === 'active') return 'status-active' + if (status === 'reauth_required') return 'status-warning' + if (status === 'invalid') return 'status-error' + return 'status-muted' +} + +const getConnectionLastInfo = (connection) => { + if (connection.meta_json?.last_error?.message) { + return connection.meta_json.last_error.message + } + if (connection.meta_json?.last_success_at) { + return `最近成功 ${formatTime(connection.meta_json.last_success_at)}` + } + if (connection.updated_at) { + return `更新于 ${formatTime(connection.updated_at)}` + } + return '暂无记录' +} + +const getSecretFieldLabel = (fieldName) => { + const labelMap = { + client_id: 'Client ID', + client_secret: 'Client Secret', + access_token: 'Access Token', + refresh_token: 'Refresh Token', + issuer_url: 'Issuer URL' + } + return labelMap[fieldName] || fieldName +} + const resetEditForm = (data) => { Object.assign(editForm, { name: data?.name || '', @@ -545,6 +1049,7 @@ const resetEditForm = (data) => { args: data?.args || [], env: data?.env || null, headersText: data?.headers ? JSON.stringify(data.headers, null, 2) : '', + authConfigText: data?.auth_config ? JSON.stringify(data.auth_config, null, 2) : '', timeout: data?.timeout, sse_read_timeout: data?.sse_read_timeout, tags: data?.tags || [], @@ -586,6 +1091,20 @@ const parseJsonToForm = () => { } } +const parseJsonText = (text, label, { allowRawString = false } = {}) => { + const trimmed = String(text || '').trim() + if (!trimmed) return null + try { + return JSON.parse(trimmed) + } catch { + if (allowRawString) { + return trimmed + } + message.error(`${label} JSON 格式错误`) + return undefined + } +} + const buildEditPayload = () => { if (formMode.value === 'json') { try { @@ -606,6 +1125,11 @@ const buildEditPayload = () => { } } + const authConfig = parseJsonText(editForm.authConfigText, '认证配置') + if (authConfig === undefined) { + return null + } + return { name: editForm.name, description: editForm.description || null, @@ -615,6 +1139,7 @@ const buildEditPayload = () => { args: editForm.args.length > 0 ? editForm.args : null, env: editForm.env, headers, + auth_config: authConfig, timeout: editForm.timeout || null, sse_read_timeout: editForm.sse_read_timeout || null, tags: editForm.tags.length > 0 ? editForm.tags : null, @@ -622,6 +1147,45 @@ const buildEditPayload = () => { } } +const resetConnectionForm = () => { + editingConnectionId.value = null + Object.assign(connectionForm, { + scopeType: 'department', + scopeId: '', + displayName: '', + externalSubject: '', + credentialText: '', + secretValues: createEmptySecretValues(), + metaText: '', + status: 'active' + }) +} + +const openCreateConnectionDrawer = () => { + resetConnectionForm() + showConnectionForm.value = true +} + +const closeConnectionForm = () => { + showConnectionForm.value = false + resetConnectionForm() +} + +const startEditConnection = (connection) => { + editingConnectionId.value = connection.id + showConnectionForm.value = true + Object.assign(connectionForm, { + scopeType: connection.scope_type || 'department', + scopeId: connection.scope_id || '', + displayName: connection.display_name || '', + externalSubject: connection.external_subject || '', + credentialText: '', + secretValues: createEmptySecretValues(), + metaText: connection.meta_json ? JSON.stringify(connection.meta_json, null, 2) : '', + status: connection.status || 'active' + }) +} + const validateEditPayload = (data) => { if (!data.name?.trim()) { message.error('MCP 名称不能为空') @@ -700,6 +1264,26 @@ const fetchTools = async () => { } } +const fetchConnections = async () => { + if (!server.value) return + try { + connectionsLoading.value = true + connectionsError.value = null + const result = await mcpApi.getMcpServerConnections(server.value.name) + if (result.success) { + connections.value = result.data || [] + } else { + connectionsError.value = result.message || '获取连接列表失败' + connections.value = [] + } + } catch (err) { + connectionsError.value = err.message || '获取连接列表失败' + connections.value = [] + } finally { + connectionsLoading.value = false + } +} + const handleToggleTool = async (tool) => { if (!server.value) return try { @@ -745,6 +1329,161 @@ const handleTestServer = async () => { } } +const buildConnectionCredential = () => { + const rawCredential = parseJsonText(connectionForm.credentialText, '长期凭据', { + allowRawString: true + }) + if (rawCredential === undefined) return undefined + if (rawCredential !== null) return rawCredential + + const secrets = {} + Object.entries(connectionForm.secretValues).forEach(([key, value]) => { + const trimmedValue = String(value || '').trim() + if (trimmedValue) { + setNestedSecretValue(secrets, key, trimmedValue) + } + }) + + if (Object.keys(secrets).length === 0) { + return null + } + return { secrets } +} + +const validateConnectionCredential = () => { + if (isEditingConnection.value || credentialSecretFields.value.length === 0) { + return true + } + + const missingFields = credentialSecretFields.value.filter( + (fieldName) => !String(connectionForm.secretValues[fieldName] || '').trim() + ) + if (missingFields.length === 0 || connectionForm.credentialText.trim()) { + return true + } + + message.error(`请填写凭据字段:${missingFields.join('、')}`) + return false +} + +const handleSubmitConnection = async () => { + if (!server.value) return + + const scopeId = + connectionForm.scopeType === 'system' + ? 'global' + : connectionForm.scopeId.trim() + if (!scopeId) { + message.error(`${scopeIdLabel.value}不能为空`) + return + } + if (!validateConnectionCredential()) return + + const metaJson = parseJsonText(connectionForm.metaText, '连接元数据') + if (metaJson === undefined) return + const credential = buildConnectionCredential() + if (credential === undefined) return + + try { + connectionSubmitting.value = true + const payload = { + display_name: connectionForm.displayName || null, + external_subject: connectionForm.externalSubject || null, + meta_json: metaJson, + status: connectionForm.status + } + if (credential !== null) { + payload.credential = credential + } + + const result = isEditingConnection.value + ? await mcpApi.updateMcpServerConnection(server.value.name, editingConnectionId.value, payload) + : await mcpApi.createMcpServerConnection(server.value.name, { + scope_type: connectionForm.scopeType, + scope_id: scopeId, + ...payload + }) + if (result.success) { + message.success(isEditingConnection.value ? '连接更新成功' : '连接创建成功') + showConnectionForm.value = false + resetConnectionForm() + await fetchConnections() + } else { + message.error(result.message || (isEditingConnection.value ? '连接更新失败' : '连接创建失败')) + } + } catch (err) { + message.error(err.message || (isEditingConnection.value ? '连接更新失败' : '连接创建失败')) + } finally { + connectionSubmitting.value = false + } +} + +const handleTestConnection = async (connection) => { + if (!server.value) return + const loadingKey = `${connection.id}:test` + try { + connectionActionLoading.value = loadingKey + const result = await mcpApi.testMcpConnection(server.value.name, connection.id) + if (result.success) { + message.success(result.message || '连接测试成功') + await fetchConnections() + } else { + message.error(result.message || '连接测试失败') + } + } catch (err) { + message.error(err.message || '连接测试失败') + } finally { + connectionActionLoading.value = null + } +} + +const handleReauthorizeConnection = async (connection) => { + if (!server.value) return + const loadingKey = `${connection.id}:reauth` + try { + connectionActionLoading.value = loadingKey + const result = await mcpApi.reauthorizeMcpConnection(server.value.name, connection.id) + if (result.success) { + message.success(result.message || '连接已重置') + await fetchConnections() + } else { + message.error(result.message || '连接重置失败') + } + } catch (err) { + message.error(err.message || '连接重置失败') + } finally { + connectionActionLoading.value = null + } +} + +const handleDeleteConnection = (connection) => { + if (!server.value) return + Modal.confirm({ + title: '确认删除连接', + content: `确定要删除连接 "${connection.display_name || `${connection.scope_type}:${connection.scope_id}`}" 吗?`, + okText: '删除', + okType: 'danger', + cancelText: '取消', + async onOk() { + try { + const result = await mcpApi.deleteMcpServerConnection(server.value.name, connection.id) + if (result.success) { + message.success(result.message || '连接已删除') + if (editingConnectionId.value === connection.id) { + showConnectionForm.value = false + resetConnectionForm() + } + await fetchConnections() + } else { + message.error(result.message || '连接删除失败') + } + } catch (err) { + message.error(err.message || '连接删除失败') + } + } + }) +} + const handleDangerAction = async () => { if (!server.value) return if (server.value.enabled === false) { @@ -774,17 +1513,17 @@ const handleSetServerEnabled = async (srv, enabled) => { const confirmDeleteServer = (srv) => { Modal.confirm({ - title: '确认删除 MCP', - content: `确定要删除 MCP "${srv.name}" 吗?此操作不可撤销。`, - okText: '删除', - okType: 'danger', + title: '确认退役 MCP', + content: `确定要退役 MCP "${srv.name}" 吗?退役后不会再被新运行加载,但配置和连接会保留。`, + okText: '退役', + okType: 'primary', cancelText: '取消', async onOk() { try { const result = await mcpApi.deleteMcpServer(srv.name) if (result.success) { - message.success('MCP 删除成功') - router.push({ path: '/extensions', query: { tab: 'mcp' } }) + message.success(result.message || 'MCP 已退役') + await fetchServer() } else { message.error(result.message || '删除失败') } @@ -799,6 +1538,9 @@ watch(detailTab, (tab) => { if (tab === 'tools' && server.value) { fetchTools() } + if (tab === 'connections' && server.value) { + fetchConnections() + } }) onMounted(() => { @@ -1131,6 +1873,202 @@ onMounted(() => { } .mcp-detail { + .connections-tab { + display: flex; + flex-direction: column; + gap: 14px; + } + + .connection-command-bar { + display: flex; + justify-content: space-between; + gap: 16px; + align-items: center; + padding: 16px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-0); + } + + .connection-command-copy { + min-width: 0; + + h3 { + margin: 0 0 4px; + color: var(--gray-900); + font-size: 16px; + font-weight: 600; + } + + p { + margin: 0; + color: var(--gray-500); + font-size: 13px; + line-height: 1.5; + } + } + + .connection-summary-strip { + display: grid; + grid-template-columns: repeat(4, minmax(0, 1fr)); + border: 1px solid var(--gray-150); + border-radius: 8px; + overflow: hidden; + background: var(--gray-0); + } + + .connection-summary-item { + display: flex; + flex-direction: column; + gap: 4px; + padding: 12px 14px; + + & + .connection-summary-item { + border-left: 1px solid var(--gray-100); + } + + .summary-label { + color: var(--gray-500); + font-size: 12px; + } + + strong { + min-width: 0; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + color: var(--gray-900); + font-size: 15px; + font-weight: 600; + } + } + + .connection-empty-state { + display: flex; + flex-direction: column; + align-items: center; + gap: 12px; + padding: 42px 16px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-0); + } + + .connection-table { + border: 1px solid var(--gray-150); + border-radius: 8px; + overflow: hidden; + background: var(--gray-0); + } + + .connection-table-header, + .connection-table-row { + display: grid; + grid-template-columns: minmax(180px, 1.3fr) minmax(150px, 1fr) 92px 76px minmax(160px, 1fr) 188px; + gap: 12px; + align-items: center; + } + + .connection-table-header { + padding: 10px 14px; + background: var(--gray-25); + border-bottom: 1px solid var(--gray-100); + color: var(--gray-500); + font-size: 12px; + font-weight: 600; + } + + .connection-table-row { + padding: 14px; + min-height: 68px; + + & + .connection-table-row { + border-top: 1px solid var(--gray-100); + } + } + + .connection-main-cell, + .connection-scope-cell { + display: flex; + min-width: 0; + flex-direction: column; + gap: 4px; + } + + .connection-title { + overflow: hidden; + color: var(--gray-900); + font-size: 14px; + font-weight: 600; + text-overflow: ellipsis; + white-space: nowrap; + } + + .connection-subtitle, + .scope-id, + .credential-cell, + .connection-last-cell { + min-width: 0; + overflow: hidden; + color: var(--gray-500); + font-size: 12px; + line-height: 1.45; + text-overflow: ellipsis; + white-space: nowrap; + } + + .scope-pill { + display: inline-flex; + width: fit-content; + max-width: 100%; + align-items: center; + gap: 5px; + padding: 2px 8px; + border: 1px solid var(--gray-150); + border-radius: 6px; + background: var(--gray-25); + color: var(--gray-700); + font-size: 12px; + font-weight: 500; + } + + .status-badge { + display: inline-flex; + align-items: center; + min-height: 24px; + padding: 2px 8px; + border-radius: 6px; + font-size: 12px; + font-weight: 500; + + &.status-active { + background: var(--color-success-50); + color: var(--color-success-700); + } + + &.status-warning { + background: var(--color-warning-50); + color: var(--color-warning-900); + } + + &.status-error { + background: var(--color-error-50); + color: var(--color-error-700); + } + + &.status-muted { + background: var(--gray-100); + color: var(--gray-600); + } + } + + .connection-row-actions { + display: flex; + flex-wrap: wrap; + justify-content: flex-end; + gap: 2px; + } + .detail-content-wrapper { flex: 1; min-height: 0; @@ -1139,9 +2077,192 @@ onMounted(() => { } .detail-content-inner { - max-width: 900px; + max-width: 1120px; margin: 0 auto; padding: 16px var(--page-padding); } } + +.connection-drawer-form { + display: flex; + min-height: 100%; + flex-direction: column; + padding: 18px 20px 0; + + :deep(.ant-form-item) { + margin-bottom: 0; + } + + :deep(.ant-form-item-label > label) { + color: var(--gray-700); + font-size: 13px; + font-weight: 500; + } +} + +.drawer-section { + display: flex; + flex-direction: column; + gap: 14px; + padding-bottom: 18px; + + & + .drawer-section { + padding-top: 18px; + border-top: 1px solid var(--gray-100); + } +} + +.drawer-section-title { + display: flex; + flex-direction: column; + gap: 3px; + + span { + color: var(--gray-900); + font-size: 14px; + font-weight: 600; + } + + small { + color: var(--gray-500); + font-size: 12px; + line-height: 1.5; + } +} + +.scope-option-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(136px, 1fr)); + gap: 8px; +} + +.scope-option { + display: flex; + min-height: 78px; + flex-direction: column; + align-items: flex-start; + justify-content: center; + gap: 4px; + padding: 10px 12px; + border: 1px solid var(--gray-150); + border-radius: 8px; + background: var(--gray-0); + color: var(--gray-700); + cursor: pointer; + text-align: left; + transition: + border-color 0.15s ease, + background-color 0.15s ease, + color 0.15s ease; + + span { + color: var(--gray-900); + font-size: 13px; + font-weight: 600; + } + + small { + color: var(--gray-500); + font-size: 12px; + } + + &:hover:not(:disabled) { + border-color: var(--main-300); + background: var(--main-10); + color: var(--main-color); + } + + &.active { + border-color: var(--main-color); + background: var(--main-30); + color: var(--main-color); + } + + &:disabled { + cursor: not-allowed; + opacity: 0.7; + } +} + +.secret-field-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(220px, 1fr)); + gap: 12px; +} + +.connection-advanced-collapse { + margin: 0 -4px 12px; + + :deep(.ant-collapse-header) { + padding: 10px 4px; + color: var(--gray-600); + font-size: 13px; + } + + :deep(.ant-collapse-content-box) { + display: flex; + flex-direction: column; + gap: 14px; + padding: 4px 4px 12px; + } +} + +.connection-drawer-footer { + position: sticky; + bottom: 0; + display: flex; + justify-content: flex-end; + gap: 8px; + margin: auto -20px 0; + padding: 14px 20px; + border-top: 1px solid var(--gray-100); + background: var(--gray-0); +} + +@media (max-width: 980px) { + .mcp-detail { + .connection-summary-strip { + grid-template-columns: repeat(2, minmax(0, 1fr)); + } + + .connection-summary-item:nth-child(3) { + border-left: 0; + border-top: 1px solid var(--gray-100); + } + + .connection-summary-item:nth-child(4) { + border-top: 1px solid var(--gray-100); + } + + .connection-table-header { + display: none; + } + + .connection-table-row { + grid-template-columns: 1fr; + gap: 10px; + align-items: stretch; + } + + .connection-row-actions { + justify-content: flex-start; + padding-top: 4px; + } + } +} + +@media (max-width: 640px) { + .mcp-detail { + .connection-command-bar { + align-items: flex-start; + flex-direction: column; + } + + .connection-summary-strip, + .scope-option-grid, + .secret-field-grid { + grid-template-columns: 1fr; + } + } +} diff --git a/web/src/components/extensions/McpFormModal.vue b/web/src/components/extensions/McpFormModal.vue index 5206170de..4f96eaf7a 100644 --- a/web/src/components/extensions/McpFormModal.vue +++ b/web/src/components/extensions/McpFormModal.vue @@ -6,7 +6,7 @@ :confirmLoading="formLoading" @cancel="visible = false" :maskClosable="false" - width="560px" + width="min(780px, calc(100vw - 32px))" class="server-modal" >
@@ -93,6 +93,7 @@ + { args: obj.args || [], env: obj.env || null, headersText: obj.headers ? JSON.stringify(obj.headers, null, 2) : '', + authConfigText: obj.auth_config ? JSON.stringify(obj.auth_config, null, 2) : '', timeout: obj.timeout || null, sse_read_timeout: obj.sse_read_timeout || null, tags: obj.tags || [], @@ -259,6 +267,16 @@ const handleFormSubmit = async () => { return } } + + let authConfig = null + if (form.authConfigText.trim()) { + try { + authConfig = JSON.parse(form.authConfigText) + } catch { + message.error('认证配置 JSON 格式错误') + return + } + } data = { name: form.name, description: form.description || null, @@ -268,6 +286,7 @@ const handleFormSubmit = async () => { args: form.args.length > 0 ? form.args : null, env: form.env, headers, + auth_config: authConfig, timeout: form.timeout || null, sse_read_timeout: form.sse_read_timeout || null, tags: form.tags.length > 0 ? form.tags : null, diff --git a/web/src/utils/__tests__/mcpAuthConfigBuilder.test.js b/web/src/utils/__tests__/mcpAuthConfigBuilder.test.js new file mode 100644 index 000000000..f1f83df9e --- /dev/null +++ b/web/src/utils/__tests__/mcpAuthConfigBuilder.test.js @@ -0,0 +1,84 @@ +import assert from 'node:assert/strict' + +import { + authConfigToBuilderForm, + buildAuthConfigFromBuilderForm, + createDefaultAuthBuilderForm, + extractSecretFieldNames +} from '../mcpAuthConfigBuilder.js' + +const run = () => { + { + const form = createDefaultAuthBuilderForm() + assert.equal(buildAuthConfigFromBuilderForm(form), null) + } + + { + const form = createDefaultAuthBuilderForm('custom_http_token') + form.bindingScope = 'department' + form.injectEntries = [ + { name: 'Authorization', value_template: 'Bearer ${access_token}' }, + { name: 'X-Yuxi-User', value_template: '${context.user_id}' } + ] + form.tokenUrl = 'http://internal-gateway/token' + form.tokenHeaders = [{ key: 'Content-Type', value: 'application/json' }] + form.tokenBodyTemplate = [ + { key: 'client_id', value: '${secret.client_id}' }, + { key: 'client_secret', value: '${secret.client_secret}' }, + { key: 'user_id', value: '${context.user_id}' } + ] + form.tokenResponseMap = [ + { key: 'access_token', value: 'data.access_token' }, + { key: 'expires_in', value: 'data.expires_in' } + ] + + const config = buildAuthConfigFromBuilderForm(form) + assert.equal(config.provider, 'custom_http_token') + assert.equal(config.binding_scope, 'department') + assert.equal(config.manifest_scope, 'binding') + assert.deepEqual(config.inject.entries, form.injectEntries) + assert.deepEqual(config.token_request, { + url: 'http://internal-gateway/token', + method: 'POST', + body_type: 'json', + headers: { 'Content-Type': 'application/json' }, + body_template: { + client_id: '${secret.client_id}', + client_secret: '${secret.client_secret}', + user_id: '${context.user_id}' + }, + response_map: { + access_token: 'data.access_token', + expires_in: 'data.expires_in' + } + }) + assert.deepEqual(extractSecretFieldNames(config), ['client_id', 'client_secret']) + } + + { + const config = { + version: 1, + provider: 'bound_secret', + binding_scope: 'user', + manifest_scope: 'binding', + inject: { + target: 'headers', + entries: [{ name: 'X-Api-Key', value_template: '${secret.api_key}' }] + }, + refresh_policy: { + pre_refresh_seconds: 120, + retry_once_on_401: false + } + } + + const form = authConfigToBuilderForm(config) + assert.equal(form.provider, 'bound_secret') + assert.equal(form.bindingScope, 'user') + assert.deepEqual(form.injectEntries, [{ name: 'X-Api-Key', value_template: '${secret.api_key}' }]) + assert.deepEqual(buildAuthConfigFromBuilderForm(form), config) + } + + console.log('mcpAuthConfigBuilder: all assertions passed') +} + +run() diff --git a/web/src/utils/mcpAuthConfigBuilder.js b/web/src/utils/mcpAuthConfigBuilder.js new file mode 100644 index 000000000..eeafc252e --- /dev/null +++ b/web/src/utils/mcpAuthConfigBuilder.js @@ -0,0 +1,206 @@ +export const FORM_AUTH_PROVIDERS = new Set([ + 'bound_secret', + 'custom_http_token', + 'client_credentials', + 'stdio_env' +]) + +const TOKEN_PROVIDERS = new Set(['custom_http_token', 'client_credentials']) + +const DEFAULT_RESPONSE_MAP = [ + { key: 'access_token', value: 'data.access_token' }, + { key: 'refresh_token', value: 'data.refresh_token' }, + { key: 'expires_in', value: 'data.expires_in' } +] + +const DEFAULT_JSON_HEADERS = [{ key: 'Content-Type', value: 'application/json' }] + +const DEFAULT_FORM_HEADERS = [ + { key: 'Content-Type', value: 'application/x-www-form-urlencoded' } +] + +const DEFAULT_GATEWAY_BODY = [ + { key: 'client_id', value: '${secret.client_id}' }, + { key: 'client_secret', value: '${secret.client_secret}' }, + { key: 'user_id', value: '${context.user_id}' }, + { key: 'department_id', value: '${context.department_id}' } +] + +const DEFAULT_CLIENT_CREDENTIALS_BODY = [ + { key: 'grant_type', value: 'client_credentials' }, + { key: 'client_id', value: '${secret.client_id}' }, + { key: 'client_secret', value: '${secret.client_secret}' } +] + +const normalizeText = (value) => String(value ?? '').trim() + +export const objectToKeyValueRows = (value, fallbackRows = [{ key: '', value: '' }]) => { + if (!value || typeof value !== 'object' || Array.isArray(value)) { + return fallbackRows.map((row) => ({ ...row })) + } + + const rows = Object.entries(value).map(([key, rowValue]) => ({ + key, + value: rowValue == null ? '' : String(rowValue) + })) + return rows.length > 0 ? rows : fallbackRows.map((row) => ({ ...row })) +} + +export const keyValueRowsToObject = (rows) => { + const entries = (Array.isArray(rows) ? rows : []) + .map((row) => ({ + key: normalizeText(row?.key), + value: row?.value == null ? '' : String(row.value) + })) + .filter((row) => row.key) + + if (entries.length === 0) { + return {} + } + return Object.fromEntries(entries.map((row) => [row.key, row.value])) +} + +const createDefaultInjectEntries = (provider) => { + if (provider === 'stdio_env') { + return [{ name: 'MCP_ACCESS_TOKEN', value_template: '${secret.access_token}' }] + } + if (provider === 'bound_secret') { + return [{ name: 'Authorization', value_template: 'Bearer ${secret.access_token}' }] + } + return [{ name: 'Authorization', value_template: 'Bearer ${access_token}' }] +} + +export const createDefaultAuthBuilderForm = (provider = 'none') => { + const normalizedProvider = provider === 'none' ? 'none' : provider + const isClientCredentials = normalizedProvider === 'client_credentials' + const isEnvProvider = normalizedProvider === 'stdio_env' + + return { + provider: normalizedProvider, + bindingScope: 'department', + manifestScope: 'binding', + injectTarget: isEnvProvider ? 'env' : 'headers', + injectEntries: + normalizedProvider === 'none' ? [] : createDefaultInjectEntries(normalizedProvider), + preRefreshSeconds: TOKEN_PROVIDERS.has(normalizedProvider) ? 300 : 0, + retryOnceOn401: TOKEN_PROVIDERS.has(normalizedProvider), + tokenUrl: '', + tokenMethod: 'POST', + tokenBodyType: isClientCredentials ? 'form' : 'json', + tokenHeaders: isClientCredentials ? DEFAULT_FORM_HEADERS.map((row) => ({ ...row })) : DEFAULT_JSON_HEADERS.map((row) => ({ ...row })), + tokenBodyTemplate: isClientCredentials + ? DEFAULT_CLIENT_CREDENTIALS_BODY.map((row) => ({ ...row })) + : DEFAULT_GATEWAY_BODY.map((row) => ({ ...row })), + tokenResponseMap: DEFAULT_RESPONSE_MAP.map((row) => ({ ...row })) + } +} + +export const isAuthConfigSupportedByBuilder = (config) => { + if (!config || Object.keys(config).length === 0) { + return true + } + return FORM_AUTH_PROVIDERS.has(config.provider) +} + +export const authConfigToBuilderForm = (config) => { + if (!config || Object.keys(config).length === 0) { + return createDefaultAuthBuilderForm() + } + + const provider = FORM_AUTH_PROVIDERS.has(config.provider) ? config.provider : 'custom_http_token' + const form = createDefaultAuthBuilderForm(provider) + const tokenRequest = config.token_request || {} + + form.bindingScope = config.binding_scope || form.bindingScope + form.manifestScope = config.manifest_scope || form.manifestScope + form.injectTarget = config.inject?.target || form.injectTarget + form.injectEntries = + Array.isArray(config.inject?.entries) && config.inject.entries.length > 0 + ? config.inject.entries.map((entry) => ({ + name: entry.name || '', + value_template: entry.value_template || '' + })) + : form.injectEntries + form.preRefreshSeconds = + Number.isFinite(Number(config.refresh_policy?.pre_refresh_seconds)) + ? Number(config.refresh_policy.pre_refresh_seconds) + : form.preRefreshSeconds + form.retryOnceOn401 = + typeof config.refresh_policy?.retry_once_on_401 === 'boolean' + ? config.refresh_policy.retry_once_on_401 + : form.retryOnceOn401 + form.tokenUrl = tokenRequest.url || '' + form.tokenMethod = tokenRequest.method || form.tokenMethod + form.tokenBodyType = tokenRequest.body_type || form.tokenBodyType + form.tokenHeaders = objectToKeyValueRows(tokenRequest.headers, form.tokenHeaders) + form.tokenBodyTemplate = objectToKeyValueRows(tokenRequest.body_template, form.tokenBodyTemplate) + form.tokenResponseMap = objectToKeyValueRows(tokenRequest.response_map, form.tokenResponseMap) + + return form +} + +const normalizeInjectEntries = (entries) => + (Array.isArray(entries) ? entries : []) + .map((entry) => ({ + name: normalizeText(entry?.name), + value_template: String(entry?.value_template ?? '').trim() + })) + .filter((entry) => entry.name) + +export const buildAuthConfigFromBuilderForm = (form) => { + if (!form || form.provider === 'none') { + return null + } + + const provider = FORM_AUTH_PROVIDERS.has(form.provider) ? form.provider : 'custom_http_token' + const config = { + version: 1, + provider, + binding_scope: form.bindingScope || 'department', + manifest_scope: form.manifestScope || 'binding', + inject: { + target: form.injectTarget || 'headers', + entries: normalizeInjectEntries(form.injectEntries) + }, + refresh_policy: { + pre_refresh_seconds: Number(form.preRefreshSeconds) || 0, + retry_once_on_401: Boolean(form.retryOnceOn401) + } + } + + if (TOKEN_PROVIDERS.has(provider)) { + config.token_request = { + url: normalizeText(form.tokenUrl), + method: normalizeText(form.tokenMethod || 'POST').toUpperCase(), + body_type: form.tokenBodyType || 'json', + headers: keyValueRowsToObject(form.tokenHeaders), + body_template: keyValueRowsToObject(form.tokenBodyTemplate), + response_map: keyValueRowsToObject(form.tokenResponseMap) + } + } + + return config +} + +export const extractSecretFieldNames = (value, fields = new Set()) => { + if (typeof value === 'string') { + const pattern = /\$\{secret\.([^}]+)\}/g + let match = pattern.exec(value) + while (match) { + fields.add(match[1]) + match = pattern.exec(value) + } + return [...fields] + } + + if (Array.isArray(value)) { + value.forEach((item) => extractSecretFieldNames(item, fields)) + return [...fields] + } + + if (value && typeof value === 'object') { + Object.values(value).forEach((item) => extractSecretFieldNames(item, fields)) + } + + return [...fields] +} From b1a44c995e37986a6f0ec1e04097812996e70508 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Thu, 4 Jun 2026 14:10:57 +0800 Subject: [PATCH 02/36] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8DMCP=E4=B8=AA?= =?UTF-8?q?=E4=BA=BA=E8=BF=9E=E6=8E=A5=E4=BB=A3=E7=90=86=E8=B6=8A=E6=9D=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/server/routers/mcp_internal_router.py | 4 ++ .../unit/routers/test_mcp_internal_router.py | 48 +++++++++++++++ .../services/test_mcp_service_auth_runtime.py | 61 +++++++++++++++++++ docs/develop-guides/roadmap.md | 2 +- 4 files changed, 114 insertions(+), 1 deletion(-) diff --git a/backend/server/routers/mcp_internal_router.py b/backend/server/routers/mcp_internal_router.py index b0dc3d827..177044cd6 100644 --- a/backend/server/routers/mcp_internal_router.py +++ b/backend/server/routers/mcp_internal_router.py @@ -73,6 +73,10 @@ async def proxy_mcp_server_request( try: connection = await _load_active_connection(db, server=server, auth_context=auth_context) + auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) + if auth_config.binding_scope != "inline" and connection is None: + raise HTTPException(status_code=403, detail="当前用户没有该 MCP 的有效连接") + body = await request.body() upstream_response = await proxy_mcp_request( server, diff --git a/backend/test/unit/routers/test_mcp_internal_router.py b/backend/test/unit/routers/test_mcp_internal_router.py index c7db1b71d..e7b5a9d18 100644 --- a/backend/test/unit/routers/test_mcp_internal_router.py +++ b/backend/test/unit/routers/test_mcp_internal_router.py @@ -82,3 +82,51 @@ def test_internal_proxy_route_requires_internal_token(): client = TestClient(_build_app()) resp = client.post("/api/internal/mcp-proxy/finance-proxy", json={"jsonrpc": "2.0", "id": 1}) assert resp.status_code == 401, resp.text + + +def test_internal_proxy_route_rejects_user_scoped_request_without_active_connection(monkeypatch): + class DummyServer: + name = "personal-proxy" + transport = "streamable_http" + auth_config_json = { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "user", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + } + + async def fake_get_mcp_server(db, name): + del db + assert name == "personal-proxy" + return DummyServer() + + async def fake_load_connection(db, *, server, auth_context): + del db + assert server.name == "personal-proxy" + assert auth_context.user_id == "user-2" + return None + + async def fake_proxy_mcp_request(server, **kwargs): + del server, kwargs + raise AssertionError("proxy request should not run without an active user connection") + + monkeypatch.setattr( + "server.routers.mcp_internal_router.decode_proxy_access_token", + lambda token, server_name: AuthContext(user_id="user-2", department_id="dep-1"), + ) + monkeypatch.setattr("server.routers.mcp_internal_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_internal_router._load_active_connection", fake_load_connection) + monkeypatch.setattr("server.routers.mcp_internal_router.proxy_mcp_request", fake_proxy_mcp_request) + + client = TestClient(_build_app()) + resp = client.post( + "/api/internal/mcp-proxy/personal-proxy", + headers={"X-Yuxi-MCP-Proxy-Token": "test-token", "content-type": "application/json"}, + json={"jsonrpc": "2.0", "id": 1}, + ) + + assert resp.status_code == 403, resp.text diff --git a/backend/test/unit/services/test_mcp_service_auth_runtime.py b/backend/test/unit/services/test_mcp_service_auth_runtime.py index 3a5273ad6..75932ba47 100644 --- a/backend/test/unit/services/test_mcp_service_auth_runtime.py +++ b/backend/test/unit/services/test_mcp_service_auth_runtime.py @@ -75,6 +75,67 @@ async def test_get_runtime_mcp_server_config_resolves_department_connection(runt assert config["headers"]["Authorization"] == "Bearer dept-token" +async def test_get_enabled_mcp_tools_does_not_reuse_user_connection_for_other_user(runtime_session, monkeypatch): + server = MCPServer( + name="personal-gateway", + transport="streamable_http", + url="http://personal.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "user", + "manifest_scope": "server", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + runtime_session.add(server) + runtime_session.add( + MCPConnection( + server_name="personal-gateway", + scope_type="user", + scope_id="user-1", + status="active", + credential_blob=json.dumps({"secrets": {"access_token": "user-1-token"}}), + created_by="tester", + updated_by="tester", + ) + ) + await runtime_session.commit() + + captured_configs: list[dict] = [] + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled_tools=None, **kwargs): + del disabled_tools, kwargs + assert server_name == "personal-gateway" + captured_configs.append(additional_servers[server_name]) + return ["private-tool"] + + monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + + user_1_tools = await mcp_service.get_enabled_mcp_tools( + "personal-gateway", + auth_context=AuthContext(user_id="user-1"), + db=runtime_session, + ) + + with pytest.raises(ValueError, match="Active MCP connection not found"): + await mcp_service.get_enabled_mcp_tools( + "personal-gateway", + auth_context=AuthContext(user_id="user-2"), + db=runtime_session, + ) + + assert user_1_tools == ["private-tool"] + assert len(captured_configs) == 1 + assert captured_configs[0]["headers"]["Authorization"] == "Bearer user-1-token" + + async def test_get_enabled_mcp_tools_uses_runtime_mcp_config(monkeypatch): captured: list[dict] = [] diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index 09ae14ca8..d040300df 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -37,7 +37,7 @@ ### 0.6.3 开发记录 - 修复 DeepAgent 未绑定 `DeepContext`,导致深度分析专用系统提示词和子智能体默认模型配置未生效的问题;同时避免运行时重复注入默认提示词。 -- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端。 +- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用。 --- From 815a5706098899bbb0ea05c7a1b6998906c7648b Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Thu, 4 Jun 2026 16:40:37 +0800 Subject: [PATCH 03/36] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20MCP=20?= =?UTF-8?q?=E9=89=B4=E6=9D=83=E5=B7=A5=E5=85=B7=E7=BC=93=E5=AD=98=E9=9A=94?= =?UTF-8?q?=E7=A6=BB=E4=B8=8E=20Redis=20=E5=88=86=E7=BA=A7=E7=BC=93?= =?UTF-8?q?=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../middlewares/runtime_config_middleware.py | 12 +- .../yuxi/services/mcp_auth/proxy_service.py | 13 - backend/package/yuxi/services/mcp_service.py | 267 ++++++++++++++++-- .../package/yuxi/services/mcp_tool_cache.py | 97 +++++++ backend/server/routers/mcp_router.py | 20 +- .../test/integration/api/test_mcp_router.py | 19 ++ .../test_runtime_config_middleware.py | 99 +++++++ backend/test/unit/routers/test_mcp_router.py | 88 ++++++ .../test/unit/services/test_mcp_service.py | 203 ++++++++++++- .../services/test_mcp_service_auth_runtime.py | 46 +++ .../test/unit/services/test_mcp_tool_cache.py | 62 ++++ docs/develop-guides/roadmap.md | 2 +- 12 files changed, 876 insertions(+), 52 deletions(-) create mode 100644 backend/package/yuxi/services/mcp_tool_cache.py create mode 100644 backend/test/unit/services/test_mcp_tool_cache.py diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index d891225e3..5f07fef16 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -90,14 +90,18 @@ async def awrap_model_call( # 获取上下文配置的工具 enabled_tools = await self.get_tools_from_context(runtime_context) existing_tools = list(request.tools or []) - enabled_tool_names = {t.name for t in enabled_tools} managed_tool_names = {t.name for t in self.tools} merged_tools = [] for t_bind in existing_tools: - # (1) 已启用的工具保留 - # (2) 非本中间件管理的工具保留 - if t_bind.name in enabled_tool_names or t_bind.name not in managed_tool_names: + # 非本中间件管理的工具保留;本中间件管理的工具统一用本轮实时加载结果覆盖。 + if t_bind.name not in managed_tool_names: merged_tools.append(t_bind) + merged_tool_names = {t.name for t in merged_tools} + for tool in enabled_tools: + if tool.name in merged_tool_names: + continue + merged_tools.append(tool) + merged_tool_names.add(tool.name) overrides["tools"] = merged_tools logger.debug(f"RuntimeConfigMiddleware selected tools: {[t.name for t in merged_tools]}") diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py index fb9cff682..7deeda0ff 100644 --- a/backend/package/yuxi/services/mcp_auth/proxy_service.py +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -71,7 +71,6 @@ def build_internal_proxy_url(proxy_base_url: str, server_name: str) -> str: def build_proxy_runtime_config( server: MCPServer, *, - auth_config: MCPAuthConfig, auth_context: AuthContext, proxy_base_url: str, ) -> dict[str, Any]: @@ -81,18 +80,6 @@ def build_proxy_runtime_config( headers[INTERNAL_PROXY_TOKEN_HEADER] = create_proxy_access_token(server.name, auth_context) config["headers"] = headers config["url"] = build_internal_proxy_url(proxy_base_url, server.name) - if auth_config.manifest_scope == "binding": - if auth_config.binding_scope == "department": - partition = f"department:{auth_context.department_id or 'unknown'}" - elif auth_config.binding_scope == "user": - partition = f"user:{auth_context.user_id or 'unknown'}" - else: - partition = f"{auth_config.binding_scope}:global" - config["__yuxi_cache_partition"] = partition - config["__yuxi_allow_global_cache"] = False - else: - config["__yuxi_cache_partition"] = "server" - config["__yuxi_allow_global_cache"] = True return config diff --git a/backend/package/yuxi/services/mcp_service.py b/backend/package/yuxi/services/mcp_service.py index 45d724745..9cf9dc7e7 100644 --- a/backend/package/yuxi/services/mcp_service.py +++ b/backend/package/yuxi/services/mcp_service.py @@ -16,6 +16,7 @@ import traceback from collections.abc import Callable from datetime import UTC, datetime +from types import SimpleNamespace from typing import Any, cast from langchain_mcp_adapters.client import MultiServerMCPClient @@ -30,6 +31,7 @@ should_use_internal_proxy, ) from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache +from yuxi.services.mcp_tool_cache import RedisMcpToolCache from yuxi.storage.postgres.models_business import AgentConfig, MCPConnection, MCPServer, Skill from yuxi.utils import logger @@ -43,6 +45,7 @@ # 本地仅缓存工具对象。配置始终以数据库为准,每次按 server_name 现查。 # cache key 使用 server_name:config_hash,当配置变化时会自然失效。 _mcp_tools_cache: dict[str, list[Callable[..., Any]]] = {} +_mcp_tool_cache_store = RedisMcpToolCache() # MCP tools statistics (for reporting enabled/disabled counts) _mcp_tools_stats: dict[str, dict[str, int]] = {} @@ -245,6 +248,140 @@ def _extract_cache_identity(server_config: dict[str, Any]) -> tuple[dict[str, An return cache_identity, cache_partition, allow_global_cache +async def _build_mcp_tool_cache_descriptor(server_name: str, server_config: dict[str, Any]) -> dict[str, Any]: + cache_identity, cache_partition, allow_global_cache = _extract_cache_identity(server_config) + config_payload = json.dumps(cache_identity, sort_keys=True, ensure_ascii=True, separators=(",", ":")) + config_hash = hashlib.sha256(config_payload.encode("utf-8")).hexdigest()[:16] + server_revision = await _mcp_tool_cache_store.get_server_revision(server_name) + partition_revision = 0 + if not allow_global_cache: + partition_revision = await _mcp_tool_cache_store.get_partition_revision(server_name, cache_partition) + revision_token = f"s{server_revision}:p{partition_revision}" + cache_prefix = f"{server_name}:{cache_partition}:{revision_token}:" + return { + "cache_identity": cache_identity, + "cache_partition": cache_partition, + "allow_global_cache": allow_global_cache, + "config_hash": config_hash, + "cache_prefix": cache_prefix, + "cache_key": f"{cache_prefix}{config_hash}", + "server_revision": server_revision, + "partition_revision": partition_revision, + } + + +def _serialize_mcp_tools_manifest( + *, + server_name: str, + cache_partition: str, + cache_key: str, + tools: list[Callable[..., Any]], +) -> dict[str, Any]: + entries = [] + for tool in tools: + if hasattr(tool, "args_schema") and tool.args_schema: + schema = tool.args_schema.schema() if hasattr(tool.args_schema, "schema") else {} + parameters = schema.get("properties", {}) + required = schema.get("required", []) + else: + parameters = {} + required = [] + metadata = dict(getattr(tool, "metadata", {}) or {}) + entries.append( + { + "name": tool.name, + "id": metadata.get("id") or tool.name, + "description": getattr(tool, "description", ""), + "parameters": parameters, + "required": required, + } + ) + return { + "server_name": server_name, + "cache_partition": cache_partition, + "cache_key": cache_key, + "tools": entries, + } + + +def _deserialize_mcp_tool_manifest(manifest: dict[str, Any]) -> list[Callable[..., Any]]: + tools: list[Callable[..., Any]] = [] + for entry in manifest.get("tools", []): + args_schema = None + parameters = entry.get("parameters") or {} + required = entry.get("required") or [] + if parameters or required: + args_schema = SimpleNamespace( + schema=lambda parameters=parameters, required=required: { + "properties": parameters, + "required": required, + } + ) + tools.append( + SimpleNamespace( + name=entry.get("name") or "", + description=entry.get("description") or "", + metadata={"id": entry.get("id") or entry.get("name") or ""}, + args_schema=args_schema, + ) + ) + return tools + + +def _resolve_runtime_tool_cache_partition( + *, + auth_config: MCPAuthConfig, + auth_context: AuthContext | None, + connection: MCPConnection | None, +) -> tuple[str, bool]: + if auth_config.binding_scope in {"department", "user"}: + connection_id = getattr(connection, "id", None) + if connection_id is not None: + return f"connection:{connection_id}", False + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) + if scope_id is None: + raise ValueError(f"auth_context is required for MCP binding scope '{auth_config.binding_scope}'") + return f"{auth_config.binding_scope}:{scope_id}", False + return "server", True + + +def _apply_runtime_tool_cache_policy( + config: dict[str, Any], + *, + auth_config: MCPAuthConfig, + auth_context: AuthContext | None, + connection: MCPConnection | None, +) -> dict[str, Any]: + partition, allow_global_cache = _resolve_runtime_tool_cache_partition( + auth_config=auth_config, + auth_context=auth_context, + connection=connection, + ) + config["__yuxi_cache_partition"] = partition + config["__yuxi_allow_global_cache"] = allow_global_cache + return config + + +def _get_mcp_auth_config(server_config: dict[str, Any]) -> MCPAuthConfig | None: + auth_payload = server_config.get("auth_config") or {} + if not auth_payload: + return None + try: + return MCPAuthConfig.model_validate(auth_payload) + except Exception as exc: + logger.warning(f"Invalid MCP auth config while resolving tool preload strategy: {exc}") + return None + + +def _can_preload_mcp_server_tools_without_runtime_auth(server_config: dict[str, Any]) -> bool: + if not (server_config.get("auth_config") or {}): + return True + auth_config = _get_mcp_auth_config(server_config) + if auth_config is None: + return False + return auth_config.provider == "legacy_static" + + def _resolve_scope_id(binding_scope: str, auth_context: AuthContext | None) -> str | None: if binding_scope == "inline": return None @@ -331,17 +468,23 @@ async def get_runtime_mcp_server_config( ) proxy_base_url = _get_internal_mcp_proxy_base_url() if should_use_internal_proxy(server, auth_config, proxy_base_url): - return build_proxy_runtime_config( + config = build_proxy_runtime_config( server, - auth_config=auth_config, auth_context=auth_context or AuthContext(), proxy_base_url=proxy_base_url or "", ) - return await resolve_runtime_mcp_config( - server, - auth_context=auth_context or AuthContext(), + else: + config = await resolve_runtime_mcp_config( + server, + auth_context=auth_context or AuthContext(), + connection=connection, + http_client=http_client, + ) + return _apply_runtime_tool_cache_policy( + config, + auth_config=auth_config, + auth_context=auth_context, connection=connection, - http_client=http_client, ) from yuxi.storage.postgres.manager import pg_manager @@ -391,15 +534,10 @@ async def get_mcp_tools( logger.warning(f"MCP server '{server_name}' not found in database or disabled") return [] - # 配置 hash 直接基于完整配置生成。只要数据库中的配置发生变化, - # 本地工具缓存 key 就会变化,从而自然触发重建。 - cache_identity, cache_partition, allow_global_cache = _extract_cache_identity(server_config) - config_payload = json.dumps(cache_identity, sort_keys=True, ensure_ascii=True, separators=(",", ":")) - config_hash = hashlib.sha256(config_payload.encode("utf-8")).hexdigest()[:16] - if allow_global_cache: - cache_key = f"{server_name}:{config_hash}" - else: - cache_key = f"{server_name}:{cache_partition}:{config_hash}" + cache_descriptor = await _build_mcp_tool_cache_descriptor(server_name, server_config) + cache_partition = cache_descriptor["cache_partition"] + cache_prefix = cache_descriptor["cache_prefix"] + cache_key = cache_descriptor["cache_key"] all_processed_tools: list[Callable[..., Any]] = [] @@ -437,12 +575,19 @@ async def get_mcp_tools( if cache: async with _mcp_lock: - stale_keys = [ - key for key in _mcp_tools_cache if key.startswith(f"{server_name}:") and key != cache_key - ] + stale_keys = [key for key in _mcp_tools_cache if key.startswith(cache_prefix) and key != cache_key] for stale_key in stale_keys: _mcp_tools_cache.pop(stale_key, None) _mcp_tools_cache[cache_key] = all_processed_tools + await _mcp_tool_cache_store.set_manifest( + cache_key, + _serialize_mcp_tools_manifest( + server_name=server_name, + cache_partition=cache_partition, + cache_key=cache_key, + tools=all_processed_tools, + ), + ) global_config_disabled = server_config.get("disabled_tools") or [] enabled_count = len([t for t in all_processed_tools if t.name not in global_config_disabled]) @@ -479,8 +624,11 @@ async def get_tools_from_all_servers() -> list[Callable[..., Any]]: """Get all tools from all configured MCP servers.""" server_configs = await _load_enabled_mcp_server_configs() all_tools = [] - for server_name in server_configs: - tools = await get_mcp_tools(server_name, additional_servers=server_configs) + for server_name, server_config in server_configs.items(): + if not _can_preload_mcp_server_tools_without_runtime_auth(server_config): + logger.info(f"Skip MCP tool preload for '{server_name}' because runtime auth context is required") + continue + tools = await get_mcp_tools(server_name, additional_servers={server_name: server_config}) all_tools.extend(tools) return all_tools @@ -503,6 +651,37 @@ def clear_mcp_server_tools_cache(server_name: str) -> None: logger.info(f"Cleared tools cache for MCP server '{server_name}'") +def clear_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: + if connection_id is None: + return + global _mcp_tools_cache + cache_prefix = f"{server_name}:connection:{connection_id}:" + stale_keys = [key for key in _mcp_tools_cache if key.startswith(cache_prefix)] + for stale_key in stale_keys: + _mcp_tools_cache.pop(stale_key, None) + if stale_keys: + logger.info(f"Cleared tools cache for MCP connection {connection_id} on server '{server_name}'") + + +async def invalidate_mcp_server_tools_cache(server_name: str) -> None: + clear_mcp_server_tools_cache(server_name) + await _mcp_tool_cache_store.bump_server_revision(server_name) + + +async def invalidate_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: + clear_mcp_connection_tools_cache(server_name, connection_id) + if connection_id is None: + return + await _mcp_tool_cache_store.bump_partition_revision(server_name, f"connection:{connection_id}") + + +async def _invalidate_mcp_tools_cache_for_connection(connection: MCPConnection) -> None: + if connection.scope_type == "system": + await invalidate_mcp_server_tools_cache(connection.server_name) + return + await invalidate_mcp_connection_tools_cache(connection.server_name, getattr(connection, "id", None)) + + async def _clear_mcp_connection_runtime_auth_cache(connection_id: int | None) -> None: if connection_id is None: return @@ -663,6 +842,7 @@ async def update_mcp_connection( await db.refresh(connection) if should_clear_runtime_auth_cache: await _clear_mcp_connection_runtime_auth_cache(connection.id) + await _invalidate_mcp_tools_cache_for_connection(connection) return connection @@ -671,9 +851,15 @@ async def delete_mcp_connection(db: AsyncSession, connection_id: int) -> bool: if connection is None: return False deleted_connection_id = connection.id + deleted_server_name = connection.server_name + deleted_scope_type = connection.scope_type await db.delete(connection) await db.commit() await _clear_mcp_connection_runtime_auth_cache(deleted_connection_id) + if deleted_scope_type == "system": + await invalidate_mcp_server_tools_cache(deleted_server_name) + else: + await invalidate_mcp_connection_tools_cache(deleted_server_name, deleted_connection_id) return True @@ -694,6 +880,7 @@ async def set_mcp_connection_status( await db.commit() await db.refresh(connection) await _clear_mcp_connection_runtime_auth_cache(connection.id) + await _invalidate_mcp_tools_cache_for_connection(connection) return connection @@ -717,6 +904,7 @@ async def reauthorize_mcp_connection( await cache.release_refresh_lock(connection.id) except Exception as exc: logger.warning(f"Failed to clear MCP refresh lock for connection {connection.id}: {exc}") + await _invalidate_mcp_tools_cache_for_connection(connection) connection.status = "active" meta_json = dict(connection.meta_json or {}) @@ -814,7 +1002,7 @@ async def create_mcp_server( await db.refresh(server) await _clear_mcp_server_runtime_auth_cache(db, name) - clear_mcp_server_tools_cache(name) + await invalidate_mcp_server_tools_cache(name) logger.info(f"Created MCP server '{name}'") return server @@ -872,7 +1060,7 @@ async def update_mcp_server( await db.commit() await db.refresh(server) - clear_mcp_server_tools_cache(name) + await invalidate_mcp_server_tools_cache(name) logger.info(f"Updated MCP server '{name}'") return server @@ -890,7 +1078,7 @@ async def delete_mcp_server(db: AsyncSession, name: str) -> bool: for connection_id in connection_ids: await _clear_mcp_connection_runtime_auth_cache(connection_id) - clear_mcp_server_tools_cache(name) + await invalidate_mcp_server_tools_cache(name) logger.info(f"Deleted MCP server '{name}'") return True @@ -947,7 +1135,7 @@ async def set_server_enabled( is_enabled = bool(server.enabled) if not is_enabled: await _clear_mcp_server_runtime_auth_cache(db, name) - clear_mcp_server_tools_cache(name) + await invalidate_mcp_server_tools_cache(name) logger.info(f"Set MCP server '{name}' enabled={is_enabled}") return is_enabled, server @@ -1046,11 +1234,17 @@ async def get_servers_config(names: list[str]) -> dict[str, dict[str, Any]]: return await _load_enabled_mcp_server_configs(names=names) -async def get_all_mcp_tools(server_name: str) -> list: +async def get_all_mcp_tools( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, + force_refresh: bool = False, +) -> list: """Get all tools of an MCP server (no filtering). For management UI to display tool list, supports viewing all tools and their enabled status. - Does NOT update the global tools cache to avoid polluting agent's filtered view. Args: server_name: Server name @@ -1058,16 +1252,29 @@ async def get_all_mcp_tools(server_name: str) -> list: Returns: List of all tools (unfiltered) """ - config = await get_enabled_mcp_server_config(server_name) + if auth_context is None and db is None: + config = await get_enabled_mcp_server_config(server_name) + else: + config = await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=db, + http_client=http_client, + ) if config is None: logger.warning(f"MCP server '{server_name}' not found in database or disabled") return [] - # Get all tools (no filtering, force refresh, no cache update) + if not force_refresh: + cache_descriptor = await _build_mcp_tool_cache_descriptor(server_name, config) + manifest = await _mcp_tool_cache_store.get_manifest(cache_descriptor["cache_key"]) + if manifest is not None: + return _deserialize_mcp_tool_manifest(manifest) + return await get_mcp_tools( server_name, additional_servers={server_name: config}, disabled_tools=[], - cache=False, - force_refresh=True, + cache=True, + force_refresh=force_refresh, ) diff --git a/backend/package/yuxi/services/mcp_tool_cache.py b/backend/package/yuxi/services/mcp_tool_cache.py new file mode 100644 index 000000000..f0d9cc115 --- /dev/null +++ b/backend/package/yuxi/services/mcp_tool_cache.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import json +import os +from collections.abc import Awaitable, Callable +from typing import Any + +from yuxi.services.run_queue_service import get_redis_client +from yuxi.utils import logger + +SERVER_REVISION_KEY_PREFIX = "yuxi:mcp:tool_cache:server_revision:v1" +PARTITION_REVISION_KEY_PREFIX = "yuxi:mcp:tool_cache:partition_revision:v1" +MANIFEST_KEY_PREFIX = "yuxi:mcp:tool_cache:manifest:v1" +MANIFEST_TTL_SECONDS = int(os.getenv("YUXI_MCP_TOOL_MANIFEST_TTL_SECONDS", "3600")) + + +def _server_revision_key(server_name: str) -> str: + return f"{SERVER_REVISION_KEY_PREFIX}:{server_name}" + + +def _partition_revision_key(server_name: str, cache_partition: str) -> str: + return f"{PARTITION_REVISION_KEY_PREFIX}:{server_name}:{cache_partition}" + + +def _manifest_key(cache_key: str) -> str: + return f"{MANIFEST_KEY_PREFIX}:{cache_key}" + + +class RedisMcpToolCache: + def __init__(self, redis_client_factory: Callable[[], Awaitable[Any]] | None = None): + self._redis_client_factory = redis_client_factory or get_redis_client + + async def _get_redis(self): + return await self._redis_client_factory() + + async def get_server_revision(self, server_name: str) -> int: + return await self._get_revision(_server_revision_key(server_name)) + + async def get_partition_revision(self, server_name: str, cache_partition: str) -> int: + return await self._get_revision(_partition_revision_key(server_name, cache_partition)) + + async def bump_server_revision(self, server_name: str) -> int: + return await self._bump_revision(_server_revision_key(server_name)) + + async def bump_partition_revision(self, server_name: str, cache_partition: str) -> int: + return await self._bump_revision(_partition_revision_key(server_name, cache_partition)) + + async def get_manifest(self, cache_key: str) -> dict[str, Any] | None: + try: + redis = await self._get_redis() + raw = await redis.get(_manifest_key(cache_key)) + except Exception as exc: + logger.warning(f"Failed to read MCP tool manifest cache for '{cache_key}': {exc}") + return None + if not raw: + return None + if isinstance(raw, dict): + return raw + try: + return json.loads(raw) + except Exception as exc: + logger.warning(f"Failed to decode MCP tool manifest cache for '{cache_key}': {exc}") + return None + + async def set_manifest(self, cache_key: str, manifest: dict[str, Any]) -> None: + try: + redis = await self._get_redis() + await redis.set( + _manifest_key(cache_key), + json.dumps(manifest, ensure_ascii=False, separators=(",", ":")), + ex=MANIFEST_TTL_SECONDS, + ) + except Exception as exc: + logger.warning(f"Failed to write MCP tool manifest cache for '{cache_key}': {exc}") + + async def _get_revision(self, key: str) -> int: + try: + redis = await self._get_redis() + raw = await redis.get(key) + except Exception as exc: + logger.warning(f"Failed to read MCP tool revision cache for '{key}': {exc}") + return 0 + if raw is None: + return 0 + try: + return int(raw) + except (TypeError, ValueError): + logger.warning(f"Invalid MCP tool revision cache value for '{key}': {raw}") + return 0 + + async def _bump_revision(self, key: str) -> int: + try: + redis = await self._get_redis() + return int(await redis.incr(key)) + except Exception as exc: + logger.warning(f"Failed to bump MCP tool revision cache for '{key}': {exc}") + return 0 diff --git a/backend/server/routers/mcp_router.py b/backend/server/routers/mcp_router.py index 68a68602e..f0bd081ac 100644 --- a/backend/server/routers/mcp_router.py +++ b/backend/server/routers/mcp_router.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession +from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_service import ( create_mcp_connection, @@ -133,6 +134,15 @@ def _validate_auth_config_or_400(payload: dict | None) -> dict | None: raise HTTPException(status_code=400, detail=f"auth_config 配置无效: {exc}") from exc +def _auth_context_from_user(current_user: User) -> AuthContext: + department_id = getattr(current_user, "department_id", None) + user_id = getattr(current_user, "user_id", None) + return AuthContext( + user_id=str(user_id) if user_id is not None else None, + department_id=str(department_id) if department_id is not None else None, + ) + + # ============================================================================= # === MCP 服务器 CRUD === # ============================================================================= @@ -579,10 +589,11 @@ async def get_mcp_server_tools( try: server = await get_server_or_404(db, name) disabled_tools = server.disabled_tools or [] + auth_context = _auth_context_from_user(current_user) try: # 获取所有工具(不过滤 disabled_tools) - tools = await get_all_mcp_tools(name) + tools = await get_all_mcp_tools(name, auth_context=auth_context, db=db) tool_list = [] for tool in tools: @@ -610,6 +621,8 @@ async def get_mcp_server_tools( "data": tool_list, "total": len(tool_list), } + except ValueError as tool_error: + raise HTTPException(status_code=403, detail=f"获取工具失败: {str(tool_error)}") except Exception as tool_error: logger.error(f"Failed to get tools from MCP server '{name}': {tool_error}") raise HTTPException(status_code=500, detail=f"获取工具失败: {str(tool_error)}") @@ -629,10 +642,11 @@ async def refresh_mcp_server_tools( """刷新 MCP 服务器的工具列表(清除缓存重新获取)""" try: await get_server_or_404(db, name) + auth_context = _auth_context_from_user(current_user) try: # 获取所有工具(不过滤 disabled_tools) - tools = await get_all_mcp_tools(name) + tools = await get_all_mcp_tools(name, auth_context=auth_context, db=db, force_refresh=True) # 获取统计信息 stats = get_mcp_tools_stats(name) @@ -652,6 +666,8 @@ async def refresh_mcp_server_tools( "enabled_count": enabled_count, "disabled_count": disabled_count, } + except ValueError as tool_error: + raise HTTPException(status_code=403, detail=f"刷新失败: {str(tool_error)}") except Exception as tool_error: raise HTTPException(status_code=500, detail=f"刷新失败: {str(tool_error)}") except HTTPException: diff --git a/backend/test/integration/api/test_mcp_router.py b/backend/test/integration/api/test_mcp_router.py index b4c3ffdfb..026400313 100644 --- a/backend/test/integration/api/test_mcp_router.py +++ b/backend/test/integration/api/test_mcp_router.py @@ -278,6 +278,25 @@ async def test_bound_auth_server_test_endpoint_requires_connection_level_testing await _cleanup_server(test_client, admin_headers, server_name) +async def test_bound_auth_tools_endpoint_requires_current_admin_connection(test_client, admin_headers): + server_name = _build_server_name("pytest-mcp-bound-tools") + await _create_server(test_client, admin_headers, server_name) + + try: + response = await test_client.get(f"/api/system/mcp-servers/{server_name}/tools", headers=admin_headers) + assert response.status_code == 403, response.text + assert "Active MCP connection not found" in response.json()["detail"] + + refresh_response = await test_client.post( + f"/api/system/mcp-servers/{server_name}/tools/refresh", + headers=admin_headers, + ) + assert refresh_response.status_code == 403, refresh_response.text + assert "Active MCP connection not found" in refresh_response.json()["detail"] + finally: + await _cleanup_server(test_client, admin_headers, server_name) + + async def test_create_mcp_server_rejects_invalid_auth_config_via_real_api(test_client, admin_headers): server_name = _build_server_name("pytest-mcp-invalid-auth") diff --git a/backend/test/unit/middlewares/test_runtime_config_middleware.py b/backend/test/unit/middlewares/test_runtime_config_middleware.py index 24010b98c..7b8453558 100644 --- a/backend/test/unit/middlewares/test_runtime_config_middleware.py +++ b/backend/test/unit/middlewares/test_runtime_config_middleware.py @@ -34,3 +34,102 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= assert tools == [] assert captured == [("finance-gateway", "user-1", "dept-9")] + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_awrap_model_call_appends_runtime_loaded_mcp_tools(monkeypatch: pytest.MonkeyPatch): + runtime_tool = SimpleNamespace(name="mcp__financeGateway__query", metadata={}) + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + assert server_name == "finance-gateway" + assert auth_context.user_id == "user-1" + assert auth_context.department_id == "dept-9" + return [runtime_tool] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + class DummyRequest: + def __init__(self): + self.runtime = SimpleNamespace( + context=SimpleNamespace( + tools=[], + mcps=["finance-gateway"], + user_id="user-1", + department_id="dept-9", + model=None, + system_prompt="", + ) + ) + self.tools = [] + self.system_message = None + + def override(self, **kwargs): + clone = DummyRequest() + clone.runtime = kwargs.get("runtime", self.runtime) + clone.tools = kwargs.get("tools", self.tools) + clone.system_message = kwargs.get("system_message", self.system_message) + return clone + + middleware = RuntimeConfigMiddleware(enable_model_override=False, enable_system_prompt_override=False) + request = DummyRequest() + + async def handler(next_request): + return next_request.tools + + tools = await middleware.awrap_model_call(request, handler) + + assert tools == [runtime_tool] + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_awrap_model_call_replaces_stale_managed_tool_with_fresh_runtime_tool(monkeypatch: pytest.MonkeyPatch): + stale_tool = SimpleNamespace(name="mcp__financeGateway__query", metadata={"version": "stale"}) + fresh_tool = SimpleNamespace(name="mcp__financeGateway__query", metadata={"version": "fresh"}) + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + assert server_name == "finance-gateway" + assert auth_context.user_id == "user-1" + assert auth_context.department_id == "dept-9" + return [fresh_tool] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + class DummyRequest: + def __init__(self, tools): + self.runtime = SimpleNamespace( + context=SimpleNamespace( + tools=[], + mcps=["finance-gateway"], + user_id="user-1", + department_id="dept-9", + model=None, + system_prompt="", + ) + ) + self.tools = tools + self.system_message = None + + def override(self, **kwargs): + clone = DummyRequest(kwargs.get("tools", self.tools)) + clone.runtime = kwargs.get("runtime", self.runtime) + clone.system_message = kwargs.get("system_message", self.system_message) + return clone + + middleware = RuntimeConfigMiddleware(enable_model_override=False, enable_system_prompt_override=False) + middleware.tools = [stale_tool] + request = DummyRequest([stale_tool]) + + async def handler(next_request): + return next_request.tools + + tools = await middleware.awrap_model_call(request, handler) + + assert tools == [fresh_tool] diff --git a/backend/test/unit/routers/test_mcp_router.py b/backend/test/unit/routers/test_mcp_router.py index 1d3d969e9..7334cef27 100644 --- a/backend/test/unit/routers/test_mcp_router.py +++ b/backend/test/unit/routers/test_mcp_router.py @@ -25,6 +25,7 @@ async def fake_admin_user(): user_id="admin", password_hash="x", role="admin", + department_id=42, ) async def fake_required_user(): @@ -33,6 +34,7 @@ async def fake_required_user(): user_id="admin" if allow_admin else "user", password_hash="x", role="admin" if allow_admin else "user", + department_id=42, ) app.dependency_overrides[get_db] = fake_db @@ -536,6 +538,92 @@ async def fake_get_server_or_404(db, name): assert resp.status_code == 400, resp.text +def test_get_mcp_server_tools_uses_current_admin_auth_context(monkeypatch): + captured = {} + + class DummyServer: + disabled_tools = ["tool_b"] + + class DummyArgsSchema: + @staticmethod + def schema(): + return {"properties": {"city": {"type": "string"}}, "required": ["city"]} + + class DummyTool: + name = "tool_a" + description = "tool a" + metadata = {"id": "mcp__gateway__toolA"} + args_schema = DummyArgsSchema() + + async def fake_get_server_or_404(db, name): + del db + assert name == "gateway" + return DummyServer() + + async def fake_get_all_mcp_tools(server_name, *, auth_context=None, db=None, http_client=None, force_refresh=False): + del db, http_client, force_refresh + captured["server_name"] = server_name + captured["user_id"] = auth_context.user_id + captured["department_id"] = auth_context.department_id + return [DummyTool()] + + monkeypatch.setattr("server.routers.mcp_router.get_server_or_404", fake_get_server_or_404) + monkeypatch.setattr("server.routers.mcp_router.get_all_mcp_tools", fake_get_all_mcp_tools) + + client = TestClient(_build_app()) + resp = client.get("/api/system/mcp-servers/gateway/tools") + + assert resp.status_code == 200, resp.text + assert captured == { + "server_name": "gateway", + "user_id": "admin", + "department_id": "42", + } + payload = resp.json() + assert payload["total"] == 1 + assert payload["data"][0]["required"] == ["city"] + assert payload["data"][0]["enabled"] is True + + +def test_get_mcp_server_tools_returns_403_when_bound_connection_missing(monkeypatch): + class DummyServer: + disabled_tools = [] + + async def fake_get_server_or_404(db, name): + del db, name + return DummyServer() + + async def fake_get_all_mcp_tools(server_name, *, auth_context=None, db=None, http_client=None, force_refresh=False): + del server_name, auth_context, db, http_client, force_refresh + raise ValueError("Active MCP connection not found for server 'gateway' and scope department:42") + + monkeypatch.setattr("server.routers.mcp_router.get_server_or_404", fake_get_server_or_404) + monkeypatch.setattr("server.routers.mcp_router.get_all_mcp_tools", fake_get_all_mcp_tools) + + client = TestClient(_build_app()) + resp = client.get("/api/system/mcp-servers/gateway/tools") + + assert resp.status_code == 403, resp.text + + +def test_refresh_mcp_server_tools_returns_403_when_bound_connection_missing(monkeypatch): + async def fake_get_server_or_404(db, name): + del db, name + return type("DummyServer", (), {})() + + async def fake_get_all_mcp_tools(server_name, *, auth_context=None, db=None, http_client=None, force_refresh=False): + del server_name, auth_context, db, http_client, force_refresh + raise ValueError("Active MCP connection not found for server 'gateway' and scope department:42") + + monkeypatch.setattr("server.routers.mcp_router.get_server_or_404", fake_get_server_or_404) + monkeypatch.setattr("server.routers.mcp_router.get_all_mcp_tools", fake_get_all_mcp_tools) + + client = TestClient(_build_app()) + resp = client.post("/api/system/mcp-servers/gateway/tools/refresh") + + assert resp.status_code == 403, resp.text + + def test_delete_mcp_server_defaults_to_retire(monkeypatch): captured = {} diff --git a/backend/test/unit/services/test_mcp_service.py b/backend/test/unit/services/test_mcp_service.py index bd38ba833..aadd0094f 100644 --- a/backend/test/unit/services/test_mcp_service.py +++ b/backend/test/unit/services/test_mcp_service.py @@ -3,6 +3,8 @@ from types import SimpleNamespace from yuxi.services import mcp_service +from yuxi.services.mcp_tool_cache import RedisMcpToolCache +from yuxi.services.mcp_auth.proxy_service import INTERNAL_PROXY_TOKEN_HEADER class _FakeClient: @@ -13,6 +15,25 @@ async def get_tools(self): return self._tools +class _FakeRedis: + def __init__(self): + self.data: dict[str, str] = {} + self.expire_calls: dict[str, int] = {} + + async def get(self, key: str) -> str | None: + return self.data.get(key) + + async def set(self, key: str, value: str, ex: int | None = None) -> None: + self.data[key] = value + if ex is not None: + self.expire_calls[key] = ex + + async def incr(self, key: str) -> int: + next_value = int(self.data.get(key) or "0") + 1 + self.data[key] = str(next_value) + return next_value + + async def test_get_enabled_mcp_tools_loads_latest_config_from_db(monkeypatch): captured: list[dict] = [] @@ -109,8 +130,8 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs assert tools == ["alpha", "beta"] assert calls == [ - ("alpha", server_configs), - ("beta", server_configs), + ("alpha", {"alpha": server_configs["alpha"]}), + ("beta", {"beta": server_configs["beta"]}), ] @@ -136,3 +157,181 @@ async def fake_get_mcp_client(server_configs): mcp_service.clear_mcp_cache() + +async def test_get_mcp_tools_keeps_connection_partitions_separate(monkeypatch): + mcp_service.clear_mcp_cache() + + configs = [ + { + "transport": "streamable_http", + "url": "http://internal-api:5050/api/internal/mcp-proxy/demo", + "headers": { + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-user-a", + }, + "__yuxi_cache_partition": "connection:101", + "__yuxi_allow_global_cache": False, + }, + { + "transport": "streamable_http", + "url": "http://internal-api:5050/api/internal/mcp-proxy/demo", + "headers": { + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-user-b", + }, + "__yuxi_cache_partition": "connection:202", + "__yuxi_allow_global_cache": False, + }, + ] + build_calls: list[str] = [] + + async def fake_get_mcp_client(server_configs): + token = server_configs["demo"]["headers"][INTERNAL_PROXY_TOKEN_HEADER] + build_calls.append(token) + tool = SimpleNamespace(name=f"tool_for_{token}", metadata={}) + return _FakeClient([tool]) + + monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + + tools_a = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) + tools_b = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) + + assert [tool.name for tool in tools_a] == ["tool_for_proxy-token-user-a"] + assert [tool.name for tool in tools_b] == ["tool_for_proxy-token-user-b"] + assert build_calls == ["proxy-token-user-a", "proxy-token-user-b"] + + mcp_service.clear_mcp_cache() + + +async def test_get_tools_from_all_servers_skips_runtime_auth_servers_without_context(monkeypatch): + server_configs = { + "shared": {"transport": "stdio", "command": "cmd-shared", "disabled_tools": []}, + "bound": { + "transport": "streamable_http", + "url": "http://bound.local/mcp", + "auth_config": { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": { + "url": "http://bound.local/token", + "method": "POST", + "response_map": {"access_token": "access_token"}, + }, + }, + "disabled_tools": [], + }, + } + calls: list[tuple[str, dict[str, dict]]] = [] + + async def fake_load_enabled_mcp_server_configs(*, names=None, db=None): + del names, db + return server_configs + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs): + del kwargs + calls.append((server_name, additional_servers or {})) + return [server_name] + + monkeypatch.setattr(mcp_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) + monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await mcp_service.get_tools_from_all_servers() + + assert tools == ["shared"] + assert calls == [ + ("shared", {"shared": server_configs["shared"]}), + ] + + +async def test_get_mcp_tools_rebuilds_when_redis_server_revision_changes(monkeypatch): + mcp_service.clear_mcp_cache() + + fake_redis = _FakeRedis() + + async def fake_redis_factory(): + return fake_redis + + monkeypatch.setattr( + mcp_service, + "_mcp_tool_cache_store", + RedisMcpToolCache(redis_client_factory=fake_redis_factory), + ) + + config = {"transport": "stdio", "command": "demo-tool", "disabled_tools": []} + build_calls: list[str] = [] + + async def fake_get_mcp_client(server_configs): + build_calls.append(server_configs["demo"]["command"]) + tool = SimpleNamespace(name=f"tool_{len(build_calls)}", metadata={}) + return _FakeClient([tool]) + + monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + + tools_first = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": config}) + tools_second = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": config}) + await mcp_service._mcp_tool_cache_store.bump_server_revision("demo") + tools_third = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": config}) + + assert [tool.name for tool in tools_first] == ["tool_1"] + assert [tool.name for tool in tools_second] == ["tool_1"] + assert [tool.name for tool in tools_third] == ["tool_2"] + assert build_calls == ["demo-tool", "demo-tool"] + + mcp_service.clear_mcp_cache() + + +async def test_get_all_mcp_tools_uses_redis_manifest_when_local_cache_is_empty(monkeypatch): + mcp_service.clear_mcp_cache() + + fake_redis = _FakeRedis() + + async def fake_redis_factory(): + return fake_redis + + monkeypatch.setattr( + mcp_service, + "_mcp_tool_cache_store", + RedisMcpToolCache(redis_client_factory=fake_redis_factory), + ) + + config = {"transport": "stdio", "command": "demo-tool", "disabled_tools": []} + + async def fake_get_mcp_client(server_configs): + del server_configs + tool = SimpleNamespace( + name="alpha_tool", + description="alpha", + metadata={}, + args_schema=SimpleNamespace( + schema=lambda: { + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + ), + ) + return _FakeClient([tool]) + + async def fake_get_enabled_mcp_server_config(server_name: str, db=None): + del server_name, db + return config + + monkeypatch.setattr(mcp_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + + tools_first = await mcp_service.get_all_mcp_tools("demo") + assert [tool.name for tool in tools_first] == ["alpha_tool"] + + mcp_service.clear_mcp_cache() + + async def fail_get_mcp_client(server_configs): + raise AssertionError(f"should not fetch live tools when redis manifest is available: {server_configs}") + + monkeypatch.setattr(mcp_service, "get_mcp_client", fail_get_mcp_client) + + tools_second = await mcp_service.get_all_mcp_tools("demo") + + assert [tool.name for tool in tools_second] == ["alpha_tool"] + assert tools_second[0].metadata["id"] == "mcp__demo__alphaTool" diff --git a/backend/test/unit/services/test_mcp_service_auth_runtime.py b/backend/test/unit/services/test_mcp_service_auth_runtime.py index 75932ba47..b10969c76 100644 --- a/backend/test/unit/services/test_mcp_service_auth_runtime.py +++ b/backend/test/unit/services/test_mcp_service_auth_runtime.py @@ -180,6 +180,50 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled ] +async def test_get_all_mcp_tools_uses_runtime_mcp_config_when_auth_context_is_provided(monkeypatch): + captured: list[dict] = [] + + async def fake_get_runtime_mcp_server_config(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + assert server_name == "demo" + assert auth_context is not None + return { + "transport": "stdio", + "command": "demo-with-auth", + "disabled_tools": ["tool_b"], + } + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled_tools=None, **kwargs): + del kwargs + captured.append( + { + "server_name": server_name, + "additional_servers": additional_servers, + "disabled_tools": list(disabled_tools or []), + } + ) + return ["tool-a", "tool-b"] + + monkeypatch.setattr(mcp_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await mcp_service.get_all_mcp_tools( + "demo", + auth_context=AuthContext(user_id="u-100", department_id="d-9"), + ) + + assert tools == ["tool-a", "tool-b"] + assert captured == [ + { + "server_name": "demo", + "additional_servers": { + "demo": {"transport": "stdio", "command": "demo-with-auth", "disabled_tools": ["tool_b"]} + }, + "disabled_tools": [], + } + ] + + async def test_get_runtime_mcp_server_config_returns_internal_proxy_for_dynamic_http_provider( runtime_session, monkeypatch ): @@ -238,3 +282,5 @@ async def test_get_runtime_mcp_server_config_returns_internal_proxy_for_dynamic_ assert config["headers"]["X-App"] == "yuxi" assert "X-Yuxi-MCP-Proxy-Token" in config["headers"] assert "Authorization" not in config["headers"] + assert config["__yuxi_cache_partition"] == "connection:31" + assert config["__yuxi_allow_global_cache"] is False diff --git a/backend/test/unit/services/test_mcp_tool_cache.py b/backend/test/unit/services/test_mcp_tool_cache.py new file mode 100644 index 000000000..54c8ae23a --- /dev/null +++ b/backend/test/unit/services/test_mcp_tool_cache.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import pytest + +from yuxi.services.mcp_tool_cache import RedisMcpToolCache + + +class _FakeRedis: + def __init__(self): + self.data: dict[str, str] = {} + self.expire_calls: dict[str, int] = {} + + async def get(self, key: str) -> str | None: + return self.data.get(key) + + async def set(self, key: str, value: str, ex: int | None = None) -> None: + self.data[key] = value + if ex is not None: + self.expire_calls[key] = ex + + async def incr(self, key: str) -> int: + next_value = int(self.data.get(key) or "0") + 1 + self.data[key] = str(next_value) + return next_value + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_redis_mcp_tool_cache_revision_and_manifest_roundtrip(): + fake_redis = _FakeRedis() + + async def fake_redis_factory(): + return fake_redis + + cache = RedisMcpToolCache(redis_client_factory=fake_redis_factory) + + assert await cache.get_server_revision("demo") == 0 + assert await cache.get_partition_revision("demo", "connection:7") == 0 + + assert await cache.bump_server_revision("demo") == 1 + assert await cache.bump_partition_revision("demo", "connection:7") == 1 + + assert await cache.get_server_revision("demo") == 1 + assert await cache.get_partition_revision("demo", "connection:7") == 1 + + manifest = { + "server_name": "demo", + "cache_partition": "connection:7", + "cache_key": "demo:connection:7:s1:p1:abc123", + "tools": [ + { + "name": "alpha_tool", + "id": "mcp__demo__alphaTool", + "description": "alpha", + "parameters": {"city": {"type": "string"}}, + "required": ["city"], + } + ], + } + await cache.set_manifest("demo:connection:7:s1:p1:abc123", manifest) + + assert await cache.get_manifest("demo:connection:7:s1:p1:abc123") == manifest diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index d040300df..5c1340401 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -37,7 +37,7 @@ ### 0.6.3 开发记录 - 修复 DeepAgent 未绑定 `DeepContext`,导致深度分析专用系统提示词和子智能体默认模型配置未生效的问题;同时避免运行时重复注入默认提示词。 -- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用。 +- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果。 --- From 2bb296f574217a3d69b2ef7ad0f7cf579c2271ca Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Thu, 4 Jun 2026 19:30:31 +0800 Subject: [PATCH 04/36] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20MCP=20?= =?UTF-8?q?=E9=89=B4=E6=9D=83=E4=BB=A3=E7=90=86=E7=BC=93=E5=AD=98=E5=A4=B1?= =?UTF-8?q?=E6=95=88=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../yuxi/services/mcp_auth/proxy_service.py | 6 +- backend/package/yuxi/services/mcp_service.py | 42 +++++++++----- backend/server/routers/mcp_internal_router.py | 2 + .../unit/routers/test_mcp_internal_router.py | 47 ++++++++++++++++ .../test/unit/services/test_mcp_service.py | 49 ++++++++++++++++- .../services/test_mcp_service_auth_runtime.py | 55 +++++++++++++++++++ docs/develop-guides/roadmap.md | 2 +- 7 files changed, 183 insertions(+), 20 deletions(-) diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py index 7deeda0ff..836992d59 100644 --- a/backend/package/yuxi/services/mcp_auth/proxy_service.py +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -12,6 +12,7 @@ from yuxi.storage.postgres.models_business import MCPConnection, MCPServer INTERNAL_PROXY_TOKEN_HEADER = "X-Yuxi-MCP-Proxy-Token" +INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY = "__yuxi_disable_tool_object_cache" _PROXY_TOKEN_TYPE = "mcp_proxy" _DYNAMIC_HTTP_PROVIDERS = {"custom_http_token", "client_credentials", "authorization_code"} _HTTP_TRANSPORTS = {"streamable_http", "sse"} @@ -31,9 +32,7 @@ def should_use_internal_proxy(server: MCPServer, auth_config: MCPAuthConfig, proxy_base_url: str | None) -> bool: return bool( - proxy_base_url - and server.transport in _HTTP_TRANSPORTS - and auth_config.provider in _DYNAMIC_HTTP_PROVIDERS + proxy_base_url and server.transport in _HTTP_TRANSPORTS and auth_config.provider in _DYNAMIC_HTTP_PROVIDERS ) @@ -80,6 +79,7 @@ def build_proxy_runtime_config( headers[INTERNAL_PROXY_TOKEN_HEADER] = create_proxy_access_token(server.name, auth_context) config["headers"] = headers config["url"] = build_internal_proxy_url(proxy_base_url, server.name) + config[INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY] = True return config diff --git a/backend/package/yuxi/services/mcp_service.py b/backend/package/yuxi/services/mcp_service.py index 9cf9dc7e7..43cf4d3d1 100644 --- a/backend/package/yuxi/services/mcp_service.py +++ b/backend/package/yuxi/services/mcp_service.py @@ -26,6 +26,7 @@ from yuxi.services.mcp_auth.crypto import encrypt_credential_blob from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config from yuxi.services.mcp_auth.proxy_service import ( + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, INTERNAL_PROXY_TOKEN_HEADER, build_proxy_runtime_config, should_use_internal_proxy, @@ -237,7 +238,13 @@ def _extract_cache_identity(server_config: dict[str, Any]) -> tuple[dict[str, An cache_identity = { key: value for key, value in server_config.items() - if key not in {"__yuxi_cache_partition", "__yuxi_allow_global_cache", "disabled_tools"} + if key + not in { + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + "disabled_tools", + } } headers = dict(cache_identity.get("headers") or {}) headers.pop(INTERNAL_PROXY_TOKEN_HEADER, None) @@ -538,11 +545,12 @@ async def get_mcp_tools( cache_partition = cache_descriptor["cache_partition"] cache_prefix = cache_descriptor["cache_prefix"] cache_key = cache_descriptor["cache_key"] + use_tool_object_cache = cache and not bool(server_config.get(INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY)) all_processed_tools: list[Callable[..., Any]] = [] async with _mcp_lock: - if not force_refresh and cache and cache_key in _mcp_tools_cache: + if not force_refresh and use_tool_object_cache and cache_key in _mcp_tools_cache: all_processed_tools = _mcp_tools_cache[cache_key] if not all_processed_tools: @@ -551,7 +559,13 @@ async def get_mcp_tools( client_config = { k: v for k, v in server_config.items() - if k not in ("disabled_tools", "__yuxi_cache_partition", "__yuxi_allow_global_cache") + if k + not in ( + "disabled_tools", + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + ) } client = await get_mcp_client({server_name: client_config}) @@ -574,11 +588,14 @@ async def get_mcp_tools( all_processed_tools.append(tool) if cache: - async with _mcp_lock: - stale_keys = [key for key in _mcp_tools_cache if key.startswith(cache_prefix) and key != cache_key] - for stale_key in stale_keys: - _mcp_tools_cache.pop(stale_key, None) - _mcp_tools_cache[cache_key] = all_processed_tools + if use_tool_object_cache: + async with _mcp_lock: + stale_keys = [ + key for key in _mcp_tools_cache if key.startswith(cache_prefix) and key != cache_key + ] + for stale_key in stale_keys: + _mcp_tools_cache.pop(stale_key, None) + _mcp_tools_cache[cache_key] = all_processed_tools await _mcp_tool_cache_store.set_manifest( cache_key, _serialize_mcp_tools_manifest( @@ -1060,6 +1077,8 @@ async def update_mcp_server( await db.commit() await db.refresh(server) + if auth_config is not _UNSET: + await _clear_mcp_server_runtime_auth_cache(db, name) await invalidate_mcp_server_tools_cache(name) logger.info(f"Updated MCP server '{name}'") @@ -1089,9 +1108,7 @@ async def get_mcp_server_dependency_summary(db: AsyncSession, name: str) -> dict skill_rows = (await db.execute(select(Skill))).scalars().all() matched_skills = [ - {"slug": item.slug, "name": item.name} - for item in skill_rows - if name in (item.mcp_dependencies or []) + {"slug": item.slug, "name": item.name} for item in skill_rows if name in (item.mcp_dependencies or []) ] agent_config_rows = (await db.execute(select(AgentConfig))).scalars().all() @@ -1102,8 +1119,7 @@ async def get_mcp_server_dependency_summary(db: AsyncSession, name: str) -> dict matched_agent_configs.append({"id": item.id, "name": item.name, "agent_id": item.agent_id}) connection_refs = [ - {"scope_type": item.scope_type, "scope_id": item.scope_id, "status": item.status} - for item in connections + {"scope_type": item.scope_type, "scope_id": item.scope_id, "status": item.status} for item in connections ] return { diff --git a/backend/server/routers/mcp_internal_router.py b/backend/server/routers/mcp_internal_router.py index 177044cd6..b5f48555b 100644 --- a/backend/server/routers/mcp_internal_router.py +++ b/backend/server/routers/mcp_internal_router.py @@ -70,6 +70,8 @@ async def proxy_mcp_server_request( server = await get_mcp_server(db, server_name) if server is None: raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在") + if not bool(getattr(server, "enabled", True)): + raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在或已停用") try: connection = await _load_active_connection(db, server=server, auth_context=auth_context) diff --git a/backend/test/unit/routers/test_mcp_internal_router.py b/backend/test/unit/routers/test_mcp_internal_router.py index e7b5a9d18..e510bc302 100644 --- a/backend/test/unit/routers/test_mcp_internal_router.py +++ b/backend/test/unit/routers/test_mcp_internal_router.py @@ -84,6 +84,53 @@ def test_internal_proxy_route_requires_internal_token(): assert resp.status_code == 401, resp.text +def test_internal_proxy_route_rejects_disabled_server(monkeypatch): + class DummyServer: + name = "disabled-proxy" + transport = "streamable_http" + enabled = 0 + auth_config_json = { + "version": 1, + "provider": "custom_http_token", + "binding_scope": "department", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + } + + async def fake_get_mcp_server(db, name): + del db + assert name == "disabled-proxy" + return DummyServer() + + async def fake_load_connection(db, *, server, auth_context): + del db, server, auth_context + raise AssertionError("disabled MCP server should be rejected before loading a connection") + + async def fake_proxy_mcp_request(server, **kwargs): + del server, kwargs + raise AssertionError("disabled MCP server should not be proxied") + + monkeypatch.setattr( + "server.routers.mcp_internal_router.decode_proxy_access_token", + lambda token, server_name: AuthContext(user_id="user-1", department_id="dep-1"), + ) + monkeypatch.setattr("server.routers.mcp_internal_router.get_mcp_server", fake_get_mcp_server) + monkeypatch.setattr("server.routers.mcp_internal_router._load_active_connection", fake_load_connection) + monkeypatch.setattr("server.routers.mcp_internal_router.proxy_mcp_request", fake_proxy_mcp_request) + + client = TestClient(_build_app()) + resp = client.post( + "/api/internal/mcp-proxy/disabled-proxy", + headers={"X-Yuxi-MCP-Proxy-Token": "test-token", "content-type": "application/json"}, + json={"jsonrpc": "2.0", "id": 1}, + ) + + assert resp.status_code == 404, resp.text + + def test_internal_proxy_route_rejects_user_scoped_request_without_active_connection(monkeypatch): class DummyServer: name = "personal-proxy" diff --git a/backend/test/unit/services/test_mcp_service.py b/backend/test/unit/services/test_mcp_service.py index aadd0094f..6668edf4a 100644 --- a/backend/test/unit/services/test_mcp_service.py +++ b/backend/test/unit/services/test_mcp_service.py @@ -62,9 +62,7 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled assert captured == [ { "server_name": "demo", - "additional_servers": { - "demo": {"transport": "stdio", "command": "demo", "disabled_tools": ["tool_b"]} - }, + "additional_servers": {"demo": {"transport": "stdio", "command": "demo", "disabled_tools": ["tool_b"]}}, "disabled_tools": ["tool_b"], } ] @@ -201,6 +199,51 @@ async def fake_get_mcp_client(server_configs): mcp_service.clear_mcp_cache() +async def test_get_mcp_tools_does_not_cache_internal_proxy_tool_objects(monkeypatch): + mcp_service.clear_mcp_cache() + + configs = [ + { + "transport": "streamable_http", + "url": "http://internal-api:5050/api/internal/mcp-proxy/demo", + "headers": { + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-v1", + }, + "__yuxi_cache_partition": "connection:101", + "__yuxi_allow_global_cache": False, + "__yuxi_disable_tool_object_cache": True, + }, + { + "transport": "streamable_http", + "url": "http://internal-api:5050/api/internal/mcp-proxy/demo", + "headers": { + INTERNAL_PROXY_TOKEN_HEADER: "proxy-token-v2", + }, + "__yuxi_cache_partition": "connection:101", + "__yuxi_allow_global_cache": False, + "__yuxi_disable_tool_object_cache": True, + }, + ] + build_calls: list[str] = [] + + async def fake_get_mcp_client(server_configs): + token = server_configs["demo"]["headers"][INTERNAL_PROXY_TOKEN_HEADER] + build_calls.append(token) + tool = SimpleNamespace(name=f"tool_for_{token}", metadata={}) + return _FakeClient([tool]) + + monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + + tools_first = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) + tools_second = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) + + assert [tool.name for tool in tools_first] == ["tool_for_proxy-token-v1"] + assert [tool.name for tool in tools_second] == ["tool_for_proxy-token-v2"] + assert build_calls == ["proxy-token-v1", "proxy-token-v2"] + + mcp_service.clear_mcp_cache() + + async def test_get_tools_from_all_servers_skips_runtime_auth_servers_without_context(monkeypatch): server_configs = { "shared": {"transport": "stdio", "command": "cmd-shared", "disabled_tools": []}, diff --git a/backend/test/unit/services/test_mcp_service_auth_runtime.py b/backend/test/unit/services/test_mcp_service_auth_runtime.py index b10969c76..d9adf8e87 100644 --- a/backend/test/unit/services/test_mcp_service_auth_runtime.py +++ b/backend/test/unit/services/test_mcp_service_auth_runtime.py @@ -284,3 +284,58 @@ async def test_get_runtime_mcp_server_config_returns_internal_proxy_for_dynamic_ assert "Authorization" not in config["headers"] assert config["__yuxi_cache_partition"] == "connection:31" assert config["__yuxi_allow_global_cache"] is False + assert config["__yuxi_disable_tool_object_cache"] is True + + +async def test_update_mcp_server_auth_config_clears_runtime_auth_cache(runtime_session, monkeypatch): + server = MCPServer( + name="finance-gateway", + transport="streamable_http", + url="http://finance.local/mcp", + auth_config_json={ + "version": 1, + "provider": "bound_secret", + "binding_scope": "system", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${secret.access_token}"}], + }, + }, + enabled=1, + created_by="tester", + updated_by="tester", + ) + runtime_session.add(server) + await runtime_session.commit() + + calls = {"runtime_auth_cache": 0, "tools_cache": 0} + + async def fake_clear_runtime_auth_cache(db, server_name): + assert db is runtime_session + assert server_name == "finance-gateway" + calls["runtime_auth_cache"] += 1 + + async def fake_invalidate_tools_cache(server_name): + assert server_name == "finance-gateway" + calls["tools_cache"] += 1 + + monkeypatch.setattr(mcp_service, "_clear_mcp_server_runtime_auth_cache", fake_clear_runtime_auth_cache) + monkeypatch.setattr(mcp_service, "invalidate_mcp_server_tools_cache", fake_invalidate_tools_cache) + + await mcp_service.update_mcp_server( + runtime_session, + "finance-gateway", + auth_config={ + "version": 1, + "provider": "custom_http_token", + "binding_scope": "system", + "inject": { + "target": "headers", + "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], + }, + "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, + }, + updated_by="tester", + ) + + assert calls == {"runtime_auth_cache": 1, "tools_cache": 1} diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index 5c1340401..0658f2618 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -37,7 +37,7 @@ ### 0.6.3 开发记录 - 修复 DeepAgent 未绑定 `DeepContext`,导致深度分析专用系统提示词和子智能体默认模型配置未生效的问题;同时避免运行时重复注入默认提示词。 -- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果。 +- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题。 --- From d29510230c89e9a7d274bd93e7e1cb4592f850ef Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Thu, 4 Jun 2026 20:03:59 +0800 Subject: [PATCH 05/36] =?UTF-8?q?fix:=20=E7=BB=9F=E4=B8=80=20MCP=20?= =?UTF-8?q?=E4=B8=AA=E4=BA=BA=E8=BF=9E=E6=8E=A5=E8=BF=90=E8=A1=8C=E6=80=81?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E6=A0=87=E8=AF=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/package/yuxi/agents/context.py | 9 ++++++ .../middlewares/runtime_config_middleware.py | 3 +- .../agents/middlewares/skills_middleware.py | 3 +- backend/package/yuxi/services/chat_service.py | 6 ++++ .../test_runtime_config_middleware.py | 29 +++++++++++++++++++ .../middlewares/test_skills_middleware.py | 26 +++++++++++++++++ .../unit/services/test_chat_service_sync.py | 22 +++++++++++++- backend/test/unit/test_base_context.py | 9 ++++-- docs/develop-guides/roadmap.md | 2 +- 9 files changed, 102 insertions(+), 7 deletions(-) diff --git a/backend/package/yuxi/agents/context.py b/backend/package/yuxi/agents/context.py index a77bb59ef..283befd4d 100644 --- a/backend/package/yuxi/agents/context.py +++ b/backend/package/yuxi/agents/context.py @@ -33,6 +33,15 @@ def update(self, data: dict): metadata={"name": "用户ID", "configurable": False, "description": "用来唯一标识一个用户"}, ) + mcp_user_id: str | None = field( + default=None, + metadata={ + "name": "MCP用户标识", + "configurable": False, + "description": "用来匹配个人 MCP 连接绑定范围的用户标识", + }, + ) + department_id: str | None = field( default=None, metadata={"name": "部门ID", "configurable": False, "description": "用来标识当前用户所属部门"}, diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index 5f07fef16..5eaa9f1f9 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -156,10 +156,11 @@ async def get_tools_from_context(self, context) -> list: continue selected_mcp_servers.add(server_name) try: + mcp_user_id = getattr(context, "mcp_user_id", None) or getattr(context, "user_id", None) mcp_tools = await get_enabled_mcp_tools( server_name, auth_context=AuthContext( - user_id=getattr(context, "user_id", None), + user_id=mcp_user_id, department_id=getattr(context, "department_id", None), ), ) diff --git a/backend/package/yuxi/agents/middlewares/skills_middleware.py b/backend/package/yuxi/agents/middlewares/skills_middleware.py index 269581e86..a561f0f0f 100644 --- a/backend/package/yuxi/agents/middlewares/skills_middleware.py +++ b/backend/package/yuxi/agents/middlewares/skills_middleware.py @@ -343,10 +343,11 @@ async def _get_mcp_tools_from_context( async def load_mcp_tools(server_name: str) -> list: """加载单个 MCP 服务器的工具""" try: + mcp_user_id = getattr(context, "mcp_user_id", None) or getattr(context, "user_id", None) mcp_tools = await get_enabled_mcp_tools( server_name, auth_context=AuthContext( - user_id=getattr(context, "user_id", None), + user_id=mcp_user_id, department_id=getattr(context, "department_id", None), ), ) diff --git a/backend/package/yuxi/services/chat_service.py b/backend/package/yuxi/services/chat_service.py index cfbf8c897..7bb4519fc 100644 --- a/backend/package/yuxi/services/chat_service.py +++ b/backend/package/yuxi/services/chat_service.py @@ -61,10 +61,12 @@ async def _build_agent_input_context( *, thread_id: str, user_id: str, + mcp_user_id: str | int | None = None, department_id: str | int | None = None, ) -> dict: input_context = dict(agent_config or {}) agents_prompt = await asyncio.to_thread(_load_workspace_agents_prompt, thread_id, user_id) + mcp_scope_user_id = str(mcp_user_id) if mcp_user_id is not None else user_id if agents_prompt: agents_section = f"用户工作区 agents/AGENTS.md 内容:\n{agents_prompt}" @@ -74,6 +76,7 @@ async def _build_agent_input_context( input_context.update( { "user_id": user_id, + "mcp_user_id": mcp_scope_user_id, "thread_id": thread_id, "department_id": str(department_id) if department_id is not None else None, } @@ -614,6 +617,7 @@ async def agent_chat( agent_config, thread_id=thread_id, user_id=user_id, + mcp_user_id=getattr(current_user, "user_id", None), department_id=getattr(current_user, "department_id", None), ) langfuse_run = _build_langfuse_run_context( @@ -835,6 +839,7 @@ def make_chunk(content=None, **kwargs): agent_config, thread_id=thread_id, user_id=user_id, + mcp_user_id=getattr(current_user, "user_id", None), department_id=getattr(current_user, "department_id", None), ) langfuse_run = _build_langfuse_run_context( @@ -1076,6 +1081,7 @@ def make_resume_chunk(content=None, **kwargs): agent_config or {}, thread_id=thread_id, user_id=user_id, + mcp_user_id=getattr(current_user, "user_id", None), department_id=getattr(current_user, "department_id", None), ) ) diff --git a/backend/test/unit/middlewares/test_runtime_config_middleware.py b/backend/test/unit/middlewares/test_runtime_config_middleware.py index 7b8453558..ac113e0d1 100644 --- a/backend/test/unit/middlewares/test_runtime_config_middleware.py +++ b/backend/test/unit/middlewares/test_runtime_config_middleware.py @@ -36,6 +36,35 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= assert captured == [("finance-gateway", "user-1", "dept-9")] +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_tools_from_context_uses_mcp_user_id_for_user_scoped_auth(monkeypatch: pytest.MonkeyPatch): + captured: list[tuple[str, str | None, str | None]] = [] + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + captured.append((server_name, auth_context.user_id, auth_context.department_id)) + return [] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = RuntimeConfigMiddleware() + context = SimpleNamespace( + tools=[], + mcps=["dts-mcp_server"], + user_id="2", + mcp_user_id="login-1001", + department_id="dept-9", + ) + + tools = await middleware.get_tools_from_context(context) + + assert tools == [] + assert captured == [("dts-mcp_server", "login-1001", "dept-9")] + + @pytest.mark.asyncio @pytest.mark.unit async def test_awrap_model_call_appends_runtime_loaded_mcp_tools(monkeypatch: pytest.MonkeyPatch): diff --git a/backend/test/unit/middlewares/test_skills_middleware.py b/backend/test/unit/middlewares/test_skills_middleware.py index 1380c15fd..25a2ea152 100644 --- a/backend/test/unit/middlewares/test_skills_middleware.py +++ b/backend/test/unit/middlewares/test_skills_middleware.py @@ -31,3 +31,29 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= assert tools == [] assert captured == [("finance-gateway", "user-1", "dept-9")] + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_get_mcp_tools_from_context_uses_mcp_user_id_for_user_scoped_auth(monkeypatch: pytest.MonkeyPatch): + captured: list[tuple[str, str | None, str | None]] = [] + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del db, http_client + captured.append((server_name, auth_context.user_id, auth_context.department_id)) + return [] + + monkeypatch.setattr(skills_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = SkillsMiddleware() + context = SimpleNamespace( + mcps=["dts-mcp_server"], + user_id="2", + mcp_user_id="login-1001", + department_id="dept-9", + ) + + tools = await middleware._get_mcp_tools_from_context(context) + + assert tools == [] + assert captured == [("dts-mcp_server", "login-1001", "dept-9")] diff --git a/backend/test/unit/services/test_chat_service_sync.py b/backend/test/unit/services/test_chat_service_sync.py index 58422ba91..d52cba702 100644 --- a/backend/test/unit/services/test_chat_service_sync.py +++ b/backend/test/unit/services/test_chat_service_sync.py @@ -142,7 +142,7 @@ def fake_get_trace_info(_run_context): thread_id="thread-1", meta={"request_id": "req-1"}, image_content=None, - current_user=SimpleNamespace(id="user-1", department_id="dept-1"), + current_user=SimpleNamespace(id="user-1", user_id="login-1001", department_id="dept-1"), db=object(), ) @@ -160,6 +160,7 @@ def fake_get_trace_info(_run_context): assert calls["invoke_input_context"] == { "temperature": 0.1, "user_id": "user-1", + "mcp_user_id": "login-1001", "thread_id": "thread-1", "department_id": "dept-1", } @@ -261,6 +262,25 @@ def fake_agents_prompt(_thread_id: str, _user_id: str) -> str: assert context["department_id"] == "dept-9" +@pytest.mark.asyncio +async def test_build_agent_input_context_keeps_runtime_user_id_and_mcp_scope_user_id( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(svc, "_load_workspace_agents_prompt", _empty_agents_prompt) + + context = await svc._build_agent_input_context( + {}, + thread_id="thread-1", + user_id="2", + department_id="dept-9", + mcp_user_id="login-1001", + ) + + assert context["user_id"] == "2" + assert context["mcp_user_id"] == "login-1001" + assert context["department_id"] == "dept-9" + + @pytest.mark.asyncio async def test_build_agent_input_context_keeps_prompt_when_workspace_agents_prompt_empty( monkeypatch: pytest.MonkeyPatch, diff --git a/backend/test/unit/test_base_context.py b/backend/test/unit/test_base_context.py index 87fffcda0..01c5c96f7 100644 --- a/backend/test/unit/test_base_context.py +++ b/backend/test/unit/test_base_context.py @@ -3,10 +3,13 @@ from yuxi.agents.context import BaseContext -def test_base_context_accepts_department_id_without_exposing_it_as_configurable(): +def test_base_context_accepts_internal_identity_fields_without_exposing_them_as_configurable(): context = BaseContext() - context.update({"department_id": "dept-9"}) + context.update({"department_id": "dept-9", "mcp_user_id": "login-1001"}) assert context.department_id == "dept-9" - assert "department_id" not in BaseContext.get_configurable_items() + assert context.mcp_user_id == "login-1001" + configurable_items = BaseContext.get_configurable_items() + assert "department_id" not in configurable_items + assert "mcp_user_id" not in configurable_items diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index 0658f2618..7ebb85ac1 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -37,7 +37,7 @@ ### 0.6.3 开发记录 - 修复 DeepAgent 未绑定 `DeepContext`,导致深度分析专用系统提示词和子智能体默认模型配置未生效的问题;同时避免运行时重复注入默认提示词。 -- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题。 +- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题;统一 Agent 运行态与连接管理页的个人 MCP scope 语义,避免运行态使用数据库主键查找 `mcp_connections.scope_id` 导致个人连接不可用。 --- From 61bfd75afcfa8d87fae9ecb690a864c907ff5e06 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Thu, 4 Jun 2026 20:50:11 +0800 Subject: [PATCH 06/36] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E6=97=B6=20MCP=20=E5=B7=A5=E5=85=B7=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E6=B3=A8=E5=86=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../middlewares/runtime_config_middleware.py | 14 ++++ backend/package/yuxi/services/chat_service.py | 37 +++++++---- .../test_runtime_config_middleware.py | 64 +++++++++++++++++++ .../unit/services/test_chat_service_sync.py | 22 +++---- docs/develop-guides/roadmap.md | 2 +- 5 files changed, 114 insertions(+), 25 deletions(-) diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index 5eaa9f1f9..54142b236 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -4,6 +4,7 @@ from typing import Any from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse +from langchain.tools.tool_node import ToolCallRequest from langchain_core.messages import SystemMessage from yuxi.agents import load_chat_model @@ -13,6 +14,8 @@ from yuxi.utils.datetime_utils import shanghai_now from yuxi.utils.logging_config import logger +_RUNTIME_DYNAMIC_TOOLS_ATTR = "_runtime_config_dynamic_tools_by_name" + class RuntimeConfigMiddleware(AgentMiddleware): """运行时配置中间件 - 应用模型/工具/MCP/提示词配置 @@ -102,6 +105,7 @@ async def awrap_model_call( continue merged_tools.append(tool) merged_tool_names.add(tool.name) + setattr(runtime_context, _RUNTIME_DYNAMIC_TOOLS_ATTR, {tool.name: tool for tool in enabled_tools}) overrides["tools"] = merged_tools logger.debug(f"RuntimeConfigMiddleware selected tools: {[t.name for t in merged_tools]}") @@ -121,6 +125,16 @@ async def awrap_model_call( return await handler(request) + async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable[[ToolCallRequest], Any]): + """Allow ToolNode to execute runtime-auth MCP tools loaded during the last model call.""" + if request.tool is None: + runtime_context = getattr(request.runtime, "context", None) + dynamic_tools = getattr(runtime_context, _RUNTIME_DYNAMIC_TOOLS_ATTR, {}) or {} + tool = dynamic_tools.get(request.tool_call.get("name")) if isinstance(dynamic_tools, dict) else None + if tool is not None: + request = request.override(tool=tool) + return await handler(request) + async def get_tools_from_context(self, context) -> list: """从上下文配置中获取工具列表""" selected_tools = [] diff --git a/backend/package/yuxi/services/chat_service.py b/backend/package/yuxi/services/chat_service.py index 7bb4519fc..b5127ca36 100644 --- a/backend/package/yuxi/services/chat_service.py +++ b/backend/package/yuxi/services/chat_service.py @@ -60,13 +60,14 @@ async def _build_agent_input_context( agent_config: dict, *, thread_id: str, - user_id: str, - mcp_user_id: str | int | None = None, - department_id: str | int | None = None, + current_user: User, ) -> dict: input_context = dict(agent_config or {}) + user_id = str(current_user.id) + current_user_scope_id = getattr(current_user, "user_id", None) + mcp_scope_user_id = str(current_user_scope_id) if current_user_scope_id is not None else user_id + department_id = getattr(current_user, "department_id", None) agents_prompt = await asyncio.to_thread(_load_workspace_agents_prompt, thread_id, user_id) - mcp_scope_user_id = str(mcp_user_id) if mcp_user_id is not None else user_id if agents_prompt: agents_section = f"用户工作区 agents/AGENTS.md 内容:\n{agents_prompt}" @@ -81,6 +82,22 @@ async def _build_agent_input_context( "department_id": str(department_id) if department_id is not None else None, } ) + + # 将用户信息拼接到 system_prompt + user_info_parts = [] + if username := getattr(current_user, "username", None): + user_info_parts.append(f"姓名: {username}") + if role := getattr(current_user, "role", None): + user_info_parts.append(f"角色: {role}") + if work_id := getattr(current_user, "user_id", None): + user_info_parts.append(f"工号: {work_id}") + + if user_info_parts: + user_info_block = "\n".join(user_info_parts) + current_prompt = str(input_context.get("system_prompt") or "").rstrip() + input_context["system_prompt"] = ( + f"{current_prompt}\n\n用户信息:\n{user_info_block}" if current_prompt else user_info_block + ) return input_context @@ -616,9 +633,7 @@ async def agent_chat( input_context = await _build_agent_input_context( agent_config, thread_id=thread_id, - user_id=user_id, - mcp_user_id=getattr(current_user, "user_id", None), - department_id=getattr(current_user, "department_id", None), + current_user=current_user, ) langfuse_run = _build_langfuse_run_context( current_user=current_user, @@ -838,9 +853,7 @@ def make_chunk(content=None, **kwargs): input_context = await _build_agent_input_context( agent_config, thread_id=thread_id, - user_id=user_id, - mcp_user_id=getattr(current_user, "user_id", None), - department_id=getattr(current_user, "department_id", None), + current_user=current_user, ) langfuse_run = _build_langfuse_run_context( current_user=current_user, @@ -1080,9 +1093,7 @@ def make_resume_chunk(content=None, **kwargs): await _build_agent_input_context( agent_config or {}, thread_id=thread_id, - user_id=user_id, - mcp_user_id=getattr(current_user, "user_id", None), - department_id=getattr(current_user, "department_id", None), + current_user=current_user, ) ) graph = await agent.get_graph(context=context) diff --git a/backend/test/unit/middlewares/test_runtime_config_middleware.py b/backend/test/unit/middlewares/test_runtime_config_middleware.py index ac113e0d1..6b538f532 100644 --- a/backend/test/unit/middlewares/test_runtime_config_middleware.py +++ b/backend/test/unit/middlewares/test_runtime_config_middleware.py @@ -3,6 +3,8 @@ from types import SimpleNamespace import pytest +from langchain.tools.tool_node import ToolCallRequest +from langchain_core.messages import ToolMessage import yuxi.agents.middlewares.runtime_config_middleware as runtime_config_middleware from yuxi.agents.middlewares.runtime_config_middleware import RuntimeConfigMiddleware @@ -65,6 +67,68 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= assert captured == [("dts-mcp_server", "login-1001", "dept-9")] +@pytest.mark.asyncio +@pytest.mark.unit +async def test_runtime_loaded_mcp_tool_can_be_executed_when_tool_node_did_not_pre_register_it( + monkeypatch: pytest.MonkeyPatch, +): + runtime_tool = SimpleNamespace(name="getTicket") + + monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) + + async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): + del auth_context, db, http_client + assert server_name == "dts-mcp_server" + return [runtime_tool] + + monkeypatch.setattr(runtime_config_middleware, "get_enabled_mcp_tools", fake_get_enabled_mcp_tools) + + middleware = RuntimeConfigMiddleware(enable_model_override=False, enable_system_prompt_override=False) + context = SimpleNamespace( + tools=[], + mcps=["dts-mcp_server"], + user_id="2", + mcp_user_id="login-1001", + department_id="dept-9", + model=None, + system_prompt="", + ) + + class DummyRequest: + def __init__(self, tools=None): + self.runtime = SimpleNamespace(context=context) + self.tools = tools or [] + self.system_message = None + + def override(self, **kwargs): + clone = DummyRequest(kwargs.get("tools", self.tools)) + clone.runtime = kwargs.get("runtime", self.runtime) + clone.system_message = kwargs.get("system_message", self.system_message) + return clone + + async def model_handler(next_request): + return next_request.tools + + await middleware.awrap_model_call(DummyRequest(), model_handler) + + tool_request = ToolCallRequest( + tool_call={"name": "getTicket", "args": {"arg0": "DTS2026012932159"}, "id": "call-1"}, + tool=None, + state={}, + runtime=SimpleNamespace(context=context), + ) + captured = {} + + async def tool_handler(next_request): + captured["tool"] = next_request.tool + return ToolMessage(content="ok", name=next_request.tool_call["name"], tool_call_id=next_request.tool_call["id"]) + + result = await middleware.awrap_tool_call(tool_request, tool_handler) + + assert result.content == "ok" + assert captured["tool"] is runtime_tool + + @pytest.mark.asyncio @pytest.mark.unit async def test_awrap_model_call_appends_runtime_loaded_mcp_tools(monkeypatch: pytest.MonkeyPatch): diff --git a/backend/test/unit/services/test_chat_service_sync.py b/backend/test/unit/services/test_chat_service_sync.py index d52cba702..ab9fddfee 100644 --- a/backend/test/unit/services/test_chat_service_sync.py +++ b/backend/test/unit/services/test_chat_service_sync.py @@ -158,6 +158,7 @@ def fake_get_trace_info(_run_context): assert isinstance(invoke_messages[0], HumanMessage) assert invoke_messages[0].content == "hello" assert calls["invoke_input_context"] == { + "system_prompt": "工号: login-1001", "temperature": 0.1, "user_id": "user-1", "mcp_user_id": "login-1001", @@ -251,11 +252,13 @@ def fake_agents_prompt(_thread_id: str, _user_id: str) -> str: context = await svc._build_agent_input_context( {"system_prompt": "原始系统提示词", "temperature": 0.1}, thread_id="thread-1", - user_id="user-1", - department_id="dept-9", + current_user=SimpleNamespace(id="user-1", user_id="login-1001", department_id="dept-9"), ) - assert context["system_prompt"] == "原始系统提示词\n\n用户工作区 agents/AGENTS.md 内容:\n回答前先读取 AGENTS.md" + assert ( + context["system_prompt"] + == "原始系统提示词\n\n用户工作区 agents/AGENTS.md 内容:\n回答前先读取 AGENTS.md\n\n用户信息:\n工号: login-1001" + ) assert context["temperature"] == 0.1 assert context["thread_id"] == "thread-1" assert context["user_id"] == "user-1" @@ -263,7 +266,7 @@ def fake_agents_prompt(_thread_id: str, _user_id: str) -> str: @pytest.mark.asyncio -async def test_build_agent_input_context_keeps_runtime_user_id_and_mcp_scope_user_id( +async def test_build_agent_input_context_derives_runtime_identity_from_current_user( monkeypatch: pytest.MonkeyPatch, ): monkeypatch.setattr(svc, "_load_workspace_agents_prompt", _empty_agents_prompt) @@ -271,9 +274,7 @@ async def test_build_agent_input_context_keeps_runtime_user_id_and_mcp_scope_use context = await svc._build_agent_input_context( {}, thread_id="thread-1", - user_id="2", - department_id="dept-9", - mcp_user_id="login-1001", + current_user=SimpleNamespace(id=2, user_id="login-1001", department_id="dept-9"), ) assert context["user_id"] == "2" @@ -282,7 +283,7 @@ async def test_build_agent_input_context_keeps_runtime_user_id_and_mcp_scope_use @pytest.mark.asyncio -async def test_build_agent_input_context_keeps_prompt_when_workspace_agents_prompt_empty( +async def test_build_agent_input_context_appends_user_info_when_workspace_agents_prompt_empty( monkeypatch: pytest.MonkeyPatch, ): monkeypatch.setattr(svc, "_load_workspace_agents_prompt", _empty_agents_prompt) @@ -290,9 +291,8 @@ async def test_build_agent_input_context_keeps_prompt_when_workspace_agents_prom context = await svc._build_agent_input_context( {"system_prompt": "原始系统提示词"}, thread_id="thread-1", - user_id="user-1", - department_id="dept-9", + current_user=SimpleNamespace(id="user-1", user_id="login-1001", department_id="dept-9"), ) - assert context["system_prompt"] == "原始系统提示词" + assert context["system_prompt"] == "原始系统提示词\n\n用户信息:\n工号: login-1001" assert context["department_id"] == "dept-9" diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index 7ebb85ac1..45d65d1e1 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -37,7 +37,7 @@ ### 0.6.3 开发记录 - 修复 DeepAgent 未绑定 `DeepContext`,导致深度分析专用系统提示词和子智能体默认模型配置未生效的问题;同时避免运行时重复注入默认提示词。 -- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题;统一 Agent 运行态与连接管理页的个人 MCP scope 语义,避免运行态使用数据库主键查找 `mcp_connections.scope_id` 导致个人连接不可用。 +- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题;统一 Agent 运行态与连接管理页的个人 MCP scope 语义,避免运行态使用数据库主键查找 `mcp_connections.scope_id` 导致个人连接不可用;补齐运行时鉴权 MCP 工具的执行阶段映射,避免模型已绑定 `getTicket` 等动态工具但 ToolNode 静态注册表无法执行的问题。 --- From f51bae001adc23de8ba97f635f12df58d6d78cf1 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Thu, 4 Jun 2026 20:54:53 +0800 Subject: [PATCH 07/36] =?UTF-8?q?refactor:=20=E7=AE=80=E5=8C=96=20current?= =?UTF-8?q?=5Fuser=20=E5=B7=A5=E5=8F=B7=E4=BD=BF=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/package/yuxi/services/chat_service.py | 15 +++++++-------- .../test/unit/services/test_chat_service_sync.py | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/backend/package/yuxi/services/chat_service.py b/backend/package/yuxi/services/chat_service.py index b5127ca36..56bfadd34 100644 --- a/backend/package/yuxi/services/chat_service.py +++ b/backend/package/yuxi/services/chat_service.py @@ -63,11 +63,10 @@ async def _build_agent_input_context( current_user: User, ) -> dict: input_context = dict(agent_config or {}) - user_id = str(current_user.id) - current_user_scope_id = getattr(current_user, "user_id", None) - mcp_scope_user_id = str(current_user_scope_id) if current_user_scope_id is not None else user_id - department_id = getattr(current_user, "department_id", None) - agents_prompt = await asyncio.to_thread(_load_workspace_agents_prompt, thread_id, user_id) + db_user_id = str(current_user.id) + work_id = current_user.user_id + department_id = current_user.department_id + agents_prompt = await asyncio.to_thread(_load_workspace_agents_prompt, thread_id, db_user_id) if agents_prompt: agents_section = f"用户工作区 agents/AGENTS.md 内容:\n{agents_prompt}" @@ -76,8 +75,8 @@ async def _build_agent_input_context( input_context.update( { - "user_id": user_id, - "mcp_user_id": mcp_scope_user_id, + "user_id": db_user_id, + "mcp_user_id": work_id, "thread_id": thread_id, "department_id": str(department_id) if department_id is not None else None, } @@ -89,7 +88,7 @@ async def _build_agent_input_context( user_info_parts.append(f"姓名: {username}") if role := getattr(current_user, "role", None): user_info_parts.append(f"角色: {role}") - if work_id := getattr(current_user, "user_id", None): + if work_id: user_info_parts.append(f"工号: {work_id}") if user_info_parts: diff --git a/backend/test/unit/services/test_chat_service_sync.py b/backend/test/unit/services/test_chat_service_sync.py index ab9fddfee..2a1650ded 100644 --- a/backend/test/unit/services/test_chat_service_sync.py +++ b/backend/test/unit/services/test_chat_service_sync.py @@ -232,7 +232,7 @@ async def fake_guard_check(_content): thread_id="thread-2", meta={"request_id": "req-2"}, image_content=None, - current_user=SimpleNamespace(id="user-1", department_id="dept-1"), + current_user=SimpleNamespace(id="user-1", user_id="login-1001", department_id="dept-1"), db=object(), ) From 5736090ba2480931f082412b004d621563f3ba69 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Thu, 4 Jun 2026 21:55:04 +0800 Subject: [PATCH 08/36] =?UTF-8?q?refactor:=20=E5=90=8C=E6=AD=A5=20skills?= =?UTF-8?q?=5Fmiddleware=20=E5=86=85=E9=83=A8=E7=94=A8=E6=88=B7=E6=A0=87?= =?UTF-8?q?=E8=AF=86=E4=B8=BA=20work=5Fid?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 延续 f51bae00,将 BaseContext.mcp_user_id 字段在 SkillsMiddleware 加载 MCP 工具时统一读取为 work_id,与 runtime_config_middleware 保持一致。同步更新对应测试夹具与字段名。 --- backend/package/yuxi/agents/context.py | 6 +++--- .../yuxi/agents/middlewares/runtime_config_middleware.py | 4 ++-- .../package/yuxi/agents/middlewares/skills_middleware.py | 4 ++-- backend/package/yuxi/services/chat_service.py | 2 +- .../test/unit/middlewares/test_runtime_config_middleware.py | 6 +++--- backend/test/unit/middlewares/test_skills_middleware.py | 4 ++-- backend/test/unit/services/test_chat_service_sync.py | 4 ++-- backend/test/unit/test_base_context.py | 6 +++--- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/backend/package/yuxi/agents/context.py b/backend/package/yuxi/agents/context.py index 283befd4d..9050b173d 100644 --- a/backend/package/yuxi/agents/context.py +++ b/backend/package/yuxi/agents/context.py @@ -33,12 +33,12 @@ def update(self, data: dict): metadata={"name": "用户ID", "configurable": False, "description": "用来唯一标识一个用户"}, ) - mcp_user_id: str | None = field( + work_id: str | None = field( default=None, metadata={ - "name": "MCP用户标识", + "name": "工号", "configurable": False, - "description": "用来匹配个人 MCP 连接绑定范围的用户标识", + "description": "用来匹配个人 MCP 连接绑定范围的工号", }, ) diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index 54142b236..3fee7a8b7 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -170,11 +170,11 @@ async def get_tools_from_context(self, context) -> list: continue selected_mcp_servers.add(server_name) try: - mcp_user_id = getattr(context, "mcp_user_id", None) or getattr(context, "user_id", None) + work_id = getattr(context, "work_id", None) or getattr(context, "user_id", None) mcp_tools = await get_enabled_mcp_tools( server_name, auth_context=AuthContext( - user_id=mcp_user_id, + user_id=work_id, department_id=getattr(context, "department_id", None), ), ) diff --git a/backend/package/yuxi/agents/middlewares/skills_middleware.py b/backend/package/yuxi/agents/middlewares/skills_middleware.py index a561f0f0f..b626c54e2 100644 --- a/backend/package/yuxi/agents/middlewares/skills_middleware.py +++ b/backend/package/yuxi/agents/middlewares/skills_middleware.py @@ -343,11 +343,11 @@ async def _get_mcp_tools_from_context( async def load_mcp_tools(server_name: str) -> list: """加载单个 MCP 服务器的工具""" try: - mcp_user_id = getattr(context, "mcp_user_id", None) or getattr(context, "user_id", None) + work_id = getattr(context, "work_id", None) or getattr(context, "user_id", None) mcp_tools = await get_enabled_mcp_tools( server_name, auth_context=AuthContext( - user_id=mcp_user_id, + user_id=work_id, department_id=getattr(context, "department_id", None), ), ) diff --git a/backend/package/yuxi/services/chat_service.py b/backend/package/yuxi/services/chat_service.py index 56bfadd34..b91cc57e0 100644 --- a/backend/package/yuxi/services/chat_service.py +++ b/backend/package/yuxi/services/chat_service.py @@ -76,7 +76,7 @@ async def _build_agent_input_context( input_context.update( { "user_id": db_user_id, - "mcp_user_id": work_id, + "work_id": work_id, "thread_id": thread_id, "department_id": str(department_id) if department_id is not None else None, } diff --git a/backend/test/unit/middlewares/test_runtime_config_middleware.py b/backend/test/unit/middlewares/test_runtime_config_middleware.py index 6b538f532..2fa28e6f6 100644 --- a/backend/test/unit/middlewares/test_runtime_config_middleware.py +++ b/backend/test/unit/middlewares/test_runtime_config_middleware.py @@ -40,7 +40,7 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= @pytest.mark.asyncio @pytest.mark.unit -async def test_get_tools_from_context_uses_mcp_user_id_for_user_scoped_auth(monkeypatch: pytest.MonkeyPatch): +async def test_get_tools_from_context_uses_work_id_for_user_scoped_auth(monkeypatch: pytest.MonkeyPatch): captured: list[tuple[str, str | None, str | None]] = [] monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) @@ -57,7 +57,7 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= tools=[], mcps=["dts-mcp_server"], user_id="2", - mcp_user_id="login-1001", + work_id="login-1001", department_id="dept-9", ) @@ -88,7 +88,7 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= tools=[], mcps=["dts-mcp_server"], user_id="2", - mcp_user_id="login-1001", + work_id="login-1001", department_id="dept-9", model=None, system_prompt="", diff --git a/backend/test/unit/middlewares/test_skills_middleware.py b/backend/test/unit/middlewares/test_skills_middleware.py index 25a2ea152..9f5734dca 100644 --- a/backend/test/unit/middlewares/test_skills_middleware.py +++ b/backend/test/unit/middlewares/test_skills_middleware.py @@ -35,7 +35,7 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= @pytest.mark.asyncio @pytest.mark.unit -async def test_get_mcp_tools_from_context_uses_mcp_user_id_for_user_scoped_auth(monkeypatch: pytest.MonkeyPatch): +async def test_get_mcp_tools_from_context_uses_work_id_for_user_scoped_auth(monkeypatch: pytest.MonkeyPatch): captured: list[tuple[str, str | None, str | None]] = [] async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): @@ -49,7 +49,7 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= context = SimpleNamespace( mcps=["dts-mcp_server"], user_id="2", - mcp_user_id="login-1001", + work_id="login-1001", department_id="dept-9", ) diff --git a/backend/test/unit/services/test_chat_service_sync.py b/backend/test/unit/services/test_chat_service_sync.py index 2a1650ded..2c4f27781 100644 --- a/backend/test/unit/services/test_chat_service_sync.py +++ b/backend/test/unit/services/test_chat_service_sync.py @@ -161,7 +161,7 @@ def fake_get_trace_info(_run_context): "system_prompt": "工号: login-1001", "temperature": 0.1, "user_id": "user-1", - "mcp_user_id": "login-1001", + "work_id": "login-1001", "thread_id": "thread-1", "department_id": "dept-1", } @@ -278,7 +278,7 @@ async def test_build_agent_input_context_derives_runtime_identity_from_current_u ) assert context["user_id"] == "2" - assert context["mcp_user_id"] == "login-1001" + assert context["work_id"] == "login-1001" assert context["department_id"] == "dept-9" diff --git a/backend/test/unit/test_base_context.py b/backend/test/unit/test_base_context.py index 01c5c96f7..f2892dffa 100644 --- a/backend/test/unit/test_base_context.py +++ b/backend/test/unit/test_base_context.py @@ -6,10 +6,10 @@ def test_base_context_accepts_internal_identity_fields_without_exposing_them_as_configurable(): context = BaseContext() - context.update({"department_id": "dept-9", "mcp_user_id": "login-1001"}) + context.update({"department_id": "dept-9", "work_id": "login-1001"}) assert context.department_id == "dept-9" - assert context.mcp_user_id == "login-1001" + assert context.work_id == "login-1001" configurable_items = BaseContext.get_configurable_items() assert "department_id" not in configurable_items - assert "mcp_user_id" not in configurable_items + assert "work_id" not in configurable_items From 125c68f4ec15c9649b8606787da8d03c6d648266 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Thu, 4 Jun 2026 22:13:47 +0800 Subject: [PATCH 09/36] =?UTF-8?q?test:=20=E4=BF=AE=E5=A4=8D=205=20?= =?UTF-8?q?=E4=B8=AA=E9=A2=84=E5=AD=98=E6=B5=8B=E8=AF=95=E5=A4=B1=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - run_queue_service.list_run_stream_events 恢复 after_seq 排他游标 (f"({after_seq}"),修复 list_run_stream_events 轮询时同 seq 重复返回导致 agent_run_service 死循环 - test_chat_service_langfuse_stream 补 current_user.user_id 字段并更新 input_context 断言(work_id/department_id) - test_chat_stream_attachment_materialize 补 _materialize_attachment_files 必填 user_id 参数 --- backend/package/yuxi/services/run_queue_service.py | 2 +- .../services/test_chat_service_langfuse_stream.py | 12 +++++++++--- .../test_chat_stream_attachment_materialize.py | 2 ++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/backend/package/yuxi/services/run_queue_service.py b/backend/package/yuxi/services/run_queue_service.py index ec0422b33..04079a8d7 100644 --- a/backend/package/yuxi/services/run_queue_service.py +++ b/backend/package/yuxi/services/run_queue_service.py @@ -190,7 +190,7 @@ async def list_run_stream_events( ) -> list[dict]: redis = await get_redis_client() key = _event_stream_key(run_id) - start = "-" if after_seq in {"0", "0-0", ""} else f"{after_seq}" + start = "-" if after_seq in {"0", "0-0", ""} else f"({after_seq}" rows = await redis.xrange(key, min=start, max="+", count=limit) events = [] diff --git a/backend/test/unit/services/test_chat_service_langfuse_stream.py b/backend/test/unit/services/test_chat_service_langfuse_stream.py index d9bbf83bb..331d66ccf 100644 --- a/backend/test/unit/services/test_chat_service_langfuse_stream.py +++ b/backend/test/unit/services/test_chat_service_langfuse_stream.py @@ -132,12 +132,18 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): thread_id="thread-1", meta={"request_id": "req-1"}, image_content=None, - current_user=SimpleNamespace(id="user-1", department_id="dept-1"), + current_user=SimpleNamespace(id="user-1", user_id="login-user-1", department_id="dept-1"), db=object(), ): chunks.append(json.loads(chunk.decode("utf-8"))) - assert calls["stream_input_context"] == {"temperature": 0.1, "user_id": "user-1", "thread_id": "thread-1"} + assert calls["stream_input_context"] == { + "temperature": 0.1, + "user_id": "user-1", + "work_id": "login-user-1", + "thread_id": "thread-1", + "department_id": "dept-1", + } assert calls["stream_kwargs"] == { "callbacks": ["handler-1"], "metadata": {"langfuse_user_id": "user-1", "langfuse_session_id": "thread-1"}, @@ -210,7 +216,7 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): thread_id="thread-1", meta={"request_id": "req-1"}, image_content=None, - current_user=SimpleNamespace(id="user-1", department_id="dept-1"), + current_user=SimpleNamespace(id="user-1", user_id="login-user-1", department_id="dept-1"), db=object(), ): chunks.append(json.loads(chunk.decode("utf-8"))) diff --git a/backend/test/unit/services/test_chat_stream_attachment_materialize.py b/backend/test/unit/services/test_chat_stream_attachment_materialize.py index d8191192d..3dfc969ba 100644 --- a/backend/test/unit/services/test_chat_stream_attachment_materialize.py +++ b/backend/test/unit/services/test_chat_stream_attachment_materialize.py @@ -67,6 +67,7 @@ async def test_materialize_attachment_files_keeps_original_file_when_markdown_co result = await cs._materialize_attachment_files( thread_id="t-1", + user_id="u-1", upload=upload, file_name="demo.pdf", file_content=b"%PDF-test", @@ -103,6 +104,7 @@ async def _fake_convert(_upload): result = await cs._materialize_attachment_files( thread_id="t-1", + user_id="u-1", upload=upload, file_name="demo.txt", file_content=b"hello", From 41e43717279ed649c055428a9f386088a4007200 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 00:54:28 +0800 Subject: [PATCH 10/36] refactor(mcp): save intermediate progress before pure architecture refactoring --- .../middlewares/runtime_config_middleware.py | 15 + backend/package/yuxi/services/mcp/__init__.py | 97 ++ .../package/yuxi/services/mcp/cache_policy.py | 102 ++ .../package/yuxi/services/mcp/client_pool.py | 210 +++ .../yuxi/services/mcp/connection_service.py | 313 ++++ .../yuxi/services/mcp/server_service.py | 469 ++++++ .../services/mcp/tool_registry_service.py | 519 +++++++ .../services/mcp_auth/fetchers/__init__.py | 14 + .../yuxi/services/mcp_auth/fetchers/base.py | 174 +++ .../services/mcp_auth/fetchers/factory.py | 19 + .../mcp_auth/fetchers/http_fetcher.py | 38 + .../mcp_auth/fetchers/oauth_fetcher.py | 89 ++ .../yuxi/services/mcp_auth/orchestrator.py | 184 +-- .../services/mcp_auth/redis_token_cache.py | 16 +- backend/package/yuxi/services/mcp_service.py | 1343 +---------------- backend/test/mcp_demo_server.py | 214 +++ .../unit/services/test_mcp_cache_policy.py | 91 ++ .../unit/services/test_mcp_client_pool.py | 95 ++ 18 files changed, 2573 insertions(+), 1429 deletions(-) create mode 100644 backend/package/yuxi/services/mcp/__init__.py create mode 100644 backend/package/yuxi/services/mcp/cache_policy.py create mode 100644 backend/package/yuxi/services/mcp/client_pool.py create mode 100644 backend/package/yuxi/services/mcp/connection_service.py create mode 100644 backend/package/yuxi/services/mcp/server_service.py create mode 100644 backend/package/yuxi/services/mcp/tool_registry_service.py create mode 100644 backend/package/yuxi/services/mcp_auth/fetchers/__init__.py create mode 100644 backend/package/yuxi/services/mcp_auth/fetchers/base.py create mode 100644 backend/package/yuxi/services/mcp_auth/fetchers/factory.py create mode 100644 backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py create mode 100644 backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py create mode 100644 backend/test/mcp_demo_server.py create mode 100644 backend/test/unit/services/test_mcp_cache_policy.py create mode 100644 backend/test/unit/services/test_mcp_client_pool.py diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index 3fee7a8b7..519dd148d 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -133,6 +133,21 @@ async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable[[Too tool = dynamic_tools.get(request.tool_call.get("name")) if isinstance(dynamic_tools, dict) else None if tool is not None: request = request.override(tool=tool) + + # NOTE: 注入当前的 AuthContext 以便于长连接拦截器 DynamicMCPTokenAuth 随时刷新 token + runtime_context = getattr(request.runtime, "context", None) + if runtime_context is not None: + work_id = getattr(runtime_context, "work_id", None) or getattr(runtime_context, "user_id", None) + dept_id = getattr(runtime_context, "department_id", None) + auth_context = AuthContext(user_id=work_id, department_id=dept_id) + + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + token = mcp_auth_context_var.set(auth_context) + try: + return await handler(request) + finally: + mcp_auth_context_var.reset(token) + return await handler(request) async def get_tools_from_context(self, context) -> list: diff --git a/backend/package/yuxi/services/mcp/__init__.py b/backend/package/yuxi/services/mcp/__init__.py new file mode 100644 index 000000000..fe4a61b73 --- /dev/null +++ b/backend/package/yuxi/services/mcp/__init__.py @@ -0,0 +1,97 @@ +from __future__ import annotations +from yuxi.services.mcp.cache_policy import ( + MCPCachePolicy, + StaticCachePolicy, + TokenInjectedCachePolicy, + DynamicProxyCachePolicy, + CachePolicyFactory, +) +from yuxi.services.mcp.client_pool import ( + mcp_client_pool, + MCPClientPool, +) +from yuxi.services.mcp.server_service import ( + ensure_builtin_mcp_servers_in_db, + get_enabled_mcp_server_config, + get_runtime_mcp_server_config, + get_enabled_mcp_server_names, + get_mcp_server, + get_all_mcp_servers, + create_mcp_server, + update_mcp_server, + delete_mcp_server, + get_mcp_server_dependency_summary, + set_server_enabled, + get_servers_config, +) +from yuxi.services.mcp.connection_service import ( + get_mcp_connection, + list_mcp_connections, + create_mcp_connection, + update_mcp_connection, + delete_mcp_connection, + set_mcp_connection_status, + reauthorize_mcp_connection, + test_mcp_connection, +) +from yuxi.services.mcp.tool_registry_service import ( + to_camel_case, + get_mcp_tools, + get_tools_from_all_servers, + clear_mcp_cache, + clear_mcp_server_tools_cache, + clear_mcp_connection_tools_cache, + invalidate_mcp_server_tools_cache, + invalidate_mcp_connection_tools_cache, + get_mcp_tools_stats, + get_enabled_mcp_tools, + get_all_mcp_tools, +) + +__all__ = [ + # 策略模式与对象池 + "MCPCachePolicy", + "StaticCachePolicy", + "TokenInjectedCachePolicy", + "DynamicProxyCachePolicy", + "CachePolicyFactory", + "mcp_client_pool", + "MCPClientPool", + + # Server CRUD + "ensure_builtin_mcp_servers_in_db", + "get_enabled_mcp_server_config", + "get_runtime_mcp_server_config", + "get_enabled_mcp_server_names", + "get_mcp_server", + "get_all_mcp_servers", + "create_mcp_server", + "update_mcp_server", + "delete_mcp_server", + "get_mcp_server_dependency_summary", + "set_server_enabled", + "get_servers_config", + + # Connection CRUD + "get_mcp_connection", + "list_mcp_connections", + "create_mcp_connection", + "update_mcp_connection", + "delete_mcp_connection", + "set_mcp_connection_status", + "reauthorize_mcp_connection", + "test_mcp_connection", + + # Tool Registry + "to_camel_case", + "get_mcp_tools", + "get_tools_from_all_servers", + "clear_mcp_cache", + "clear_mcp_server_tools_cache", + "clear_mcp_connection_tools_cache", + "invalidate_mcp_server_tools_cache", + "invalidate_mcp_connection_tools_cache", + "get_mcp_tools_stats", + "get_enabled_mcp_tools", + "get_all_mcp_tools", +] diff --git a/backend/package/yuxi/services/mcp/cache_policy.py b/backend/package/yuxi/services/mcp/cache_policy.py new file mode 100644 index 000000000..1c6ad1f5f --- /dev/null +++ b/backend/package/yuxi/services/mcp/cache_policy.py @@ -0,0 +1,102 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from yuxi.services.mcp_auth.orchestrator import AuthContext + from yuxi.storage.postgres.models_business import MCPConnection + + +class MCPCachePolicy(ABC): + """MCP 缓存策略抽象基类""" + + @abstractmethod + def should_cache_tool_object(self) -> bool: + """是否在内存中缓存底层的 Tool 实例对象""" + pass + + @abstractmethod + def resolve_cache_partition( + self, + auth_context: AuthContext, + connection: MCPConnection | None, + ) -> tuple[str, bool]: + """ + 解析该连接应被划分到哪一个缓存分区中。 + + 返回: + tuple[partition_key, is_shared_across_users] + - partition_key: 用于区分 Redis 缓存或内存缓存隔离区段的 Key。 + - is_shared_across_users: 表明该分区下的缓存在不同用户间是否可以共享。 + """ + pass + + +class StaticCachePolicy(MCPCachePolicy): + """静态配置(无鉴权)缓存策略""" + + def should_cache_tool_object(self) -> bool: + # NOTE: 静态服务无任何鉴权和状态变化,完全可以缓存 Tool 对象以提升性能 + return True + + def resolve_cache_partition( + self, + auth_context: AuthContext, + connection: MCPConnection | None, + ) -> tuple[str, bool]: + # NOTE: 静态连接全局共享同一个分区 + return "global", True + + +class TokenInjectedCachePolicy(MCPCachePolicy): + """静态凭据/环境变量注入型缓存策略(例如 bound_secret, stdio_env)""" + + def should_cache_tool_object(self) -> bool: + # NOTE: 绑定了静态凭据或环境变量的连接,一旦 Connection 确定,其工具列表也是确定的,支持缓存 Tool 对象 + return True + + def resolve_cache_partition( + self, + auth_context: AuthContext, + connection: MCPConnection | None, + ) -> tuple[str, bool]: + if connection is None: + return "global", True + + # NOTE: 仅系统级别(system)是多用户共享的,部门和个人级别一律判定为独占 + is_shared = connection.scope_type == "system" + return f"connection:{connection.id}", is_shared + + +class DynamicProxyCachePolicy(MCPCachePolicy): + """动态 Token 鉴权代理缓存策略(例如 custom_http_token, authorization_code)""" + + def should_cache_tool_object(self) -> bool: + # NOTE: 动态 Token 具有时效性且可能因用户身份变化,为了安全性,禁止在内存中缓存带有具体 Token 的 Tool 实例 + return False + + def resolve_cache_partition( + self, + auth_context: AuthContext, + connection: MCPConnection | None, + ) -> tuple[str, bool]: + if connection is None: + return "global", True + + # NOTE: 仅系统级别(system)是多用户共享的,部门和个人级别一律判定为独占 + is_shared = connection.scope_type == "system" + return f"connection:{connection.id}", is_shared + + +class CachePolicyFactory: + """缓存策略工厂,根据 auth_provider 获取匹配的 CachePolicy 实例""" + + @staticmethod + def get_policy(provider: str | None) -> MCPCachePolicy: + if not provider or provider == "legacy_static": + return StaticCachePolicy() + elif provider in ("bound_secret", "stdio_env"): + return TokenInjectedCachePolicy() + else: + # 默认为动态代理鉴权策略(支持 custom_http_token, client_credentials, authorization_code 等) + return DynamicProxyCachePolicy() diff --git a/backend/package/yuxi/services/mcp/client_pool.py b/backend/package/yuxi/services/mcp/client_pool.py new file mode 100644 index 000000000..a55cf24ea --- /dev/null +++ b/backend/package/yuxi/services/mcp/client_pool.py @@ -0,0 +1,210 @@ +from __future__ import annotations +import asyncio +import hashlib +import json +import logging +from typing import Any, AsyncGenerator, TYPE_CHECKING +import httpx + +from langchain_mcp_adapters.client import MultiServerMCPClient +from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + +if TYPE_CHECKING: + from mcp import ClientSession + +logger = logging.getLogger("yuxi.mcp.client_pool") + + +class DynamicMCPTokenAuth(httpx.Auth): + """动态 MCP Token 认证拦截器,每次 HTTP 请求前从 ContextVar 动态读取并注入 Authorization 头部""" + + def __init__(self, server_name: str): + self.server_name = server_name + + async def async_auth_flow( + self, request: httpx.Request + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + # NOTE: 1. 从当前协程上下文读取 AuthContext + auth_context = mcp_auth_context_var.get() + if auth_context: + try: + # 导入数据库会话管理器以获取连接与 Token + from yuxi.storage.postgres.manager import pg_manager + async with pg_manager.get_async_session_context() as session: + from yuxi.services.mcp_service import get_runtime_mcp_server_config + # NOTE: 2. 读取当前上下文对应的最新运行时配置(含 Token 自动刷新逻辑) + runtime_config = await get_runtime_mcp_server_config( + self.server_name, + auth_context=auth_context, + db=session, + ) + if runtime_config: + # NOTE: 3. 将最新的头部注入到当前 HTTP 请求中 + headers = runtime_config.get("headers") or {} + for key, val in headers.items(): + request.headers[key] = str(val) + except Exception as exc: + logger.error( + f"DynamicMCPTokenAuth failed to resolve token headers for '{self.server_name}': {exc}" + ) + yield request + + +class LongLivedSession: + """长期存活的 MCP Client 及其 Session 生命周期管理器""" + + def __init__(self, client: MultiServerMCPClient, server_name: str): + self.client = client + self.server_name = server_name + self.session: ClientSession | None = None + self._running = False + self._loop_task: asyncio.Task | None = None + self._ready_event = asyncio.Event() + self._stop_event = asyncio.Event() + + async def start(self): + """在后台启动长连接 Session""" + if not hasattr(self.client, "session"): + self.session = self.client + self._ready_event.set() + return + + self._running = True + self._stop_event.clear() + self._ready_event.clear() + self._loop_task = asyncio.create_task(self._run_loop()) + # 等待 Session 成功连接并完成 initialize() + await self._ready_event.wait() + if not self.session: + raise RuntimeError(f"Failed to startup MCP ClientSession for {self.server_name}") + + async def _run_loop(self): + try: + # NOTE: 利用 client.session 会在退出上下文时自动释放底层的 Stdio 子进程或 HTTP Keep-Alive 连接 + async with self.client.session(self.server_name) as session: + self.session = session + self._ready_event.set() + # 挂起直到收到停止指令 + await self._stop_event.wait() + except Exception as exc: + logger.error(f"Error in long-lived MCP session loop for {self.server_name}: {exc}") + finally: + self.session = None + self._running = False + self._ready_event.set() + + async def stop(self): + """停止长连接,回收子进程与 TCP 连接资源""" + self._stop_event.set() + if self._loop_task: + try: + await asyncio.wait_for(self._loop_task, timeout=5.0) + except asyncio.TimeoutError: + logger.warning(f"Timeout waiting for long-lived session of {self.server_name} to stop.") + self._loop_task.cancel() + except Exception as exc: + logger.debug(f"Exception during long-lived session cleanup of {self.server_name}: {exc}") + self._loop_task = None + + +class MCPClientPool: + """MCP 客户端连接池实现""" + + def __init__(self): + # 缓存键格式: (server_name, partition_key) -> (LongLivedSession, config_hash) + self._sessions: dict[tuple[str, str], tuple[LongLivedSession, str]] = {} + self._lock = asyncio.Lock() + + def _calculate_config_hash(self, config: dict[str, Any]) -> str: + """根据配置计算 Hash 用于比对配置是否脏变""" + clean_config = { + k: v + for k, v in config.items() + if k not in { + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + "disabled_tools", + } + } + # 剔除 header 中可能随时变化的 token/Authorization 以便准确比对静态配置 + headers = dict(clean_config.get("headers") or {}) + headers.pop("Authorization", None) + if headers: + clean_config["headers"] = headers + elif "headers" in clean_config: + clean_config["headers"] = {} + + payload = json.dumps(clean_config, sort_keys=True, ensure_ascii=True, separators=(",", ":")) + return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16] + + async def get_session( + self, + server_name: str, + partition_key: str, + runtime_config: dict[str, Any], + ) -> ClientSession: + """获取或重建匹配当前配置的 ClientSession""" + config_hash = self._calculate_config_hash(runtime_config) + cache_key = (server_name, partition_key) + + async with self._lock: + existing = self._sessions.get(cache_key) + if existing: + ll_session, cached_hash = existing + # NOTE: 如果配置无变化且 Session 处于活动状态,直接复用 + if cached_hash == config_hash and ll_session.session is not None: + return ll_session.session + + # 如果发生配置变化或 Session 断开,执行销毁 + logger.info(f"Destroying stale/disconnected MCP session for {cache_key}") + await ll_session.stop() + self._sessions.pop(cache_key, None) + + # NOTE: 针对 HTTP/SSE 协议,注入自定义的 httpx.Auth 认证流以支持长连接动态 Token + client_config = dict(runtime_config) + # 清理框架保留魔法键 + for magic_k in ( + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + "disabled_tools", + ): + client_config.pop(magic_k, None) + + if client_config.get("transport") in ("sse", "http", "streamable_http", "streamable-http"): + # 注入 DynamicMCPTokenAuth,让底层 httpx 在长连接执行每个具体请求时动态提取最新 Token + client_config["auth"] = DynamicMCPTokenAuth(server_name) + + logger.info(f"Creating new long-lived MCP session for {cache_key} (transport: {client_config.get('transport')})") + from yuxi.services.mcp_service import get_mcp_client + client = await get_mcp_client({server_name: client_config}) + if client is None: + raise RuntimeError(f"Failed to initialize MCP client for {server_name}") + ll_session = LongLivedSession(client, server_name) + await ll_session.start() + + self._sessions[cache_key] = (ll_session, config_hash) + return ll_session.session + + async def ensure_prewarm( + self, + server_name: str, + partition_key: str, + runtime_config: dict[str, Any], + ): + """后台异步预热加载,减少首次访问时的冷启动卡顿""" + try: + await self.get_session(server_name, partition_key, runtime_config) + except Exception as exc: + logger.warning(f"Failed to pre-warm MCP server '{server_name}': {exc}") + + async def shutdown(self): + """关闭并回收连接池中的所有连接""" + async with self._lock: + for cache_key, (ll_session, _) in list(self._sessions.items()): + logger.info(f"Stopping MCP session for {cache_key} during shutdown") + await ll_session.stop() + self._sessions.clear() + + +# 全局单例连接池 +mcp_client_pool = MCPClientPool() diff --git a/backend/package/yuxi/services/mcp/connection_service.py b/backend/package/yuxi/services/mcp/connection_service.py new file mode 100644 index 000000000..c4073f1e0 --- /dev/null +++ b/backend/package/yuxi/services/mcp/connection_service.py @@ -0,0 +1,313 @@ +from __future__ import annotations +import logging +from datetime import UTC, datetime +from typing import Any +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from yuxi.services.mcp_auth.crypto import encrypt_credential_blob +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.storage.postgres.models_business import MCPConnection + +logger = logging.getLogger("yuxi.mcp.connection_service") + +_UNSET = object() +_VALID_MCP_CONNECTION_SCOPE_TYPES = {"system", "department", "user"} +_VALID_MCP_CONNECTION_STATUSES = {"active", "disabled", "reauth_required", "invalid"} + + +def _resolve_scope_id(binding_scope: str, auth_context: AuthContext | None) -> str | None: + """依据 Scope 类别从 AuthContext 中解出匹配的 ID""" + if binding_scope == "inline": + return None + if binding_scope == "system": + return "global" + if auth_context is None: + raise ValueError(f"auth_context is required for MCP binding scope '{binding_scope}'") + if binding_scope == "department": + if not auth_context.department_id: + raise ValueError("department_id is required for department-scoped MCP auth") + return str(auth_context.department_id) + if binding_scope == "user": + if not auth_context.user_id: + raise ValueError("user_id is required for user-scoped MCP auth") + return str(auth_context.user_id) + raise ValueError(f"Unsupported MCP binding scope: {binding_scope}") + + +def _normalize_mcp_connection_scope(scope_type: str, scope_id: str | None) -> tuple[str, str]: + normalized_scope_type = str(scope_type or "").strip().lower() + if normalized_scope_type not in _VALID_MCP_CONNECTION_SCOPE_TYPES: + raise ValueError("scope_type must be one of: system, department, user") + + normalized_scope_id = str(scope_id or "").strip() + if normalized_scope_type == "system": + return normalized_scope_type, "global" + if not normalized_scope_id: + raise ValueError(f"scope_id is required for {normalized_scope_type}-scoped MCP connections") + return normalized_scope_type, normalized_scope_id + + +def _normalize_mcp_connection_status(status: str) -> str: + normalized_status = str(status or "").strip().lower() + if normalized_status not in _VALID_MCP_CONNECTION_STATUSES: + raise ValueError("status must be one of: active, disabled, reauth_required, invalid") + return normalized_status + + +async def get_mcp_connection(db: AsyncSession, connection_id: int) -> MCPConnection | None: + """获取单个 Connection 记录""" + result = await db.execute(select(MCPConnection).where(MCPConnection.id == connection_id)) + return result.scalar_one_or_none() + + +def _auth_context_from_connection(connection: MCPConnection) -> AuthContext: + """基于连接绑定生成对应的 AuthContext 用于模拟联调与测试""" + if connection.scope_type == "department": + return AuthContext(department_id=connection.scope_id) + if connection.scope_type == "user": + return AuthContext(user_id=connection.scope_id) + return AuthContext() + + +async def list_mcp_connections( + db: AsyncSession, + *, + server_name: str | None = None, + scope_type: str | None = None, + scope_id: str | None = None, +) -> list[MCPConnection]: + """多条件列表查询 Connection""" + stmt = select(MCPConnection) + if server_name is not None: + stmt = stmt.where(MCPConnection.server_name == server_name) + if scope_type is not None: + stmt = stmt.where(MCPConnection.scope_type == scope_type) + if scope_id is not None: + stmt = stmt.where(MCPConnection.scope_id == scope_id) + stmt = stmt.order_by(MCPConnection.id.asc()) + result = await db.execute(stmt) + return list(result.scalars().all()) + + +async def create_mcp_connection( + db: AsyncSession, + *, + server_name: str, + scope_type: str, + scope_id: str, + display_name: str | None = None, + external_subject: str | None = None, + status: str = "active", + credential_blob: str | None = None, + meta_json: dict[str, Any] | None = None, + created_by: str | None = None, +) -> MCPConnection: + """创建 MCP 绑定连接""" + from yuxi.services.mcp.server_service import get_mcp_server + server = await get_mcp_server(db, server_name) + if server is None: + raise ValueError(f"Server '{server_name}' does not exist") + normalized_scope_type, normalized_scope_id = _normalize_mcp_connection_scope(scope_type, scope_id) + normalized_status = _normalize_mcp_connection_status(status) + + encrypted_credential_blob = ( + encrypt_credential_blob(credential_blob) + if isinstance(credential_blob, str) and credential_blob.strip() + else credential_blob + ) + + connection = MCPConnection( + server_name=server_name, + scope_type=normalized_scope_type, + scope_id=normalized_scope_id, + display_name=display_name, + external_subject=external_subject, + status=normalized_status, + credential_blob=encrypted_credential_blob, + meta_json=meta_json or {}, + created_by=created_by, + updated_by=created_by, + ) + db.add(connection) + await db.commit() + await db.refresh(connection) + return connection + + +async def update_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + display_name: str | None = None, + external_subject: str | None = None, + credential_blob: Any = _UNSET, + meta_json: dict[str, Any] | None = None, + status: str | None = None, + updated_by: str | None = None, +) -> MCPConnection: + """更新 MCP 绑定连接""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + should_clear_runtime_auth_cache = False + if display_name is not None: + connection.display_name = display_name + if external_subject is not None: + connection.external_subject = external_subject + if credential_blob is not _UNSET: + if isinstance(credential_blob, str) and credential_blob.strip(): + connection.credential_blob = encrypt_credential_blob(credential_blob) + else: + connection.credential_blob = credential_blob + should_clear_runtime_auth_cache = True + if meta_json is not None: + connection.meta_json = meta_json + if status is not None: + connection.status = _normalize_mcp_connection_status(status) + should_clear_runtime_auth_cache = True + if updated_by is not None: + connection.updated_by = updated_by + + await db.commit() + await db.refresh(connection) + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_connection_runtime_auth_cache, + _invalidate_mcp_tools_cache_for_connection, + ) + if should_clear_runtime_auth_cache: + await _clear_mcp_connection_runtime_auth_cache(connection.id) + await _invalidate_mcp_tools_cache_for_connection(connection) + return connection + + +async def delete_mcp_connection(db: AsyncSession, connection_id: int) -> bool: + """删除 MCP 绑定连接""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + return False + deleted_connection_id = connection.id + deleted_server_name = connection.server_name + deleted_scope_type = connection.scope_type + await db.delete(connection) + await db.commit() + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_connection_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + invalidate_mcp_connection_tools_cache, + ) + await _clear_mcp_connection_runtime_auth_cache(deleted_connection_id) + if deleted_scope_type == "system": + await invalidate_mcp_server_tools_cache(deleted_server_name) + else: + await invalidate_mcp_connection_tools_cache(deleted_server_name, deleted_connection_id) + return True + + +async def set_mcp_connection_status( + db: AsyncSession, + connection_id: int, + *, + status: str, + updated_by: str | None = None, +) -> MCPConnection: + """设置 MCP 绑定状态""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + connection.status = _normalize_mcp_connection_status(status) + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_connection_runtime_auth_cache, + _invalidate_mcp_tools_cache_for_connection, + ) + await _clear_mcp_connection_runtime_auth_cache(connection.id) + await _invalidate_mcp_tools_cache_for_connection(connection) + return connection + + +async def reauthorize_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + updated_by: str | None = None, +) -> MCPConnection: + """重置授权连接凭据缓存并重新开启连接""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + from yuxi.services import mcp_service + cache = mcp_service.RedisTokenCache() + if getattr(connection, "id", None) is not None: + try: + await cache.delete_access_token(connection.id) + except Exception as exc: + logger.warning(f"Failed to clear MCP token cache for connection {connection.id}: {exc}") + try: + await cache.release_refresh_lock(connection.id) + except Exception as exc: + logger.warning(f"Failed to clear MCP refresh lock for connection {connection.id}: {exc}") + + from yuxi.services.mcp.tool_registry_service import _invalidate_mcp_tools_cache_for_connection + await _invalidate_mcp_tools_cache_for_connection(connection) + + connection.status = "active" + meta_json = dict(connection.meta_json or {}) + meta_json.pop("last_error", None) + connection.meta_json = meta_json + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + return connection + + +async def test_mcp_connection( + db: AsyncSession, + connection_id: int, + *, + updated_by: str | None = None, +) -> dict[str, Any]: + """测试连接联调可用性,获取可用的工具列表""" + connection = await get_mcp_connection(db, connection_id) + if connection is None: + raise ValueError(f"MCP connection '{connection_id}' does not exist") + + from yuxi.services.mcp.server_service import get_mcp_server + server = await get_mcp_server(db, connection.server_name) + if server is None: + raise ValueError(f"Server '{connection.server_name}' does not exist") + + auth_context = _auth_context_from_connection(connection) + from yuxi.services import mcp_service + config = await mcp_service.get_runtime_mcp_server_config(server.name, auth_context=auth_context, db=db) + if config is None: + raise ValueError(f"MCP server '{server.name}' runtime config unavailable") + + tools = await mcp_service.get_mcp_tools( + server.name, + additional_servers={server.name: config}, + disabled_tools=[], + cache=False, + force_refresh=True, + ) + + meta_json = dict(connection.meta_json or {}) + meta_json["last_success_at"] = datetime.now(tz=UTC).isoformat() + meta_json.pop("last_error", None) + connection.meta_json = meta_json + connection.status = "active" + if updated_by is not None: + connection.updated_by = updated_by + await db.commit() + await db.refresh(connection) + return {"tool_count": len(tools), "connection": connection} diff --git a/backend/package/yuxi/services/mcp/server_service.py b/backend/package/yuxi/services/mcp/server_service.py new file mode 100644 index 000000000..2187419b2 --- /dev/null +++ b/backend/package/yuxi/services/mcp/server_service.py @@ -0,0 +1,469 @@ +from __future__ import annotations +import hashlib +import json +import logging +import os +import traceback +from typing import Any +import httpx +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config +from yuxi.services.mcp_auth.proxy_service import ( + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + build_proxy_runtime_config, + should_use_internal_proxy, +) +from yuxi.services.mcp_tool_cache import RedisMcpToolCache +from yuxi.storage.postgres.models_business import AgentConfig, MCPConnection, MCPServer, Skill + +logger = logging.getLogger("yuxi.mcp.server_service") + +_UNSET = object() +_MCP_PROXY_BASE_URL_ENV = "YUXI_INTERNAL_MCP_PROXY_BASE_URL" + +_DEFAULT_MCP_SERVERS = { + "sequentialthinking": { + "url": "https://remote.mcpservers.org/sequentialthinking/mcp", + "transport": "streamable_http", + "description": "顺序思考工具,帮助 AI 将复杂问题分解为多个步骤", + "icon": "🧠", + "tags": ["内置", "AI"], + }, + "mcp-server-chart": { + "command": "npx", + "args": ["-y", "@antv/mcp-server-chart"], + "transport": "stdio", + "description": "图表生成工具,支持生成各类图表(柱状图、折线图、饼图等)", + "icon": "📊", + "tags": ["内置", "图表"], + }, +} + +_SYNCED_MCP_FIELDS = ( + "description", + "transport", + "url", + "command", + "args", + "env", + "headers", + "timeout", + "sse_read_timeout", + "tags", + "icon", +) + + +async def ensure_builtin_mcp_servers_in_db() -> None: + """同步代码预置的内置 MCP 服务器至数据库中""" + from yuxi.storage.postgres.manager import pg_manager + + try: + async with pg_manager.get_async_session_context() as session: + result = await session.execute(select(func.count(MCPServer.name))) + count = result.scalar() + + if count == 0: + logger.info("No MCP servers in database, importing default configurations...") + for name, config in _DEFAULT_MCP_SERVERS.items(): + server = MCPServer( + name=name, + description=config.get("description"), + transport=config["transport"], + url=config.get("url"), + command=config.get("command"), + args=config.get("args"), + env=config.get("env"), + headers=config.get("headers"), + timeout=config.get("timeout"), + sse_read_timeout=config.get("sse_read_timeout"), + tags=config.get("tags"), + icon=config.get("icon"), + enabled=0, + created_by="system", + updated_by="system", + ) + session.add(server) + await session.commit() + logger.info(f"Imported {len(_DEFAULT_MCP_SERVERS)} default MCP servers to database") + else: + for name, config in _DEFAULT_MCP_SERVERS.items(): + result = await session.execute(select(MCPServer).filter(MCPServer.name == name)) + existing = result.scalar_one_or_none() + if not existing: + server = MCPServer( + name=name, + description=config.get("description"), + transport=config["transport"], + url=config.get("url"), + command=config.get("command"), + args=config.get("args"), + env=config.get("env"), + headers=config.get("headers"), + timeout=config.get("timeout"), + sse_read_timeout=config.get("sse_read_timeout"), + tags=config.get("tags"), + icon=config.get("icon"), + enabled=0, + created_by="system", + updated_by="system", + ) + session.add(server) + logger.info(f"Added built-in MCP server '{name}' to database") + else: + changed = False + for field in _SYNCED_MCP_FIELDS: + next_value = config.get(field) + if getattr(existing, field) != next_value: + setattr(existing, field, next_value) + changed = True + if changed: + existing.updated_by = "system" + if session.new: + await session.commit() + elif session.dirty: + await session.commit() + + except Exception as e: + logger.error(f"Failed to ensure builtin MCP servers in database: {e}, traceback: {traceback.format_exc()}") + + +async def _load_enabled_mcp_server_configs( + *, + names: list[str] | None = None, + db: AsyncSession | None = None, +) -> dict[str, dict[str, Any]]: + """从数据库中加载已启用的服务器 MCP 配置""" + if db is not None: + stmt = select(MCPServer).where(MCPServer.enabled == 1) + if names: + stmt = stmt.where(MCPServer.name.in_(names)) + result = await db.execute(stmt) + servers = result.scalars().all() + return {server.name: server.to_mcp_config() for server in servers} + + from yuxi.storage.postgres.manager import pg_manager + + async with pg_manager.get_async_session_context() as session: + return await _load_enabled_mcp_server_configs(names=names, db=session) + + +async def get_enabled_mcp_server_config(server_name: str, *, db: AsyncSession | None = None) -> dict[str, Any] | None: + """获取最新启用的指定服务器的 MCP 配置""" + configs = await _load_enabled_mcp_server_configs(names=[server_name], db=db) + return configs.get(server_name) + + +def _get_internal_mcp_proxy_base_url() -> str | None: + value = os.getenv(_MCP_PROXY_BASE_URL_ENV, "").strip() + return value or None + + +async def _get_enabled_mcp_server_record(server_name: str, *, db: AsyncSession) -> MCPServer | None: + result = await db.execute( + select(MCPServer).where( + MCPServer.enabled == 1, + MCPServer.name == server_name, + ) + ) + return result.scalar_one_or_none() + + +def _apply_runtime_tool_cache_policy( + config: dict[str, Any], + *, + auth_config: MCPAuthConfig, + auth_context: AuthContext | None, + connection: MCPConnection | None, +) -> dict[str, Any]: + """利用 CachePolicy 模式获取缓存 key 的隔离区划并应用""" + from yuxi.services.mcp.cache_policy import CachePolicyFactory + + policy = CachePolicyFactory.get_policy(auth_config.provider) + partition, is_shared = policy.resolve_cache_partition( + auth_context or AuthContext(), + connection, + ) + config["__yuxi_cache_partition"] = partition + config["__yuxi_allow_global_cache"] = is_shared + return config + + +async def get_runtime_mcp_server_config( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, +) -> dict[str, Any] | None: + """解析获取附带运行时鉴权与租户范围的 MCP 服务配置""" + if db is None and auth_context is None: + from yuxi.services import mcp_service + return await mcp_service.get_enabled_mcp_server_config(server_name) + + if db is not None: + server = await _get_enabled_mcp_server_record(server_name, db=db) + if server is None: + return None + if not server.auth_config_json: + return server.to_mcp_config() + + auth_config = MCPAuthConfig.model_validate(server.auth_config_json) + from yuxi.services import mcp_service + scope_id = mcp_service._resolve_scope_id(auth_config.binding_scope, auth_context) + if scope_id is None: + return server.to_mcp_config() + + result = await db.execute( + select(MCPConnection).where( + MCPConnection.server_name == server_name, + MCPConnection.scope_type == auth_config.binding_scope, + MCPConnection.scope_id == scope_id, + MCPConnection.status == "active", + ) + ) + connection = result.scalar_one_or_none() + if connection is None: + raise ValueError( + f"Active MCP connection not found for server '{server_name}' and scope " + f"{auth_config.binding_scope}:{scope_id}" + ) + proxy_base_url = _get_internal_mcp_proxy_base_url() + if should_use_internal_proxy(server, auth_config, proxy_base_url): + config = build_proxy_runtime_config( + server, + auth_context=auth_context or AuthContext(), + proxy_base_url=proxy_base_url or "", + ) + else: + config = await resolve_runtime_mcp_config( + server, + auth_context=auth_context or AuthContext(), + connection=connection, + http_client=http_client, + ) + return _apply_runtime_tool_cache_policy( + config, + auth_config=auth_config, + auth_context=auth_context, + connection=connection, + ) + + from yuxi.storage.postgres.manager import pg_manager + + async with pg_manager.get_async_session_context() as session: + return await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=session, + http_client=http_client, + ) + + +async def get_enabled_mcp_server_names(*, db: AsyncSession | None = None) -> list[str]: + """获取所有已启用的服务器名称""" + from yuxi.services import mcp_service + configs = await mcp_service._load_enabled_mcp_server_configs(db=db) + return list(configs.keys()) + + +async def get_mcp_server(db: AsyncSession, name: str) -> MCPServer | None: + """获取单个服务器对象记录""" + result = await db.execute(select(MCPServer).filter(MCPServer.name == name)) + return result.scalar_one_or_none() + + +async def get_all_mcp_servers(db: AsyncSession) -> list[MCPServer]: + """获取所有配置的服务器对象列表""" + result = await db.execute(select(MCPServer)) + return list(result.scalars().all()) + + +async def create_mcp_server( + db: AsyncSession, + name: str, + transport: str, + url: str = None, + command: str = None, + args: list = None, + env: dict = None, + description: str = None, + headers: dict = None, + timeout: int = None, + sse_read_timeout: int = None, + tags: list = None, + icon: str = None, + auth_config: dict | None = None, + created_by: str = None, +) -> MCPServer: + """创建 MCP 服务器配置""" + existing = await get_mcp_server(db, name) + if existing: + raise ValueError(f"Server name '{name}' already exists") + + server = MCPServer( + name=name, + description=description, + transport=transport, + url=url, + command=command, + args=args, + env=env, + headers=headers, + auth_config_json=auth_config, + timeout=timeout, + sse_read_timeout=sse_read_timeout, + tags=tags, + icon=icon, + enabled=1, + created_by=created_by, + updated_by=created_by, + ) + db.add(server) + await db.commit() + await db.refresh(server) + + from yuxi.services import mcp_service + await mcp_service._clear_mcp_server_runtime_auth_cache(db, name) + await mcp_service.invalidate_mcp_server_tools_cache(name) + + logger.info(f"Created MCP server '{name}'") + return server + + +async def update_mcp_server( + db: AsyncSession, + name: str, + description: str = None, + transport: str = None, + url: str = None, + command: str = None, + args: list = None, + env: Any = _UNSET, + headers: dict = None, + timeout: int = None, + sse_read_timeout: int = None, + tags: list = None, + icon: str = None, + auth_config: Any = _UNSET, + updated_by: str = None, +) -> MCPServer: + """更新服务器配置""" + server = await get_mcp_server(db, name) + if not server: + raise ValueError(f"Server '{name}' does not exist") + + if description is not None: + server.description = description + if transport is not None: + server.transport = transport + if url is not None: + server.url = url + if command is not None: + server.command = command + if args is not None: + server.args = args + if env is not _UNSET: + server.env = env + if headers is not None: + server.headers = headers + if auth_config is not _UNSET: + server.auth_config_json = auth_config + if timeout is not None: + server.timeout = timeout + if sse_read_timeout is not None: + server.sse_read_timeout = sse_read_timeout + if tags is not None: + server.tags = tags + if icon is not None: + server.icon = icon + if updated_by is not None: + server.updated_by = updated_by + + await db.commit() + await db.refresh(server) + + from yuxi.services import mcp_service + if auth_config is not _UNSET: + await mcp_service._clear_mcp_server_runtime_auth_cache(db, name) + await mcp_service.invalidate_mcp_server_tools_cache(name) + + logger.info(f"Updated MCP server '{name}'") + return server + + +async def delete_mcp_server(db: AsyncSession, name: str) -> bool: + """删除服务器""" + server = await get_mcp_server(db, name) + if not server: + return False + + await db.delete(server) + await db.commit() + + from yuxi.services import mcp_service + await mcp_service._clear_mcp_server_runtime_auth_cache(db, name) + await mcp_service.invalidate_mcp_server_tools_cache(name) + + logger.info(f"Deleted MCP server '{name}'") + return True + + +async def get_mcp_server_dependency_summary(db: AsyncSession, name: str) -> dict[str, Any]: + """获取依赖于该 MCP 服务器的智能体、技能和连接概要""" + from yuxi.services import mcp_service + connections = await mcp_service.list_mcp_connections(db, server_name=name) + + skill_rows = (await db.execute(select(Skill))).scalars().all() + matched_skills = [ + {"slug": item.slug, "name": item.name} for item in skill_rows if name in (item.mcp_dependencies or []) + ] + + agent_config_rows = (await db.execute(select(AgentConfig))).scalars().all() + matched_agent_configs = [] + for item in agent_config_rows: + config_json = item.config_json or {} + if name in (config_json.get("mcps") or []): + matched_agent_configs.append({"id": item.id, "name": item.name, "agent_id": item.agent_id}) + + connection_refs = [ + {"scope_type": item.scope_type, "scope_id": item.scope_id, "status": item.status} for item in connections + ] + + return { + "has_references": bool(connection_refs or matched_skills or matched_agent_configs), + "connections": connection_refs, + "skills": matched_skills, + "agent_configs": matched_agent_configs, + } + + +async def set_server_enabled( + db: AsyncSession, name: str, enabled: bool, updated_by: str = None +) -> tuple[bool, MCPServer]: + """设置服务器的启用状态""" + server = await get_mcp_server(db, name) + if not server: + raise ValueError(f"Server '{name}' does not exist") + + server.enabled = 1 if enabled else 0 + if updated_by is not None: + server.updated_by = updated_by + await db.commit() + + is_enabled = bool(server.enabled) + from yuxi.services import mcp_service + if not is_enabled: + await mcp_service._clear_mcp_server_runtime_auth_cache(db, name) + await mcp_service.invalidate_mcp_server_tools_cache(name) + + logger.info(f"Set MCP server '{name}' enabled={is_enabled}") + return is_enabled, server + + +async def get_servers_config(names: list[str]) -> dict[str, dict[str, Any]]: + """批量获取服务器配置""" + return await _load_enabled_mcp_server_configs(names=names) diff --git a/backend/package/yuxi/services/mcp/tool_registry_service.py b/backend/package/yuxi/services/mcp/tool_registry_service.py new file mode 100644 index 000000000..33b8bff3a --- /dev/null +++ b/backend/package/yuxi/services/mcp/tool_registry_service.py @@ -0,0 +1,519 @@ +from __future__ import annotations +import asyncio +import hashlib +import json +import logging +import traceback +from collections.abc import Callable +from types import SimpleNamespace +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.crypto import encrypt_credential_blob +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.services.mcp_auth.proxy_service import ( + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + INTERNAL_PROXY_TOKEN_HEADER, +) +from yuxi.services.mcp_tool_cache import RedisMcpToolCache +from yuxi.storage.postgres.models_business import MCPConnection, MCPServer + +logger = logging.getLogger("yuxi.mcp.tool_registry_service") + +# 全局共享状态(直接在本模块维护,供外部和测试使用) +_mcp_tools_cache: dict[str, list[Callable[..., Any]]] = {} +_mcp_tools_stats: dict[str, dict[str, int]] = {} +_mcp_tool_cache_store = RedisMcpToolCache() +_mcp_lock = asyncio.Lock() + + + + +def to_camel_case(s: str) -> str: + """转换字符串为 lowerCamelCase 命名格式""" + import re + s = re.sub(r"[-_]+(.)", lambda m: m.group(1).upper(), s) + if len(s) > 0: + s = s[0].lower() + s[1:] + return s + + +def _extract_cache_identity(server_config: dict[str, Any]) -> tuple[dict[str, Any], str, bool]: + """提取用于缓存 key 比较的标识配置""" + cache_partition = str(server_config.get("__yuxi_cache_partition") or "server") + allow_global_cache = bool(server_config.get("__yuxi_allow_global_cache", True)) + + cache_identity = { + key: value + for key, value in server_config.items() + if key not in { + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + "disabled_tools", + } + } + + headers = dict(cache_identity.get("headers") or {}) + headers.pop(INTERNAL_PROXY_TOKEN_HEADER, None) + if headers: + cache_identity["headers"] = headers + elif "headers" in cache_identity: + cache_identity["headers"] = {} + return cache_identity, cache_partition, allow_global_cache + + +async def _build_mcp_tool_cache_descriptor(server_name: str, server_config: dict[str, Any]) -> dict[str, Any]: + """生成缓存 Key 描述信息字典""" + cache_identity, cache_partition, allow_global_cache = _extract_cache_identity(server_config) + config_payload = json.dumps(cache_identity, sort_keys=True, ensure_ascii=True, separators=(",", ":")) + config_hash = hashlib.sha256(config_payload.encode("utf-8")).hexdigest()[:16] + + server_revision = await _mcp_tool_cache_store.get_server_revision(server_name) + partition_revision = 0 + if not allow_global_cache: + partition_revision = await _mcp_tool_cache_store.get_partition_revision(server_name, cache_partition) + revision_token = f"s{server_revision}:p{partition_revision}" + cache_prefix = f"{server_name}:{cache_partition}:{revision_token}:" + + return { + "cache_identity": cache_identity, + "cache_partition": cache_partition, + "allow_global_cache": allow_global_cache, + "config_hash": config_hash, + "cache_prefix": cache_prefix, + "cache_key": f"{cache_prefix}{config_hash}", + "server_revision": server_revision, + "partition_revision": partition_revision, + } + + +def _serialize_mcp_tools_manifest( + *, + server_name: str, + cache_partition: str, + cache_key: str, + tools: list[Callable[..., Any]], +) -> dict[str, Any]: + """将 Langchain 运行态 Tool 转换为 Manifest 字典以缓存到 Redis 中""" + entries = [] + for tool in tools: + if hasattr(tool, "args_schema") and tool.args_schema: + schema = tool.args_schema.schema() if hasattr(tool.args_schema, "schema") else {} + parameters = schema.get("properties", {}) + required = schema.get("required", []) + else: + parameters = {} + required = [] + metadata = dict(getattr(tool, "metadata", {}) or {}) + entries.append( + { + "name": tool.name, + "id": metadata.get("id") or tool.name, + "description": getattr(tool, "description", ""), + "parameters": parameters, + "required": required, + } + ) + return { + "server_name": server_name, + "cache_partition": cache_partition, + "cache_key": cache_key, + "tools": entries, + } + + +def _deserialize_mcp_tool_manifest(manifest: dict[str, Any]) -> list[Callable[..., Any]]: + """反序列化 Redis 中的 Manifest 字典还原为本地 Tool 对象结构""" + tools: list[Callable[..., Any]] = [] + for entry in manifest.get("tools", []): + args_schema = None + parameters = entry.get("parameters") or {} + required = entry.get("required") or [] + if parameters or required: + args_schema = SimpleNamespace( + schema=lambda parameters=parameters, required=required: { + "properties": parameters, + "required": required, + } + ) + tools.append( + SimpleNamespace( + name=entry.get("name") or "", + description=entry.get("description") or "", + metadata={"id": entry.get("id") or entry.get("name") or ""}, + args_schema=args_schema, + ) + ) + return tools + + +def _get_mcp_auth_config(server_config: dict[str, Any]) -> MCPAuthConfig | None: + auth_payload = server_config.get("auth_config") or {} + if not auth_payload: + return None + try: + return MCPAuthConfig.model_validate(auth_payload) + except Exception as exc: + logger.warning(f"Invalid MCP auth config while resolving tool preload strategy: {exc}") + return None + + +def _can_preload_mcp_server_tools_without_runtime_auth(server_config: dict[str, Any]) -> bool: + if not (server_config.get("auth_config") or {}): + return True + auth_config = _get_mcp_auth_config(server_config) + if auth_config is None: + return False + return auth_config.provider == "legacy_static" + + +async def get_mcp_tools( + server_name: str, + additional_servers: dict[str, dict[str, Any]] | None = None, + disabled_tools: list[str] = None, + cache: bool = True, + force_refresh: bool = False, +) -> list[Callable[..., Any]]: + """ + 获取指定 MCP 服务器的工具列表。 + + 优化生命周期: + - 集成缓存策略模式 (CachePolicy),动态决策是否在进程内容许缓存 Tool 对象。 + - 集成客户端连接池 (MCPClientPool),复用 Stdio 长期子进程及 HTTP Keep-Alive 连接。 + """ + if additional_servers and server_name in additional_servers: + server_config = additional_servers[server_name] + else: + from yuxi.services.mcp.server_service import get_enabled_mcp_server_config + server_config = await get_enabled_mcp_server_config(server_name) + + if server_config is None: + logger.warning(f"MCP server '{server_name}' not found in database or disabled") + return [] + + cache_descriptor = await _build_mcp_tool_cache_descriptor(server_name, server_config) + cache_partition = cache_descriptor["cache_partition"] + cache_prefix = cache_descriptor["cache_prefix"] + cache_key = cache_descriptor["cache_key"] + + # 策略模式:根据 AuthProvider 确认是否容许内存缓存 Tool 实例对象 + from yuxi.services.mcp.cache_policy import CachePolicyFactory + auth_config = _get_mcp_auth_config(server_config) + policy = CachePolicyFactory.get_policy(auth_config.provider if auth_config else None) + use_tool_object_cache = ( + cache + and policy.should_cache_tool_object() + and not bool(server_config.get(INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY)) + ) + + all_processed_tools: list[Callable[..., Any]] = [] + + async with _mcp_lock: + if not force_refresh and use_tool_object_cache and cache_key in _mcp_tools_cache: + all_processed_tools = _mcp_tools_cache[cache_key] + + if not all_processed_tools: + try: + client_config = { + k: v + for k, v in server_config.items() + if k not in ( + "disabled_tools", + "__yuxi_cache_partition", + "__yuxi_allow_global_cache", + INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, + ) + } + + # NOTE: 从长连接池中提取 ClientSession 实例(对 Stdio 而言子进程被挂起复用,避免频繁启停;HTTP 协议亦保持 Keep-Alive) + from yuxi.services.mcp.client_pool import mcp_client_pool + session = await mcp_client_pool.get_session( + server_name, + partition_key=f"{cache_partition}:s{cache_descriptor['server_revision']}:p{cache_descriptor['partition_revision']}", + runtime_config=client_config, + ) + + # 如果 session 是 Fake Client (有 get_tools 方法),我们直接调用它获取工具列表,避免 load_mcp_tools 报错 + if hasattr(session, "get_tools"): + raw_tools = cast(list[Any], await session.get_tools()) + else: + # 调用 langchain 官方加载工具,直接传入已预备并建立好的 session + from langchain_mcp_adapters.tools import load_mcp_tools + raw_tools = cast(list[Any], await load_mcp_tools(session, server_name=server_name)) + + server_cc = to_camel_case(server_name) + for tool in raw_tools: + original_name = tool.name + tool_cc = to_camel_case(original_name) + unique_id = f"mcp__{server_cc}__{tool_cc}" + + if tool.metadata is None: + tool.metadata = {} + tool.metadata["id"] = unique_id + tool.handle_tool_error = True + all_processed_tools.append(tool) + + if cache: + if use_tool_object_cache: + async with _mcp_lock: + stale_keys = [ + key for key in _mcp_tools_cache if key.startswith(cache_prefix) and key != cache_key + ] + for stale_key in stale_keys: + _mcp_tools_cache.pop(stale_key, None) + _mcp_tools_cache[cache_key] = all_processed_tools + + await _mcp_tool_cache_store.set_manifest( + cache_key, + _serialize_mcp_tools_manifest( + server_name=server_name, + cache_partition=cache_partition, + cache_key=cache_key, + tools=all_processed_tools, + ), + ) + + global_config_disabled = server_config.get("disabled_tools") or [] + enabled_count = len([t for t in all_processed_tools if t.name not in global_config_disabled]) + _mcp_tools_stats[server_name] = { + "total": len(all_processed_tools), + "enabled": enabled_count, + "disabled": len(all_processed_tools) - enabled_count, + } + + logger.info( + f"Refreshed MCP tools cache for '{server_name}' with key '{cache_key}': " + f"{len(all_processed_tools)} tools loaded." + ) + + except Exception as e: + logger.error( + f"Failed to load tools from MCP server '{server_name}': {e}, traceback: {traceback.format_exc()}" + ) + return [] + + if disabled_tools: + filtered_tools = [t for t in all_processed_tools if t.name not in disabled_tools] + return filtered_tools + + return all_processed_tools + + +async def get_tools_from_all_servers() -> list[Callable[..., Any]]: + """批量载入所有可用服务的工具(用于系统初始化及预热)""" + from yuxi.services.mcp.server_service import _load_enabled_mcp_server_configs + server_configs = await _load_enabled_mcp_server_configs() + all_tools = [] + for server_name, server_config in server_configs.items(): + if not _can_preload_mcp_server_tools_without_runtime_auth(server_config): + logger.info(f"Skip MCP tool preload for '{server_name}' because runtime auth context is required") + continue + tools = await get_mcp_tools(server_name, additional_servers={server_name: server_config}) + all_tools.extend(tools) + return all_tools + + +def clear_mcp_cache() -> None: + """清空本地内存工具缓存""" + global _mcp_tools_cache + _mcp_tools_cache = {} + + try: + from yuxi.services.mcp.client_pool import mcp_client_pool + mcp_client_pool._sessions.clear() + except Exception: + pass + + +def clear_mcp_server_tools_cache(server_name: str) -> None: + """清空指定服务器下的所有本地缓存""" + global _mcp_tools_cache + prefix = f"{server_name}:" + stale_keys = [k for k in _mcp_tools_cache if k.startswith(prefix)] + for key in stale_keys: + _mcp_tools_cache.pop(key, None) + + +def clear_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: + """清空指定连接下的本地内存缓存""" + if connection_id is None: + return + global _mcp_tools_cache + suffix = f":connection:{connection_id}:" + stale_keys = [k for k in _mcp_tools_cache if suffix in k and k.startswith(f"{server_name}:")] + for key in stale_keys: + _mcp_tools_cache.pop(key, None) + + +async def invalidate_mcp_server_tools_cache(server_name: str) -> None: + """全局失效指定服务器的全部二级缓存""" + clear_mcp_server_tools_cache(server_name) + await _mcp_tool_cache_store.bump_server_revision(server_name) + + +async def invalidate_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: + """失效指定连接下的二级缓存区划""" + if connection_id is None: + return + clear_mcp_connection_tools_cache(server_name, connection_id) + await _mcp_tool_cache_store.bump_partition_revision(server_name, f"connection:{connection_id}") + + +async def _invalidate_mcp_tools_cache_for_connection(connection: MCPConnection) -> None: + """依据 Scope 类别自动刷新并失效缓存""" + if connection.scope_type == "system": + await invalidate_mcp_server_tools_cache(connection.server_name) + else: + await invalidate_mcp_connection_tools_cache(connection.server_name, connection.id) + + +async def _clear_mcp_connection_runtime_auth_cache(connection_id: int | None) -> None: + """清理 Redis 中缓存的 Access Token 与锁状态""" + if connection_id is None: + return + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + cache = RedisTokenCache() + try: + await cache.delete_access_token(connection_id) + except Exception as exc: + logger.warning(f"Failed to clear MCP token cache for connection {connection_id}: {exc}") + try: + await cache.release_refresh_lock(connection_id) + except Exception as exc: + logger.warning(f"Failed to clear MCP refresh lock for connection {connection_id}: {exc}") + + +async def _clear_mcp_server_runtime_auth_cache(db: AsyncSession, server_name: str) -> None: + """清理服务器下所有关联连接的 Token 缓存""" + from yuxi.services.mcp.connection_service import list_mcp_connections + connections = await list_mcp_connections(db, server_name=server_name) + for connection in connections: + await _clear_mcp_connection_runtime_auth_cache(getattr(connection, "id", None)) + + +def get_mcp_tools_stats(server_name: str) -> dict[str, int] | None: + return _mcp_tools_stats.get(server_name) + + +async def get_enabled_mcp_tools( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, +) -> list: + from yuxi.services.mcp.server_service import get_runtime_mcp_server_config + + token = None + if auth_context: + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + token = mcp_auth_context_var.set(auth_context) + + try: + config = await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=db, + http_client=http_client, + ) + if config is None: + logger.warning(f"MCP server '{server_name}' not found in database or disabled") + return [] + + disabled_tools = config.get("disabled_tools") or [] + return await get_mcp_tools( + server_name, + additional_servers={server_name: config}, + disabled_tools=disabled_tools, + ) + finally: + if token: + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + mcp_auth_context_var.reset(token) + + +async def get_all_mcp_tools( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: AsyncSession | None = None, + http_client: httpx.AsyncClient | None = None, + force_refresh: bool = False, +) -> list: + from yuxi.services.mcp.server_service import get_enabled_mcp_server_config, get_runtime_mcp_server_config + + token = None + if auth_context: + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + token = mcp_auth_context_var.set(auth_context) + + try: + if auth_context is None and db is None: + config = await get_enabled_mcp_server_config(server_name) + else: + config = await get_runtime_mcp_server_config( + server_name, + auth_context=auth_context, + db=db, + http_client=http_client, + ) + if config is None: + logger.warning(f"MCP server '{server_name}' not found in database or disabled") + return [] + + if not force_refresh: + cache_descriptor = await _build_mcp_tool_cache_descriptor(server_name, config) + manifest = await _mcp_tool_cache_store.get_manifest(cache_descriptor["cache_key"]) + if manifest is not None: + return _deserialize_mcp_tool_manifest(manifest) + + return await get_mcp_tools( + server_name, + additional_servers={server_name: config}, + disabled_tools=[], + cache=True, + force_refresh=force_refresh, + ) + finally: + if token: + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + mcp_auth_context_var.reset(token) + + +async def toggle_tool_enabled( + db: AsyncSession, + server_name: str, + tool_name: str, + updated_by: str | None = None, +) -> tuple[bool, MCPServer]: + """切换单个工具的启用状态""" + from yuxi.services.mcp.server_service import get_mcp_server + server = await get_mcp_server(db, server_name) + if not server: + raise ValueError(f"Server '{server_name}' does not exist") + + disabled_tools = list(server.disabled_tools or []) + + if tool_name in disabled_tools: + disabled_tools.remove(tool_name) + enabled = True + else: + disabled_tools.append(tool_name) + enabled = False + + server.disabled_tools = disabled_tools + if updated_by is not None: + server.updated_by = updated_by + await db.commit() + + # 清除内存工具缓存 + clear_mcp_server_tools_cache(server_name) + + logger.info(f"Toggled tool '{tool_name}' for server '{server_name}' enabled={enabled}") + return enabled, server + + diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py b/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py new file mode 100644 index 000000000..ca812760c --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py @@ -0,0 +1,14 @@ +from __future__ import annotations +from yuxi.services.mcp_auth.fetchers.base import ITokenFetcher, BaseTokenFetcher +from yuxi.services.mcp_auth.fetchers.http_fetcher import CustomHttpTokenFetcher, ClientCredentialsFetcher +from yuxi.services.mcp_auth.fetchers.oauth_fetcher import AuthorizationCodeFetcher +from yuxi.services.mcp_auth.fetchers.factory import TokenFetcherFactory + +__all__ = [ + "ITokenFetcher", + "BaseTokenFetcher", + "CustomHttpTokenFetcher", + "ClientCredentialsFetcher", + "AuthorizationCodeFetcher", + "TokenFetcherFactory", +] diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/base.py b/backend/package/yuxi/services/mcp_auth/fetchers/base.py new file mode 100644 index 000000000..299773c80 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/base.py @@ -0,0 +1,174 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Any +import httpx +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.template_resolver import resolve_template_value + +# 注释必须使用简体中文,符合 RULE[user_global] +# NOTE: 所有获取 Token 的具体策略需要继承 ITokenFetcher 并实现 fetch_token 方法。 + +_DEFAULT_TOKEN_RESPONSE_MAP = { + "access_token": "access_token", + "refresh_token": "refresh_token", + "expires_in": "expires_in", + "expires_at": "expires_at", + "scope": "scope", + "token_type": "token_type", +} + +def extract_path(payload: dict[str, Any], path: str) -> Any: + """从 payload 中根据点分路径提取字段值""" + current: Any = payload + for segment in path.split("."): + if isinstance(current, dict): + current = current[segment] + continue + raise KeyError(path) + return current + +async def fetch_custom_http_token( + request_config: dict[str, Any], + *, + response_map: dict[str, str] | None, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, +) -> dict[str, Any]: + """执行自定义 HTTP 请求获取 Token""" + from yuxi.services.mcp_auth.orchestrator import _normalize_token_payload + + response_map = response_map or dict(_DEFAULT_TOKEN_RESPONSE_MAP) + if http_client is None: + http_client = httpx.AsyncClient() + should_close = True + else: + should_close = False + + try: + headers = resolve_template_value( + request_config.get("headers") or {}, + context=context_payload, + secret=secret_values, + token=token_values, + access_token=token_values.get("access_token"), + ) + body = resolve_template_value( + request_config.get("body_template") or {}, + context=context_payload, + secret=secret_values, + token=token_values, + access_token=token_values.get("access_token"), + ) + body_type = request_config.get("body_type", "json") + request_kwargs: dict[str, Any] = { + "method": (request_config.get("method") or "POST").upper(), + "url": request_config["url"], + "headers": headers, + } + if body_type == "json": + request_kwargs["json"] = body + else: + request_kwargs["data"] = body + + response = await http_client.request(**request_kwargs) + response.raise_for_status() + payload = response.json() + resolved = {} + for field_name, path in response_map.items(): + try: + resolved[field_name] = extract_path(payload, path) + except KeyError: + continue + return _normalize_token_payload(resolved) + except Exception as exc: + import traceback + from yuxi.utils import logger + logger.error(f"fetch_custom_http_token failure: {exc}, traceback: {traceback.format_exc()}") + raise + finally: + if should_close: + await http_client.aclose() + + +class ITokenFetcher(ABC): + """Token 获取接口""" + + @abstractmethod + async def fetch_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + """ + 获取或刷新 Access Token + """ + pass + + +class BaseTokenFetcher(ITokenFetcher, ABC): + """带自动 Refresh 逻辑的 Token 获取基类""" + + async def fetch_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + # NOTE: 优先检查是否有可用 refresh token,并进行刷新 + token_request = auth_config.token_request or {} + refresh_request = token_request.get("refresh") + if ( + token_values + and refresh_request + and (token_values.get("refresh_token") or credential_payload.get("refresh_token")) + ): + refresh_token_values = dict(token_values) + if not refresh_token_values.get("refresh_token") and credential_payload.get("refresh_token"): + refresh_token_values["refresh_token"] = credential_payload["refresh_token"] + + refreshed = await fetch_custom_http_token( + refresh_request, + response_map=(refresh_request.get("response_map") or token_request.get("response_map")), + context_payload=context_payload, + secret_values=secret_values, + token_values=refresh_token_values, + http_client=http_client, + ) + if not refreshed.get("refresh_token") and refresh_token_values.get("refresh_token"): + refreshed["refresh_token"] = refresh_token_values["refresh_token"] + return refreshed + + # NOTE: 如果不满足刷新条件,则获取全新 Token + return await self._fetch_new_token( + auth_config, + context_payload=context_payload, + secret_values=secret_values, + credential_payload=credential_payload, + token_values=token_values, + http_client=http_client, + ) + + @abstractmethod + async def _fetch_new_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + """获取全新的 Token""" + pass diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/factory.py b/backend/package/yuxi/services/mcp_auth/fetchers/factory.py new file mode 100644 index 000000000..0b546cc61 --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/factory.py @@ -0,0 +1,19 @@ +from __future__ import annotations +from yuxi.services.mcp_auth.fetchers.base import ITokenFetcher +from yuxi.services.mcp_auth.fetchers.http_fetcher import CustomHttpTokenFetcher, ClientCredentialsFetcher +from yuxi.services.mcp_auth.fetchers.oauth_fetcher import AuthorizationCodeFetcher + + +class TokenFetcherFactory: + """TokenFetcher 工厂""" + + @staticmethod + def get_fetcher(provider: str) -> ITokenFetcher: + if provider == "custom_http_token": + return CustomHttpTokenFetcher() + elif provider == "client_credentials": + return ClientCredentialsFetcher() + elif provider == "authorization_code": + return AuthorizationCodeFetcher() + else: + raise ValueError(f"Unsupported MCP auth provider for dynamic token: {provider}") diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py b/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py new file mode 100644 index 000000000..ffc74602f --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py @@ -0,0 +1,38 @@ +from __future__ import annotations +from typing import Any +import httpx +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.fetchers.base import BaseTokenFetcher, fetch_custom_http_token + + +class CustomHttpTokenFetcher(BaseTokenFetcher): + """自定义 HTTP 方式获取 Token""" + + async def _fetch_new_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + token_request = auth_config.token_request or {} + resolved = await fetch_custom_http_token( + token_request, + response_map=token_request.get("response_map"), + context_payload=context_payload, + secret_values=secret_values, + token_values=token_values, + http_client=http_client, + ) + if not resolved.get("refresh_token") and credential_payload.get("refresh_token"): + resolved["refresh_token"] = credential_payload["refresh_token"] + return resolved + + +class ClientCredentialsFetcher(CustomHttpTokenFetcher): + """客户端凭证 (Client Credentials) 方式获取 Token""" + # NOTE: 当前其底层获取逻辑与 CustomHttpTokenFetcher 相同 + pass diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py b/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py new file mode 100644 index 000000000..00446803f --- /dev/null +++ b/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py @@ -0,0 +1,89 @@ +from __future__ import annotations +from typing import Any +import httpx +from yuxi.services.mcp_auth.config_models import MCPAuthConfig +from yuxi.services.mcp_auth.fetchers.base import ITokenFetcher, fetch_custom_http_token, _DEFAULT_TOKEN_RESPONSE_MAP + + +class AuthorizationCodeFetcher(ITokenFetcher): + """授权码 (Authorization Code) 模式下的后台 Token 刷新获取""" + + async def _resolve_token_request_config( + self, + token_request: dict[str, Any], + secret_values: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient, + ) -> tuple[dict[str, Any], dict[str, str]]: + issuer_url = ( + token_request.get("issuer_url") + or secret_values.get("issuer_url") + or token_values.get("issuer_url") + ) + if not issuer_url: + raise ValueError("authorization_code provider requires token_request.issuer_url") + discovery_url = f"{str(issuer_url).rstrip('/')}/.well-known/openid-configuration" + response = await http_client.get(discovery_url) + response.raise_for_status() + payload = response.json() + token_endpoint = payload.get("token_endpoint") + if not token_endpoint: + raise ValueError("authorization_code provider discovery missing token_endpoint") + + return { + "url": token_endpoint, + "method": "POST", + "body_type": "form", + "headers": { + "Content-Type": "application/x-www-form-urlencoded", + }, + "body_template": { + "grant_type": "refresh_token", + "refresh_token": "${token.refresh_token}", + "client_id": token_request.get("client_id", "${secret.client_id}"), + "client_secret": token_request.get("client_secret", "${secret.client_secret}"), + }, + }, dict(_DEFAULT_TOKEN_RESPONSE_MAP) + + async def fetch_token( + self, + auth_config: MCPAuthConfig, + *, + context_payload: dict[str, Any], + secret_values: dict[str, Any], + credential_payload: dict[str, Any], + token_values: dict[str, Any], + http_client: httpx.AsyncClient | None, + ) -> dict[str, Any]: + if http_client is None: + http_client = httpx.AsyncClient() + should_close = True + else: + should_close = False + + try: + token_request = auth_config.token_request or {} + authorization_request, response_map = await self._resolve_token_request_config( + token_request=token_request, + secret_values=secret_values, + token_values=token_values or credential_payload, + http_client=http_client, + ) + authorization_token_values = dict(token_values or credential_payload) + if not authorization_token_values.get("refresh_token") and credential_payload.get("refresh_token"): + authorization_token_values["refresh_token"] = credential_payload["refresh_token"] + + resolved = await fetch_custom_http_token( + authorization_request, + response_map=response_map, + context_payload=context_payload, + secret_values=secret_values, + token_values=authorization_token_values, + http_client=http_client, + ) + if not resolved.get("refresh_token") and authorization_token_values.get("refresh_token"): + resolved["refresh_token"] = authorization_token_values["refresh_token"] + return resolved + finally: + if should_close: + await http_client.aclose() diff --git a/backend/package/yuxi/services/mcp_auth/orchestrator.py b/backend/package/yuxi/services/mcp_auth/orchestrator.py index 18133fc7e..a3548f6bf 100644 --- a/backend/package/yuxi/services/mcp_auth/orchestrator.py +++ b/backend/package/yuxi/services/mcp_auth/orchestrator.py @@ -15,12 +15,19 @@ from yuxi.utils import logger +import contextvars + + @dataclass(slots=True) class AuthContext: user_id: str | None = None department_id: str | None = None +mcp_auth_context_var: contextvars.ContextVar[AuthContext | None] = contextvars.ContextVar( + "mcp_auth_context_var", default=None +) + _DEFAULT_TOKEN_RESPONSE_MAP = { "access_token": "access_token", "refresh_token": "refresh_token", @@ -138,109 +145,6 @@ def _merge_injected_entries( return config -async def _fetch_custom_http_token( - request_config: dict[str, Any], - *, - response_map: dict[str, str] | None, - context: AuthContext, - secret_values: dict[str, Any], - token_values: dict[str, Any], - http_client: httpx.AsyncClient | None, -) -> dict[str, Any]: - response_map = response_map or dict(_DEFAULT_TOKEN_RESPONSE_MAP) - if http_client is None: - http_client = httpx.AsyncClient() - should_close = True - else: - should_close = False - - try: - headers = resolve_template_value( - request_config.get("headers") or {}, - context=_context_payload(context), - secret=secret_values, - token=token_values, - access_token=token_values.get("access_token"), - ) - body = resolve_template_value( - request_config.get("body_template") or {}, - context=_context_payload(context), - secret=secret_values, - token=token_values, - access_token=token_values.get("access_token"), - ) - body_type = request_config.get("body_type", "json") - request_kwargs: dict[str, Any] = { - "method": (request_config.get("method") or "POST").upper(), - "url": request_config["url"], - "headers": headers, - } - if body_type == "json": - request_kwargs["json"] = body - else: - request_kwargs["data"] = body - - response = await http_client.request(**request_kwargs) - response.raise_for_status() - payload = response.json() - resolved = {} - for field_name, path in response_map.items(): - try: - resolved[field_name] = _extract_path(payload, path) - except KeyError: - continue - return _normalize_token_payload(resolved) - finally: - if should_close: - await http_client.aclose() - - -async def _resolve_authorization_code_token_request( - *, - token_request: dict[str, Any], - secret_values: dict[str, Any], - token_values: dict[str, Any], - http_client: httpx.AsyncClient | None, -) -> tuple[dict[str, Any], dict[str, str]]: - if http_client is None: - http_client = httpx.AsyncClient() - should_close = True - else: - should_close = False - - try: - issuer_url = ( - token_request.get("issuer_url") - or secret_values.get("issuer_url") - or token_values.get("issuer_url") - ) - if not issuer_url: - raise ValueError("authorization_code provider requires token_request.issuer_url") - discovery_url = f"{str(issuer_url).rstrip('/')}/.well-known/openid-configuration" - response = await http_client.get(discovery_url) - response.raise_for_status() - payload = response.json() - token_endpoint = payload.get("token_endpoint") - if not token_endpoint: - raise ValueError("authorization_code provider discovery missing token_endpoint") - return { - "url": token_endpoint, - "method": "POST", - "body_type": "form", - "headers": { - "Content-Type": "application/x-www-form-urlencoded", - }, - "body_template": { - "grant_type": "refresh_token", - "refresh_token": "${token.refresh_token}", - "client_id": token_request.get("client_id", "${secret.client_id}"), - "client_secret": token_request.get("client_secret", "${secret.client_secret}"), - }, - }, dict(_DEFAULT_TOKEN_RESPONSE_MAP) - finally: - if should_close: - await http_client.aclose() - async def _load_cached_token( *, @@ -340,70 +244,21 @@ async def _request_dynamic_token_values( token_cache: Any | None, token_values: dict[str, Any], ) -> dict[str, Any]: - token_request = auth_config.token_request or {} - refresh_request = token_request.get("refresh") - if ( - token_values - and refresh_request - and (token_values.get("refresh_token") or credential_payload.get("refresh_token")) - ): - refresh_token_values = dict(token_values) - if not refresh_token_values.get("refresh_token") and credential_payload.get("refresh_token"): - refresh_token_values["refresh_token"] = credential_payload["refresh_token"] - refreshed = await _fetch_custom_http_token( - refresh_request, - response_map=(refresh_request.get("response_map") or token_request.get("response_map")), - context=context, - secret_values=secret_values, - token_values=refresh_token_values, - http_client=http_client, - ) - if not refreshed.get("refresh_token") and refresh_token_values.get("refresh_token"): - refreshed["refresh_token"] = refresh_token_values["refresh_token"] - await _store_cached_token( - token_cache=token_cache, - connection_id=getattr(connection, "id", None), - token_payload=refreshed, - ) - return refreshed - - if auth_config.provider == "authorization_code": - authorization_request, response_map = await _resolve_authorization_code_token_request( - token_request=token_request, - secret_values=secret_values, - token_values=token_values or credential_payload, - http_client=http_client, - ) - authorization_token_values = dict(token_values or credential_payload) - if not authorization_token_values.get("refresh_token") and credential_payload.get("refresh_token"): - authorization_token_values["refresh_token"] = credential_payload["refresh_token"] - resolved = await _fetch_custom_http_token( - authorization_request, - response_map=response_map, - context=context, - secret_values=secret_values, - token_values=authorization_token_values, - http_client=http_client, - ) - if not resolved.get("refresh_token") and authorization_token_values.get("refresh_token"): - resolved["refresh_token"] = authorization_token_values["refresh_token"] - await _store_cached_token( - token_cache=token_cache, - connection_id=getattr(connection, "id", None), - token_payload=resolved, - ) - return resolved - - resolved = await _fetch_custom_http_token( - token_request, - response_map=token_request.get("response_map"), - context=context, + from yuxi.services.mcp_auth.fetchers.factory import TokenFetcherFactory + + fetcher = TokenFetcherFactory.get_fetcher(auth_config.provider) + resolved = await fetcher.fetch_token( + auth_config, + context_payload={ + "user_id": context.user_id, + "department_id": context.department_id, + }, secret_values=secret_values, + credential_payload=credential_payload, token_values=token_values, http_client=http_client, ) - if not resolved.get("refresh_token") and credential_payload.get("refresh_token"): - resolved["refresh_token"] = credential_payload["refresh_token"] + await _store_cached_token( token_cache=token_cache, connection_id=getattr(connection, "id", None), @@ -412,6 +267,7 @@ async def _request_dynamic_token_values( return resolved + async def _resolve_dynamic_token_values( auth_config: MCPAuthConfig, *, @@ -423,8 +279,10 @@ async def _resolve_dynamic_token_values( token_cache: Any | None, ) -> dict[str, Any]: if token_cache is None and connection is not None: + from yuxi.services.mcp_service import RedisTokenCache token_cache = RedisTokenCache() + cached_token = await _load_cached_token( token_cache=token_cache, connection_id=getattr(connection, "id", None), diff --git a/backend/package/yuxi/services/mcp_auth/redis_token_cache.py b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py index 59177c06b..9b1a2f050 100644 --- a/backend/package/yuxi/services/mcp_auth/redis_token_cache.py +++ b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py @@ -14,12 +14,24 @@ DEFAULT_LOCK_TTL_SECONDS = 30 +import uuid +_PYTEST_SESSION_TOKEN = uuid.uuid4().hex[:8] + + def _access_token_key(connection_id: int) -> str: - return f"{ACCESS_TOKEN_KEY_PREFIX}:{connection_id}" + key = f"{ACCESS_TOKEN_KEY_PREFIX}:{connection_id}" + import os + if os.environ.get("PYTEST_CURRENT_TEST"): + return f"test:{_PYTEST_SESSION_TOKEN}:{key}" + return key def _refresh_lock_key(connection_id: int) -> str: - return f"{REFRESH_LOCK_KEY_PREFIX}:{connection_id}" + key = f"{REFRESH_LOCK_KEY_PREFIX}:{connection_id}" + import os + if os.environ.get("PYTEST_CURRENT_TEST"): + return f"test:{_PYTEST_SESSION_TOKEN}:{key}" + return key def _compute_token_ttl_seconds(token_payload: dict[str, Any]) -> int: diff --git a/backend/package/yuxi/services/mcp_service.py b/backend/package/yuxi/services/mcp_service.py index 43cf4d3d1..1d72c377a 100644 --- a/backend/package/yuxi/services/mcp_service.py +++ b/backend/package/yuxi/services/mcp_service.py @@ -1,1296 +1,111 @@ -"""MCP Service - Unified business logic and state management for MCP. +"""MCP Service - Facade 适配门面,保持项目的完全向下兼容。 -Responsibilities: -- Server configuration CRUD operations -- Built-in configuration synchronization (Code <-> Database) -- Unified entry point for Agent tool retrieval (auto-filtering disabled_tools) -- MCP Client and Tools management (formerly in agents/common/mcp.py) +所有的实质逻辑均已根据职责分工拆分重构至以下子服务模块中: +- yuxi.services.mcp.server_service +- yuxi.services.mcp.connection_service +- yuxi.services.mcp.tool_registry_service +- yuxi.services.mcp.client_pool """ +from __future__ import annotations import asyncio -import hashlib -import httpx -import json -import os -import re -import traceback -from collections.abc import Callable -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any, cast +from typing import Any + +from yuxi.services.mcp.server_service import ( + ensure_builtin_mcp_servers_in_db, + get_enabled_mcp_server_names, + get_mcp_server, + get_all_mcp_servers, + create_mcp_server, + update_mcp_server, + delete_mcp_server, + get_mcp_server_dependency_summary, + set_server_enabled, + get_servers_config, +) +from yuxi.services.mcp.connection_service import ( + get_mcp_connection, + list_mcp_connections, + create_mcp_connection, + update_mcp_connection, + delete_mcp_connection, + set_mcp_connection_status, + reauthorize_mcp_connection, + test_mcp_connection, + _resolve_scope_id, +) +from yuxi.services.mcp.tool_registry_service import ( + to_camel_case, + get_tools_from_all_servers, + clear_mcp_cache, + clear_mcp_server_tools_cache, + clear_mcp_connection_tools_cache, + invalidate_mcp_server_tools_cache, + invalidate_mcp_connection_tools_cache, + get_mcp_tools_stats, + get_enabled_mcp_tools, + get_all_mcp_tools, + toggle_tool_enabled, +) +# 兼容原导入以防万一 +from yuxi.services.mcp_auth.orchestrator import AuthContext from langchain_mcp_adapters.client import MultiServerMCPClient -from sqlalchemy import func, select -from sqlalchemy.ext.asyncio import AsyncSession -from yuxi.services.mcp_auth.config_models import MCPAuthConfig -from yuxi.services.mcp_auth.crypto import encrypt_credential_blob -from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config -from yuxi.services.mcp_auth.proxy_service import ( - INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, - INTERNAL_PROXY_TOKEN_HEADER, - build_proxy_runtime_config, - should_use_internal_proxy, -) from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache from yuxi.services.mcp_tool_cache import RedisMcpToolCache -from yuxi.storage.postgres.models_business import AgentConfig, MCPConnection, MCPServer, Skill from yuxi.utils import logger -# ============================================================================= -# === Global Cache & State === -# ============================================================================= - -# Global Lock for MCP state -_mcp_lock = asyncio.Lock() - -# 本地仅缓存工具对象。配置始终以数据库为准,每次按 server_name 现查。 -# cache key 使用 server_name:config_hash,当配置变化时会自然失效。 -_mcp_tools_cache: dict[str, list[Callable[..., Any]]] = {} +# ----------------------------------------------------------------------------- +# --- 共享状态与依赖(提供给外部/子服务使用,并对单元测试 Mock 100% 兼容) --- +# ----------------------------------------------------------------------------- +_mcp_tools_cache = {} _mcp_tool_cache_store = RedisMcpToolCache() +_mcp_tools_stats = {} +_mcp_lock = asyncio.Lock() -# MCP tools statistics (for reporting enabled/disabled counts) -_mcp_tools_stats: dict[str, dict[str, int]] = {} -_UNSET = object() -_VALID_MCP_CONNECTION_SCOPE_TYPES = {"system", "department", "user"} -_VALID_MCP_CONNECTION_STATUSES = {"active", "disabled", "reauth_required", "invalid"} - -# Default MCP Server configurations (Imported to DB on first run) -_DEFAULT_MCP_SERVERS = { - "sequentialthinking": { - "url": "https://remote.mcpservers.org/sequentialthinking/mcp", - "transport": "streamable_http", - "description": "顺序思考工具,帮助 AI 将复杂问题分解为多个步骤", - "icon": "🧠", - "tags": ["内置", "AI"], - }, - "mcp-server-chart": { - "command": "npx", - "args": ["-y", "@antv/mcp-server-chart"], - "transport": "stdio", - "description": "图表生成工具,支持生成各类图表(柱状图、折线图、饼图等)", - "icon": "📊", - "tags": ["内置", "图表"], - }, -} - -_SYNCED_MCP_FIELDS = ( - "description", - "transport", - "url", - "command", - "args", - "env", - "headers", - "timeout", - "sse_read_timeout", - "tags", - "icon", -) - -_MCP_PROXY_BASE_URL_ENV = "YUXI_INTERNAL_MCP_PROXY_BASE_URL" -# ============================================================================= -# === Core Logic (Moved from agents/common/mcp.py) === -# ============================================================================= +# ----------------------------------------------------------------------------- +# --- 兼容性转发入口(支持被测试 monkeypatch.setattr 覆盖) --- +# ----------------------------------------------------------------------------- +async def get_enabled_mcp_server_config(server_name: str, *, db: Any = None) -> dict[str, Any] | None: + from yuxi.services.mcp.server_service import get_enabled_mcp_server_config as _get_cfg + return await _get_cfg(server_name, db=db) -async def ensure_builtin_mcp_servers_in_db() -> None: - """Ensure built-in MCP server definitions exist in the database. +async def get_runtime_mcp_server_config( + server_name: str, + *, + auth_context: AuthContext | None = None, + db: Any = None, + http_client: Any = None, +) -> dict[str, Any] | None: + from yuxi.services.mcp.server_service import get_runtime_mcp_server_config as _get_run_cfg + return await _get_run_cfg(server_name, auth_context=auth_context, db=db, http_client=http_client) - This function only synchronizes code-defined built-ins to the database. - It does not preload runtime configuration into memory. - """ - # Delayed import to avoid circular references - from yuxi.storage.postgres.manager import pg_manager - try: - async with pg_manager.get_async_session_context() as session: - # Check if database has MCP configurations - result = await session.execute(select(func.count(MCPServer.name))) - count = result.scalar() +async def _load_enabled_mcp_server_configs(names: list[str] | None = None) -> dict[str, dict[str, Any]]: + from yuxi.services.mcp.server_service import _load_enabled_mcp_server_configs as _load_cfg + return await _load_cfg(names=names) - if count == 0: - # Database is empty, import default configurations - logger.info("No MCP servers in database, importing default configurations...") - for name, config in _DEFAULT_MCP_SERVERS.items(): - server = MCPServer( - name=name, - description=config.get("description"), - transport=config["transport"], - url=config.get("url"), - command=config.get("command"), - args=config.get("args"), - env=config.get("env"), - headers=config.get("headers"), - timeout=config.get("timeout"), - sse_read_timeout=config.get("sse_read_timeout"), - tags=config.get("tags"), - icon=config.get("icon"), - enabled=0, - created_by="system", - updated_by="system", - ) - session.add(server) - await session.commit() - logger.info(f"Imported {len(_DEFAULT_MCP_SERVERS)} default MCP servers to database") - else: - # Ensure all built-in MCP servers exist in database - for name, config in _DEFAULT_MCP_SERVERS.items(): - result = await session.execute(select(MCPServer).filter(MCPServer.name == name)) - existing = result.scalar_one_or_none() - if not existing: - server = MCPServer( - name=name, - description=config.get("description"), - transport=config["transport"], - url=config.get("url"), - command=config.get("command"), - args=config.get("args"), - env=config.get("env"), - headers=config.get("headers"), - timeout=config.get("timeout"), - sse_read_timeout=config.get("sse_read_timeout"), - tags=config.get("tags"), - icon=config.get("icon"), - enabled=0, - created_by="system", - updated_by="system", - ) - session.add(server) - logger.info(f"Added built-in MCP server '{name}' to database") - else: - changed = False - for field in _SYNCED_MCP_FIELDS: - next_value = config.get(field) - if getattr(existing, field) != next_value: - setattr(existing, field, next_value) - changed = True - if changed: - existing.updated_by = "system" - # Commit if any new servers were added (check session state) - if session.new: - await session.commit() - elif session.dirty: - await session.commit() - except Exception as e: - logger.error(f"Failed to ensure builtin MCP servers in database: {e}, traceback: {traceback.format_exc()}") +async def get_mcp_tools(server_name: str, **kwargs: Any) -> list[Any]: + from yuxi.services.mcp.tool_registry_service import get_mcp_tools as _get_t + return await _get_t(server_name, **kwargs) async def get_mcp_client( server_configs: dict[str, Any] | None = None, ) -> MultiServerMCPClient | None: - """Initializes an MCP client with the given server configurations.""" + """初始化并拉起 MCP 客户端。保留该底层入口以确保单元测试中的 monkeypatch 拦截顺畅传导。""" try: client = MultiServerMCPClient(server_configs) # pyright: ignore[reportArgumentType] logger.info(f"Initialized MCP client with servers: {list(server_configs.keys())}") return client except Exception as e: - logger.error("Failed to initialize MCP client: {}", e) - return None - - -def to_camel_case(s: str) -> str: - """Convert string to lowerCamelCase.""" - - # Handle - and _ - s = re.sub(r"[-_]+(.)", lambda m: m.group(1).upper(), s) - # Lowercase first letter - if len(s) > 0: - s = s[0].lower() + s[1:] - return s - - -async def _load_enabled_mcp_server_configs( - *, - names: list[str] | None = None, - db: AsyncSession | None = None, -) -> dict[str, dict[str, Any]]: - """Load enabled MCP server configs directly from the database.""" - if db is not None: - stmt = select(MCPServer).where(MCPServer.enabled == 1) - if names: - stmt = stmt.where(MCPServer.name.in_(names)) - result = await db.execute(stmt) - servers = result.scalars().all() - return {server.name: server.to_mcp_config() for server in servers} - - from yuxi.storage.postgres.manager import pg_manager - - async with pg_manager.get_async_session_context() as session: - return await _load_enabled_mcp_server_configs(names=names, db=session) - - -async def get_enabled_mcp_server_config(server_name: str, *, db: AsyncSession | None = None) -> dict[str, Any] | None: - """Get the latest enabled MCP server config from the database.""" - configs = await _load_enabled_mcp_server_configs(names=[server_name], db=db) - return configs.get(server_name) - - -def _get_internal_mcp_proxy_base_url() -> str | None: - value = os.getenv(_MCP_PROXY_BASE_URL_ENV, "").strip() - return value or None - - -def _extract_cache_identity(server_config: dict[str, Any]) -> tuple[dict[str, Any], str, bool]: - cache_partition = str(server_config.get("__yuxi_cache_partition") or "server") - allow_global_cache = bool(server_config.get("__yuxi_allow_global_cache", True)) - cache_identity = { - key: value - for key, value in server_config.items() - if key - not in { - "__yuxi_cache_partition", - "__yuxi_allow_global_cache", - INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, - "disabled_tools", - } - } - headers = dict(cache_identity.get("headers") or {}) - headers.pop(INTERNAL_PROXY_TOKEN_HEADER, None) - if headers: - cache_identity["headers"] = headers - elif "headers" in cache_identity: - cache_identity["headers"] = {} - return cache_identity, cache_partition, allow_global_cache - - -async def _build_mcp_tool_cache_descriptor(server_name: str, server_config: dict[str, Any]) -> dict[str, Any]: - cache_identity, cache_partition, allow_global_cache = _extract_cache_identity(server_config) - config_payload = json.dumps(cache_identity, sort_keys=True, ensure_ascii=True, separators=(",", ":")) - config_hash = hashlib.sha256(config_payload.encode("utf-8")).hexdigest()[:16] - server_revision = await _mcp_tool_cache_store.get_server_revision(server_name) - partition_revision = 0 - if not allow_global_cache: - partition_revision = await _mcp_tool_cache_store.get_partition_revision(server_name, cache_partition) - revision_token = f"s{server_revision}:p{partition_revision}" - cache_prefix = f"{server_name}:{cache_partition}:{revision_token}:" - return { - "cache_identity": cache_identity, - "cache_partition": cache_partition, - "allow_global_cache": allow_global_cache, - "config_hash": config_hash, - "cache_prefix": cache_prefix, - "cache_key": f"{cache_prefix}{config_hash}", - "server_revision": server_revision, - "partition_revision": partition_revision, - } - - -def _serialize_mcp_tools_manifest( - *, - server_name: str, - cache_partition: str, - cache_key: str, - tools: list[Callable[..., Any]], -) -> dict[str, Any]: - entries = [] - for tool in tools: - if hasattr(tool, "args_schema") and tool.args_schema: - schema = tool.args_schema.schema() if hasattr(tool.args_schema, "schema") else {} - parameters = schema.get("properties", {}) - required = schema.get("required", []) - else: - parameters = {} - required = [] - metadata = dict(getattr(tool, "metadata", {}) or {}) - entries.append( - { - "name": tool.name, - "id": metadata.get("id") or tool.name, - "description": getattr(tool, "description", ""), - "parameters": parameters, - "required": required, - } - ) - return { - "server_name": server_name, - "cache_partition": cache_partition, - "cache_key": cache_key, - "tools": entries, - } - - -def _deserialize_mcp_tool_manifest(manifest: dict[str, Any]) -> list[Callable[..., Any]]: - tools: list[Callable[..., Any]] = [] - for entry in manifest.get("tools", []): - args_schema = None - parameters = entry.get("parameters") or {} - required = entry.get("required") or [] - if parameters or required: - args_schema = SimpleNamespace( - schema=lambda parameters=parameters, required=required: { - "properties": parameters, - "required": required, - } - ) - tools.append( - SimpleNamespace( - name=entry.get("name") or "", - description=entry.get("description") or "", - metadata={"id": entry.get("id") or entry.get("name") or ""}, - args_schema=args_schema, - ) - ) - return tools - - -def _resolve_runtime_tool_cache_partition( - *, - auth_config: MCPAuthConfig, - auth_context: AuthContext | None, - connection: MCPConnection | None, -) -> tuple[str, bool]: - if auth_config.binding_scope in {"department", "user"}: - connection_id = getattr(connection, "id", None) - if connection_id is not None: - return f"connection:{connection_id}", False - scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) - if scope_id is None: - raise ValueError(f"auth_context is required for MCP binding scope '{auth_config.binding_scope}'") - return f"{auth_config.binding_scope}:{scope_id}", False - return "server", True - - -def _apply_runtime_tool_cache_policy( - config: dict[str, Any], - *, - auth_config: MCPAuthConfig, - auth_context: AuthContext | None, - connection: MCPConnection | None, -) -> dict[str, Any]: - partition, allow_global_cache = _resolve_runtime_tool_cache_partition( - auth_config=auth_config, - auth_context=auth_context, - connection=connection, - ) - config["__yuxi_cache_partition"] = partition - config["__yuxi_allow_global_cache"] = allow_global_cache - return config - - -def _get_mcp_auth_config(server_config: dict[str, Any]) -> MCPAuthConfig | None: - auth_payload = server_config.get("auth_config") or {} - if not auth_payload: - return None - try: - return MCPAuthConfig.model_validate(auth_payload) - except Exception as exc: - logger.warning(f"Invalid MCP auth config while resolving tool preload strategy: {exc}") - return None - - -def _can_preload_mcp_server_tools_without_runtime_auth(server_config: dict[str, Any]) -> bool: - if not (server_config.get("auth_config") or {}): - return True - auth_config = _get_mcp_auth_config(server_config) - if auth_config is None: - return False - return auth_config.provider == "legacy_static" - - -def _resolve_scope_id(binding_scope: str, auth_context: AuthContext | None) -> str | None: - if binding_scope == "inline": + logger.error(f"Failed to initialize MCP client: {e}") return None - if binding_scope == "system": - return "global" - if auth_context is None: - raise ValueError(f"auth_context is required for MCP binding scope '{binding_scope}'") - if binding_scope == "department": - if not auth_context.department_id: - raise ValueError("department_id is required for department-scoped MCP auth") - return str(auth_context.department_id) - if binding_scope == "user": - if not auth_context.user_id: - raise ValueError("user_id is required for user-scoped MCP auth") - return str(auth_context.user_id) - raise ValueError(f"Unsupported MCP binding scope: {binding_scope}") - - -def _normalize_mcp_connection_scope(scope_type: str, scope_id: str | None) -> tuple[str, str]: - normalized_scope_type = str(scope_type or "").strip().lower() - if normalized_scope_type not in _VALID_MCP_CONNECTION_SCOPE_TYPES: - raise ValueError("scope_type must be one of: system, department, user") - - normalized_scope_id = str(scope_id or "").strip() - if normalized_scope_type == "system": - return normalized_scope_type, "global" - if not normalized_scope_id: - raise ValueError(f"scope_id is required for {normalized_scope_type}-scoped MCP connections") - return normalized_scope_type, normalized_scope_id - - -def _normalize_mcp_connection_status(status: str) -> str: - normalized_status = str(status or "").strip().lower() - if normalized_status not in _VALID_MCP_CONNECTION_STATUSES: - raise ValueError("status must be one of: active, disabled, reauth_required, invalid") - return normalized_status - - -async def _get_enabled_mcp_server_record(server_name: str, *, db: AsyncSession) -> MCPServer | None: - result = await db.execute( - select(MCPServer).where( - MCPServer.enabled == 1, - MCPServer.name == server_name, - ) - ) - return result.scalar_one_or_none() - - -async def get_runtime_mcp_server_config( - server_name: str, - *, - auth_context: AuthContext | None = None, - db: AsyncSession | None = None, - http_client: httpx.AsyncClient | None = None, -) -> dict[str, Any] | None: - if db is None and auth_context is None: - return await get_enabled_mcp_server_config(server_name) - - if db is not None: - server = await _get_enabled_mcp_server_record(server_name, db=db) - if server is None: - return None - if not server.auth_config_json: - return server.to_mcp_config() - - auth_config = MCPAuthConfig.model_validate(server.auth_config_json) - scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) - if scope_id is None: - return server.to_mcp_config() - - result = await db.execute( - select(MCPConnection).where( - MCPConnection.server_name == server_name, - MCPConnection.scope_type == auth_config.binding_scope, - MCPConnection.scope_id == scope_id, - MCPConnection.status == "active", - ) - ) - connection = result.scalar_one_or_none() - if connection is None: - raise ValueError( - f"Active MCP connection not found for server '{server_name}' and scope " - f"{auth_config.binding_scope}:{scope_id}" - ) - proxy_base_url = _get_internal_mcp_proxy_base_url() - if should_use_internal_proxy(server, auth_config, proxy_base_url): - config = build_proxy_runtime_config( - server, - auth_context=auth_context or AuthContext(), - proxy_base_url=proxy_base_url or "", - ) - else: - config = await resolve_runtime_mcp_config( - server, - auth_context=auth_context or AuthContext(), - connection=connection, - http_client=http_client, - ) - return _apply_runtime_tool_cache_policy( - config, - auth_config=auth_config, - auth_context=auth_context, - connection=connection, - ) - - from yuxi.storage.postgres.manager import pg_manager - - async with pg_manager.get_async_session_context() as session: - return await get_runtime_mcp_server_config( - server_name, - auth_context=auth_context, - db=session, - http_client=http_client, - ) - - -async def get_enabled_mcp_server_names(*, db: AsyncSession | None = None) -> list[str]: - """Get enabled MCP server names from the database.""" - configs = await _load_enabled_mcp_server_configs(db=db) - return list(configs.keys()) - - -async def get_mcp_tools( - server_name: str, - additional_servers: dict[str, dict[str, Any]] | None = None, - disabled_tools: list[str] = None, - cache: bool = True, - force_refresh: bool = False, -) -> list[Callable[..., Any]]: - """Get MCP tools for a specific server. - - Architecture: - 1. Fetching: Connects to MCP server to get ALL tools. - 2. Caching: Stores the FULL, UNFILTERED list of tools in `_mcp_tools_cache`. - 3. Filtering: Filters the return value based on `disabled_tools` argument. - - Args: - server_name: Server name - additional_servers: Additional server configurations - disabled_tools: List of tool names to filter out from the RETURN value (does not affect cache) - cache: Whether to use/update the cache (default: True) - force_refresh: Whether to force a refresh from the server (default: False) - """ - if additional_servers and server_name in additional_servers: - server_config = additional_servers[server_name] - else: - server_config = await get_enabled_mcp_server_config(server_name) - - if server_config is None: - logger.warning(f"MCP server '{server_name}' not found in database or disabled") - return [] - - cache_descriptor = await _build_mcp_tool_cache_descriptor(server_name, server_config) - cache_partition = cache_descriptor["cache_partition"] - cache_prefix = cache_descriptor["cache_prefix"] - cache_key = cache_descriptor["cache_key"] - use_tool_object_cache = cache and not bool(server_config.get(INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY)) - - all_processed_tools: list[Callable[..., Any]] = [] - - async with _mcp_lock: - if not force_refresh and use_tool_object_cache and cache_key in _mcp_tools_cache: - all_processed_tools = _mcp_tools_cache[cache_key] - - if not all_processed_tools: - try: - # disabled_tools 只影响返回值过滤,不参与 MCP client 建连参数。 - client_config = { - k: v - for k, v in server_config.items() - if k - not in ( - "disabled_tools", - "__yuxi_cache_partition", - "__yuxi_allow_global_cache", - INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, - ) - } - - client = await get_mcp_client({server_name: client_config}) - if client is None: - return [] - - raw_tools = cast(list[Any], await client.get_tools()) - - server_cc = to_camel_case(server_name) - for tool in raw_tools: - original_name = tool.name - tool_cc = to_camel_case(original_name) - unique_id = f"mcp__{server_cc}__{tool_cc}" - - if tool.metadata is None: - tool.metadata = {} - tool.metadata["id"] = unique_id - # 开启错误处理,防止工具调用抛出 ToolException 时击穿服务 - tool.handle_tool_error = True - all_processed_tools.append(tool) - - if cache: - if use_tool_object_cache: - async with _mcp_lock: - stale_keys = [ - key for key in _mcp_tools_cache if key.startswith(cache_prefix) and key != cache_key - ] - for stale_key in stale_keys: - _mcp_tools_cache.pop(stale_key, None) - _mcp_tools_cache[cache_key] = all_processed_tools - await _mcp_tool_cache_store.set_manifest( - cache_key, - _serialize_mcp_tools_manifest( - server_name=server_name, - cache_partition=cache_partition, - cache_key=cache_key, - tools=all_processed_tools, - ), - ) - - global_config_disabled = server_config.get("disabled_tools") or [] - enabled_count = len([t for t in all_processed_tools if t.name not in global_config_disabled]) - _mcp_tools_stats[server_name] = { - "total": len(all_processed_tools), - "enabled": enabled_count, - "disabled": len(all_processed_tools) - enabled_count, - } - - logger.info( - f"Refreshed MCP tools cache for '{server_name}' with key '{cache_key}': " - f"{len(all_processed_tools)} tools loaded." - ) - - except Exception as e: - logger.error( - f"Failed to load tools from MCP server '{server_name}': {e}, traceback: {traceback.format_exc()}" - ) - return [] - - # 3. Filtering (Apply to Return Value Only) - if disabled_tools: - filtered_tools = [t for t in all_processed_tools if t.name not in disabled_tools] - logger.debug( - f"Returning {len(filtered_tools)}/{len(all_processed_tools)} tools for '{server_name}' " - f"(filtered {len(disabled_tools)} by argument)" - ) - return filtered_tools - - return all_processed_tools - - -async def get_tools_from_all_servers() -> list[Callable[..., Any]]: - """Get all tools from all configured MCP servers.""" - server_configs = await _load_enabled_mcp_server_configs() - all_tools = [] - for server_name, server_config in server_configs.items(): - if not _can_preload_mcp_server_tools_without_runtime_auth(server_config): - logger.info(f"Skip MCP tool preload for '{server_name}' because runtime auth context is required") - continue - tools = await get_mcp_tools(server_name, additional_servers={server_name: server_config}) - all_tools.extend(tools) - return all_tools - - -def clear_mcp_cache() -> None: - """Clear the MCP tools cache (useful for testing).""" - global _mcp_tools_cache, _mcp_tools_stats - _mcp_tools_cache = {} - _mcp_tools_stats = {} - - -def clear_mcp_server_tools_cache(server_name: str) -> None: - """Clear the tools cache for a specific MCP server.""" - global _mcp_tools_cache, _mcp_tools_stats - server_prefix = f"{server_name}:" - stale_keys = [key for key in _mcp_tools_cache if key.startswith(server_prefix)] - for stale_key in stale_keys: - _mcp_tools_cache.pop(stale_key, None) - _mcp_tools_stats.pop(server_name, None) - logger.info(f"Cleared tools cache for MCP server '{server_name}'") - - -def clear_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: - if connection_id is None: - return - global _mcp_tools_cache - cache_prefix = f"{server_name}:connection:{connection_id}:" - stale_keys = [key for key in _mcp_tools_cache if key.startswith(cache_prefix)] - for stale_key in stale_keys: - _mcp_tools_cache.pop(stale_key, None) - if stale_keys: - logger.info(f"Cleared tools cache for MCP connection {connection_id} on server '{server_name}'") - - -async def invalidate_mcp_server_tools_cache(server_name: str) -> None: - clear_mcp_server_tools_cache(server_name) - await _mcp_tool_cache_store.bump_server_revision(server_name) - - -async def invalidate_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: - clear_mcp_connection_tools_cache(server_name, connection_id) - if connection_id is None: - return - await _mcp_tool_cache_store.bump_partition_revision(server_name, f"connection:{connection_id}") - - -async def _invalidate_mcp_tools_cache_for_connection(connection: MCPConnection) -> None: - if connection.scope_type == "system": - await invalidate_mcp_server_tools_cache(connection.server_name) - return - await invalidate_mcp_connection_tools_cache(connection.server_name, getattr(connection, "id", None)) - - -async def _clear_mcp_connection_runtime_auth_cache(connection_id: int | None) -> None: - if connection_id is None: - return - - cache = RedisTokenCache() - try: - await cache.delete_access_token(connection_id) - except Exception as exc: - logger.warning(f"Failed to clear MCP token cache for connection {connection_id}: {exc}") - try: - await cache.release_refresh_lock(connection_id) - except Exception as exc: - logger.warning(f"Failed to clear MCP refresh lock for connection {connection_id}: {exc}") - - -async def _clear_mcp_server_runtime_auth_cache(db: AsyncSession, server_name: str) -> None: - connections = await list_mcp_connections(db, server_name=server_name) - for connection in connections: - await _clear_mcp_connection_runtime_auth_cache(getattr(connection, "id", None)) - - -def get_mcp_tools_stats(server_name: str) -> dict[str, int] | None: - """Get tools statistics for a MCP server. - - Returns: - dict with 'total', 'enabled', 'disabled' counts, or None if not available - """ - return _mcp_tools_stats.get(server_name) - - -# ============================================================================= -# === Server Config CRUD (Existing in mcp_service.py) === -# ============================================================================= - - -async def get_mcp_server(db: AsyncSession, name: str) -> MCPServer | None: - """Get single server configuration.""" - result = await db.execute(select(MCPServer).filter(MCPServer.name == name)) - return result.scalar_one_or_none() - - -async def get_all_mcp_servers(db: AsyncSession) -> list[MCPServer]: - """Get all server configurations.""" - result = await db.execute(select(MCPServer)) - return list(result.scalars().all()) - - -async def get_mcp_connection(db: AsyncSession, connection_id: int) -> MCPConnection | None: - result = await db.execute(select(MCPConnection).where(MCPConnection.id == connection_id)) - return result.scalar_one_or_none() - - -def _auth_context_from_connection(connection: MCPConnection) -> AuthContext: - if connection.scope_type == "department": - return AuthContext(department_id=connection.scope_id) - if connection.scope_type == "user": - return AuthContext(user_id=connection.scope_id) - return AuthContext() - - -async def list_mcp_connections( - db: AsyncSession, - *, - server_name: str | None = None, - scope_type: str | None = None, - scope_id: str | None = None, -) -> list[MCPConnection]: - stmt = select(MCPConnection) - if server_name is not None: - stmt = stmt.where(MCPConnection.server_name == server_name) - if scope_type is not None: - stmt = stmt.where(MCPConnection.scope_type == scope_type) - if scope_id is not None: - stmt = stmt.where(MCPConnection.scope_id == scope_id) - stmt = stmt.order_by(MCPConnection.id.asc()) - result = await db.execute(stmt) - return list(result.scalars().all()) - - -async def create_mcp_connection( - db: AsyncSession, - *, - server_name: str, - scope_type: str, - scope_id: str, - display_name: str | None = None, - external_subject: str | None = None, - status: str = "active", - credential_blob: str | None = None, - meta_json: dict[str, Any] | None = None, - created_by: str | None = None, -) -> MCPConnection: - server = await get_mcp_server(db, server_name) - if server is None: - raise ValueError(f"Server '{server_name}' does not exist") - normalized_scope_type, normalized_scope_id = _normalize_mcp_connection_scope(scope_type, scope_id) - normalized_status = _normalize_mcp_connection_status(status) - - encrypted_credential_blob = ( - encrypt_credential_blob(credential_blob) - if isinstance(credential_blob, str) and credential_blob.strip() - else credential_blob - ) - - connection = MCPConnection( - server_name=server_name, - scope_type=normalized_scope_type, - scope_id=normalized_scope_id, - display_name=display_name, - external_subject=external_subject, - status=normalized_status, - credential_blob=encrypted_credential_blob, - meta_json=meta_json or {}, - created_by=created_by, - updated_by=created_by, - ) - db.add(connection) - await db.commit() - await db.refresh(connection) - return connection - - -async def update_mcp_connection( - db: AsyncSession, - connection_id: int, - *, - display_name: str | None = None, - external_subject: str | None = None, - credential_blob: Any = _UNSET, - meta_json: dict[str, Any] | None = None, - status: str | None = None, - updated_by: str | None = None, -) -> MCPConnection: - connection = await get_mcp_connection(db, connection_id) - if connection is None: - raise ValueError(f"MCP connection '{connection_id}' does not exist") - - should_clear_runtime_auth_cache = False - if display_name is not None: - connection.display_name = display_name - if external_subject is not None: - connection.external_subject = external_subject - if credential_blob is not _UNSET: - if isinstance(credential_blob, str) and credential_blob.strip(): - connection.credential_blob = encrypt_credential_blob(credential_blob) - else: - connection.credential_blob = credential_blob - should_clear_runtime_auth_cache = True - if meta_json is not None: - connection.meta_json = meta_json - if status is not None: - connection.status = _normalize_mcp_connection_status(status) - should_clear_runtime_auth_cache = True - if updated_by is not None: - connection.updated_by = updated_by - - await db.commit() - await db.refresh(connection) - if should_clear_runtime_auth_cache: - await _clear_mcp_connection_runtime_auth_cache(connection.id) - await _invalidate_mcp_tools_cache_for_connection(connection) - return connection - - -async def delete_mcp_connection(db: AsyncSession, connection_id: int) -> bool: - connection = await get_mcp_connection(db, connection_id) - if connection is None: - return False - deleted_connection_id = connection.id - deleted_server_name = connection.server_name - deleted_scope_type = connection.scope_type - await db.delete(connection) - await db.commit() - await _clear_mcp_connection_runtime_auth_cache(deleted_connection_id) - if deleted_scope_type == "system": - await invalidate_mcp_server_tools_cache(deleted_server_name) - else: - await invalidate_mcp_connection_tools_cache(deleted_server_name, deleted_connection_id) - return True - - -async def set_mcp_connection_status( - db: AsyncSession, - connection_id: int, - *, - status: str, - updated_by: str | None = None, -) -> MCPConnection: - connection = await get_mcp_connection(db, connection_id) - if connection is None: - raise ValueError(f"MCP connection '{connection_id}' does not exist") - - connection.status = _normalize_mcp_connection_status(status) - if updated_by is not None: - connection.updated_by = updated_by - await db.commit() - await db.refresh(connection) - await _clear_mcp_connection_runtime_auth_cache(connection.id) - await _invalidate_mcp_tools_cache_for_connection(connection) - return connection - - -async def reauthorize_mcp_connection( - db: AsyncSession, - connection_id: int, - *, - updated_by: str | None = None, -) -> MCPConnection: - connection = await get_mcp_connection(db, connection_id) - if connection is None: - raise ValueError(f"MCP connection '{connection_id}' does not exist") - - cache = RedisTokenCache() - if getattr(connection, "id", None) is not None: - try: - await cache.delete_access_token(connection.id) - except Exception as exc: - logger.warning(f"Failed to clear MCP token cache for connection {connection.id}: {exc}") - try: - await cache.release_refresh_lock(connection.id) - except Exception as exc: - logger.warning(f"Failed to clear MCP refresh lock for connection {connection.id}: {exc}") - await _invalidate_mcp_tools_cache_for_connection(connection) - - connection.status = "active" - meta_json = dict(connection.meta_json or {}) - meta_json.pop("last_error", None) - connection.meta_json = meta_json - if updated_by is not None: - connection.updated_by = updated_by - await db.commit() - await db.refresh(connection) - return connection - - -async def test_mcp_connection( - db: AsyncSession, - connection_id: int, - *, - updated_by: str | None = None, -) -> dict[str, Any]: - connection = await get_mcp_connection(db, connection_id) - if connection is None: - raise ValueError(f"MCP connection '{connection_id}' does not exist") - - server = await get_mcp_server(db, connection.server_name) - if server is None: - raise ValueError(f"Server '{connection.server_name}' does not exist") - - auth_context = _auth_context_from_connection(connection) - config = await get_runtime_mcp_server_config(server.name, auth_context=auth_context, db=db) - if config is None: - raise ValueError(f"MCP server '{server.name}' runtime config unavailable") - - tools = await get_mcp_tools( - server.name, - additional_servers={server.name: config}, - disabled_tools=[], - cache=False, - force_refresh=True, - ) - - meta_json = dict(connection.meta_json or {}) - meta_json["last_success_at"] = datetime.now(tz=UTC).isoformat() - meta_json.pop("last_error", None) - connection.meta_json = meta_json - connection.status = "active" - if updated_by is not None: - connection.updated_by = updated_by - await db.commit() - await db.refresh(connection) - return {"tool_count": len(tools), "connection": connection} - - -async def create_mcp_server( - db: AsyncSession, - name: str, - transport: str, - url: str = None, - command: str = None, - args: list = None, - env: dict = None, - description: str = None, - headers: dict = None, - timeout: int = None, - sse_read_timeout: int = None, - tags: list = None, - icon: str = None, - auth_config: dict | None = None, - created_by: str = None, -) -> MCPServer: - """Create server.""" - # Check if name exists - existing = await get_mcp_server(db, name) - if existing: - raise ValueError(f"Server name '{name}' already exists") - - server = MCPServer( - name=name, - description=description, - transport=transport, - url=url, - command=command, - args=args, - env=env, - headers=headers, - auth_config_json=auth_config, - timeout=timeout, - sse_read_timeout=sse_read_timeout, - tags=tags, - icon=icon, - enabled=1, - created_by=created_by, - updated_by=created_by, - ) - db.add(server) - await db.commit() - await db.refresh(server) - - await _clear_mcp_server_runtime_auth_cache(db, name) - await invalidate_mcp_server_tools_cache(name) - - logger.info(f"Created MCP server '{name}'") - return server - - -async def update_mcp_server( - db: AsyncSession, - name: str, - description: str = None, - transport: str = None, - url: str = None, - command: str = None, - args: list = None, - env: Any = _UNSET, - headers: dict = None, - timeout: int = None, - sse_read_timeout: int = None, - tags: list = None, - icon: str = None, - auth_config: Any = _UNSET, - updated_by: str = None, -) -> MCPServer: - """Update server configuration.""" - server = await get_mcp_server(db, name) - if not server: - raise ValueError(f"Server '{name}' does not exist") - - if description is not None: - server.description = description - if transport is not None: - server.transport = transport - if url is not None: - server.url = url - if command is not None: - server.command = command - if args is not None: - server.args = args - if env is not _UNSET: - server.env = env - if headers is not None: - server.headers = headers - if auth_config is not _UNSET: - server.auth_config_json = auth_config - if timeout is not None: - server.timeout = timeout - if sse_read_timeout is not None: - server.sse_read_timeout = sse_read_timeout - if tags is not None: - server.tags = tags - if icon is not None: - server.icon = icon - if updated_by is not None: - server.updated_by = updated_by - - await db.commit() - await db.refresh(server) - - if auth_config is not _UNSET: - await _clear_mcp_server_runtime_auth_cache(db, name) - await invalidate_mcp_server_tools_cache(name) - - logger.info(f"Updated MCP server '{name}'") - return server - - -async def delete_mcp_server(db: AsyncSession, name: str) -> bool: - """Delete server.""" - server = await get_mcp_server(db, name) - if not server: - return False - - connection_ids = [item.id for item in await list_mcp_connections(db, server_name=name)] - await db.delete(server) - await db.commit() - - for connection_id in connection_ids: - await _clear_mcp_connection_runtime_auth_cache(connection_id) - await invalidate_mcp_server_tools_cache(name) - - logger.info(f"Deleted MCP server '{name}'") - return True - - -async def get_mcp_server_dependency_summary(db: AsyncSession, name: str) -> dict[str, Any]: - connections = await list_mcp_connections(db, server_name=name) - - skill_rows = (await db.execute(select(Skill))).scalars().all() - matched_skills = [ - {"slug": item.slug, "name": item.name} for item in skill_rows if name in (item.mcp_dependencies or []) - ] - - agent_config_rows = (await db.execute(select(AgentConfig))).scalars().all() - matched_agent_configs = [] - for item in agent_config_rows: - config_json = item.config_json or {} - if name in (config_json.get("mcps") or []): - matched_agent_configs.append({"id": item.id, "name": item.name, "agent_id": item.agent_id}) - - connection_refs = [ - {"scope_type": item.scope_type, "scope_id": item.scope_id, "status": item.status} for item in connections - ] - - return { - "has_references": bool(connection_refs or matched_skills or matched_agent_configs), - "connections": connection_refs, - "skills": matched_skills, - "agent_configs": matched_agent_configs, - } - - -# ============================================================================= -# === Tool Management === -# ============================================================================= - - -async def set_server_enabled( - db: AsyncSession, name: str, enabled: bool, updated_by: str = None -) -> tuple[bool, MCPServer]: - """Set server enabled status.""" - server = await get_mcp_server(db, name) - if not server: - raise ValueError(f"Server '{name}' does not exist") - - server.enabled = 1 if enabled else 0 - if updated_by is not None: - server.updated_by = updated_by - await db.commit() - - is_enabled = bool(server.enabled) - if not is_enabled: - await _clear_mcp_server_runtime_auth_cache(db, name) - await invalidate_mcp_server_tools_cache(name) - - logger.info(f"Set MCP server '{name}' enabled={is_enabled}") - return is_enabled, server - - -async def toggle_tool_enabled( - db: AsyncSession, - server_name: str, - tool_name: str, - updated_by: str = None, -) -> tuple[bool, MCPServer]: - """Toggle single tool enabled status. - - Args: - db: Database session - server_name: Server name - tool_name: Tool name - updated_by: Updater - - Returns: - (enabled, server): Tool enabled status and updated server object - """ - server = await get_mcp_server(db, server_name) - if not server: - raise ValueError(f"Server '{server_name}' does not exist") - - disabled_tools = list(server.disabled_tools or []) - - if tool_name in disabled_tools: - disabled_tools.remove(tool_name) - enabled = True - else: - disabled_tools.append(tool_name) - enabled = False - - server.disabled_tools = disabled_tools - if updated_by is not None: - server.updated_by = updated_by - await db.commit() - - # Clear tool cache (re-filtered on next fetch) - clear_mcp_server_tools_cache(server_name) - - logger.info(f"Toggled tool '{tool_name}' for server '{server_name}' enabled={enabled}") - return enabled, server - - -# ============================================================================= -# === Unified Entry Points (Wrappers) === -# ============================================================================= - - -async def get_enabled_mcp_tools( - server_name: str, - *, - auth_context: AuthContext | None = None, - db: AsyncSession | None = None, - http_client: httpx.AsyncClient | None = None, -) -> list: - """Get MCP server tools (auto-filtering disabled_tools). - - Unified entry point for Agents, automatically: - 1. Gets the latest server config from database - 2. Gets all tools - 3. Filters out disabled_tools - - Args: - server_name: Server name - - Returns: - List of enabled tools - """ - config = await get_runtime_mcp_server_config( - server_name, - auth_context=auth_context, - db=db, - http_client=http_client, - ) - if config is None: - logger.warning(f"MCP server '{server_name}' not found in database or disabled") - return [] - - disabled_tools = config.get("disabled_tools") or [] - return await get_mcp_tools(server_name, additional_servers={server_name: config}, disabled_tools=disabled_tools) - - -async def get_servers_config(names: list[str]) -> dict[str, dict[str, Any]]: - """Batch get server configurations. - - Args: - names: List of server names - - Returns: - {name: config} dictionary, containing only found servers - """ - return await _load_enabled_mcp_server_configs(names=names) - - -async def get_all_mcp_tools( - server_name: str, - *, - auth_context: AuthContext | None = None, - db: AsyncSession | None = None, - http_client: httpx.AsyncClient | None = None, - force_refresh: bool = False, -) -> list: - """Get all tools of an MCP server (no filtering). - - For management UI to display tool list, supports viewing all tools and their enabled status. - - Args: - server_name: Server name - - Returns: - List of all tools (unfiltered) - """ - if auth_context is None and db is None: - config = await get_enabled_mcp_server_config(server_name) - else: - config = await get_runtime_mcp_server_config( - server_name, - auth_context=auth_context, - db=db, - http_client=http_client, - ) - if config is None: - logger.warning(f"MCP server '{server_name}' not found in database or disabled") - return [] - if not force_refresh: - cache_descriptor = await _build_mcp_tool_cache_descriptor(server_name, config) - manifest = await _mcp_tool_cache_store.get_manifest(cache_descriptor["cache_key"]) - if manifest is not None: - return _deserialize_mcp_tool_manifest(manifest) - return await get_mcp_tools( - server_name, - additional_servers={server_name: config}, - disabled_tools=[], - cache=True, - force_refresh=force_refresh, - ) +async def _clear_mcp_server_runtime_auth_cache(db: Any, server_name: str) -> None: + from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache as _clear_auth + await _clear_auth(db, server_name) diff --git a/backend/test/mcp_demo_server.py b/backend/test/mcp_demo_server.py new file mode 100644 index 000000000..7c9e95e56 --- /dev/null +++ b/backend/test/mcp_demo_server.py @@ -0,0 +1,214 @@ +from __future__ import annotations +import argparse +import asyncio +import contextvars +import logging +import os +import sys +from typing import Any +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from mcp.server import Server, NotificationOptions +from mcp.server.models import InitializationOptions +import mcp.types as types +from mcp.server.sse import SseServerTransport +from mcp.server.stdio import stdio_server + +# 简体中文注释与日志规范 (RULE[user_global]) +logger = logging.getLogger("mcp_demo_server") +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + +# 用于在不同传输协议下传递当前请求身份上下文的 ContextVar +current_request_headers_var = contextvars.ContextVar("current_request_headers", default=None) + +# 实例化 MCP 核心服务对象 +server = Server("yuxi-mcp-demo-server") + +@server.list_tools() +async def handle_list_tools() -> list[types.Tool]: + """根据身份或环境变量返回过滤后的三级权限工具列表""" + headers = current_request_headers_var.get() or {} + + # 优先级: HTTP Headers > 系统环境变量 (兼容 stdio 与 sse 两种环境的测试) + dept_id = headers.get("x-department-id") or os.environ.get("X_DEPARTMENT_ID") + user_id = headers.get("x-user-id") or os.environ.get("X_USER_ID") + auth_token = headers.get("authorization") or os.environ.get("AUTHORIZATION") + + logger.info(f"Listing tools - AuthToken: {auth_token}, DeptID: {dept_id}, UserID: {user_id}") + + # 基础路由工具 (全局可见) + tools = [ + types.Tool( + name="echo_global", + description="全局通用工具,无须任何权限即可访问", + inputSchema={ + "type": "object", + "properties": { + "message": {"type": "string", "description": "要回显的内容"} + }, + "required": ["message"] + }, + ) + ] + + # 部门级别权限工具 + if dept_id: + tools.append( + types.Tool( + name="echo_dept_data", + description=f"部门级别受限工具,当前已授权部门ID: {dept_id}", + inputSchema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "查询参数"} + }, + "required": ["query"] + }, + ) + ) + + # 用户个人级别权限工具 + if user_id: + tools.append( + types.Tool( + name="echo_user_profile", + description=f"个人受限工具,当前已授权用户ID: {user_id}", + inputSchema={ + "type": "object", + "properties": { + "dummy": {"type": "string", "description": "占位参数"} + } + }, + ) + ) + + return tools + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: + """执行工具回显结果""" + logger.info(f"Calling tool: {name} with args: {arguments}") + args = arguments or {} + + if name == "echo_global": + msg = args.get("message", "") + return [types.TextContent(type="text", text=f"[Global Output] 回显内容: {msg}")] + + elif name == "echo_dept_data": + query = args.get("query", "") + return [types.TextContent(type="text", text=f"[Department Output] 数据查询回显: {query}")] + + elif name == "echo_user_profile": + return [types.TextContent(type="text", text="[User Output] 成功获取用户专有敏感配置与画像数据")] + + else: + raise ValueError(f"Unknown tool: {name}") + + +# ============================================================================= +# === SSE 传输协议支持 (FastAPI) === +# ============================================================================= + +app = FastAPI(title="MCP Demo Server") +sse_transport = SseServerTransport("/messages") + +@app.post("/oauth/token") +async def oauth_token(request: Request): + """ + 模拟 OAuth2 认证端点。 + 返回一个 15 秒过期的 access_token,用以充分验证 Yuxi 后台的“短期 Token 自动失效与刷新”链路。 + """ + logger.info("Handling OAuth token request...") + return { + "access_token": "mock_access_token_123456", + "refresh_token": "mock_refresh_token_789000", + "expires_in": 15, # 15 秒过期,利于测试 + "token_type": "Bearer", + "scope": "read write" + } + +@app.get("/sse") +async def sse(request: Request): + """建立 SSE 长连接通道,并将当前的 Headers 存入 ContextVar""" + headers_dict = dict(request.headers) + logger.info(f"New SSE connection attempt. Headers: {headers_dict}") + + # 注入 ContextVar,使得在该长连接处理循环下的所有 list/call_tool 能读取到 header + token = current_request_headers_var.set(headers_dict) + try: + async with sse_transport.connect_sse( + request.scope, request.receive, request.send + ) as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="yuxi-mcp-demo-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + finally: + current_request_headers_var.reset(token) + +@app.post("/messages") +async def messages(request: Request): + """接收 SSE 通道发来的具体 JSON-RPC 请求""" + await sse_transport.handle_post_message(request.scope, request.receive, request.send) + + +# ============================================================================= +# === Stdio 传输协议支持 (本地子进程) === +# ============================================================================= + +async def run_stdio(): + """以 Stdio 形式在控制台管道中拉起""" + logger.info("Starting Stdio server transport...") + + # Stdio 模式下从系统环境变量读取 headers + current_request_headers_var.set(dict(os.environ)) + + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="yuxi-mcp-demo-server", + server_version="1.0.0", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + +# ============================================================================= +# === 启动入口 === +# ============================================================================= + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mock MCP Demo Server") + parser.add_argument( + "--transport", + choices=["stdio", "sse"], + default="sse", + help="传输协议类型 (默认为 sse)" + ) + parser.add_argument( + "--port", + type=int, + default=8999, + help="FastAPI SSE 服务的端口号" + ) + args = parser.parse_args() + + if args.transport == "stdio": + asyncio.run(run_stdio()) + else: + logger.info(f"Starting SSE FastAPI Server on port {args.port}...") + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/backend/test/unit/services/test_mcp_cache_policy.py b/backend/test/unit/services/test_mcp_cache_policy.py new file mode 100644 index 000000000..3973fe88b --- /dev/null +++ b/backend/test/unit/services/test_mcp_cache_policy.py @@ -0,0 +1,91 @@ +from __future__ import annotations +import pytest +from unittest.mock import MagicMock + +from yuxi.services.mcp.cache_policy import ( + CachePolicyFactory, + StaticCachePolicy, + TokenInjectedCachePolicy, + DynamicProxyCachePolicy, +) +from yuxi.services.mcp_auth.orchestrator import AuthContext +from yuxi.storage.postgres.models_business import MCPConnection + + +def test_cache_policy_factory(): + """测试 CachePolicyFactory 的策略派发逻辑""" + # 静态/无鉴权 + assert isinstance(CachePolicyFactory.get_policy(None), StaticCachePolicy) + assert isinstance(CachePolicyFactory.get_policy("legacy_static"), StaticCachePolicy) + + # 注入型 + assert isinstance(CachePolicyFactory.get_policy("bound_secret"), TokenInjectedCachePolicy) + assert isinstance(CachePolicyFactory.get_policy("stdio_env"), TokenInjectedCachePolicy) + + # 动态 Token 型 + assert isinstance(CachePolicyFactory.get_policy("custom_http_token"), DynamicProxyCachePolicy) + assert isinstance(CachePolicyFactory.get_policy("client_credentials"), DynamicProxyCachePolicy) + assert isinstance(CachePolicyFactory.get_policy("authorization_code"), DynamicProxyCachePolicy) + + +def test_static_cache_policy(): + """测试 StaticCachePolicy""" + policy = StaticCachePolicy() + assert policy.should_cache_tool_object() is True + + auth_context = AuthContext() + partition, is_shared = policy.resolve_cache_partition(auth_context, None) + assert partition == "global" + assert is_shared is True + + +def test_token_injected_cache_policy(): + """测试 TokenInjectedCachePolicy""" + policy = TokenInjectedCachePolicy() + assert policy.should_cache_tool_object() is True + + auth_context = AuthContext(user_id="user_1", department_id="dept_A") + + # connection 为 None 退避 + partition, is_shared = policy.resolve_cache_partition(auth_context, None) + assert partition == "global" + assert is_shared is True + + # 系统连接,共享 + conn_sys = MCPConnection(id=10, scope_type="system", scope_id="global") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_sys) + assert partition == "connection:10" + assert is_shared is True + + # 部门连接,独占 + conn_dept = MCPConnection(id=11, scope_type="department", scope_id="dept_A") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_dept) + assert partition == "connection:11" + assert is_shared is False + + # 个人连接,独占 + conn_user = MCPConnection(id=12, scope_type="user", scope_id="user_1") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_user) + assert partition == "connection:12" + assert is_shared is False + + +def test_dynamic_proxy_cache_policy(): + """测试 DynamicProxyCachePolicy""" + policy = DynamicProxyCachePolicy() + # 动态鉴权必须禁止把带临时 Token 的 Tool 实例缓存在共享内存中 + assert policy.should_cache_tool_object() is False + + auth_context = AuthContext(user_id="user_1", department_id="dept_A") + + # 部门隔离连接,独占 + conn_dept = MCPConnection(id=20, scope_type="department", scope_id="dept_A") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_dept) + assert partition == "connection:20" + assert is_shared is False + + # 个人隔离连接,独占 + conn_user = MCPConnection(id=21, scope_type="user", scope_id="user_1") + partition, is_shared = policy.resolve_cache_partition(auth_context, conn_user) + assert partition == "connection:21" + assert is_shared is False diff --git a/backend/test/unit/services/test_mcp_client_pool.py b/backend/test/unit/services/test_mcp_client_pool.py new file mode 100644 index 000000000..57a40a163 --- /dev/null +++ b/backend/test/unit/services/test_mcp_client_pool.py @@ -0,0 +1,95 @@ +from __future__ import annotations +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from yuxi.services.mcp.client_pool import MCPClientPool, LongLivedSession + + +@pytest.mark.asyncio +async def test_long_lived_session_lifecycle(): + """测试 LongLivedSession 正常的启动与停止流程""" + mock_client = MagicMock() + mock_session = MagicMock() + + # 模拟 client.session() 返回一个 AsyncContextManager + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + mock_client.session.return_value = mock_context + + ll_session = LongLivedSession(mock_client, "test_server") + + # 启动 + await ll_session.start() + assert ll_session._running is True + assert ll_session.session == mock_session + mock_client.session.assert_called_once_with("test_server") + + # 停止 + await ll_session.stop() + assert ll_session._running is False + assert ll_session.session is None + + +@pytest.mark.asyncio +async def test_client_pool_reuse_and_recreate(): + """测试 MCPClientPool 的复用逻辑与配置脏变重构逻辑""" + pool = MCPClientPool() + + config_1 = { + "transport": "stdio", + "command": "node", + "args": ["file1.js"], + "__yuxi_cache_partition": "p1", + } + + config_2 = { + "transport": "stdio", + "command": "node", + "args": ["file1.js"], + "__yuxi_cache_partition": "p1", + } + + config_changed = { + "transport": "stdio", + "command": "node", + "args": ["file2.js"], # 配置发生改变 + "__yuxi_cache_partition": "p1", + } + + mock_client_instance = MagicMock() + mock_session_instance = MagicMock() + + # Mock LongLivedSession 的 start/stop 以防真实建连 + with patch("yuxi.services.mcp.client_pool.MultiServerMCPClient", return_value=mock_client_instance), \ + patch("yuxi.services.mcp.client_pool.LongLivedSession") as MockLongLivedSession: + + mock_ll_instance = MagicMock() + mock_ll_instance.session = mock_session_instance + mock_ll_instance.start = AsyncMock() + mock_ll_instance.stop = AsyncMock() + MockLongLivedSession.return_value = mock_ll_instance + + # 1. 首次获取,创建新 Session + session_1 = await pool.get_session("test_server", "p1", config_1) + assert session_1 == mock_session_instance + assert MockLongLivedSession.call_count == 1 + mock_ll_instance.start.assert_called_once() + + # 2. 相同配置获取,直接复用 + session_2 = await pool.get_session("test_server", "p1", config_2) + assert session_2 == mock_session_instance + assert MockLongLivedSession.call_count == 1 # 没增加,说明复用了 + + # 3. 配置改变获取,销毁旧的,重新创建 + session_changed = await pool.get_session("test_server", "p1", config_changed) + assert session_changed == mock_session_instance + # 销毁被调用了 + mock_ll_instance.stop.assert_called_once() + # 创建计数增加 + assert MockLongLivedSession.call_count == 2 + + # 4. shutdown 清理 + await pool.shutdown() + assert mock_ll_instance.stop.call_count == 2 # 新增的那个也被 stop + assert len(pool._sessions) == 0 From 7053d3ba4d7fe83f65727aec662c5e32837671ff Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 00:59:02 +0800 Subject: [PATCH 11/36] feat(mcp): complete mcp gateway proxy streaming refactor and purge mcp_service facade --- backend/package/yuxi/agents/__init__.py | 2 +- .../yuxi/agents/buildin/chatbot/graph.py | 2 +- .../yuxi/agents/buildin/deep_agent/graph.py | 2 +- .../middlewares/dynamic_tool_middleware.py | 2 +- .../middlewares/runtime_config_middleware.py | 2 +- .../agents/middlewares/skills_middleware.py | 2 +- .../package/yuxi/services/mcp/client_pool.py | 14 +- .../yuxi/services/mcp/connection_service.py | 12 +- .../yuxi/services/mcp/server_service.py | 38 ++-- .../yuxi/services/mcp_auth/orchestrator.py | 2 +- .../yuxi/services/mcp_auth/proxy_service.py | 199 +++++++++++++----- backend/package/yuxi/services/mcp_service.py | 111 ---------- backend/package/yuxi/services/run_worker.py | 2 +- .../package/yuxi/services/skill_service.py | 2 +- backend/server/routers/mcp_internal_router.py | 102 ++------- backend/server/routers/mcp_router.py | 24 ++- backend/server/utils/lifespan.py | 2 +- ...th_runtime.py => test_mcp_auth_runtime.py} | 32 +-- .../services/test_mcp_connection_service.py | 68 +++--- ...e.py => test_mcp_tool_registry_service.py} | 102 +++++---- fix_mcp_service_imports.py | 49 +++++ fix_tests.py | 55 +++++ 22 files changed, 424 insertions(+), 402 deletions(-) delete mode 100644 backend/package/yuxi/services/mcp_service.py rename backend/test/unit/services/{test_mcp_service_auth_runtime.py => test_mcp_auth_runtime.py} (88%) rename backend/test/unit/services/{test_mcp_service.py => test_mcp_tool_registry_service.py} (73%) create mode 100644 fix_mcp_service_imports.py create mode 100644 fix_tests.py diff --git a/backend/package/yuxi/agents/__init__.py b/backend/package/yuxi/agents/__init__.py index dd7174a4f..8c654c59c 100644 --- a/backend/package/yuxi/agents/__init__.py +++ b/backend/package/yuxi/agents/__init__.py @@ -12,7 +12,7 @@ from yuxi.agents.toolkits.utils import get_tool_info # MCP - Agent 层统一入口(自动过滤 disabled_tools) -from yuxi.services.mcp_service import get_enabled_mcp_tools +from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools __all__ = [ # Base classes diff --git a/backend/package/yuxi/agents/buildin/chatbot/graph.py b/backend/package/yuxi/agents/buildin/chatbot/graph.py index d7dcdc2d2..3f0872256 100644 --- a/backend/package/yuxi/agents/buildin/chatbot/graph.py +++ b/backend/package/yuxi/agents/buildin/chatbot/graph.py @@ -13,7 +13,7 @@ ) from yuxi.agents.middlewares.knowledge_base_middleware import KnowledgeBaseMiddleware from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware -from yuxi.services.mcp_service import get_tools_from_all_servers +from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers from yuxi.services.subagent_service import get_subagents_from_names from .prompt import TODO_MID_PROMPT, build_prompt_with_context diff --git a/backend/package/yuxi/agents/buildin/deep_agent/graph.py b/backend/package/yuxi/agents/buildin/deep_agent/graph.py index 706ec3b1f..e56155fe5 100644 --- a/backend/package/yuxi/agents/buildin/deep_agent/graph.py +++ b/backend/package/yuxi/agents/buildin/deep_agent/graph.py @@ -17,7 +17,7 @@ from yuxi.agents.middlewares.knowledge_base_middleware import KnowledgeBaseMiddleware from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware from yuxi.agents.toolkits.buildin.tools import _create_tavily_search -from yuxi.services.mcp_service import get_tools_from_all_servers +from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers from yuxi.services.subagent_service import get_subagents_from_names from yuxi.utils import logger diff --git a/backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py b/backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py index 4f8c3aac5..5b0d47820 100644 --- a/backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py +++ b/backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py @@ -3,7 +3,7 @@ from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse -from yuxi.services.mcp_service import get_mcp_tools +from yuxi.services.mcp.tool_registry_service import get_mcp_tools from yuxi.utils import logger diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index 519dd148d..ad6a70c23 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -10,7 +10,7 @@ from yuxi.agents import load_chat_model from yuxi.agents.toolkits import get_all_tool_instances from yuxi.services.mcp_auth.orchestrator import AuthContext -from yuxi.services.mcp_service import get_enabled_mcp_tools +from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools from yuxi.utils.datetime_utils import shanghai_now from yuxi.utils.logging_config import logger diff --git a/backend/package/yuxi/agents/middlewares/skills_middleware.py b/backend/package/yuxi/agents/middlewares/skills_middleware.py index b626c54e2..6b0d4c0b9 100644 --- a/backend/package/yuxi/agents/middlewares/skills_middleware.py +++ b/backend/package/yuxi/agents/middlewares/skills_middleware.py @@ -16,7 +16,7 @@ from yuxi.agents.toolkits import get_all_tool_instances from yuxi.repositories.skill_repository import SkillRepository from yuxi.services.mcp_auth.orchestrator import AuthContext -from yuxi.services.mcp_service import get_enabled_mcp_tools +from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools from yuxi.services.skill_service import _normalize_string_list, is_valid_skill_slug from yuxi.storage.postgres.manager import pg_manager from yuxi.utils.logging_config import logger diff --git a/backend/package/yuxi/services/mcp/client_pool.py b/backend/package/yuxi/services/mcp/client_pool.py index a55cf24ea..a205cd896 100644 --- a/backend/package/yuxi/services/mcp/client_pool.py +++ b/backend/package/yuxi/services/mcp/client_pool.py @@ -31,7 +31,7 @@ async def async_auth_flow( # 导入数据库会话管理器以获取连接与 Token from yuxi.storage.postgres.manager import pg_manager async with pg_manager.get_async_session_context() as session: - from yuxi.services.mcp_service import get_runtime_mcp_server_config + from yuxi.services.mcp.server_service import get_runtime_mcp_server_config # NOTE: 2. 读取当前上下文对应的最新运行时配置(含 Token 自动刷新逻辑) runtime_config = await get_runtime_mcp_server_config( self.server_name, @@ -137,6 +137,15 @@ def _calculate_config_hash(self, config: dict[str, Any]) -> str: payload = json.dumps(clean_config, sort_keys=True, ensure_ascii=True, separators=(",", ":")) return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16] + async def _get_mcp_client(self, server_configs: dict[str, Any] | None = None) -> MultiServerMCPClient | None: + try: + client = MultiServerMCPClient(server_configs) # pyright: ignore[reportArgumentType] + logger.info(f"Initialized MCP client with servers: {list(server_configs.keys() or [])}") + return client + except Exception as e: + logger.error(f"Failed to initialize MCP client: {e}") + return None + async def get_session( self, server_name: str, @@ -175,8 +184,7 @@ async def get_session( client_config["auth"] = DynamicMCPTokenAuth(server_name) logger.info(f"Creating new long-lived MCP session for {cache_key} (transport: {client_config.get('transport')})") - from yuxi.services.mcp_service import get_mcp_client - client = await get_mcp_client({server_name: client_config}) + client = await self._get_mcp_client({server_name: client_config}) if client is None: raise RuntimeError(f"Failed to initialize MCP client for {server_name}") ll_session = LongLivedSession(client, server_name) diff --git a/backend/package/yuxi/services/mcp/connection_service.py b/backend/package/yuxi/services/mcp/connection_service.py index c4073f1e0..1acfbb974 100644 --- a/backend/package/yuxi/services/mcp/connection_service.py +++ b/backend/package/yuxi/services/mcp/connection_service.py @@ -245,8 +245,8 @@ async def reauthorize_mcp_connection( if connection is None: raise ValueError(f"MCP connection '{connection_id}' does not exist") - from yuxi.services import mcp_service - cache = mcp_service.RedisTokenCache() + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + cache = RedisTokenCache() if getattr(connection, "id", None) is not None: try: await cache.delete_access_token(connection.id) @@ -288,12 +288,14 @@ async def test_mcp_connection( raise ValueError(f"Server '{connection.server_name}' does not exist") auth_context = _auth_context_from_connection(connection) - from yuxi.services import mcp_service - config = await mcp_service.get_runtime_mcp_server_config(server.name, auth_context=auth_context, db=db) + from yuxi.services.mcp.server_service import get_runtime_mcp_server_config + from yuxi.services.mcp.tool_registry_service import get_mcp_tools + + config = await get_runtime_mcp_server_config(server.name, auth_context=auth_context, db=db) if config is None: raise ValueError(f"MCP server '{server.name}' runtime config unavailable") - tools = await mcp_service.get_mcp_tools( + tools = await get_mcp_tools( server.name, additional_servers={server.name: config}, disabled_tools=[], diff --git a/backend/package/yuxi/services/mcp/server_service.py b/backend/package/yuxi/services/mcp/server_service.py index 2187419b2..22ee6f792 100644 --- a/backend/package/yuxi/services/mcp/server_service.py +++ b/backend/package/yuxi/services/mcp/server_service.py @@ -201,8 +201,7 @@ async def get_runtime_mcp_server_config( ) -> dict[str, Any] | None: """解析获取附带运行时鉴权与租户范围的 MCP 服务配置""" if db is None and auth_context is None: - from yuxi.services import mcp_service - return await mcp_service.get_enabled_mcp_server_config(server_name) + return await get_enabled_mcp_server_config(server_name) if db is not None: server = await _get_enabled_mcp_server_record(server_name, db=db) @@ -212,8 +211,8 @@ async def get_runtime_mcp_server_config( return server.to_mcp_config() auth_config = MCPAuthConfig.model_validate(server.auth_config_json) - from yuxi.services import mcp_service - scope_id = mcp_service._resolve_scope_id(auth_config.binding_scope, auth_context) + from yuxi.services.mcp.connection_service import _resolve_scope_id + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) if scope_id is None: return server.to_mcp_config() @@ -265,8 +264,7 @@ async def get_runtime_mcp_server_config( async def get_enabled_mcp_server_names(*, db: AsyncSession | None = None) -> list[str]: """获取所有已启用的服务器名称""" - from yuxi.services import mcp_service - configs = await mcp_service._load_enabled_mcp_server_configs(db=db) + configs = await _load_enabled_mcp_server_configs(db=db) return list(configs.keys()) @@ -326,9 +324,9 @@ async def create_mcp_server( await db.commit() await db.refresh(server) - from yuxi.services import mcp_service - await mcp_service._clear_mcp_server_runtime_auth_cache(db, name) - await mcp_service.invalidate_mcp_server_tools_cache(name) + from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache + await _clear_mcp_server_runtime_auth_cache(db, name) + await invalidate_mcp_server_tools_cache(name) logger.info(f"Created MCP server '{name}'") return server @@ -386,10 +384,10 @@ async def update_mcp_server( await db.commit() await db.refresh(server) - from yuxi.services import mcp_service + from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache if auth_config is not _UNSET: - await mcp_service._clear_mcp_server_runtime_auth_cache(db, name) - await mcp_service.invalidate_mcp_server_tools_cache(name) + await _clear_mcp_server_runtime_auth_cache(db, name) + await invalidate_mcp_server_tools_cache(name) logger.info(f"Updated MCP server '{name}'") return server @@ -404,9 +402,9 @@ async def delete_mcp_server(db: AsyncSession, name: str) -> bool: await db.delete(server) await db.commit() - from yuxi.services import mcp_service - await mcp_service._clear_mcp_server_runtime_auth_cache(db, name) - await mcp_service.invalidate_mcp_server_tools_cache(name) + from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache + await _clear_mcp_server_runtime_auth_cache(db, name) + await invalidate_mcp_server_tools_cache(name) logger.info(f"Deleted MCP server '{name}'") return True @@ -414,8 +412,8 @@ async def delete_mcp_server(db: AsyncSession, name: str) -> bool: async def get_mcp_server_dependency_summary(db: AsyncSession, name: str) -> dict[str, Any]: """获取依赖于该 MCP 服务器的智能体、技能和连接概要""" - from yuxi.services import mcp_service - connections = await mcp_service.list_mcp_connections(db, server_name=name) + from yuxi.services.mcp.connection_service import list_mcp_connections + connections = await list_mcp_connections(db, server_name=name) skill_rows = (await db.execute(select(Skill))).scalars().all() matched_skills = [ @@ -455,10 +453,10 @@ async def set_server_enabled( await db.commit() is_enabled = bool(server.enabled) - from yuxi.services import mcp_service + from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache if not is_enabled: - await mcp_service._clear_mcp_server_runtime_auth_cache(db, name) - await mcp_service.invalidate_mcp_server_tools_cache(name) + await _clear_mcp_server_runtime_auth_cache(db, name) + await invalidate_mcp_server_tools_cache(name) logger.info(f"Set MCP server '{name}' enabled={is_enabled}") return is_enabled, server diff --git a/backend/package/yuxi/services/mcp_auth/orchestrator.py b/backend/package/yuxi/services/mcp_auth/orchestrator.py index a3548f6bf..aeb7f7558 100644 --- a/backend/package/yuxi/services/mcp_auth/orchestrator.py +++ b/backend/package/yuxi/services/mcp_auth/orchestrator.py @@ -279,7 +279,7 @@ async def _resolve_dynamic_token_values( token_cache: Any | None, ) -> dict[str, Any]: if token_cache is None and connection is not None: - from yuxi.services.mcp_service import RedisTokenCache + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache token_cache = RedisTokenCache() diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py index 836992d59..266aada40 100644 --- a/backend/package/yuxi/services/mcp_auth/proxy_service.py +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -5,6 +5,11 @@ from urllib.parse import urlencode import httpx +from fastapi import Request, Response, HTTPException +from fastapi.responses import StreamingResponse +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select +from starlette.background import BackgroundTask from server.utils.auth_utils import AuthUtils from yuxi.services.mcp_auth.config_models import MCPAuthConfig @@ -128,71 +133,153 @@ def _record_scope_error(connection: MCPConnection | None, message: str) -> None: connection.meta_json = meta_json -async def proxy_mcp_request( +async def handle_mcp_proxy_request( + server_name: str, + request: Request, + path: str, + internal_token: str, + db: AsyncSession, +) -> Response: + """内部网关主入口:鉴权解析、查库拦截与流式代理""" + from yuxi.services.mcp.server_service import get_mcp_server + + try: + auth_context = decode_proxy_access_token(internal_token, server_name=server_name) + except ValueError as exc: + raise HTTPException(status_code=401, detail=str(exc)) from exc + + server = await get_mcp_server(db, server_name) + if server is None: + raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在") + if not bool(getattr(server, "enabled", True)): + raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在或已停用") + + auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) + + from yuxi.services.mcp.connection_service import _resolve_scope_id + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) + connection = None + if scope_id is not None: + result = await db.execute( + select(MCPConnection).where( + MCPConnection.server_name == server.name, + MCPConnection.scope_type == auth_config.binding_scope, + MCPConnection.scope_id == scope_id, + MCPConnection.status == "active", + ) + ) + connection = result.scalar_one_or_none() + + if auth_config.binding_scope != "inline" and connection is None: + raise HTTPException(status_code=403, detail="当前用户没有该 MCP 的有效连接") + + # 注意:我们读取整个 request body,因为 MCP 请求参数通常极小, + # 但由于可能有 401 重试,我们需要保存下 body 来实现背压重发。 + body = await request.body() + return await _proxy_mcp_request_stream( + server=server, + connection=connection, + auth_context=auth_context, + request=request, + body=body, + path=path, + db=db, + ) + + +async def _proxy_mcp_request_stream( server: MCPServer, *, connection: MCPConnection | None, auth_context: AuthContext, - method: str, - headers: dict[str, str] | None, - query_params: dict[str, Any] | None, + request: Request, body: bytes, path: str = "", - http_client: httpx.AsyncClient | None = None, - token_cache: Any | None = None, -) -> httpx.Response: + db: AsyncSession, +) -> Response: + """底层流式转发逻辑:处理 HTTPX 透传、SSE 和 401 重试闭环事务""" auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) if server.transport not in _HTTP_TRANSPORTS: - raise ValueError(f"Internal proxy only supports HTTP MCP transports, got: {server.transport}") + raise HTTPException(status_code=400, detail=f"Internal proxy only supports HTTP MCP transports, got: {server.transport}") - if http_client is None: - http_client = httpx.AsyncClient() - should_close = True - else: - should_close = False + http_client = httpx.AsyncClient(timeout=server.timeout or 60.0) + bg_task = BackgroundTask(http_client.aclose) + + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + token_cache = RedisTokenCache() - try: - max_attempts = 2 if auth_config.refresh_policy.retry_once_on_401 else 1 - for attempt in range(max_attempts): - runtime_config = await resolve_runtime_mcp_config( - server, - auth_context=auth_context, - connection=connection, - http_client=http_client, - token_cache=token_cache, + max_attempts = 2 if auth_config.refresh_policy.retry_once_on_401 else 1 + + for attempt in range(max_attempts): + runtime_config = await resolve_runtime_mcp_config( + server, + auth_context=auth_context, + connection=connection, + http_client=http_client, + token_cache=token_cache, + ) + target_url = _build_target_url(runtime_config["url"], path=path, query_params=dict(request.query_params)) + upstream_headers = _merge_upstream_headers(runtime_config.get("headers") or {}, dict(request.headers)) + + request_obj = http_client.build_request( + method=request.method.upper(), + url=target_url, + headers=upstream_headers, + content=body, + ) + + # 使用 send(stream=True) 获取异步可迭代响应而不会阻塞 SSE 长链接 + response = await http_client.send(request_obj, stream=True) + + if response.status_code == 403: + await response.aclose() + _record_scope_error(connection, "MCP upstream rejected request due to insufficient scope") + if connection is not None: + await db.commit() + return Response( + content='{"error": "insufficient_scope", "message": "当前授权范围不足"}', + status_code=403, + media_type="application/json", + background=bg_task ) - target_url = _build_target_url(runtime_config["url"], path=path, query_params=query_params) - upstream_headers = _merge_upstream_headers(runtime_config.get("headers") or {}, headers) - response = await http_client.request( - method=method.upper(), - url=target_url, - headers=upstream_headers, - content=body, + + if response.status_code != 401: + # 正常响应,此时直接闭环提交事务,防止污染外层 + if connection is not None and hasattr(db, "commit"): + await db.commit() + + async def proxy_stream_generator(): + try: + async for chunk in response.aiter_raw(): + yield chunk + finally: + await response.aclose() + + resp_headers = {} + for k, v in response.headers.items(): + if k.lower() not in _HOP_BY_HOP_HEADERS and k.lower() not in ("content-encoding", "content-length"): + resp_headers[k] = v + + return StreamingResponse( + proxy_stream_generator(), + status_code=response.status_code, + headers=resp_headers, + background=bg_task ) - if response.status_code == 403: - _record_scope_error(connection, "MCP upstream rejected request due to insufficient scope") - return httpx.Response( - 403, - json={ - "error": "insufficient_scope", - "message": "当前授权范围不足,请联系管理员或重新授权", - }, - ) - if response.status_code != 401: - return response - if attempt + 1 >= max_attempts: - break - if token_cache is not None and connection is not None and getattr(connection, "id", None) is not None: - await token_cache.delete_access_token(connection.id) - - _mark_reauth_required(connection, "MCP upstream returned 401 after retry") - return httpx.Response( - 424, - json={ - "error": "reauth_required", - "message": "连接失效,请重新连接", - }, - ) - finally: - if should_close: - await http_client.aclose() + + # 如果是 401,回收流连接并准备重试 + await response.aclose() + if attempt + 1 >= max_attempts: + break + if connection is not None and getattr(connection, "id", None) is not None: + await token_cache.delete_access_token(connection.id) + + _mark_reauth_required(connection, "MCP upstream returned 401 after retry") + if connection is not None: + await db.commit() + return Response( + content='{"error": "reauth_required", "message": "连接失效,请重新连接"}', + status_code=424, + media_type="application/json", + background=bg_task + ) diff --git a/backend/package/yuxi/services/mcp_service.py b/backend/package/yuxi/services/mcp_service.py deleted file mode 100644 index 1d72c377a..000000000 --- a/backend/package/yuxi/services/mcp_service.py +++ /dev/null @@ -1,111 +0,0 @@ -"""MCP Service - Facade 适配门面,保持项目的完全向下兼容。 - -所有的实质逻辑均已根据职责分工拆分重构至以下子服务模块中: -- yuxi.services.mcp.server_service -- yuxi.services.mcp.connection_service -- yuxi.services.mcp.tool_registry_service -- yuxi.services.mcp.client_pool -""" - -from __future__ import annotations -import asyncio -from typing import Any - -from yuxi.services.mcp.server_service import ( - ensure_builtin_mcp_servers_in_db, - get_enabled_mcp_server_names, - get_mcp_server, - get_all_mcp_servers, - create_mcp_server, - update_mcp_server, - delete_mcp_server, - get_mcp_server_dependency_summary, - set_server_enabled, - get_servers_config, -) -from yuxi.services.mcp.connection_service import ( - get_mcp_connection, - list_mcp_connections, - create_mcp_connection, - update_mcp_connection, - delete_mcp_connection, - set_mcp_connection_status, - reauthorize_mcp_connection, - test_mcp_connection, - _resolve_scope_id, -) -from yuxi.services.mcp.tool_registry_service import ( - to_camel_case, - get_tools_from_all_servers, - clear_mcp_cache, - clear_mcp_server_tools_cache, - clear_mcp_connection_tools_cache, - invalidate_mcp_server_tools_cache, - invalidate_mcp_connection_tools_cache, - get_mcp_tools_stats, - get_enabled_mcp_tools, - get_all_mcp_tools, - toggle_tool_enabled, -) - -# 兼容原导入以防万一 -from yuxi.services.mcp_auth.orchestrator import AuthContext -from langchain_mcp_adapters.client import MultiServerMCPClient -from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache -from yuxi.services.mcp_tool_cache import RedisMcpToolCache -from yuxi.utils import logger - -# ----------------------------------------------------------------------------- -# --- 共享状态与依赖(提供给外部/子服务使用,并对单元测试 Mock 100% 兼容) --- -# ----------------------------------------------------------------------------- -_mcp_tools_cache = {} -_mcp_tool_cache_store = RedisMcpToolCache() -_mcp_tools_stats = {} -_mcp_lock = asyncio.Lock() - - -# ----------------------------------------------------------------------------- -# --- 兼容性转发入口(支持被测试 monkeypatch.setattr 覆盖) --- -# ----------------------------------------------------------------------------- -async def get_enabled_mcp_server_config(server_name: str, *, db: Any = None) -> dict[str, Any] | None: - from yuxi.services.mcp.server_service import get_enabled_mcp_server_config as _get_cfg - return await _get_cfg(server_name, db=db) - - -async def get_runtime_mcp_server_config( - server_name: str, - *, - auth_context: AuthContext | None = None, - db: Any = None, - http_client: Any = None, -) -> dict[str, Any] | None: - from yuxi.services.mcp.server_service import get_runtime_mcp_server_config as _get_run_cfg - return await _get_run_cfg(server_name, auth_context=auth_context, db=db, http_client=http_client) - - -async def _load_enabled_mcp_server_configs(names: list[str] | None = None) -> dict[str, dict[str, Any]]: - from yuxi.services.mcp.server_service import _load_enabled_mcp_server_configs as _load_cfg - return await _load_cfg(names=names) - - -async def get_mcp_tools(server_name: str, **kwargs: Any) -> list[Any]: - from yuxi.services.mcp.tool_registry_service import get_mcp_tools as _get_t - return await _get_t(server_name, **kwargs) - - -async def get_mcp_client( - server_configs: dict[str, Any] | None = None, -) -> MultiServerMCPClient | None: - """初始化并拉起 MCP 客户端。保留该底层入口以确保单元测试中的 monkeypatch 拦截顺畅传导。""" - try: - client = MultiServerMCPClient(server_configs) # pyright: ignore[reportArgumentType] - logger.info(f"Initialized MCP client with servers: {list(server_configs.keys())}") - return client - except Exception as e: - logger.error(f"Failed to initialize MCP client: {e}") - return None - - -async def _clear_mcp_server_runtime_auth_cache(db: Any, server_name: str) -> None: - from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache as _clear_auth - await _clear_auth(db, server_name) diff --git a/backend/package/yuxi/services/run_worker.py b/backend/package/yuxi/services/run_worker.py index a6ef9e14c..12b606dd3 100644 --- a/backend/package/yuxi/services/run_worker.py +++ b/backend/package/yuxi/services/run_worker.py @@ -12,7 +12,7 @@ from sqlalchemy.exc import OperationalError from yuxi.repositories.agent_run_repository import TERMINAL_RUN_STATUSES, AgentRunRepository from yuxi.services.chat_service import stream_agent_chat -from yuxi.services.mcp_service import ensure_builtin_mcp_servers_in_db +from yuxi.services.mcp.server_service import ensure_builtin_mcp_servers_in_db from yuxi.services.run_queue_service import ( append_run_stream_event, clear_cancel_signal, diff --git a/backend/package/yuxi/services/skill_service.py b/backend/package/yuxi/services/skill_service.py index a6663a637..ce094e720 100644 --- a/backend/package/yuxi/services/skill_service.py +++ b/backend/package/yuxi/services/skill_service.py @@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from yuxi import config as sys_config from yuxi.repositories.skill_repository import SkillRepository -from yuxi.services.mcp_service import get_enabled_mcp_server_names +from yuxi.services.mcp.server_service import get_enabled_mcp_server_names from yuxi.storage.postgres.models_business import Skill from yuxi.utils.logging_config import logger diff --git a/backend/server/routers/mcp_internal_router.py b/backend/server/routers/mcp_internal_router.py index b5f48555b..2de6fcf26 100644 --- a/backend/server/routers/mcp_internal_router.py +++ b/backend/server/routers/mcp_internal_router.py @@ -1,55 +1,19 @@ from __future__ import annotations -from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response -from sqlalchemy import select +from fastapi import APIRouter, Depends, Header, Request, Response from sqlalchemy.ext.asyncio import AsyncSession from server.utils.auth_middleware import get_db -from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_auth.proxy_service import ( INTERNAL_PROXY_TOKEN_HEADER, - decode_proxy_access_token, - proxy_mcp_request, + handle_mcp_proxy_request, ) -from yuxi.services.mcp_service import _resolve_scope_id, get_mcp_server -from yuxi.storage.postgres.models_business import MCPConnection, MCPServer -from yuxi.utils import logger mcp_internal = APIRouter(prefix="/internal/mcp-proxy", tags=["mcp-internal"]) -async def _load_active_connection( - db: AsyncSession, - *, - server: MCPServer, - auth_context, -) -> MCPConnection | None: - auth_payload = server.auth_config_json or {} - if not auth_payload: - return None - - auth_config = MCPAuthConfig.model_validate(auth_payload) - scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) - if scope_id is None: - return None - - result = await db.execute( - select(MCPConnection).where( - MCPConnection.server_name == server.name, - MCPConnection.scope_type == auth_config.binding_scope, - MCPConnection.scope_id == scope_id, - MCPConnection.status == "active", - ) - ) - return result.scalar_one_or_none() - - @mcp_internal.api_route( - "/{server_name}", - methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], -) -@mcp_internal.api_route( - "/{server_name}/{path:path}", + "/{server_name}{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], ) async def proxy_mcp_server_request( @@ -58,51 +22,15 @@ async def proxy_mcp_server_request( path: str = "", internal_token: str | None = Header(None, alias=INTERNAL_PROXY_TOKEN_HEADER), db: AsyncSession = Depends(get_db), -): - if not internal_token: - raise HTTPException(status_code=401, detail="missing internal proxy token") - - try: - auth_context = decode_proxy_access_token(internal_token, server_name=server_name) - except ValueError as exc: - raise HTTPException(status_code=401, detail=str(exc)) from exc - - server = await get_mcp_server(db, server_name) - if server is None: - raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在") - if not bool(getattr(server, "enabled", True)): - raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在或已停用") - - try: - connection = await _load_active_connection(db, server=server, auth_context=auth_context) - auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) - if auth_config.binding_scope != "inline" and connection is None: - raise HTTPException(status_code=403, detail="当前用户没有该 MCP 的有效连接") - - body = await request.body() - upstream_response = await proxy_mcp_request( - server, - connection=connection, - auth_context=auth_context, - method=request.method, - headers=dict(request.headers), - query_params=dict(request.query_params), - body=body, - path=path, - ) - if connection is not None and hasattr(db, "commit"): - await db.commit() - response_headers = {} - content_type = upstream_response.headers.get("content-type") - if content_type: - response_headers["content-type"] = content_type - return Response( - content=upstream_response.content, - status_code=upstream_response.status_code, - headers=response_headers, - ) - except HTTPException: - raise - except Exception as exc: - logger.error(f"Failed to proxy MCP server '{server_name}': {exc}") - raise HTTPException(status_code=500, detail=str(exc)) from exc +) -> Response: + """代理路由(纯路由层):业务鉴权、DB操作及背压透传已全部下沉到 proxy_service 领域服务处理""" + # 去除前导斜杠,以兼容不带 path 和带 path 两种情况 + path = path.lstrip("/") + + return await handle_mcp_proxy_request( + server_name=server_name, + request=request, + path=path, + internal_token=internal_token or "", + db=db, + ) diff --git a/backend/server/routers/mcp_router.py b/backend/server/routers/mcp_router.py index f0bd081ac..ce7b9cc07 100644 --- a/backend/server/routers/mcp_router.py +++ b/backend/server/routers/mcp_router.py @@ -8,25 +8,29 @@ from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.mcp_auth.config_models import MCPAuthConfig -from yuxi.services.mcp_service import ( - create_mcp_connection, +from yuxi.services.mcp.server_service import ( create_mcp_server, - delete_mcp_connection, - get_mcp_tools_stats, delete_mcp_server, - get_mcp_connection, - get_mcp_server_dependency_summary, get_all_mcp_servers, - get_all_mcp_tools, get_mcp_server, + get_mcp_server_dependency_summary, + set_server_enabled, + update_mcp_server, +) +from yuxi.services.mcp.connection_service import ( + create_mcp_connection, + delete_mcp_connection, + get_mcp_connection, list_mcp_connections, reauthorize_mcp_connection, set_mcp_connection_status, - set_server_enabled, test_mcp_connection, - toggle_tool_enabled, update_mcp_connection, - update_mcp_server, +) +from yuxi.services.mcp.tool_registry_service import ( + get_all_mcp_tools, + get_mcp_tools_stats, + toggle_tool_enabled, ) from yuxi.storage.postgres.models_business import User from yuxi.utils import logger diff --git a/backend/server/utils/lifespan.py b/backend/server/utils/lifespan.py index 10380f252..9b6539471 100644 --- a/backend/server/utils/lifespan.py +++ b/backend/server/utils/lifespan.py @@ -5,7 +5,7 @@ from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from yuxi.services.task_service import tasker -from yuxi.services.mcp_service import ensure_builtin_mcp_servers_in_db +from yuxi.services.mcp.server_service import ensure_builtin_mcp_servers_in_db from yuxi.services.model_provider_service import ensure_builtin_model_providers_in_db from yuxi.services.subagent_service import init_builtin_subagents from yuxi.services.run_queue_service import close_queue_clients, get_redis_client diff --git a/backend/test/unit/services/test_mcp_service_auth_runtime.py b/backend/test/unit/services/test_mcp_auth_runtime.py similarity index 88% rename from backend/test/unit/services/test_mcp_service_auth_runtime.py rename to backend/test/unit/services/test_mcp_auth_runtime.py index d9adf8e87..94f68a8bf 100644 --- a/backend/test/unit/services/test_mcp_service_auth_runtime.py +++ b/backend/test/unit/services/test_mcp_auth_runtime.py @@ -9,7 +9,9 @@ os.environ.setdefault("OPENAI_API_KEY", "test-key") -from yuxi.services import mcp_service +from yuxi.services.mcp import connection_service, server_service, tool_registry_service +from yuxi.services.mcp.client_pool import mcp_client_pool +from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.storage.postgres.models_business import MCPConnection, MCPServer @@ -65,7 +67,7 @@ async def test_get_runtime_mcp_server_config_resolves_department_connection(runt ) await runtime_session.commit() - config = await mcp_service.get_runtime_mcp_server_config( + config = await server_service.get_runtime_mcp_server_config( "finance-gateway", auth_context=AuthContext(user_id="u-1", department_id="42"), db=runtime_session, @@ -116,16 +118,16 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled captured_configs.append(additional_servers[server_name]) return ["private-tool"] - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) - user_1_tools = await mcp_service.get_enabled_mcp_tools( + user_1_tools = await tool_registry_service.get_enabled_mcp_tools( "personal-gateway", auth_context=AuthContext(user_id="user-1"), db=runtime_session, ) with pytest.raises(ValueError, match="Active MCP connection not found"): - await mcp_service.get_enabled_mcp_tools( + await tool_registry_service.get_enabled_mcp_tools( "personal-gateway", auth_context=AuthContext(user_id="user-2"), db=runtime_session, @@ -160,10 +162,10 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled ) return ["tool-a"] - monkeypatch.setattr(mcp_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + monkeypatch.setattr(server_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) - tools = await mcp_service.get_enabled_mcp_tools( + tools = await tool_registry_service.get_enabled_mcp_tools( "demo", auth_context=AuthContext(user_id="u-100", department_id="d-9"), ) @@ -204,10 +206,10 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled ) return ["tool-a", "tool-b"] - monkeypatch.setattr(mcp_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + monkeypatch.setattr(server_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) - tools = await mcp_service.get_all_mcp_tools( + tools = await tool_registry_service.get_all_mcp_tools( "demo", auth_context=AuthContext(user_id="u-100", department_id="d-9"), ) @@ -271,7 +273,7 @@ async def test_get_runtime_mcp_server_config_returns_internal_proxy_for_dynamic_ ) await runtime_session.commit() - config = await mcp_service.get_runtime_mcp_server_config( + config = await server_service.get_runtime_mcp_server_config( "finance-proxy", auth_context=AuthContext(user_id="user-1", department_id="dep-88"), db=runtime_session, @@ -319,10 +321,10 @@ async def fake_invalidate_tools_cache(server_name): assert server_name == "finance-gateway" calls["tools_cache"] += 1 - monkeypatch.setattr(mcp_service, "_clear_mcp_server_runtime_auth_cache", fake_clear_runtime_auth_cache) - monkeypatch.setattr(mcp_service, "invalidate_mcp_server_tools_cache", fake_invalidate_tools_cache) + monkeypatch.setattr(tool_registry_service, "_clear_mcp_server_runtime_auth_cache", fake_clear_runtime_auth_cache) + monkeypatch.setattr(tool_registry_service, "invalidate_mcp_server_tools_cache", fake_invalidate_tools_cache) - await mcp_service.update_mcp_server( + await server_service.update_mcp_server( runtime_session, "finance-gateway", auth_config={ diff --git a/backend/test/unit/services/test_mcp_connection_service.py b/backend/test/unit/services/test_mcp_connection_service.py index 348e7b598..3c803a7b4 100644 --- a/backend/test/unit/services/test_mcp_connection_service.py +++ b/backend/test/unit/services/test_mcp_connection_service.py @@ -8,7 +8,9 @@ os.environ.setdefault("OPENAI_API_KEY", "test-key") -from yuxi.services import mcp_service +from yuxi.services.mcp import connection_service, server_service, tool_registry_service +from yuxi.services.mcp.client_pool import mcp_client_pool +from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache from yuxi.services.mcp_auth.crypto import decrypt_credential_blob from yuxi.storage.postgres.models_business import AgentConfig, Department, MCPConnection, MCPServer, Skill @@ -60,7 +62,7 @@ async def test_create_and_list_mcp_connections(connection_service_session, monke ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="finance-gateway", scope_type="department", @@ -72,7 +74,7 @@ async def test_create_and_list_mcp_connections(connection_service_session, monke created_by="tester", ) - listed = await mcp_service.list_mcp_connections(connection_service_session, server_name="finance-gateway") + listed = await connection_service.list_mcp_connections(connection_service_session, server_name="finance-gateway") assert created.server_name == "finance-gateway" assert created.scope_type == "department" @@ -92,7 +94,7 @@ async def test_create_mcp_connection_normalizes_system_scope_to_global(connectio ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="global-gateway", scope_type="system", @@ -119,7 +121,7 @@ async def test_set_mcp_connection_status_updates_status(connection_service_sessi ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="corp-gateway", scope_type="system", @@ -129,7 +131,7 @@ async def test_set_mcp_connection_status_updates_status(connection_service_sessi created_by="tester", ) - updated = await mcp_service.set_mcp_connection_status( + updated = await connection_service.set_mcp_connection_status( connection_service_session, created.id, status="reauth_required", @@ -154,7 +156,7 @@ async def test_create_mcp_connection_rejects_invalid_scope_type(connection_servi await connection_service_session.commit() with pytest.raises(ValueError, match="scope_type"): - await mcp_service.create_mcp_connection( + await connection_service.create_mcp_connection( connection_service_session, server_name="invalid-scope-gateway", scope_type="tenant", @@ -177,7 +179,7 @@ async def test_create_mcp_connection_rejects_missing_department_scope_id(connect await connection_service_session.commit() with pytest.raises(ValueError, match="scope_id"): - await mcp_service.create_mcp_connection( + await connection_service.create_mcp_connection( connection_service_session, server_name="missing-scope-id-gateway", scope_type="department", @@ -199,7 +201,7 @@ async def test_set_mcp_connection_status_rejects_invalid_status(connection_servi ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="invalid-status-gateway", scope_type="system", @@ -208,7 +210,7 @@ async def test_set_mcp_connection_status_rejects_invalid_status(connection_servi ) with pytest.raises(ValueError, match="status"): - await mcp_service.set_mcp_connection_status( + await connection_service.set_mcp_connection_status( connection_service_session, created.id, status="broken", @@ -230,7 +232,7 @@ async def test_create_mcp_connection_encrypts_credentials(connection_service_ses await connection_service_session.commit() plaintext = '{"secrets":{"access_token":"secure-token"}}' - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="secure-gateway", scope_type="system", @@ -259,7 +261,7 @@ async def test_create_mcp_connection_rejects_plaintext_credentials_without_maste await connection_service_session.commit() with pytest.raises(ValueError, match="MCP_CREDENTIALS_MASTER_KEY"): - await mcp_service.create_mcp_connection( + await connection_service.create_mcp_connection( connection_service_session, server_name="insecure-gateway", scope_type="system", @@ -321,7 +323,7 @@ async def test_get_mcp_server_dependency_summary_reports_runtime_references(dele ) await delete_semantics_session.commit() - summary = await mcp_service.get_mcp_server_dependency_summary(delete_semantics_session, "finance-gateway") + summary = await server_service.get_mcp_server_dependency_summary(delete_semantics_session, "finance-gateway") assert summary["has_references"] is True assert summary["connections"] == [{"scope_type": "department", "scope_id": "42", "status": "active"}] @@ -342,7 +344,7 @@ async def test_update_mcp_connection_reencrypts_credentials(connection_service_s ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="update-gateway", scope_type="system", @@ -352,7 +354,7 @@ async def test_update_mcp_connection_reencrypts_credentials(connection_service_s created_by="tester", ) - updated = await mcp_service.update_mcp_connection( + updated = await connection_service.update_mcp_connection( connection_service_session, created.id, display_name="new", @@ -377,7 +379,7 @@ async def delete_access_token(self, connection_id): async def release_refresh_lock(self, connection_id): released_connection_ids.append(connection_id) - monkeypatch.setattr(mcp_service, "RedisTokenCache", lambda: DummyTokenCache()) + monkeypatch.setattr(connection_service, "RedisTokenCache", lambda: DummyTokenCache()) connection_service_session.add( MCPServer( name="delete-connection-gateway", @@ -389,7 +391,7 @@ async def release_refresh_lock(self, connection_id): ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="delete-connection-gateway", scope_type="system", @@ -398,12 +400,12 @@ async def release_refresh_lock(self, connection_id): created_by="tester", ) - deleted = await mcp_service.delete_mcp_connection(connection_service_session, created.id) + deleted = await connection_service.delete_mcp_connection(connection_service_session, created.id) assert deleted is True assert cleared_connection_ids == [created.id] assert released_connection_ids == [created.id] - assert await mcp_service.get_mcp_connection(connection_service_session, created.id) is None + assert await connection_service.get_mcp_connection(connection_service_session, created.id) is None async def test_reauthorize_mcp_connection_clears_runtime_error(connection_service_session, monkeypatch): @@ -418,7 +420,7 @@ async def delete_access_token(self, connection_id): async def release_refresh_lock(self, connection_id): released_connection_ids.append(connection_id) - monkeypatch.setattr(mcp_service, "RedisTokenCache", lambda: DummyTokenCache()) + monkeypatch.setattr(connection_service, "RedisTokenCache", lambda: DummyTokenCache()) connection_service_session.add( MCPServer( @@ -431,7 +433,7 @@ async def release_refresh_lock(self, connection_id): ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="reauth-gateway", scope_type="system", @@ -442,7 +444,7 @@ async def release_refresh_lock(self, connection_id): created_by="tester", ) - updated = await mcp_service.reauthorize_mcp_connection( + updated = await connection_service.reauthorize_mcp_connection( connection_service_session, created.id, updated_by="admin", @@ -469,7 +471,7 @@ async def delete_access_token(self, connection_id): async def release_refresh_lock(self, connection_id): released_connection_ids.append(connection_id) - monkeypatch.setattr(mcp_service, "RedisTokenCache", lambda: DummyTokenCache()) + monkeypatch.setattr(connection_service, "RedisTokenCache", lambda: DummyTokenCache()) connection_service_session.add( MCPServer( name="credential-update-gateway", @@ -481,7 +483,7 @@ async def release_refresh_lock(self, connection_id): ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="credential-update-gateway", scope_type="system", @@ -490,7 +492,7 @@ async def release_refresh_lock(self, connection_id): created_by="tester", ) - updated = await mcp_service.update_mcp_connection( + updated = await connection_service.update_mcp_connection( connection_service_session, created.id, credential_blob='{"secrets":{"access_token":"new-token"}}', @@ -516,7 +518,7 @@ async def delete_access_token(self, connection_id): async def release_refresh_lock(self, connection_id): released_connection_ids.append(connection_id) - monkeypatch.setattr(mcp_service, "RedisTokenCache", lambda: DummyTokenCache()) + monkeypatch.setattr(connection_service, "RedisTokenCache", lambda: DummyTokenCache()) connection_service_session.add( MCPServer( name="retire-gateway", @@ -529,7 +531,7 @@ async def release_refresh_lock(self, connection_id): ) await connection_service_session.commit() - first = await mcp_service.create_mcp_connection( + first = await connection_service.create_mcp_connection( connection_service_session, server_name="retire-gateway", scope_type="department", @@ -537,7 +539,7 @@ async def release_refresh_lock(self, connection_id): credential_blob='{"secrets":{"access_token":"token-1"}}', created_by="tester", ) - second = await mcp_service.create_mcp_connection( + second = await connection_service.create_mcp_connection( connection_service_session, server_name="retire-gateway", scope_type="department", @@ -546,7 +548,7 @@ async def release_refresh_lock(self, connection_id): created_by="tester", ) - enabled, server = await mcp_service.set_server_enabled( + enabled, server = await server_service.set_server_enabled( connection_service_session, "retire-gateway", False, @@ -570,8 +572,8 @@ async def fake_get_mcp_tools(server_name, additional_servers=None, disabled_tool del additional_servers, disabled_tools, kwargs return [server_name, "tool-b"] - monkeypatch.setattr(mcp_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + monkeypatch.setattr(server_service, "get_runtime_mcp_server_config", fake_get_runtime_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) connection_service_session.add( MCPServer( @@ -584,7 +586,7 @@ async def fake_get_mcp_tools(server_name, additional_servers=None, disabled_tool ) await connection_service_session.commit() - created = await mcp_service.create_mcp_connection( + created = await connection_service.create_mcp_connection( connection_service_session, server_name="test-gateway", scope_type="department", @@ -595,7 +597,7 @@ async def fake_get_mcp_tools(server_name, additional_servers=None, disabled_tool created_by="tester", ) - result = await mcp_service.test_mcp_connection( + result = await connection_service.test_mcp_connection( connection_service_session, created.id, updated_by="admin", diff --git a/backend/test/unit/services/test_mcp_service.py b/backend/test/unit/services/test_mcp_tool_registry_service.py similarity index 73% rename from backend/test/unit/services/test_mcp_service.py rename to backend/test/unit/services/test_mcp_tool_registry_service.py index 6668edf4a..f3d6578cb 100644 --- a/backend/test/unit/services/test_mcp_service.py +++ b/backend/test/unit/services/test_mcp_tool_registry_service.py @@ -2,7 +2,9 @@ from types import SimpleNamespace -from yuxi.services import mcp_service +from yuxi.services.mcp import connection_service, server_service, tool_registry_service +from yuxi.services.mcp.client_pool import mcp_client_pool +from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache from yuxi.services.mcp_tool_cache import RedisMcpToolCache from yuxi.services.mcp_auth.proxy_service import INTERNAL_PROXY_TOKEN_HEADER @@ -53,10 +55,10 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled ) return ["tool-a"] - monkeypatch.setattr(mcp_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + monkeypatch.setattr(server_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) - tools = await mcp_service.get_enabled_mcp_tools("demo") + tools = await tool_registry_service.get_enabled_mcp_tools("demo") assert tools == ["tool-a"] assert captured == [ @@ -69,7 +71,7 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled async def test_get_mcp_tools_rebuilds_cache_when_config_hash_changes(monkeypatch): - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() configs = [ {"transport": "stdio", "command": "demo-v1", "disabled_tools": []}, @@ -88,21 +90,21 @@ async def fake_get_mcp_client(server_configs): tool = SimpleNamespace(name=f"tool_for_{config['command']}", metadata={}) return _FakeClient([tool]) - monkeypatch.setattr(mcp_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + monkeypatch.setattr(server_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) - tools_v1_first = await mcp_service.get_mcp_tools("demo") - tools_v1_second = await mcp_service.get_mcp_tools("demo") + tools_v1_first = await tool_registry_service.get_mcp_tools("demo") + tools_v1_second = await tool_registry_service.get_mcp_tools("demo") configs[0] = configs[1] - tools_v2 = await mcp_service.get_mcp_tools("demo") + tools_v2 = await tool_registry_service.get_mcp_tools("demo") assert [tool.name for tool in tools_v1_first] == ["tool_for_demo-v1"] assert [tool.name for tool in tools_v1_second] == ["tool_for_demo-v1"] assert [tool.name for tool in tools_v2] == ["tool_for_demo-v2"] assert build_calls == ["demo-v1", "demo-v2"] - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() async def test_get_tools_from_all_servers_loads_names_from_db_once(monkeypatch): @@ -121,10 +123,10 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs calls.append((server_name, additional_servers or {})) return [server_name] - monkeypatch.setattr(mcp_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + monkeypatch.setattr(server_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) - tools = await mcp_service.get_tools_from_all_servers() + tools = await tool_registry_service.get_tools_from_all_servers() assert tools == ["alpha", "beta"] assert calls == [ @@ -134,7 +136,7 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs async def test_get_mcp_tools_sets_handle_tool_error(monkeypatch): - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() config = {"transport": "stdio", "command": "demo-tool", "disabled_tools": []} @@ -146,18 +148,18 @@ async def fake_get_mcp_client(server_configs): tool = SimpleNamespace(name="demo_tool", metadata={}) return _FakeClient([tool]) - monkeypatch.setattr(mcp_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + monkeypatch.setattr(server_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) - tools = await mcp_service.get_mcp_tools("demo") + tools = await tool_registry_service.get_mcp_tools("demo") assert len(tools) == 1 assert tools[0].handle_tool_error is True - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() async def test_get_mcp_tools_keeps_connection_partitions_separate(monkeypatch): - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() configs = [ { @@ -187,20 +189,20 @@ async def fake_get_mcp_client(server_configs): tool = SimpleNamespace(name=f"tool_for_{token}", metadata={}) return _FakeClient([tool]) - monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) - tools_a = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) - tools_b = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) + tools_a = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) + tools_b = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) assert [tool.name for tool in tools_a] == ["tool_for_proxy-token-user-a"] assert [tool.name for tool in tools_b] == ["tool_for_proxy-token-user-b"] assert build_calls == ["proxy-token-user-a", "proxy-token-user-b"] - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() async def test_get_mcp_tools_does_not_cache_internal_proxy_tool_objects(monkeypatch): - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() configs = [ { @@ -232,16 +234,16 @@ async def fake_get_mcp_client(server_configs): tool = SimpleNamespace(name=f"tool_for_{token}", metadata={}) return _FakeClient([tool]) - monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) - tools_first = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) - tools_second = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) + tools_first = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) + tools_second = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) assert [tool.name for tool in tools_first] == ["tool_for_proxy-token-v1"] assert [tool.name for tool in tools_second] == ["tool_for_proxy-token-v2"] assert build_calls == ["proxy-token-v1", "proxy-token-v2"] - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() async def test_get_tools_from_all_servers_skips_runtime_auth_servers_without_context(monkeypatch): @@ -278,10 +280,10 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs calls.append((server_name, additional_servers or {})) return [server_name] - monkeypatch.setattr(mcp_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) - monkeypatch.setattr(mcp_service, "get_mcp_tools", fake_get_mcp_tools) + monkeypatch.setattr(server_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) - tools = await mcp_service.get_tools_from_all_servers() + tools = await tool_registry_service.get_tools_from_all_servers() assert tools == ["shared"] assert calls == [ @@ -290,16 +292,14 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs async def test_get_mcp_tools_rebuilds_when_redis_server_revision_changes(monkeypatch): - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() fake_redis = _FakeRedis() async def fake_redis_factory(): return fake_redis - monkeypatch.setattr( - mcp_service, - "_mcp_tool_cache_store", + monkeypatch.setattr(tool_registry_service, "_mcp_tool_cache_store", RedisMcpToolCache(redis_client_factory=fake_redis_factory), ) @@ -311,32 +311,30 @@ async def fake_get_mcp_client(server_configs): tool = SimpleNamespace(name=f"tool_{len(build_calls)}", metadata={}) return _FakeClient([tool]) - monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) - tools_first = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": config}) - tools_second = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": config}) - await mcp_service._mcp_tool_cache_store.bump_server_revision("demo") - tools_third = await mcp_service.get_mcp_tools("demo", additional_servers={"demo": config}) + tools_first = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": config}) + tools_second = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": config}) + await tool_registry_service._mcp_tool_cache_store.bump_server_revision("demo") + tools_third = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": config}) assert [tool.name for tool in tools_first] == ["tool_1"] assert [tool.name for tool in tools_second] == ["tool_1"] assert [tool.name for tool in tools_third] == ["tool_2"] assert build_calls == ["demo-tool", "demo-tool"] - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() async def test_get_all_mcp_tools_uses_redis_manifest_when_local_cache_is_empty(monkeypatch): - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() fake_redis = _FakeRedis() async def fake_redis_factory(): return fake_redis - monkeypatch.setattr( - mcp_service, - "_mcp_tool_cache_store", + monkeypatch.setattr(tool_registry_service, "_mcp_tool_cache_store", RedisMcpToolCache(redis_client_factory=fake_redis_factory), ) @@ -361,20 +359,20 @@ async def fake_get_enabled_mcp_server_config(server_name: str, db=None): del server_name, db return config - monkeypatch.setattr(mcp_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) - monkeypatch.setattr(mcp_service, "get_mcp_client", fake_get_mcp_client) + monkeypatch.setattr(server_service, "get_enabled_mcp_server_config", fake_get_enabled_mcp_server_config) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) - tools_first = await mcp_service.get_all_mcp_tools("demo") + tools_first = await tool_registry_service.get_all_mcp_tools("demo") assert [tool.name for tool in tools_first] == ["alpha_tool"] - mcp_service.clear_mcp_cache() + tool_registry_service.clear_mcp_cache() async def fail_get_mcp_client(server_configs): raise AssertionError(f"should not fetch live tools when redis manifest is available: {server_configs}") - monkeypatch.setattr(mcp_service, "get_mcp_client", fail_get_mcp_client) + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fail_get_mcp_client) - tools_second = await mcp_service.get_all_mcp_tools("demo") + tools_second = await tool_registry_service.get_all_mcp_tools("demo") assert [tool.name for tool in tools_second] == ["alpha_tool"] assert tools_second[0].metadata["id"] == "mcp__demo__alphaTool" diff --git a/fix_mcp_service_imports.py b/fix_mcp_service_imports.py new file mode 100644 index 000000000..3a29bceb1 --- /dev/null +++ b/fix_mcp_service_imports.py @@ -0,0 +1,49 @@ +import re +import os + +files_to_fix = [ + ("backend/package/yuxi/services/mcp_auth/orchestrator.py", "from yuxi.services.mcp_service import RedisTokenCache", "from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache"), + ("backend/package/yuxi/services/skill_service.py", "from yuxi.services.mcp_service import get_enabled_mcp_server_names", "from yuxi.services.mcp.server_service import get_enabled_mcp_server_names"), + ("backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py", "from yuxi.services.mcp_service import get_mcp_tools", "from yuxi.services.mcp.tool_registry_service import get_mcp_tools"), + ("backend/package/yuxi/agents/__init__.py", "from yuxi.services.mcp_service import get_enabled_mcp_tools", "from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools"), + ("backend/package/yuxi/agents/middlewares/skills_middleware.py", "from yuxi.services.mcp_service import get_enabled_mcp_tools", "from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools"), + ("backend/package/yuxi/agents/middlewares/runtime_config_middleware.py", "from yuxi.services.mcp_service import get_enabled_mcp_tools", "from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools"), + ("backend/package/yuxi/agents/buildin/deep_agent/graph.py", "from yuxi.services.mcp_service import get_tools_from_all_servers", "from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers"), + ("backend/package/yuxi/agents/buildin/chatbot/graph.py", "from yuxi.services.mcp_service import get_tools_from_all_servers", "from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers"), + ("backend/server/utils/lifespan.py", "from yuxi.services.mcp_service import ensure_builtin_mcp_servers_in_db", "from yuxi.services.mcp.server_service import ensure_builtin_mcp_servers_in_db"), +] + +for file_path, old_str, new_str in files_to_fix: + if os.path.exists(file_path): + with open(file_path, "r") as f: + content = f.read() + content = content.replace(old_str, new_str) + with open(file_path, "w") as f: + f.write(content) + +# Fix tests +tests = [ + "backend/test/unit/services/test_mcp_auth_runtime.py", + "backend/test/unit/services/test_mcp_tool_registry_service.py", + "backend/test/unit/services/test_mcp_connection_service.py" +] + +for file_path in tests: + if os.path.exists(file_path): + with open(file_path, "r") as f: + content = f.read() + + content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"get_mcp_tools"', 'monkeypatch.setattr(tool_registry_service, "get_mcp_tools"', content) + content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"get_enabled_mcp_server_config"', 'monkeypatch.setattr(server_service, "get_enabled_mcp_server_config"', content) + content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"_load_enabled_mcp_server_configs"', 'monkeypatch.setattr(server_service, "_load_enabled_mcp_server_configs"', content) + content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"get_runtime_mcp_server_config"', 'monkeypatch.setattr(server_service, "get_runtime_mcp_server_config"', content) + content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"get_mcp_client"', 'monkeypatch.setattr(mcp_client_pool, "_get_mcp_client"', content) + content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"_clear_mcp_server_runtime_auth_cache"', 'monkeypatch.setattr(tool_registry_service, "_clear_mcp_server_runtime_auth_cache"', content) + content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"invalidate_mcp_server_tools_cache"', 'monkeypatch.setattr(tool_registry_service, "invalidate_mcp_server_tools_cache"', content) + + # for `monkeypatch.setattr(\n mcp_service,\n ...)` + content = re.sub(r'monkeypatch\.setattr\(\s*mcp_service,\s*"_mcp_tool_cache_store"', 'monkeypatch.setattr(tool_registry_service, "_mcp_tool_cache_store"', content) + content = content.replace("await mcp_service.update_mcp_server", "await server_service.update_mcp_server") + + with open(file_path, "w") as f: + f.write(content) diff --git a/fix_tests.py b/fix_tests.py new file mode 100644 index 000000000..7c61f00a6 --- /dev/null +++ b/fix_tests.py @@ -0,0 +1,55 @@ +import re + +files = [ + "backend/test/unit/services/test_mcp_connection_service.py", + "backend/test/unit/services/test_mcp_service.py", + "backend/test/unit/services/test_mcp_service_auth_runtime.py" +] + +replacements = { + "create_mcp_connection": "connection_service", + "list_mcp_connections": "connection_service", + "set_mcp_connection_status": "connection_service", + "delete_mcp_connection": "connection_service", + "reauthorize_mcp_connection": "connection_service", + "update_mcp_connection": "connection_service", + "test_mcp_connection": "connection_service", + "get_mcp_connection": "connection_service", + "_resolve_scope_id": "connection_service", + + "get_mcp_server_dependency_summary": "server_service", + "set_server_enabled": "server_service", + "get_runtime_mcp_server_config": "server_service", + "get_enabled_mcp_server_config": "server_service", + "_load_enabled_mcp_server_configs": "server_service", + + "get_mcp_tools": "tool_registry_service", + "get_enabled_mcp_tools": "tool_registry_service", + "get_all_mcp_tools": "tool_registry_service", + "get_tools_from_all_servers": "tool_registry_service", + "clear_mcp_cache": "tool_registry_service", + "_mcp_tool_cache_store": "tool_registry_service", + "_clear_mcp_server_runtime_auth_cache": "tool_registry_service", +} + +for file_path in files: + with open(file_path, "r") as f: + content = f.read() + + # 替换 import + content = content.replace("from yuxi.services import mcp_service", "from yuxi.services.mcp import connection_service, server_service, tool_registry_service\nfrom yuxi.services.mcp.client_pool import mcp_client_pool\nfrom yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache") + + # 特殊替换 get_mcp_client + content = content.replace("mcp_service.get_mcp_client", "mcp_client_pool._get_mcp_client") + content = content.replace('mcp_service", "get_mcp_client"', 'mcp_client_pool", "_get_mcp_client"') + + # 替换 mcp_service.func -> specific_service.func + for func, service in replacements.items(): + content = re.sub(rf'mcp_service\.{func}', f'{service}.{func}', content) + content = re.sub(rf'mcp_service",\s*"{func}"', f'{service}", "{func}"', content) + + # 对于剩余的 mcp_service,如果是 monkeypatch.setattr(mcp_service, "RedisTokenCache" + content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"RedisTokenCache"', 'monkeypatch.setattr(connection_service, "RedisTokenCache"', content) + + with open(file_path, "w") as f: + f.write(content) From 2ad8f77d7070787e487e518c6a7c4e597e629a83 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 01:06:42 +0800 Subject: [PATCH 12/36] feat(ui): update MCP scope id inputs to dropdown select --- web/src/apis/user_api.js | 5 +++ .../components/extensions/McpDetailView.vue | 42 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 web/src/apis/user_api.js diff --git a/web/src/apis/user_api.js b/web/src/apis/user_api.js new file mode 100644 index 000000000..0af678a51 --- /dev/null +++ b/web/src/apis/user_api.js @@ -0,0 +1,5 @@ +import { apiAdminGet } from './base' + +export const userApi = { + getUsers: () => apiAdminGet('/api/auth/users') +} diff --git a/web/src/components/extensions/McpDetailView.vue b/web/src/components/extensions/McpDetailView.vue index 81d353205..57b0273ad 100644 --- a/web/src/components/extensions/McpDetailView.vue +++ b/web/src/components/extensions/McpDetailView.vue @@ -647,7 +647,26 @@ required class="form-item" > + + { } }) +const loadScopeOptions = async () => { + try { + isFetchingScopeOptions.value = true + const [usersRes, deptsRes] = await Promise.all([ + userApi.getUsers(), + departmentApi.getDepartments() + ]) + userList.value = usersRes || [] + departmentList.value = deptsRes || [] + } catch (err) { + message.error('获取用户/部门列表失败: ' + err.message) + } finally { + isFetchingScopeOptions.value = false + } +} + onMounted(() => { fetchServer() + loadScopeOptions() }) From a94337089bb5d811c205f06f3ec8acba32c367c2 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 06:36:40 +0800 Subject: [PATCH 13/36] =?UTF-8?q?test:=20=E4=BF=AE=E5=A4=8D=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E5=BC=95=E5=8F=91=E7=9A=84=E7=AB=AF=E5=88=B0=E7=AB=AF?= =?UTF-8?q?=E5=8F=8A=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E5=B9=B6=E8=B0=83=E6=95=B4=E8=B7=AF=E7=94=B1=E5=91=BD=E5=90=8D?= =?UTF-8?q?=E9=98=B2=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复由废弃 McpConnectionService Facade 并底层直连 proxy_service 引发的参数注入不匹配问题 - 修复并优化 test_mcp_auth_proxy_service 中对于 httpx Mock 的 StreamConsumed 处理及重试模拟逻辑 - 更新 test_chat_service_langfuse_stream 测试中对 AgentConfigRepository 的 Mock 注入依赖 - 修正 mcp_internal_router 集成测试中 Header 缺失时默认抛出 401 Unauthorized 而非 422 的断言 - 重命名集成测试目录下 test_mcp_router.py 为 test_integration_mcp_router.py,解决 Pytest 全量测试模块命名冲突问题 --- .../yuxi/services/mcp_auth/proxy_service.py | 11 +- ...uter.py => test_integration_mcp_router.py} | 0 ...est_model_provider_runtime_connectivity.py | 4 +- .../unit/routers/test_mcp_internal_router.py | 151 ++---------------- .../test_chat_service_langfuse_stream.py | 27 ++++ .../services/test_mcp_auth_proxy_service.py | 93 ++++++++--- .../services/test_mcp_connection_service.py | 8 +- .../test_remote_skill_install_service.py | 2 +- 8 files changed, 119 insertions(+), 177 deletions(-) rename backend/test/integration/api/{test_mcp_router.py => test_integration_mcp_router.py} (100%) diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py index 266aada40..fd37ea816 100644 --- a/backend/package/yuxi/services/mcp_auth/proxy_service.py +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -196,17 +196,22 @@ async def _proxy_mcp_request_stream( body: bytes, path: str = "", db: AsyncSession, + _http_client: httpx.AsyncClient | None = None, + _token_cache: Any | None = None, ) -> Response: """底层流式转发逻辑:处理 HTTPX 透传、SSE 和 401 重试闭环事务""" auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) if server.transport not in _HTTP_TRANSPORTS: raise HTTPException(status_code=400, detail=f"Internal proxy only supports HTTP MCP transports, got: {server.transport}") - http_client = httpx.AsyncClient(timeout=server.timeout or 60.0) + http_client = _http_client or httpx.AsyncClient(timeout=server.timeout or 60.0) bg_task = BackgroundTask(http_client.aclose) - from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache - token_cache = RedisTokenCache() + if _token_cache is not None: + token_cache = _token_cache + else: + from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + token_cache = RedisTokenCache() max_attempts = 2 if auth_config.refresh_policy.retry_once_on_401 else 1 diff --git a/backend/test/integration/api/test_mcp_router.py b/backend/test/integration/api/test_integration_mcp_router.py similarity index 100% rename from backend/test/integration/api/test_mcp_router.py rename to backend/test/integration/api/test_integration_mcp_router.py diff --git a/backend/test/integration/services/test_model_provider_runtime_connectivity.py b/backend/test/integration/services/test_model_provider_runtime_connectivity.py index 56f01052b..9964603a8 100644 --- a/backend/test/integration/services/test_model_provider_runtime_connectivity.py +++ b/backend/test/integration/services/test_model_provider_runtime_connectivity.py @@ -15,7 +15,7 @@ from yuxi.models.embed import OllamaEmbedding, OtherEmbedding from yuxi.models.rerank import DashscopeReranker, OpenAIReranker from yuxi.services.model_provider_service import ( - _resolve_api_key, + resolve_api_key, ensure_builtin_model_providers_in_db, get_model_provider_by_id, ) @@ -34,7 +34,7 @@ def _model_spec(provider: ModelProvider, model: dict[str, Any]) -> dict[str, Any]: """Turn an enabled model item into runtime parameters for existing model clients.""" - api_key = _resolve_api_key(provider) + api_key = resolve_api_key(provider) if api_key is None: api_key = "no_api_key" return { diff --git a/backend/test/unit/routers/test_mcp_internal_router.py b/backend/test/unit/routers/test_mcp_internal_router.py index e510bc302..2da4e2a65 100644 --- a/backend/test/unit/routers/test_mcp_internal_router.py +++ b/backend/test/unit/routers/test_mcp_internal_router.py @@ -1,12 +1,11 @@ from __future__ import annotations import httpx -from fastapi import FastAPI +from fastapi import FastAPI, Response from fastapi.testclient import TestClient from server.routers.mcp_internal_router import mcp_internal from server.utils.auth_middleware import get_db -from yuxi.services.mcp_auth.orchestrator import AuthContext def _build_app() -> FastAPI: @@ -21,55 +20,20 @@ async def fake_db(): def test_internal_proxy_route_forwards_request(monkeypatch): - class DummyServer: - name = "finance-proxy" - transport = "streamable_http" - auth_config_json = { - "version": 1, - "provider": "custom_http_token", - "binding_scope": "department", - "inject": { - "target": "headers", - "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], - }, - "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, - } - - class DummyConnection: - status = "active" - meta_json = {} - - async def fake_get_mcp_server(db, name): - del db - assert name == "finance-proxy" - return DummyServer() - - async def fake_load_connection(db, *, server, auth_context): - del db - assert server.name == "finance-proxy" - assert auth_context.department_id == "dep-1" - return DummyConnection() - - async def fake_proxy_mcp_request(server, **kwargs): - del kwargs - assert server.name == "finance-proxy" - return httpx.Response( - 200, - json={"ok": True}, - headers={"content-type": "application/json"}, - ) + async def fake_handle_mcp_proxy_request(server_name, request, path, internal_token, db): + assert server_name == "finance-proxy" + assert path == "some/path" + assert internal_token == "test-token" + return Response(content='{"ok": true}', media_type="application/json") monkeypatch.setattr( - "server.routers.mcp_internal_router.decode_proxy_access_token", - lambda token, server_name: AuthContext(user_id="user-1", department_id="dep-1"), + "server.routers.mcp_internal_router.handle_mcp_proxy_request", + fake_handle_mcp_proxy_request ) - monkeypatch.setattr("server.routers.mcp_internal_router.get_mcp_server", fake_get_mcp_server) - monkeypatch.setattr("server.routers.mcp_internal_router._load_active_connection", fake_load_connection) - monkeypatch.setattr("server.routers.mcp_internal_router.proxy_mcp_request", fake_proxy_mcp_request) client = TestClient(_build_app()) resp = client.post( - "/api/internal/mcp-proxy/finance-proxy", + "/api/internal/mcp-proxy/finance-proxy/some/path", headers={"X-Yuxi-MCP-Proxy-Token": "test-token", "content-type": "application/json"}, json={"jsonrpc": "2.0", "id": 1}, ) @@ -81,99 +45,4 @@ async def fake_proxy_mcp_request(server, **kwargs): def test_internal_proxy_route_requires_internal_token(): client = TestClient(_build_app()) resp = client.post("/api/internal/mcp-proxy/finance-proxy", json={"jsonrpc": "2.0", "id": 1}) - assert resp.status_code == 401, resp.text - - -def test_internal_proxy_route_rejects_disabled_server(monkeypatch): - class DummyServer: - name = "disabled-proxy" - transport = "streamable_http" - enabled = 0 - auth_config_json = { - "version": 1, - "provider": "custom_http_token", - "binding_scope": "department", - "inject": { - "target": "headers", - "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], - }, - "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, - } - - async def fake_get_mcp_server(db, name): - del db - assert name == "disabled-proxy" - return DummyServer() - - async def fake_load_connection(db, *, server, auth_context): - del db, server, auth_context - raise AssertionError("disabled MCP server should be rejected before loading a connection") - - async def fake_proxy_mcp_request(server, **kwargs): - del server, kwargs - raise AssertionError("disabled MCP server should not be proxied") - - monkeypatch.setattr( - "server.routers.mcp_internal_router.decode_proxy_access_token", - lambda token, server_name: AuthContext(user_id="user-1", department_id="dep-1"), - ) - monkeypatch.setattr("server.routers.mcp_internal_router.get_mcp_server", fake_get_mcp_server) - monkeypatch.setattr("server.routers.mcp_internal_router._load_active_connection", fake_load_connection) - monkeypatch.setattr("server.routers.mcp_internal_router.proxy_mcp_request", fake_proxy_mcp_request) - - client = TestClient(_build_app()) - resp = client.post( - "/api/internal/mcp-proxy/disabled-proxy", - headers={"X-Yuxi-MCP-Proxy-Token": "test-token", "content-type": "application/json"}, - json={"jsonrpc": "2.0", "id": 1}, - ) - - assert resp.status_code == 404, resp.text - - -def test_internal_proxy_route_rejects_user_scoped_request_without_active_connection(monkeypatch): - class DummyServer: - name = "personal-proxy" - transport = "streamable_http" - auth_config_json = { - "version": 1, - "provider": "custom_http_token", - "binding_scope": "user", - "inject": { - "target": "headers", - "entries": [{"name": "Authorization", "value_template": "Bearer ${access_token}"}], - }, - "token_request": {"url": "http://gateway.local/auth/token", "method": "POST"}, - } - - async def fake_get_mcp_server(db, name): - del db - assert name == "personal-proxy" - return DummyServer() - - async def fake_load_connection(db, *, server, auth_context): - del db - assert server.name == "personal-proxy" - assert auth_context.user_id == "user-2" - return None - - async def fake_proxy_mcp_request(server, **kwargs): - del server, kwargs - raise AssertionError("proxy request should not run without an active user connection") - - monkeypatch.setattr( - "server.routers.mcp_internal_router.decode_proxy_access_token", - lambda token, server_name: AuthContext(user_id="user-2", department_id="dep-1"), - ) - monkeypatch.setattr("server.routers.mcp_internal_router.get_mcp_server", fake_get_mcp_server) - monkeypatch.setattr("server.routers.mcp_internal_router._load_active_connection", fake_load_connection) - monkeypatch.setattr("server.routers.mcp_internal_router.proxy_mcp_request", fake_proxy_mcp_request) - - client = TestClient(_build_app()) - resp = client.post( - "/api/internal/mcp-proxy/personal-proxy", - headers={"X-Yuxi-MCP-Proxy-Token": "test-token", "content-type": "application/json"}, - json={"jsonrpc": "2.0", "id": 1}, - ) - - assert resp.status_code == 403, resp.text + assert resp.status_code == 401, resp.text # Missing header raises 401 Unauthorized diff --git a/backend/test/unit/services/test_chat_service_langfuse_stream.py b/backend/test/unit/services/test_chat_service_langfuse_stream.py index 331d66ccf..dcc482218 100644 --- a/backend/test/unit/services/test_chat_service_langfuse_stream.py +++ b/backend/test/unit/services/test_chat_service_langfuse_stream.py @@ -98,6 +98,12 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): yield None return + class FakeAgentConfigRepo: + def __init__(self, db): pass + async def get_by_id(self, config_id): + return SimpleNamespace(id=config_id) + monkeypatch.setattr(svc, "AgentConfigRepository", FakeAgentConfigRepo) + monkeypatch.setattr(svc.agent_manager, "get_agent", lambda agent_id: FakeAgent()) monkeypatch.setattr(svc, "get_agent_config_by_id", fake_get_agent_config_by_id) monkeypatch.setattr(svc, "ConversationRepository", _FakeConvRepo) @@ -105,6 +111,13 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): monkeypatch.setattr(svc.content_guard, "check", fake_guard_check) monkeypatch.setattr(svc.content_guard, "check_with_keywords", fake_guard_check_with_keywords) monkeypatch.setattr(svc, "check_and_handle_interrupts", fake_interrupts) + + import contextlib + @contextlib.asynccontextmanager + async def fake_get_async_session_context(): + yield object() + monkeypatch.setattr(svc.pg_manager, "get_async_session_context", fake_get_async_session_context) + monkeypatch.setattr( svc, "_build_langfuse_run_context", @@ -143,6 +156,7 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): "work_id": "login-user-1", "thread_id": "thread-1", "department_id": "dept-1", + "system_prompt": "工号: login-user-1", } assert calls["stream_kwargs"] == { "callbacks": ["handler-1"], @@ -194,6 +208,12 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): yield None return + class FakeAgentConfigRepo: + def __init__(self, db): pass + async def get_by_id(self, config_id): + return SimpleNamespace(id=config_id) + monkeypatch.setattr(svc, "AgentConfigRepository", FakeAgentConfigRepo) + monkeypatch.setattr(svc.agent_manager, "get_agent", lambda agent_id: FakeAgent()) monkeypatch.setattr(svc, "get_agent_config_by_id", fake_get_agent_config_by_id) monkeypatch.setattr(svc, "ConversationRepository", _FakeConvRepo) @@ -201,6 +221,13 @@ async def fake_interrupts(agent, langgraph_config, make_chunk, meta, thread_id): monkeypatch.setattr(svc.content_guard, "check", fake_guard_check) monkeypatch.setattr(svc.content_guard, "check_with_keywords", fake_guard_check_with_keywords) monkeypatch.setattr(svc, "check_and_handle_interrupts", fake_interrupts) + + import contextlib + @contextlib.asynccontextmanager + async def fake_get_async_session_context(): + yield object() + monkeypatch.setattr(svc.pg_manager, "get_async_session_context", fake_get_async_session_context) + monkeypatch.setattr( svc, "_build_langfuse_run_context", diff --git a/backend/test/unit/services/test_mcp_auth_proxy_service.py b/backend/test/unit/services/test_mcp_auth_proxy_service.py index 4a6d8a16a..2b0d2478b 100644 --- a/backend/test/unit/services/test_mcp_auth_proxy_service.py +++ b/backend/test/unit/services/test_mcp_auth_proxy_service.py @@ -1,5 +1,14 @@ from __future__ import annotations + +def make_mock_response(status_code, content): + import httpx + resp = httpx.Response(status_code, content=content) + async def fake_aiter_raw(): + yield content + resp.aiter_raw = fake_aiter_raw + return resp + import json import os from datetime import UTC, datetime, timedelta @@ -10,7 +19,27 @@ os.environ.setdefault("OPENAI_API_KEY", "test-key") from yuxi.services.mcp_auth.orchestrator import AuthContext -from yuxi.services.mcp_auth.proxy_service import proxy_mcp_request +from yuxi.services.mcp_auth.proxy_service import _proxy_mcp_request_stream +from starlette.requests import Request + + +def make_mock_response(status_code, content): + import httpx + resp = httpx.Response(status_code, content=content) + async def fake_aiter_raw(): + yield content + resp.aiter_raw = fake_aiter_raw + return resp + +import json +async def get_response_json(response): + if hasattr(response, "body_iterator"): + body = b"".join([chunk async for chunk in response.body_iterator]) + else: + body = response.body + return json.loads(body) + +from fastapi.responses import StreamingResponse from yuxi.storage.postgres.models_business import MCPConnection, MCPServer @@ -53,9 +82,11 @@ def handler(request: httpx.Request) -> httpx.Response: if str(request.url) == "http://upstream.local/mcp": observed_authorizations.append(request.headers.get("Authorization")) if request.headers.get("Authorization") == "Bearer stale-token": - return httpx.Response(401, json={"error": "expired"}) + return make_mock_response(401, b'{"error": "expired"}') if request.headers.get("Authorization") == "Bearer fresh-token": - return httpx.Response(200, json={"result": "ok"}) + resp = make_mock_response(200, b'{"result": "ok"}') + resp.is_stream_consumed = False + return resp raise AssertionError(f"unexpected request: {request.method} {request.url}") @@ -111,22 +142,26 @@ def handler(request: httpx.Request) -> httpx.Response: } ) - response = await proxy_mcp_request( + req = Request({"type": "http", "method": "POST", "headers": [(b"content-type", b"application/json")], "query_string": b""}) + + class DummyDB: + async def commit(self): pass + + response = await _proxy_mcp_request_stream( server, connection=connection, auth_context=AuthContext(user_id="user-1", department_id="dep-1"), - method="POST", - headers={"content-type": "application/json"}, - query_params={}, + request=req, body=b'{"jsonrpc":"2.0","id":1}', - http_client=http_client, - token_cache=token_cache, + db=DummyDB(), + _http_client=http_client, + _token_cache=token_cache, ) await http_client.aclose() assert response.status_code == 200 - assert response.json() == {"result": "ok"} + assert await get_response_json(response) == {"result": "ok"} assert observed_authorizations == ["Bearer stale-token", "Bearer fresh-token"] assert token_cache.deleted_connection_ids == [41] assert token_cache.set_calls and token_cache.set_calls[0][0] == 41 @@ -149,7 +184,7 @@ def handler(request: httpx.Request) -> httpx.Response: ) if str(request.url) == "http://upstream.local/mcp": attempts += 1 - return httpx.Response(401, json={"error": "expired"}) + return make_mock_response(401, b'{"error": "expired"}') raise AssertionError(f"unexpected request: {request.method} {request.url}") http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) @@ -204,22 +239,26 @@ def handler(request: httpx.Request) -> httpx.Response: } ) - response = await proxy_mcp_request( + req = Request({"type": "http", "method": "POST", "headers": [(b"content-type", b"application/json")], "query_string": b""}) + + class DummyDB: + async def commit(self): pass + + response = await _proxy_mcp_request_stream( server, connection=connection, auth_context=AuthContext(user_id="user-1", department_id="dep-1"), - method="POST", - headers={"content-type": "application/json"}, - query_params={}, + request=req, body=b'{"jsonrpc":"2.0","id":1}', - http_client=http_client, - token_cache=token_cache, + db=DummyDB(), + _http_client=http_client, + _token_cache=token_cache, ) await http_client.aclose() assert response.status_code == 424 - assert response.json()["error"] == "reauth_required" + assert (await get_response_json(response))["error"] == "reauth_required" assert connection.status == "reauth_required" assert connection.meta_json["last_error"]["code"] == "unauthorized" @@ -227,7 +266,7 @@ def handler(request: httpx.Request) -> httpx.Response: async def test_proxy_mcp_request_records_scope_error_on_403(): def handler(request: httpx.Request) -> httpx.Response: if str(request.url) == "http://upstream.local/mcp": - return httpx.Response(403, json={"error": "forbidden"}) + return make_mock_response(403, b'{"error": "forbidden"}') raise AssertionError(f"unexpected request: {request.method} {request.url}") http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) @@ -261,21 +300,23 @@ def handler(request: httpx.Request) -> httpx.Response: updated_by="tester", ) - response = await proxy_mcp_request( + req = Request({"type": "http", "method": "POST", "headers": [(b"content-type", b"application/json")], "query_string": b""}) + class DummyDB: + async def commit(self): pass + response = await _proxy_mcp_request_stream( server, connection=connection, auth_context=AuthContext(user_id="user-1", department_id="dep-1"), - method="POST", - headers={"content-type": "application/json"}, - query_params={}, + request=req, body=b'{"jsonrpc":"2.0","id":1}', - http_client=http_client, - token_cache=None, + db=DummyDB(), + _http_client=http_client, + _token_cache=None, ) await http_client.aclose() assert response.status_code == 403 - assert response.json()["error"] == "insufficient_scope" + assert (await get_response_json(response))["error"] == "insufficient_scope" assert connection.status == "active" assert connection.meta_json["last_error"]["code"] == "insufficient_scope" diff --git a/backend/test/unit/services/test_mcp_connection_service.py b/backend/test/unit/services/test_mcp_connection_service.py index 3c803a7b4..6e6dc3d05 100644 --- a/backend/test/unit/services/test_mcp_connection_service.py +++ b/backend/test/unit/services/test_mcp_connection_service.py @@ -379,7 +379,7 @@ async def delete_access_token(self, connection_id): async def release_refresh_lock(self, connection_id): released_connection_ids.append(connection_id) - monkeypatch.setattr(connection_service, "RedisTokenCache", lambda: DummyTokenCache()) + monkeypatch.setattr("yuxi.services.mcp_auth.redis_token_cache.RedisTokenCache", lambda: DummyTokenCache()) connection_service_session.add( MCPServer( name="delete-connection-gateway", @@ -420,7 +420,7 @@ async def delete_access_token(self, connection_id): async def release_refresh_lock(self, connection_id): released_connection_ids.append(connection_id) - monkeypatch.setattr(connection_service, "RedisTokenCache", lambda: DummyTokenCache()) + monkeypatch.setattr("yuxi.services.mcp_auth.redis_token_cache.RedisTokenCache", lambda: DummyTokenCache()) connection_service_session.add( MCPServer( @@ -471,7 +471,7 @@ async def delete_access_token(self, connection_id): async def release_refresh_lock(self, connection_id): released_connection_ids.append(connection_id) - monkeypatch.setattr(connection_service, "RedisTokenCache", lambda: DummyTokenCache()) + monkeypatch.setattr("yuxi.services.mcp_auth.redis_token_cache.RedisTokenCache", lambda: DummyTokenCache()) connection_service_session.add( MCPServer( name="credential-update-gateway", @@ -518,7 +518,7 @@ async def delete_access_token(self, connection_id): async def release_refresh_lock(self, connection_id): released_connection_ids.append(connection_id) - monkeypatch.setattr(connection_service, "RedisTokenCache", lambda: DummyTokenCache()) + monkeypatch.setattr("yuxi.services.mcp_auth.redis_token_cache.RedisTokenCache", lambda: DummyTokenCache()) connection_service_session.add( MCPServer( name="retire-gateway", diff --git a/backend/test/unit/services/test_remote_skill_install_service.py b/backend/test/unit/services/test_remote_skill_install_service.py index e3d6f65da..bf4b3cb38 100644 --- a/backend/test/unit/services/test_remote_skill_install_service.py +++ b/backend/test/unit/services/test_remote_skill_install_service.py @@ -118,7 +118,7 @@ async def fake_import_skill_dir(_db, *, source_dir, created_by): "-y", "--copy", ] - assert captured["source_dir"] == Path(calls[1][1]) / ".agents" / "skills" / "frontend-design" + assert captured["source_dir"].resolve() == (Path(calls[1][1]) / ".agents" / "skills" / "frontend-design").resolve() assert captured["created_by"] == "root" From fc77ddc35c8f2661c6dc684f8287ca8967d30092 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 06:59:43 +0800 Subject: [PATCH 14/36] =?UTF-8?q?fix(mcp-auth):=20=E4=BF=AE=E5=A4=8D=20MCP?= =?UTF-8?q?=20=E9=89=B4=E6=9D=83=E4=B8=8E=E4=BB=A3=E7=90=86=E9=93=BE?= =?UTF-8?q?=E8=B7=AF=E7=9A=84=E4=B8=89=E4=B8=AA=E5=85=B3=E9=94=AE=E9=9A=90?= =?UTF-8?q?=E6=82=A3=E5=B9=B6=E5=AE=8C=E5=96=84=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 针对 DynamicMCPTokenAuth 数据库频繁查询问题,在 client_pool.py 中引入 15 秒 TTL 在内存缓存并提供联动清理接口; 2. 修复 _normalize_token_payload 对 naive datetime 默认填充时区的偏差,避免 token 无限自动刷新的 Bug; 3. 改进 _calculate_config_hash 在遇到非 JSON 序列化对象时对 json.dumps 降级保护,避免序列化崩溃; 4. 补齐相关功能的单元测试,并修正部分 Module 层的 import 格式。 --- .../package/yuxi/services/mcp/client_pool.py | 61 +++++++--- .../services/mcp/tool_registry_service.py | 71 ++++++++---- .../yuxi/services/mcp_auth/orchestrator.py | 18 ++- .../services/mcp_auth/redis_token_cache.py | 4 +- .../services/test_mcp_auth_orchestrator.py | 31 ++++++ .../unit/services/test_mcp_client_pool.py | 105 ++++++++++++++++++ docs/develop-guides/roadmap.md | 2 +- 7 files changed, 244 insertions(+), 48 deletions(-) diff --git a/backend/package/yuxi/services/mcp/client_pool.py b/backend/package/yuxi/services/mcp/client_pool.py index a205cd896..8ed81380e 100644 --- a/backend/package/yuxi/services/mcp/client_pool.py +++ b/backend/package/yuxi/services/mcp/client_pool.py @@ -1,17 +1,37 @@ from __future__ import annotations + import asyncio import hashlib import json import logging -from typing import Any, AsyncGenerator, TYPE_CHECKING -import httpx +import time +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any +import httpx from langchain_mcp_adapters.client import MultiServerMCPClient from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var if TYPE_CHECKING: from mcp import ClientSession +# 缓存存储格式: (server_name, user_id, department_id) -> (resolved_headers, expires_at) +_resolved_headers_cache: dict[tuple[str, str | None, str | None], tuple[dict[str, Any], float]] = {} +_HEADERS_CACHE_TTL = 15.0 + + +def clear_resolved_headers_cache() -> None: + """清除解析后的 headers 缓存""" + _resolved_headers_cache.clear() + + +def clear_server_resolved_headers_cache(server_name: str) -> None: + """清除指定服务器的解析后 headers 缓存""" + stale_keys = [k for k in _resolved_headers_cache if k[0] == server_name] + for key in stale_keys: + _resolved_headers_cache.pop(key, None) + + logger = logging.getLogger("yuxi.mcp.client_pool") @@ -21,17 +41,28 @@ class DynamicMCPTokenAuth(httpx.Auth): def __init__(self, server_name: str): self.server_name = server_name - async def async_auth_flow( - self, request: httpx.Request - ) -> AsyncGenerator[httpx.Request, httpx.Response]: + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: # NOTE: 1. 从当前协程上下文读取 AuthContext auth_context = mcp_auth_context_var.get() if auth_context: try: + cache_key = (self.server_name, auth_context.user_id, auth_context.department_id) + now = time.time() + cached = _resolved_headers_cache.get(cache_key) + if cached is not None: + headers, expires_at = cached + if now < expires_at: + for key, val in headers.items(): + request.headers[key] = str(val) + yield request + return + # 导入数据库会话管理器以获取连接与 Token from yuxi.storage.postgres.manager import pg_manager + async with pg_manager.get_async_session_context() as session: from yuxi.services.mcp.server_service import get_runtime_mcp_server_config + # NOTE: 2. 读取当前上下文对应的最新运行时配置(含 Token 自动刷新逻辑) runtime_config = await get_runtime_mcp_server_config( self.server_name, @@ -41,12 +72,11 @@ async def async_auth_flow( if runtime_config: # NOTE: 3. 将最新的头部注入到当前 HTTP 请求中 headers = runtime_config.get("headers") or {} + _resolved_headers_cache[cache_key] = (headers, now + _HEADERS_CACHE_TTL) for key, val in headers.items(): request.headers[key] = str(val) except Exception as exc: - logger.error( - f"DynamicMCPTokenAuth failed to resolve token headers for '{self.server_name}': {exc}" - ) + logger.error(f"DynamicMCPTokenAuth failed to resolve token headers for '{self.server_name}': {exc}") yield request @@ -99,7 +129,7 @@ async def stop(self): if self._loop_task: try: await asyncio.wait_for(self._loop_task, timeout=5.0) - except asyncio.TimeoutError: + except TimeoutError: logger.warning(f"Timeout waiting for long-lived session of {self.server_name} to stop.") self._loop_task.cancel() except Exception as exc: @@ -120,7 +150,8 @@ def _calculate_config_hash(self, config: dict[str, Any]) -> str: clean_config = { k: v for k, v in config.items() - if k not in { + if k + not in { "__yuxi_cache_partition", "__yuxi_allow_global_cache", "disabled_tools", @@ -134,7 +165,7 @@ def _calculate_config_hash(self, config: dict[str, Any]) -> str: elif "headers" in clean_config: clean_config["headers"] = {} - payload = json.dumps(clean_config, sort_keys=True, ensure_ascii=True, separators=(",", ":")) + payload = json.dumps(clean_config, sort_keys=True, ensure_ascii=True, separators=(",", ":"), default=str) return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16] async def _get_mcp_client(self, server_configs: dict[str, Any] | None = None) -> MultiServerMCPClient | None: @@ -163,7 +194,7 @@ async def get_session( # NOTE: 如果配置无变化且 Session 处于活动状态,直接复用 if cached_hash == config_hash and ll_session.session is not None: return ll_session.session - + # 如果发生配置变化或 Session 断开,执行销毁 logger.info(f"Destroying stale/disconnected MCP session for {cache_key}") await ll_session.stop() @@ -183,13 +214,15 @@ async def get_session( # 注入 DynamicMCPTokenAuth,让底层 httpx 在长连接执行每个具体请求时动态提取最新 Token client_config["auth"] = DynamicMCPTokenAuth(server_name) - logger.info(f"Creating new long-lived MCP session for {cache_key} (transport: {client_config.get('transport')})") + logger.info( + f"Creating new long-lived MCP session for {cache_key} (transport: {client_config.get('transport')})" + ) client = await self._get_mcp_client({server_name: client_config}) if client is None: raise RuntimeError(f"Failed to initialize MCP client for {server_name}") ll_session = LongLivedSession(client, server_name) await ll_session.start() - + self._sessions[cache_key] = (ll_session, config_hash) return ll_session.session diff --git a/backend/package/yuxi/services/mcp/tool_registry_service.py b/backend/package/yuxi/services/mcp/tool_registry_service.py index 33b8bff3a..557c32a23 100644 --- a/backend/package/yuxi/services/mcp/tool_registry_service.py +++ b/backend/package/yuxi/services/mcp/tool_registry_service.py @@ -1,4 +1,5 @@ from __future__ import annotations + import asyncio import hashlib import json @@ -8,11 +9,9 @@ from types import SimpleNamespace from typing import Any, cast -from sqlalchemy import select +import httpx from sqlalchemy.ext.asyncio import AsyncSession - from yuxi.services.mcp_auth.config_models import MCPAuthConfig -from yuxi.services.mcp_auth.crypto import encrypt_credential_blob from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.mcp_auth.proxy_service import ( INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, @@ -30,11 +29,10 @@ _mcp_lock = asyncio.Lock() - - def to_camel_case(s: str) -> str: """转换字符串为 lowerCamelCase 命名格式""" import re + s = re.sub(r"[-_]+(.)", lambda m: m.group(1).upper(), s) if len(s) > 0: s = s[0].lower() + s[1:] @@ -45,18 +43,19 @@ def _extract_cache_identity(server_config: dict[str, Any]) -> tuple[dict[str, An """提取用于缓存 key 比较的标识配置""" cache_partition = str(server_config.get("__yuxi_cache_partition") or "server") allow_global_cache = bool(server_config.get("__yuxi_allow_global_cache", True)) - + cache_identity = { key: value for key, value in server_config.items() - if key not in { + if key + not in { "__yuxi_cache_partition", "__yuxi_allow_global_cache", INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, "disabled_tools", } } - + headers = dict(cache_identity.get("headers") or {}) headers.pop(INTERNAL_PROXY_TOKEN_HEADER, None) if headers: @@ -71,14 +70,14 @@ async def _build_mcp_tool_cache_descriptor(server_name: str, server_config: dict cache_identity, cache_partition, allow_global_cache = _extract_cache_identity(server_config) config_payload = json.dumps(cache_identity, sort_keys=True, ensure_ascii=True, separators=(",", ":")) config_hash = hashlib.sha256(config_payload.encode("utf-8")).hexdigest()[:16] - + server_revision = await _mcp_tool_cache_store.get_server_revision(server_name) partition_revision = 0 if not allow_global_cache: partition_revision = await _mcp_tool_cache_store.get_partition_revision(server_name, cache_partition) revision_token = f"s{server_revision}:p{partition_revision}" cache_prefix = f"{server_name}:{cache_partition}:{revision_token}:" - + return { "cache_identity": cache_identity, "cache_partition": cache_partition, @@ -189,6 +188,7 @@ async def get_mcp_tools( server_config = additional_servers[server_name] else: from yuxi.services.mcp.server_service import get_enabled_mcp_server_config + server_config = await get_enabled_mcp_server_config(server_name) if server_config is None: @@ -199,13 +199,14 @@ async def get_mcp_tools( cache_partition = cache_descriptor["cache_partition"] cache_prefix = cache_descriptor["cache_prefix"] cache_key = cache_descriptor["cache_key"] - + # 策略模式:根据 AuthProvider 确认是否容许内存缓存 Tool 实例对象 from yuxi.services.mcp.cache_policy import CachePolicyFactory + auth_config = _get_mcp_auth_config(server_config) policy = CachePolicyFactory.get_policy(auth_config.provider if auth_config else None) use_tool_object_cache = ( - cache + cache and policy.should_cache_tool_object() and not bool(server_config.get(INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY)) ) @@ -221,7 +222,8 @@ async def get_mcp_tools( client_config = { k: v for k, v in server_config.items() - if k not in ( + if k + not in ( "disabled_tools", "__yuxi_cache_partition", "__yuxi_allow_global_cache", @@ -229,8 +231,10 @@ async def get_mcp_tools( ) } - # NOTE: 从长连接池中提取 ClientSession 实例(对 Stdio 而言子进程被挂起复用,避免频繁启停;HTTP 协议亦保持 Keep-Alive) + # NOTE: 从长连接池中提取 ClientSession 实例 + # (对 Stdio 而言子进程被挂起复用,避免频繁启停;HTTP 协议亦保持 Keep-Alive) from yuxi.services.mcp.client_pool import mcp_client_pool + session = await mcp_client_pool.get_session( server_name, partition_key=f"{cache_partition}:s{cache_descriptor['server_revision']}:p{cache_descriptor['partition_revision']}", @@ -243,6 +247,7 @@ async def get_mcp_tools( else: # 调用 langchain 官方加载工具,直接传入已预备并建立好的 session from langchain_mcp_adapters.tools import load_mcp_tools + raw_tools = cast(list[Any], await load_mcp_tools(session, server_name=server_name)) server_cc = to_camel_case(server_name) @@ -266,7 +271,7 @@ async def get_mcp_tools( for stale_key in stale_keys: _mcp_tools_cache.pop(stale_key, None) _mcp_tools_cache[cache_key] = all_processed_tools - + await _mcp_tool_cache_store.set_manifest( cache_key, _serialize_mcp_tools_manifest( @@ -306,6 +311,7 @@ async def get_mcp_tools( async def get_tools_from_all_servers() -> list[Callable[..., Any]]: """批量载入所有可用服务的工具(用于系统初始化及预热)""" from yuxi.services.mcp.server_service import _load_enabled_mcp_server_configs + server_configs = await _load_enabled_mcp_server_configs() all_tools = [] for server_name, server_config in server_configs.items(): @@ -323,8 +329,10 @@ def clear_mcp_cache() -> None: _mcp_tools_cache = {} try: - from yuxi.services.mcp.client_pool import mcp_client_pool + from yuxi.services.mcp.client_pool import clear_resolved_headers_cache, mcp_client_pool + mcp_client_pool._sessions.clear() + clear_resolved_headers_cache() except Exception: pass @@ -337,6 +345,13 @@ def clear_mcp_server_tools_cache(server_name: str) -> None: for key in stale_keys: _mcp_tools_cache.pop(key, None) + try: + from yuxi.services.mcp.client_pool import clear_server_resolved_headers_cache + + clear_server_resolved_headers_cache(server_name) + except Exception: + pass + def clear_mcp_connection_tools_cache(server_name: str, connection_id: int | None) -> None: """清空指定连接下的本地内存缓存""" @@ -348,6 +363,13 @@ def clear_mcp_connection_tools_cache(server_name: str, connection_id: int | None for key in stale_keys: _mcp_tools_cache.pop(key, None) + try: + from yuxi.services.mcp.client_pool import clear_server_resolved_headers_cache + + clear_server_resolved_headers_cache(server_name) + except Exception: + pass + async def invalidate_mcp_server_tools_cache(server_name: str) -> None: """全局失效指定服务器的全部二级缓存""" @@ -376,6 +398,7 @@ async def _clear_mcp_connection_runtime_auth_cache(connection_id: int | None) -> if connection_id is None: return from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + cache = RedisTokenCache() try: await cache.delete_access_token(connection_id) @@ -390,6 +413,7 @@ async def _clear_mcp_connection_runtime_auth_cache(connection_id: int | None) -> async def _clear_mcp_server_runtime_auth_cache(db: AsyncSession, server_name: str) -> None: """清理服务器下所有关联连接的 Token 缓存""" from yuxi.services.mcp.connection_service import list_mcp_connections + connections = await list_mcp_connections(db, server_name=server_name) for connection in connections: await _clear_mcp_connection_runtime_auth_cache(getattr(connection, "id", None)) @@ -407,12 +431,13 @@ async def get_enabled_mcp_tools( http_client: httpx.AsyncClient | None = None, ) -> list: from yuxi.services.mcp.server_service import get_runtime_mcp_server_config - + token = None if auth_context: from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + token = mcp_auth_context_var.set(auth_context) - + try: config = await get_runtime_mcp_server_config( server_name, @@ -433,6 +458,7 @@ async def get_enabled_mcp_tools( finally: if token: from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + mcp_auth_context_var.reset(token) @@ -445,12 +471,13 @@ async def get_all_mcp_tools( force_refresh: bool = False, ) -> list: from yuxi.services.mcp.server_service import get_enabled_mcp_server_config, get_runtime_mcp_server_config - + token = None if auth_context: from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + token = mcp_auth_context_var.set(auth_context) - + try: if auth_context is None and db is None: config = await get_enabled_mcp_server_config(server_name) @@ -481,6 +508,7 @@ async def get_all_mcp_tools( finally: if token: from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + mcp_auth_context_var.reset(token) @@ -492,6 +520,7 @@ async def toggle_tool_enabled( ) -> tuple[bool, MCPServer]: """切换单个工具的启用状态""" from yuxi.services.mcp.server_service import get_mcp_server + server = await get_mcp_server(db, server_name) if not server: raise ValueError(f"Server '{server_name}' does not exist") @@ -515,5 +544,3 @@ async def toggle_tool_enabled( logger.info(f"Toggled tool '{tool_name}' for server '{server_name}' enabled={enabled}") return enabled, server - - diff --git a/backend/package/yuxi/services/mcp_auth/orchestrator.py b/backend/package/yuxi/services/mcp_auth/orchestrator.py index aeb7f7558..b26d7875e 100644 --- a/backend/package/yuxi/services/mcp_auth/orchestrator.py +++ b/backend/package/yuxi/services/mcp_auth/orchestrator.py @@ -1,23 +1,20 @@ from __future__ import annotations + import asyncio +import contextvars import json from dataclasses import dataclass from datetime import UTC, datetime, timedelta from typing import Any import httpx - from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_auth.crypto import decrypt_credential_blob -from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache from yuxi.services.mcp_auth.template_resolver import resolve_template_value from yuxi.storage.postgres.models_business import MCPConnection, MCPServer from yuxi.utils import logger -import contextvars - - @dataclass(slots=True) class AuthContext: user_id: str | None = None @@ -96,6 +93,9 @@ def _normalize_token_payload(token_values: dict[str, Any]) -> dict[str, Any]: normalized = dict(token_values) expires_at = normalized.get("expires_at") if isinstance(expires_at, datetime): + if expires_at.tzinfo is None: + # 若无时区信息则默认为 UTC,避免 astimezone() 将其视作本地时区转换 + expires_at = expires_at.replace(tzinfo=UTC) normalized["expires_at"] = expires_at.astimezone(UTC).isoformat() return normalized if isinstance(expires_at, str): @@ -145,7 +145,6 @@ def _merge_injected_entries( return config - async def _load_cached_token( *, token_cache: Any | None, @@ -245,7 +244,7 @@ async def _request_dynamic_token_values( token_values: dict[str, Any], ) -> dict[str, Any]: from yuxi.services.mcp_auth.fetchers.factory import TokenFetcherFactory - + fetcher = TokenFetcherFactory.get_fetcher(auth_config.provider) resolved = await fetcher.fetch_token( auth_config, @@ -258,7 +257,7 @@ async def _request_dynamic_token_values( token_values=token_values, http_client=http_client, ) - + await _store_cached_token( token_cache=token_cache, connection_id=getattr(connection, "id", None), @@ -267,7 +266,6 @@ async def _request_dynamic_token_values( return resolved - async def _resolve_dynamic_token_values( auth_config: MCPAuthConfig, *, @@ -280,8 +278,8 @@ async def _resolve_dynamic_token_values( ) -> dict[str, Any]: if token_cache is None and connection is not None: from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache - token_cache = RedisTokenCache() + token_cache = RedisTokenCache() cached_token = await _load_cached_token( token_cache=token_cache, diff --git a/backend/package/yuxi/services/mcp_auth/redis_token_cache.py b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py index 9b1a2f050..606789ea5 100644 --- a/backend/package/yuxi/services/mcp_auth/redis_token_cache.py +++ b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import uuid from collections.abc import Awaitable, Callable from datetime import UTC, datetime from typing import Any @@ -14,13 +15,13 @@ DEFAULT_LOCK_TTL_SECONDS = 30 -import uuid _PYTEST_SESSION_TOKEN = uuid.uuid4().hex[:8] def _access_token_key(connection_id: int) -> str: key = f"{ACCESS_TOKEN_KEY_PREFIX}:{connection_id}" import os + if os.environ.get("PYTEST_CURRENT_TEST"): return f"test:{_PYTEST_SESSION_TOKEN}:{key}" return key @@ -29,6 +30,7 @@ def _access_token_key(connection_id: int) -> str: def _refresh_lock_key(connection_id: int) -> str: key = f"{REFRESH_LOCK_KEY_PREFIX}:{connection_id}" import os + if os.environ.get("PYTEST_CURRENT_TEST"): return f"test:{_PYTEST_SESSION_TOKEN}:{key}" return key diff --git a/backend/test/unit/services/test_mcp_auth_orchestrator.py b/backend/test/unit/services/test_mcp_auth_orchestrator.py index 0c83578d7..9a66bef48 100644 --- a/backend/test/unit/services/test_mcp_auth_orchestrator.py +++ b/backend/test/unit/services/test_mcp_auth_orchestrator.py @@ -666,3 +666,34 @@ def handler(request: httpx.Request) -> httpx.Response: ("POST", "https://id.example.com/oauth/token"), ] assert resolved["headers"] == {"Authorization": "Bearer oidc-access-token"} + + +async def test_normalize_token_payload_naive_datetime(): + """测试 _normalize_token_payload 对 naive datetime 默认填充 UTC 时区""" + from yuxi.services.mcp_auth.orchestrator import _normalize_token_payload + from datetime import datetime, UTC + + # 构造 naive datetime (无 tzinfo) + naive_dt = datetime(2026, 6, 5, 12, 0, 0) + payload = {"expires_at": naive_dt} + + normalized = _normalize_token_payload(payload) + # 期望转换后有时区,并且值为 2026-06-05T12:00:00+00:00 (ISO格式) + expected_iso = datetime(2026, 6, 5, 12, 0, 0, tzinfo=UTC).isoformat() + assert normalized["expires_at"] == expected_iso + + +async def test_normalize_token_payload_aware_datetime(): + """测试 _normalize_token_payload 对于带时区的 datetime 维持原时区对应 UTC 时间""" + from yuxi.services.mcp_auth.orchestrator import _normalize_token_payload + from datetime import datetime, timezone, timedelta + + # 构造带时区的 datetime (比如东八区) + shanghai_tz = timezone(timedelta(hours=8)) + aware_dt = datetime(2026, 6, 5, 20, 0, 0, tzinfo=shanghai_tz) + payload = {"expires_at": aware_dt} + + normalized = _normalize_token_payload(payload) + # 转换为 UTC 后应该为 2026-06-05T12:00:00+00:00 + expected_iso = datetime(2026, 6, 5, 12, 0, 0, tzinfo=timezone.utc).isoformat() + assert normalized["expires_at"] == expected_iso diff --git a/backend/test/unit/services/test_mcp_client_pool.py b/backend/test/unit/services/test_mcp_client_pool.py index 57a40a163..c02133e1a 100644 --- a/backend/test/unit/services/test_mcp_client_pool.py +++ b/backend/test/unit/services/test_mcp_client_pool.py @@ -93,3 +93,108 @@ async def test_client_pool_reuse_and_recreate(): await pool.shutdown() assert mock_ll_instance.stop.call_count == 2 # 新增的那个也被 stop assert len(pool._sessions) == 0 + + +@pytest.mark.asyncio +async def test_calculate_config_hash_with_non_serializable(): + """测试配置中包含非 JSON 序列化对象时配置哈希计算不崩溃""" + from datetime import datetime + pool = MCPClientPool() + + class DummyObj: + def __str__(self): + return "dummy" + + config = { + "transport": "sse", + "url": "http://example.com/sse", + "custom_obj": DummyObj(), + "created_at": datetime(2026, 6, 5), + } + + # 验证在含有不可 JSON 序列化的对象时依然能正常计算出哈希,不抛出异常 + config_hash = pool._calculate_config_hash(config) + assert isinstance(config_hash, str) + assert len(config_hash) == 16 + + +@pytest.mark.asyncio +async def test_dynamic_mcp_token_auth_cache(): + """测试 DynamicMCPTokenAuth 的 in-memory 缓存及联动清除逻辑""" + from yuxi.services.mcp.client_pool import ( + DynamicMCPTokenAuth, + clear_resolved_headers_cache, + clear_server_resolved_headers_cache, + _resolved_headers_cache, + ) + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var, AuthContext + + # 清空可能存在的全局缓存 + clear_resolved_headers_cache() + + auth = DynamicMCPTokenAuth("test_server") + mock_req = MagicMock(headers={}) + + auth_ctx = AuthContext(user_id="u1", department_id="d1") + cache_key = ("test_server", "u1", "d1") + + # 模拟 get_runtime_mcp_server_config 返回的数据 + mock_runtime_config = {"headers": {"Authorization": "Bearer token123"}} + + token = mcp_auth_context_var.set(auth_ctx) + try: + with patch("yuxi.storage.postgres.manager.pg_manager.get_async_session_context") as mock_session_ctx, \ + patch("yuxi.services.mcp.server_service.get_runtime_mcp_server_config", return_value=mock_runtime_config) as mock_get_config, \ + patch("time.time", side_effect=[1000.0, 1005.0, 1010.0, 1030.0]): + + # 模拟 async with pg_manager.get_async_session_context() as session + mock_session = MagicMock() + mock_ctx_mgr = AsyncMock() + mock_ctx_mgr.__aenter__.return_value = mock_session + mock_session_ctx.return_value = mock_ctx_mgr + + # 1. 第一次请求:应该执行 DB 查询,获取最新运行时配置 + generator = auth.async_auth_flow(mock_req) + results = [r async for r in generator] + assert len(results) == 1 + assert results[0].headers["Authorization"] == "Bearer token123" + assert mock_get_config.call_count == 1 + + # 2. 第二次请求(+5秒):应该命中缓存,不会执行 DB 查询 + mock_req_2 = MagicMock(headers={}) + generator_2 = auth.async_auth_flow(mock_req_2) + results_2 = [r async for r in generator_2] + assert len(results_2) == 1 + assert results_2[0].headers["Authorization"] == "Bearer token123" + # call_count 依然是 1,说明命中了缓存 + assert mock_get_config.call_count == 1 + + # 3. 细粒度清除缓存:清除指定 server_name + clear_server_resolved_headers_cache("test_server") + + # 4. 第三次请求(+10秒):清除缓存后,应该再次执行 DB 查询 + mock_req_3 = MagicMock(headers={}) + generator_3 = auth.async_auth_flow(mock_req_3) + results_3 = [r async for r in generator_3] + assert len(results_3) == 1 + assert mock_get_config.call_count == 2 + + # 5. 第四次请求(+20秒):经过了 10 秒(累计过了15秒的 TTL 限制),由于 1020 - 1010 >= 15,缓存过期,再次执行 DB 查询 + mock_req_4 = MagicMock(headers={}) + generator_4 = auth.async_auth_flow(mock_req_4) + results_4 = [r async for r in generator_4] + assert len(results_4) == 1 + assert mock_get_config.call_count == 3 + + # 6. 测试 clear_mcp_cache / clear_mcp_server_tools_cache 联动清除所有 resolved_headers 缓存 + from yuxi.services.mcp.tool_registry_service import clear_mcp_cache, clear_mcp_server_tools_cache + # 确保当前有缓存项 + _resolved_headers_cache[cache_key] = ({"Auth": "Bearer test"}, 2000.0) + clear_mcp_server_tools_cache("test_server") + assert len(_resolved_headers_cache) == 0 + + _resolved_headers_cache[cache_key] = ({"Auth": "Bearer test"}, 2000.0) + clear_mcp_cache() + assert len(_resolved_headers_cache) == 0 + finally: + mcp_auth_context_var.reset(token) diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index 45d65d1e1..9436853d5 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -37,7 +37,7 @@ ### 0.6.3 开发记录 - 修复 DeepAgent 未绑定 `DeepContext`,导致深度分析专用系统提示词和子智能体默认模型配置未生效的问题;同时避免运行时重复注入默认提示词。 -- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题;统一 Agent 运行态与连接管理页的个人 MCP scope 语义,避免运行态使用数据库主键查找 `mcp_connections.scope_id` 导致个人连接不可用;补齐运行时鉴权 MCP 工具的执行阶段映射,避免模型已绑定 `getTicket` 等动态工具但 ToolNode 静态注册表无法执行的问题。 +- **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期和重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载的 managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题;统一 Agent 运行态与连接管理页的个人 MCP scope 语义,避免运行态使用数据库主键查找 `mcp_connections.scope_id` 导致个人连接不可用;补齐运行时鉴权 MCP 工具的执行阶段映射,避免模型已绑定 `getTicket` 等动态工具但 ToolNode 静态注册表无法执行的问题;审计并修复该链路隐患:通过 DynamicMCPTokenAuth 引入 15 秒 TTL 在内存缓存(含联动清除机制)解决 httpx 请求对 DB 的高频重复查询问题;修复 `_normalize_token_payload` 处理 naive datetime 的时区偏差问题以消除 token 无限自动刷新的 Bug;改进 `_calculate_config_hash` 哈希计算逻辑,对 json.dumps 增加 default=str 降级保护防止无法序列化而崩溃的问题。 --- From 8065a16419184e41547f930345e975947a5848f2 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 07:00:42 +0800 Subject: [PATCH 15/36] =?UTF-8?q?style:=20=E8=87=AA=E5=8A=A8=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E5=8C=96=20Python=20=E4=BB=A3=E7=A0=81=E4=BB=A5?= =?UTF-8?q?=E7=AC=A6=E5=90=88=20Ruff=20=E8=A7=84=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../middlewares/runtime_config_middleware.py | 5 +- .../agents/middlewares/skills_middleware.py | 2 +- backend/package/yuxi/services/mcp/__init__.py | 58 +++++++++---------- .../package/yuxi/services/mcp/cache_policy.py | 3 +- .../yuxi/services/mcp/connection_service.py | 18 ++++-- .../yuxi/services/mcp/server_service.py | 35 +++++++---- .../services/mcp_auth/fetchers/__init__.py | 7 ++- .../yuxi/services/mcp_auth/fetchers/base.py | 10 +++- .../services/mcp_auth/fetchers/factory.py | 3 +- .../mcp_auth/fetchers/http_fetcher.py | 3 + .../mcp_auth/fetchers/oauth_fetcher.py | 12 ++-- .../yuxi/services/mcp_auth/proxy_service.py | 47 +++++++-------- 12 files changed, 119 insertions(+), 84 deletions(-) diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index ad6a70c23..c21fa5d50 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -9,8 +9,8 @@ from yuxi.agents import load_chat_model from yuxi.agents.toolkits import get_all_tool_instances -from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools +from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.utils.datetime_utils import shanghai_now from yuxi.utils.logging_config import logger @@ -140,8 +140,9 @@ async def awrap_tool_call(self, request: ToolCallRequest, handler: Callable[[Too work_id = getattr(runtime_context, "work_id", None) or getattr(runtime_context, "user_id", None) dept_id = getattr(runtime_context, "department_id", None) auth_context = AuthContext(user_id=work_id, department_id=dept_id) - + from yuxi.services.mcp_auth.orchestrator import mcp_auth_context_var + token = mcp_auth_context_var.set(auth_context) try: return await handler(request) diff --git a/backend/package/yuxi/agents/middlewares/skills_middleware.py b/backend/package/yuxi/agents/middlewares/skills_middleware.py index 6b0d4c0b9..d47d34d7b 100644 --- a/backend/package/yuxi/agents/middlewares/skills_middleware.py +++ b/backend/package/yuxi/agents/middlewares/skills_middleware.py @@ -15,8 +15,8 @@ from yuxi.agents.toolkits import get_all_tool_instances from yuxi.repositories.skill_repository import SkillRepository -from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools +from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.services.skill_service import _normalize_string_list, is_valid_skill_slug from yuxi.storage.postgres.manager import pg_manager from yuxi.utils.logging_config import logger diff --git a/backend/package/yuxi/services/mcp/__init__.py b/backend/package/yuxi/services/mcp/__init__.py index fe4a61b73..be19b5986 100644 --- a/backend/package/yuxi/services/mcp/__init__.py +++ b/backend/package/yuxi/services/mcp/__init__.py @@ -1,51 +1,52 @@ from __future__ import annotations + from yuxi.services.mcp.cache_policy import ( + CachePolicyFactory, + DynamicProxyCachePolicy, MCPCachePolicy, StaticCachePolicy, TokenInjectedCachePolicy, - DynamicProxyCachePolicy, - CachePolicyFactory, ) from yuxi.services.mcp.client_pool import ( - mcp_client_pool, MCPClientPool, + mcp_client_pool, +) +from yuxi.services.mcp.connection_service import ( + create_mcp_connection, + delete_mcp_connection, + get_mcp_connection, + list_mcp_connections, + reauthorize_mcp_connection, + set_mcp_connection_status, + test_mcp_connection, + update_mcp_connection, ) from yuxi.services.mcp.server_service import ( + create_mcp_server, + delete_mcp_server, ensure_builtin_mcp_servers_in_db, + get_all_mcp_servers, get_enabled_mcp_server_config, - get_runtime_mcp_server_config, get_enabled_mcp_server_names, get_mcp_server, - get_all_mcp_servers, - create_mcp_server, - update_mcp_server, - delete_mcp_server, get_mcp_server_dependency_summary, - set_server_enabled, + get_runtime_mcp_server_config, get_servers_config, -) -from yuxi.services.mcp.connection_service import ( - get_mcp_connection, - list_mcp_connections, - create_mcp_connection, - update_mcp_connection, - delete_mcp_connection, - set_mcp_connection_status, - reauthorize_mcp_connection, - test_mcp_connection, + set_server_enabled, + update_mcp_server, ) from yuxi.services.mcp.tool_registry_service import ( - to_camel_case, - get_mcp_tools, - get_tools_from_all_servers, clear_mcp_cache, - clear_mcp_server_tools_cache, clear_mcp_connection_tools_cache, - invalidate_mcp_server_tools_cache, - invalidate_mcp_connection_tools_cache, - get_mcp_tools_stats, - get_enabled_mcp_tools, + clear_mcp_server_tools_cache, get_all_mcp_tools, + get_enabled_mcp_tools, + get_mcp_tools, + get_mcp_tools_stats, + get_tools_from_all_servers, + invalidate_mcp_connection_tools_cache, + invalidate_mcp_server_tools_cache, + to_camel_case, ) __all__ = [ @@ -57,7 +58,6 @@ "CachePolicyFactory", "mcp_client_pool", "MCPClientPool", - # Server CRUD "ensure_builtin_mcp_servers_in_db", "get_enabled_mcp_server_config", @@ -71,7 +71,6 @@ "get_mcp_server_dependency_summary", "set_server_enabled", "get_servers_config", - # Connection CRUD "get_mcp_connection", "list_mcp_connections", @@ -81,7 +80,6 @@ "set_mcp_connection_status", "reauthorize_mcp_connection", "test_mcp_connection", - # Tool Registry "to_camel_case", "get_mcp_tools", diff --git a/backend/package/yuxi/services/mcp/cache_policy.py b/backend/package/yuxi/services/mcp/cache_policy.py index 1c6ad1f5f..7df1d4805 100644 --- a/backend/package/yuxi/services/mcp/cache_policy.py +++ b/backend/package/yuxi/services/mcp/cache_policy.py @@ -1,4 +1,5 @@ from __future__ import annotations + from abc import ABC, abstractmethod from typing import TYPE_CHECKING @@ -23,7 +24,7 @@ def resolve_cache_partition( ) -> tuple[str, bool]: """ 解析该连接应被划分到哪一个缓存分区中。 - + 返回: tuple[partition_key, is_shared_across_users] - partition_key: 用于区分 Redis 缓存或内存缓存隔离区段的 Key。 diff --git a/backend/package/yuxi/services/mcp/connection_service.py b/backend/package/yuxi/services/mcp/connection_service.py index 1acfbb974..070bd0eab 100644 --- a/backend/package/yuxi/services/mcp/connection_service.py +++ b/backend/package/yuxi/services/mcp/connection_service.py @@ -1,10 +1,11 @@ from __future__ import annotations + import logging from datetime import UTC, datetime from typing import Any + from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession - from yuxi.services.mcp_auth.crypto import encrypt_credential_blob from yuxi.services.mcp_auth.orchestrator import AuthContext from yuxi.storage.postgres.models_business import MCPConnection @@ -105,6 +106,7 @@ async def create_mcp_connection( ) -> MCPConnection: """创建 MCP 绑定连接""" from yuxi.services.mcp.server_service import get_mcp_server + server = await get_mcp_server(db, server_name) if server is None: raise ValueError(f"Server '{server_name}' does not exist") @@ -172,11 +174,12 @@ async def update_mcp_connection( await db.commit() await db.refresh(connection) - + from yuxi.services.mcp.tool_registry_service import ( _clear_mcp_connection_runtime_auth_cache, _invalidate_mcp_tools_cache_for_connection, ) + if should_clear_runtime_auth_cache: await _clear_mcp_connection_runtime_auth_cache(connection.id) await _invalidate_mcp_tools_cache_for_connection(connection) @@ -196,9 +199,10 @@ async def delete_mcp_connection(db: AsyncSession, connection_id: int) -> bool: from yuxi.services.mcp.tool_registry_service import ( _clear_mcp_connection_runtime_auth_cache, - invalidate_mcp_server_tools_cache, invalidate_mcp_connection_tools_cache, + invalidate_mcp_server_tools_cache, ) + await _clear_mcp_connection_runtime_auth_cache(deleted_connection_id) if deleted_scope_type == "system": await invalidate_mcp_server_tools_cache(deleted_server_name) @@ -229,6 +233,7 @@ async def set_mcp_connection_status( _clear_mcp_connection_runtime_auth_cache, _invalidate_mcp_tools_cache_for_connection, ) + await _clear_mcp_connection_runtime_auth_cache(connection.id) await _invalidate_mcp_tools_cache_for_connection(connection) return connection @@ -246,6 +251,7 @@ async def reauthorize_mcp_connection( raise ValueError(f"MCP connection '{connection_id}' does not exist") from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + cache = RedisTokenCache() if getattr(connection, "id", None) is not None: try: @@ -256,8 +262,9 @@ async def reauthorize_mcp_connection( await cache.release_refresh_lock(connection.id) except Exception as exc: logger.warning(f"Failed to clear MCP refresh lock for connection {connection.id}: {exc}") - + from yuxi.services.mcp.tool_registry_service import _invalidate_mcp_tools_cache_for_connection + await _invalidate_mcp_tools_cache_for_connection(connection) connection.status = "active" @@ -283,6 +290,7 @@ async def test_mcp_connection( raise ValueError(f"MCP connection '{connection_id}' does not exist") from yuxi.services.mcp.server_service import get_mcp_server + server = await get_mcp_server(db, connection.server_name) if server is None: raise ValueError(f"Server '{connection.server_name}' does not exist") @@ -290,7 +298,7 @@ async def test_mcp_connection( auth_context = _auth_context_from_connection(connection) from yuxi.services.mcp.server_service import get_runtime_mcp_server_config from yuxi.services.mcp.tool_registry_service import get_mcp_tools - + config = await get_runtime_mcp_server_config(server.name, auth_context=auth_context, db=db) if config is None: raise ValueError(f"MCP server '{server.name}' runtime config unavailable") diff --git a/backend/package/yuxi/services/mcp/server_service.py b/backend/package/yuxi/services/mcp/server_service.py index 22ee6f792..cc64d604b 100644 --- a/backend/package/yuxi/services/mcp/server_service.py +++ b/backend/package/yuxi/services/mcp/server_service.py @@ -1,22 +1,19 @@ from __future__ import annotations -import hashlib -import json + import logging import os import traceback from typing import Any + import httpx from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession - from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config from yuxi.services.mcp_auth.proxy_service import ( - INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY, build_proxy_runtime_config, should_use_internal_proxy, ) -from yuxi.services.mcp_tool_cache import RedisMcpToolCache from yuxi.storage.postgres.models_business import AgentConfig, MCPConnection, MCPServer, Skill logger = logging.getLogger("yuxi.mcp.server_service") @@ -181,7 +178,7 @@ def _apply_runtime_tool_cache_policy( ) -> dict[str, Any]: """利用 CachePolicy 模式获取缓存 key 的隔离区划并应用""" from yuxi.services.mcp.cache_policy import CachePolicyFactory - + policy = CachePolicyFactory.get_policy(auth_config.provider) partition, is_shared = policy.resolve_cache_partition( auth_context or AuthContext(), @@ -212,6 +209,7 @@ async def get_runtime_mcp_server_config( auth_config = MCPAuthConfig.model_validate(server.auth_config_json) from yuxi.services.mcp.connection_service import _resolve_scope_id + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) if scope_id is None: return server.to_mcp_config() @@ -324,7 +322,11 @@ async def create_mcp_server( await db.commit() await db.refresh(server) - from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_server_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + ) + await _clear_mcp_server_runtime_auth_cache(db, name) await invalidate_mcp_server_tools_cache(name) @@ -384,7 +386,11 @@ async def update_mcp_server( await db.commit() await db.refresh(server) - from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_server_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + ) + if auth_config is not _UNSET: await _clear_mcp_server_runtime_auth_cache(db, name) await invalidate_mcp_server_tools_cache(name) @@ -402,7 +408,11 @@ async def delete_mcp_server(db: AsyncSession, name: str) -> bool: await db.delete(server) await db.commit() - from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_server_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + ) + await _clear_mcp_server_runtime_auth_cache(db, name) await invalidate_mcp_server_tools_cache(name) @@ -413,6 +423,7 @@ async def delete_mcp_server(db: AsyncSession, name: str) -> bool: async def get_mcp_server_dependency_summary(db: AsyncSession, name: str) -> dict[str, Any]: """获取依赖于该 MCP 服务器的智能体、技能和连接概要""" from yuxi.services.mcp.connection_service import list_mcp_connections + connections = await list_mcp_connections(db, server_name=name) skill_rows = (await db.execute(select(Skill))).scalars().all() @@ -453,7 +464,11 @@ async def set_server_enabled( await db.commit() is_enabled = bool(server.enabled) - from yuxi.services.mcp.tool_registry_service import _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache + from yuxi.services.mcp.tool_registry_service import ( + _clear_mcp_server_runtime_auth_cache, + invalidate_mcp_server_tools_cache, + ) + if not is_enabled: await _clear_mcp_server_runtime_auth_cache(db, name) await invalidate_mcp_server_tools_cache(name) diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py b/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py index ca812760c..ea8ae30aa 100644 --- a/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py +++ b/backend/package/yuxi/services/mcp_auth/fetchers/__init__.py @@ -1,8 +1,9 @@ from __future__ import annotations -from yuxi.services.mcp_auth.fetchers.base import ITokenFetcher, BaseTokenFetcher -from yuxi.services.mcp_auth.fetchers.http_fetcher import CustomHttpTokenFetcher, ClientCredentialsFetcher -from yuxi.services.mcp_auth.fetchers.oauth_fetcher import AuthorizationCodeFetcher + +from yuxi.services.mcp_auth.fetchers.base import BaseTokenFetcher, ITokenFetcher from yuxi.services.mcp_auth.fetchers.factory import TokenFetcherFactory +from yuxi.services.mcp_auth.fetchers.http_fetcher import ClientCredentialsFetcher, CustomHttpTokenFetcher +from yuxi.services.mcp_auth.fetchers.oauth_fetcher import AuthorizationCodeFetcher __all__ = [ "ITokenFetcher", diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/base.py b/backend/package/yuxi/services/mcp_auth/fetchers/base.py index 299773c80..f670597b7 100644 --- a/backend/package/yuxi/services/mcp_auth/fetchers/base.py +++ b/backend/package/yuxi/services/mcp_auth/fetchers/base.py @@ -1,6 +1,8 @@ from __future__ import annotations + from abc import ABC, abstractmethod from typing import Any + import httpx from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_auth.template_resolver import resolve_template_value @@ -17,6 +19,7 @@ "token_type": "token_type", } + def extract_path(payload: dict[str, Any], path: str) -> Any: """从 payload 中根据点分路径提取字段值""" current: Any = payload @@ -27,6 +30,7 @@ def extract_path(payload: dict[str, Any], path: str) -> Any: raise KeyError(path) return current + async def fetch_custom_http_token( request_config: dict[str, Any], *, @@ -38,7 +42,7 @@ async def fetch_custom_http_token( ) -> dict[str, Any]: """执行自定义 HTTP 请求获取 Token""" from yuxi.services.mcp_auth.orchestrator import _normalize_token_payload - + response_map = response_map or dict(_DEFAULT_TOKEN_RESPONSE_MAP) if http_client is None: http_client = httpx.AsyncClient() @@ -84,7 +88,9 @@ async def fetch_custom_http_token( return _normalize_token_payload(resolved) except Exception as exc: import traceback + from yuxi.utils import logger + logger.error(f"fetch_custom_http_token failure: {exc}, traceback: {traceback.format_exc()}") raise finally: @@ -136,7 +142,7 @@ async def fetch_token( refresh_token_values = dict(token_values) if not refresh_token_values.get("refresh_token") and credential_payload.get("refresh_token"): refresh_token_values["refresh_token"] = credential_payload["refresh_token"] - + refreshed = await fetch_custom_http_token( refresh_request, response_map=(refresh_request.get("response_map") or token_request.get("response_map")), diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/factory.py b/backend/package/yuxi/services/mcp_auth/fetchers/factory.py index 0b546cc61..f02e69ac1 100644 --- a/backend/package/yuxi/services/mcp_auth/fetchers/factory.py +++ b/backend/package/yuxi/services/mcp_auth/fetchers/factory.py @@ -1,6 +1,7 @@ from __future__ import annotations + from yuxi.services.mcp_auth.fetchers.base import ITokenFetcher -from yuxi.services.mcp_auth.fetchers.http_fetcher import CustomHttpTokenFetcher, ClientCredentialsFetcher +from yuxi.services.mcp_auth.fetchers.http_fetcher import ClientCredentialsFetcher, CustomHttpTokenFetcher from yuxi.services.mcp_auth.fetchers.oauth_fetcher import AuthorizationCodeFetcher diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py b/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py index ffc74602f..7ef45f431 100644 --- a/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py +++ b/backend/package/yuxi/services/mcp_auth/fetchers/http_fetcher.py @@ -1,5 +1,7 @@ from __future__ import annotations + from typing import Any + import httpx from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_auth.fetchers.base import BaseTokenFetcher, fetch_custom_http_token @@ -34,5 +36,6 @@ async def _fetch_new_token( class ClientCredentialsFetcher(CustomHttpTokenFetcher): """客户端凭证 (Client Credentials) 方式获取 Token""" + # NOTE: 当前其底层获取逻辑与 CustomHttpTokenFetcher 相同 pass diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py b/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py index 00446803f..fe7b8eaa0 100644 --- a/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py +++ b/backend/package/yuxi/services/mcp_auth/fetchers/oauth_fetcher.py @@ -1,8 +1,10 @@ from __future__ import annotations + from typing import Any + import httpx from yuxi.services.mcp_auth.config_models import MCPAuthConfig -from yuxi.services.mcp_auth.fetchers.base import ITokenFetcher, fetch_custom_http_token, _DEFAULT_TOKEN_RESPONSE_MAP +from yuxi.services.mcp_auth.fetchers.base import _DEFAULT_TOKEN_RESPONSE_MAP, ITokenFetcher, fetch_custom_http_token class AuthorizationCodeFetcher(ITokenFetcher): @@ -16,9 +18,7 @@ async def _resolve_token_request_config( http_client: httpx.AsyncClient, ) -> tuple[dict[str, Any], dict[str, str]]: issuer_url = ( - token_request.get("issuer_url") - or secret_values.get("issuer_url") - or token_values.get("issuer_url") + token_request.get("issuer_url") or secret_values.get("issuer_url") or token_values.get("issuer_url") ) if not issuer_url: raise ValueError("authorization_code provider requires token_request.issuer_url") @@ -29,7 +29,7 @@ async def _resolve_token_request_config( token_endpoint = payload.get("token_endpoint") if not token_endpoint: raise ValueError("authorization_code provider discovery missing token_endpoint") - + return { "url": token_endpoint, "method": "POST", @@ -72,7 +72,7 @@ async def fetch_token( authorization_token_values = dict(token_values or credential_payload) if not authorization_token_values.get("refresh_token") and credential_payload.get("refresh_token"): authorization_token_values["refresh_token"] = credential_payload["refresh_token"] - + resolved = await fetch_custom_http_token( authorization_request, response_map=response_map, diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py index fd37ea816..1897ede29 100644 --- a/backend/package/yuxi/services/mcp_auth/proxy_service.py +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -5,17 +5,17 @@ from urllib.parse import urlencode import httpx -from fastapi import Request, Response, HTTPException +from fastapi import HTTPException, Request, Response from fastapi.responses import StreamingResponse -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from starlette.background import BackgroundTask - -from server.utils.auth_utils import AuthUtils from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config from yuxi.storage.postgres.models_business import MCPConnection, MCPServer +from server.utils.auth_utils import AuthUtils + INTERNAL_PROXY_TOKEN_HEADER = "X-Yuxi-MCP-Proxy-Token" INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY = "__yuxi_disable_tool_object_cache" _PROXY_TOKEN_TYPE = "mcp_proxy" @@ -142,7 +142,7 @@ async def handle_mcp_proxy_request( ) -> Response: """内部网关主入口:鉴权解析、查库拦截与流式代理""" from yuxi.services.mcp.server_service import get_mcp_server - + try: auth_context = decode_proxy_access_token(internal_token, server_name=server_name) except ValueError as exc: @@ -155,8 +155,9 @@ async def handle_mcp_proxy_request( raise HTTPException(status_code=404, detail=f"服务器 '{server_name}' 不存在或已停用") auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) - + from yuxi.services.mcp.connection_service import _resolve_scope_id + scope_id = _resolve_scope_id(auth_config.binding_scope, auth_context) connection = None if scope_id is not None: @@ -202,19 +203,22 @@ async def _proxy_mcp_request_stream( """底层流式转发逻辑:处理 HTTPX 透传、SSE 和 401 重试闭环事务""" auth_config = MCPAuthConfig.model_validate(server.auth_config_json or {}) if server.transport not in _HTTP_TRANSPORTS: - raise HTTPException(status_code=400, detail=f"Internal proxy only supports HTTP MCP transports, got: {server.transport}") + raise HTTPException( + status_code=400, detail=f"Internal proxy only supports HTTP MCP transports, got: {server.transport}" + ) http_client = _http_client or httpx.AsyncClient(timeout=server.timeout or 60.0) bg_task = BackgroundTask(http_client.aclose) - + if _token_cache is not None: token_cache = _token_cache else: from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache + token_cache = RedisTokenCache() max_attempts = 2 if auth_config.refresh_policy.retry_once_on_401 else 1 - + for attempt in range(max_attempts): runtime_config = await resolve_runtime_mcp_config( server, @@ -225,17 +229,17 @@ async def _proxy_mcp_request_stream( ) target_url = _build_target_url(runtime_config["url"], path=path, query_params=dict(request.query_params)) upstream_headers = _merge_upstream_headers(runtime_config.get("headers") or {}, dict(request.headers)) - + request_obj = http_client.build_request( method=request.method.upper(), url=target_url, headers=upstream_headers, content=body, ) - + # 使用 send(stream=True) 获取异步可迭代响应而不会阻塞 SSE 长链接 response = await http_client.send(request_obj, stream=True) - + if response.status_code == 403: await response.aclose() _record_scope_error(connection, "MCP upstream rejected request due to insufficient scope") @@ -245,33 +249,30 @@ async def _proxy_mcp_request_stream( content='{"error": "insufficient_scope", "message": "当前授权范围不足"}', status_code=403, media_type="application/json", - background=bg_task + background=bg_task, ) - + if response.status_code != 401: # 正常响应,此时直接闭环提交事务,防止污染外层 if connection is not None and hasattr(db, "commit"): await db.commit() - + async def proxy_stream_generator(): try: async for chunk in response.aiter_raw(): yield chunk finally: await response.aclose() - + resp_headers = {} for k, v in response.headers.items(): if k.lower() not in _HOP_BY_HOP_HEADERS and k.lower() not in ("content-encoding", "content-length"): resp_headers[k] = v - + return StreamingResponse( - proxy_stream_generator(), - status_code=response.status_code, - headers=resp_headers, - background=bg_task + proxy_stream_generator(), status_code=response.status_code, headers=resp_headers, background=bg_task ) - + # 如果是 401,回收流连接并准备重试 await response.aclose() if attempt + 1 >= max_attempts: @@ -286,5 +287,5 @@ async def proxy_stream_generator(): content='{"error": "reauth_required", "message": "连接失效,请重新连接"}', status_code=424, media_type="application/json", - background=bg_task + background=bg_task, ) From 5de697ab08da1c7c2bad4a8395a63e871a695497 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 07:40:28 +0800 Subject: [PATCH 16/36] =?UTF-8?q?fix(mcp):=20=E4=BF=AE=E5=A4=8D=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E7=B3=BB=E7=BB=9F=E4=BB=A3=E7=A0=81=E5=AE=A1=E8=AE=A1?= =?UTF-8?q?=E4=B8=AD=E5=8F=91=E7=8E=B0=E7=9A=84=E5=86=85=E5=AD=98=E6=B3=84?= =?UTF-8?q?=E6=BC=8F=E4=B8=8E=E5=B9=B6=E5=8F=91=E7=BC=BA=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 `client_pool.py` 中由于 cache revision 变化导致的旧 session 实例泄漏问题 - 将 `clear_mcp_cache` 调整为异步函数并调用 `shutdown`,防止清理缓存时产生孤儿子进程 - 增加对 `_resolved_headers_cache` 字典的惰性清理,避免无界膨胀 - 修复 `connection_service.py` 唯一性约束冲突未捕获导致 HTTP 500 的问题,改抛 ValueError - 修正 `proxy_service.py` 的 sse read timeout 配置和 Authorization 头部被覆盖漏洞 - 移除遗留冗余脚本 `fix_mcp_service_imports.py` 和 `fix_tests.py` --- .../package/yuxi/services/mcp/client_pool.py | 20 ++++++- .../yuxi/services/mcp/connection_service.py | 10 +++- .../services/mcp/tool_registry_service.py | 4 +- .../yuxi/services/mcp_auth/proxy_service.py | 17 +++++- .../unit/services/test_mcp_client_pool.py | 2 +- .../test_mcp_tool_registry_service.py | 24 ++++---- fix_mcp_service_imports.py | 49 ----------------- fix_tests.py | 55 ------------------- 8 files changed, 58 insertions(+), 123 deletions(-) delete mode 100644 fix_mcp_service_imports.py delete mode 100644 fix_tests.py diff --git a/backend/package/yuxi/services/mcp/client_pool.py b/backend/package/yuxi/services/mcp/client_pool.py index 8ed81380e..dae543d56 100644 --- a/backend/package/yuxi/services/mcp/client_pool.py +++ b/backend/package/yuxi/services/mcp/client_pool.py @@ -71,8 +71,18 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. ) if runtime_config: # NOTE: 3. 将最新的头部注入到当前 HTTP 请求中 - headers = runtime_config.get("headers") or {} + headers = dict(runtime_config.get("headers") or {}) _resolved_headers_cache[cache_key] = (headers, now + _HEADERS_CACHE_TTL) + + # 惰性清理过期条目,避免无界膨胀 + stale_keys = [ + k + for k in list(_resolved_headers_cache.keys())[:100] + if _resolved_headers_cache.get(k, (None, float("inf")))[1] <= now + ][:20] + for k in stale_keys: + _resolved_headers_cache.pop(k, None) + for key, val in headers.items(): request.headers[key] = str(val) except Exception as exc: @@ -200,6 +210,14 @@ async def get_session( await ll_session.stop() self._sessions.pop(cache_key, None) + # NOTE: 驱逐同 server_name 下其他过时 partition_key 的旧 session, + # 防止 revision 变化导致的连接池内存泄漏 + stale_keys = [k for k in self._sessions if k[0] == server_name and k != cache_key] + for stale_key in stale_keys: + stale_session, _ = self._sessions.pop(stale_key) + logger.info(f"Evicting stale MCP session for {stale_key}") + await stale_session.stop() + # NOTE: 针对 HTTP/SSE 协议,注入自定义的 httpx.Auth 认证流以支持长连接动态 Token client_config = dict(runtime_config) # 清理框架保留魔法键 diff --git a/backend/package/yuxi/services/mcp/connection_service.py b/backend/package/yuxi/services/mcp/connection_service.py index 070bd0eab..2c4afa478 100644 --- a/backend/package/yuxi/services/mcp/connection_service.py +++ b/backend/package/yuxi/services/mcp/connection_service.py @@ -132,7 +132,15 @@ async def create_mcp_connection( updated_by=created_by, ) db.add(connection) - await db.commit() + from sqlalchemy.exc import IntegrityError + + try: + await db.commit() + except IntegrityError: + await db.rollback() + raise ValueError( + f"该 MCP 服务器 '{server_name}' 在范围 {normalized_scope_type}:{normalized_scope_id} 下已存在连接" + ) await db.refresh(connection) return connection diff --git a/backend/package/yuxi/services/mcp/tool_registry_service.py b/backend/package/yuxi/services/mcp/tool_registry_service.py index 557c32a23..ce6db2515 100644 --- a/backend/package/yuxi/services/mcp/tool_registry_service.py +++ b/backend/package/yuxi/services/mcp/tool_registry_service.py @@ -323,7 +323,7 @@ async def get_tools_from_all_servers() -> list[Callable[..., Any]]: return all_tools -def clear_mcp_cache() -> None: +async def clear_mcp_cache() -> None: """清空本地内存工具缓存""" global _mcp_tools_cache _mcp_tools_cache = {} @@ -331,7 +331,7 @@ def clear_mcp_cache() -> None: try: from yuxi.services.mcp.client_pool import clear_resolved_headers_cache, mcp_client_pool - mcp_client_pool._sessions.clear() + await mcp_client_pool.shutdown() clear_resolved_headers_cache() except Exception: pass diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py index 1897ede29..aad2234c5 100644 --- a/backend/package/yuxi/services/mcp_auth/proxy_service.py +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -93,8 +93,12 @@ def _merge_upstream_headers( request_headers: dict[str, str] | None, ) -> dict[str, Any]: merged = dict(base_headers or {}) + _PROTECTED_HEADERS = { + INTERNAL_PROXY_TOKEN_HEADER.lower(), + "authorization", + } for key, value in (request_headers or {}).items(): - if key.lower() in _HOP_BY_HOP_HEADERS or key.lower() == INTERNAL_PROXY_TOKEN_HEADER.lower(): + if key.lower() in _HOP_BY_HOP_HEADERS or key.lower() in _PROTECTED_HEADERS: continue merged[key] = value return merged @@ -207,7 +211,16 @@ async def _proxy_mcp_request_stream( status_code=400, detail=f"Internal proxy only supports HTTP MCP transports, got: {server.transport}" ) - http_client = _http_client or httpx.AsyncClient(timeout=server.timeout or 60.0) + connect_timeout = server.timeout or 60.0 + read_timeout = server.sse_read_timeout or connect_timeout + http_client = _http_client or httpx.AsyncClient( + timeout=httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=connect_timeout, + pool=connect_timeout, + ) + ) bg_task = BackgroundTask(http_client.aclose) if _token_cache is not None: diff --git a/backend/test/unit/services/test_mcp_client_pool.py b/backend/test/unit/services/test_mcp_client_pool.py index c02133e1a..0cb43fd06 100644 --- a/backend/test/unit/services/test_mcp_client_pool.py +++ b/backend/test/unit/services/test_mcp_client_pool.py @@ -194,7 +194,7 @@ async def test_dynamic_mcp_token_auth_cache(): assert len(_resolved_headers_cache) == 0 _resolved_headers_cache[cache_key] = ({"Auth": "Bearer test"}, 2000.0) - clear_mcp_cache() + await clear_mcp_cache() assert len(_resolved_headers_cache) == 0 finally: mcp_auth_context_var.reset(token) diff --git a/backend/test/unit/services/test_mcp_tool_registry_service.py b/backend/test/unit/services/test_mcp_tool_registry_service.py index f3d6578cb..4fe07518c 100644 --- a/backend/test/unit/services/test_mcp_tool_registry_service.py +++ b/backend/test/unit/services/test_mcp_tool_registry_service.py @@ -71,7 +71,7 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, disabled async def test_get_mcp_tools_rebuilds_cache_when_config_hash_changes(monkeypatch): - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() configs = [ {"transport": "stdio", "command": "demo-v1", "disabled_tools": []}, @@ -104,7 +104,7 @@ async def fake_get_mcp_client(server_configs): assert [tool.name for tool in tools_v2] == ["tool_for_demo-v2"] assert build_calls == ["demo-v1", "demo-v2"] - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() async def test_get_tools_from_all_servers_loads_names_from_db_once(monkeypatch): @@ -136,7 +136,7 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs async def test_get_mcp_tools_sets_handle_tool_error(monkeypatch): - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() config = {"transport": "stdio", "command": "demo-tool", "disabled_tools": []} @@ -155,11 +155,11 @@ async def fake_get_mcp_client(server_configs): assert len(tools) == 1 assert tools[0].handle_tool_error is True - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() async def test_get_mcp_tools_keeps_connection_partitions_separate(monkeypatch): - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() configs = [ { @@ -198,11 +198,11 @@ async def fake_get_mcp_client(server_configs): assert [tool.name for tool in tools_b] == ["tool_for_proxy-token-user-b"] assert build_calls == ["proxy-token-user-a", "proxy-token-user-b"] - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() async def test_get_mcp_tools_does_not_cache_internal_proxy_tool_objects(monkeypatch): - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() configs = [ { @@ -243,7 +243,7 @@ async def fake_get_mcp_client(server_configs): assert [tool.name for tool in tools_second] == ["tool_for_proxy-token-v2"] assert build_calls == ["proxy-token-v1", "proxy-token-v2"] - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() async def test_get_tools_from_all_servers_skips_runtime_auth_servers_without_context(monkeypatch): @@ -292,7 +292,7 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs async def test_get_mcp_tools_rebuilds_when_redis_server_revision_changes(monkeypatch): - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() fake_redis = _FakeRedis() @@ -323,11 +323,11 @@ async def fake_get_mcp_client(server_configs): assert [tool.name for tool in tools_third] == ["tool_2"] assert build_calls == ["demo-tool", "demo-tool"] - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() async def test_get_all_mcp_tools_uses_redis_manifest_when_local_cache_is_empty(monkeypatch): - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() fake_redis = _FakeRedis() @@ -365,7 +365,7 @@ async def fake_get_enabled_mcp_server_config(server_name: str, db=None): tools_first = await tool_registry_service.get_all_mcp_tools("demo") assert [tool.name for tool in tools_first] == ["alpha_tool"] - tool_registry_service.clear_mcp_cache() + await tool_registry_service.clear_mcp_cache() async def fail_get_mcp_client(server_configs): raise AssertionError(f"should not fetch live tools when redis manifest is available: {server_configs}") diff --git a/fix_mcp_service_imports.py b/fix_mcp_service_imports.py deleted file mode 100644 index 3a29bceb1..000000000 --- a/fix_mcp_service_imports.py +++ /dev/null @@ -1,49 +0,0 @@ -import re -import os - -files_to_fix = [ - ("backend/package/yuxi/services/mcp_auth/orchestrator.py", "from yuxi.services.mcp_service import RedisTokenCache", "from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache"), - ("backend/package/yuxi/services/skill_service.py", "from yuxi.services.mcp_service import get_enabled_mcp_server_names", "from yuxi.services.mcp.server_service import get_enabled_mcp_server_names"), - ("backend/package/yuxi/agents/middlewares/dynamic_tool_middleware.py", "from yuxi.services.mcp_service import get_mcp_tools", "from yuxi.services.mcp.tool_registry_service import get_mcp_tools"), - ("backend/package/yuxi/agents/__init__.py", "from yuxi.services.mcp_service import get_enabled_mcp_tools", "from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools"), - ("backend/package/yuxi/agents/middlewares/skills_middleware.py", "from yuxi.services.mcp_service import get_enabled_mcp_tools", "from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools"), - ("backend/package/yuxi/agents/middlewares/runtime_config_middleware.py", "from yuxi.services.mcp_service import get_enabled_mcp_tools", "from yuxi.services.mcp.tool_registry_service import get_enabled_mcp_tools"), - ("backend/package/yuxi/agents/buildin/deep_agent/graph.py", "from yuxi.services.mcp_service import get_tools_from_all_servers", "from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers"), - ("backend/package/yuxi/agents/buildin/chatbot/graph.py", "from yuxi.services.mcp_service import get_tools_from_all_servers", "from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers"), - ("backend/server/utils/lifespan.py", "from yuxi.services.mcp_service import ensure_builtin_mcp_servers_in_db", "from yuxi.services.mcp.server_service import ensure_builtin_mcp_servers_in_db"), -] - -for file_path, old_str, new_str in files_to_fix: - if os.path.exists(file_path): - with open(file_path, "r") as f: - content = f.read() - content = content.replace(old_str, new_str) - with open(file_path, "w") as f: - f.write(content) - -# Fix tests -tests = [ - "backend/test/unit/services/test_mcp_auth_runtime.py", - "backend/test/unit/services/test_mcp_tool_registry_service.py", - "backend/test/unit/services/test_mcp_connection_service.py" -] - -for file_path in tests: - if os.path.exists(file_path): - with open(file_path, "r") as f: - content = f.read() - - content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"get_mcp_tools"', 'monkeypatch.setattr(tool_registry_service, "get_mcp_tools"', content) - content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"get_enabled_mcp_server_config"', 'monkeypatch.setattr(server_service, "get_enabled_mcp_server_config"', content) - content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"_load_enabled_mcp_server_configs"', 'monkeypatch.setattr(server_service, "_load_enabled_mcp_server_configs"', content) - content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"get_runtime_mcp_server_config"', 'monkeypatch.setattr(server_service, "get_runtime_mcp_server_config"', content) - content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"get_mcp_client"', 'monkeypatch.setattr(mcp_client_pool, "_get_mcp_client"', content) - content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"_clear_mcp_server_runtime_auth_cache"', 'monkeypatch.setattr(tool_registry_service, "_clear_mcp_server_runtime_auth_cache"', content) - content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"invalidate_mcp_server_tools_cache"', 'monkeypatch.setattr(tool_registry_service, "invalidate_mcp_server_tools_cache"', content) - - # for `monkeypatch.setattr(\n mcp_service,\n ...)` - content = re.sub(r'monkeypatch\.setattr\(\s*mcp_service,\s*"_mcp_tool_cache_store"', 'monkeypatch.setattr(tool_registry_service, "_mcp_tool_cache_store"', content) - content = content.replace("await mcp_service.update_mcp_server", "await server_service.update_mcp_server") - - with open(file_path, "w") as f: - f.write(content) diff --git a/fix_tests.py b/fix_tests.py deleted file mode 100644 index 7c61f00a6..000000000 --- a/fix_tests.py +++ /dev/null @@ -1,55 +0,0 @@ -import re - -files = [ - "backend/test/unit/services/test_mcp_connection_service.py", - "backend/test/unit/services/test_mcp_service.py", - "backend/test/unit/services/test_mcp_service_auth_runtime.py" -] - -replacements = { - "create_mcp_connection": "connection_service", - "list_mcp_connections": "connection_service", - "set_mcp_connection_status": "connection_service", - "delete_mcp_connection": "connection_service", - "reauthorize_mcp_connection": "connection_service", - "update_mcp_connection": "connection_service", - "test_mcp_connection": "connection_service", - "get_mcp_connection": "connection_service", - "_resolve_scope_id": "connection_service", - - "get_mcp_server_dependency_summary": "server_service", - "set_server_enabled": "server_service", - "get_runtime_mcp_server_config": "server_service", - "get_enabled_mcp_server_config": "server_service", - "_load_enabled_mcp_server_configs": "server_service", - - "get_mcp_tools": "tool_registry_service", - "get_enabled_mcp_tools": "tool_registry_service", - "get_all_mcp_tools": "tool_registry_service", - "get_tools_from_all_servers": "tool_registry_service", - "clear_mcp_cache": "tool_registry_service", - "_mcp_tool_cache_store": "tool_registry_service", - "_clear_mcp_server_runtime_auth_cache": "tool_registry_service", -} - -for file_path in files: - with open(file_path, "r") as f: - content = f.read() - - # 替换 import - content = content.replace("from yuxi.services import mcp_service", "from yuxi.services.mcp import connection_service, server_service, tool_registry_service\nfrom yuxi.services.mcp.client_pool import mcp_client_pool\nfrom yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache") - - # 特殊替换 get_mcp_client - content = content.replace("mcp_service.get_mcp_client", "mcp_client_pool._get_mcp_client") - content = content.replace('mcp_service", "get_mcp_client"', 'mcp_client_pool", "_get_mcp_client"') - - # 替换 mcp_service.func -> specific_service.func - for func, service in replacements.items(): - content = re.sub(rf'mcp_service\.{func}', f'{service}.{func}', content) - content = re.sub(rf'mcp_service",\s*"{func}"', f'{service}", "{func}"', content) - - # 对于剩余的 mcp_service,如果是 monkeypatch.setattr(mcp_service, "RedisTokenCache" - content = re.sub(r'monkeypatch\.setattr\(mcp_service,\s*"RedisTokenCache"', 'monkeypatch.setattr(connection_service, "RedisTokenCache"', content) - - with open(file_path, "w") as f: - f.write(content) From 17cc449525c01034f022958976c7825d34601f8e Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 09:06:06 +0800 Subject: [PATCH 17/36] =?UTF-8?q?refactor(mcp=5Fauth):=20=E6=B7=B1?= =?UTF-8?q?=E5=BA=A6=E9=87=8D=E6=9E=84=20MCP=20Auth=20=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8D=E5=AE=89=E5=85=A8=E9=9A=90=E6=82=A3?= =?UTF-8?q?=E4=B8=8E=E5=B9=B6=E5=8F=91=E7=93=B6=E9=A2=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 密码学安全升级:弃用单纯 SHA-256,采用 V2 HKDF 派生密钥加盐加密,并兼容解析 V1 遗留凭据。 2. 修复并发与连接泄露:重构 proxy_service 以复用共享 httpx 客户端,并在 client_pool 中采用 Future 占位模式替代全局协程锁,提升启动并发效率。 3. 缓存优化:引入 cachetools,使用 LRUCache 与 TTLCache 替换无界字典,防止内存泄漏。 4. 数据库一致性修复:在 server_service 删除实例前提前清理 Redis 缓存,防止级联删除后遗失追踪句柄。 5. 测试修复:全面修复因 httpx.Timeout、过期 TTL 以及代理环境变量带来的测试失败问题。 --- backend/package/pyproject.toml | 1 + .../package/yuxi/services/mcp/client_pool.py | 146 +++++++++++------- .../yuxi/services/mcp/server_service.py | 8 +- .../services/mcp/tool_registry_service.py | 5 +- .../yuxi/services/mcp_auth/config_models.py | 2 +- .../package/yuxi/services/mcp_auth/crypto.py | 38 ++++- .../yuxi/services/mcp_auth/fetchers/base.py | 4 +- .../yuxi/services/mcp_auth/proxy_service.py | 39 +++-- .../services/mcp_auth/redis_token_cache.py | 51 +++--- backend/server/utils/lifespan.py | 8 + .../unit/services/test_mcp_auth_crypto.py | 27 +++- .../unit/services/test_mcp_client_pool.py | 21 +-- .../services/test_model_provider_service.py | 6 + backend/uv.lock | 2 + 14 files changed, 235 insertions(+), 123 deletions(-) diff --git a/backend/package/pyproject.toml b/backend/package/pyproject.toml index 9188d218f..7ef87fd41 100644 --- a/backend/package/pyproject.toml +++ b/backend/package/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "aiosqlite>=0.20.0", "argon2-cffi>=25.1.0", "asyncpg>=0.30.0", + "cachetools>=5.3.0", "chardet>=5.0.0", "colorlog>=6.9.0", "dashscope>=1.23.2", diff --git a/backend/package/yuxi/services/mcp/client_pool.py b/backend/package/yuxi/services/mcp/client_pool.py index dae543d56..5bf84fd59 100644 --- a/backend/package/yuxi/services/mcp/client_pool.py +++ b/backend/package/yuxi/services/mcp/client_pool.py @@ -4,7 +4,6 @@ import hashlib import json import logging -import time from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any @@ -15,9 +14,10 @@ if TYPE_CHECKING: from mcp import ClientSession -# 缓存存储格式: (server_name, user_id, department_id) -> (resolved_headers, expires_at) -_resolved_headers_cache: dict[tuple[str, str | None, str | None], tuple[dict[str, Any], float]] = {} -_HEADERS_CACHE_TTL = 15.0 +from cachetools import TTLCache + +# 缓存存储格式: (server_name, user_id, department_id) -> resolved_headers +_resolved_headers_cache: TTLCache = TTLCache(maxsize=1024, ttl=15.0) def clear_resolved_headers_cache() -> None: @@ -42,20 +42,26 @@ def __init__(self, server_name: str): self.server_name = server_name async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: - # NOTE: 1. 从当前协程上下文读取 AuthContext auth_context = mcp_auth_context_var.get() if auth_context: try: cache_key = (self.server_name, auth_context.user_id, auth_context.department_id) - now = time.time() - cached = _resolved_headers_cache.get(cache_key) - if cached is not None: - headers, expires_at = cached - if now < expires_at: - for key, val in headers.items(): - request.headers[key] = str(val) - yield request - return + cached_headers = _resolved_headers_cache.get(cache_key) + if cached_headers is not None: + for key, val in cached_headers.items(): + request.headers[key] = str(val) + yield request + return + + from yuxi.services.mcp_auth.proxy_service import INTERNAL_PROXY_TOKEN_HEADER, create_proxy_access_token + + if INTERNAL_PROXY_TOKEN_HEADER.lower() in request.headers: + # NOTE: 代理模式下,直接在本地生成新的代理 JWT,跳过 DB 事务 + new_token = create_proxy_access_token(self.server_name, auth_context) + request.headers[INTERNAL_PROXY_TOKEN_HEADER] = new_token + _resolved_headers_cache[cache_key] = dict(request.headers) + yield request + return # 导入数据库会话管理器以获取连接与 Token from yuxi.storage.postgres.manager import pg_manager @@ -63,25 +69,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. async with pg_manager.get_async_session_context() as session: from yuxi.services.mcp.server_service import get_runtime_mcp_server_config - # NOTE: 2. 读取当前上下文对应的最新运行时配置(含 Token 自动刷新逻辑) + # NOTE: 读取当前上下文对应的最新运行时配置(含 Token 自动刷新逻辑) runtime_config = await get_runtime_mcp_server_config( self.server_name, auth_context=auth_context, db=session, ) if runtime_config: - # NOTE: 3. 将最新的头部注入到当前 HTTP 请求中 headers = dict(runtime_config.get("headers") or {}) - _resolved_headers_cache[cache_key] = (headers, now + _HEADERS_CACHE_TTL) - - # 惰性清理过期条目,避免无界膨胀 - stale_keys = [ - k - for k in list(_resolved_headers_cache.keys())[:100] - if _resolved_headers_cache.get(k, (None, float("inf")))[1] <= now - ][:20] - for k in stale_keys: - _resolved_headers_cache.pop(k, None) + _resolved_headers_cache[cache_key] = headers for key, val in headers.items(): request.headers[key] = str(val) @@ -151,9 +147,9 @@ class MCPClientPool: """MCP 客户端连接池实现""" def __init__(self): - # 缓存键格式: (server_name, partition_key) -> (LongLivedSession, config_hash) - self._sessions: dict[tuple[str, str], tuple[LongLivedSession, str]] = {} - self._lock = asyncio.Lock() + # 缓存键格式: (server_name, partition_key) -> tuple[LongLivedSession, str] | asyncio.Future + self._sessions: dict[tuple[str, str], Any] = {} + self._dict_lock = asyncio.Lock() def _calculate_config_hash(self, config: dict[str, Any]) -> str: """根据配置计算 Hash 用于比对配置是否脏变""" @@ -197,30 +193,52 @@ async def get_session( config_hash = self._calculate_config_hash(runtime_config) cache_key = (server_name, partition_key) - async with self._lock: - existing = self._sessions.get(cache_key) - if existing: - ll_session, cached_hash = existing - # NOTE: 如果配置无变化且 Session 处于活动状态,直接复用 - if cached_hash == config_hash and ll_session.session is not None: - return ll_session.session - - # 如果发生配置变化或 Session 断开,执行销毁 + while True: + async with self._dict_lock: + existing = self._sessions.get(cache_key) + + if existing is not None: + if isinstance(existing, asyncio.Future): + future = existing + stale_session = None + else: + ll_session, cached_hash = existing + if cached_hash == config_hash and ll_session.session is not None: + return ll_session.session + + self._sessions.pop(cache_key, None) + stale_session = ll_session + future = None + else: + future = None + stale_session = None + + stale_keys = [k for k in self._sessions if k[0] == server_name and k != cache_key] + stale_other_sessions = [] + for stale_key in stale_keys: + stale_val = self._sessions.pop(stale_key) + if not isinstance(stale_val, asyncio.Future): + stale_other_sessions.append(stale_val[0]) + + init_future = asyncio.get_running_loop().create_future() + self._sessions[cache_key] = init_future + break + + if future is not None: + await future + continue + + if stale_session is not None: logger.info(f"Destroying stale/disconnected MCP session for {cache_key}") - await ll_session.stop() - self._sessions.pop(cache_key, None) - - # NOTE: 驱逐同 server_name 下其他过时 partition_key 的旧 session, - # 防止 revision 变化导致的连接池内存泄漏 - stale_keys = [k for k in self._sessions if k[0] == server_name and k != cache_key] - for stale_key in stale_keys: - stale_session, _ = self._sessions.pop(stale_key) - logger.info(f"Evicting stale MCP session for {stale_key}") await stale_session.stop() + continue - # NOTE: 针对 HTTP/SSE 协议,注入自定义的 httpx.Auth 认证流以支持长连接动态 Token + for s_session in stale_other_sessions: + logger.info("Evicting stale MCP session") + await s_session.stop() + + try: client_config = dict(runtime_config) - # 清理框架保留魔法键 for magic_k in ( "__yuxi_cache_partition", "__yuxi_allow_global_cache", @@ -229,7 +247,6 @@ async def get_session( client_config.pop(magic_k, None) if client_config.get("transport") in ("sse", "http", "streamable_http", "streamable-http"): - # 注入 DynamicMCPTokenAuth,让底层 httpx 在长连接执行每个具体请求时动态提取最新 Token client_config["auth"] = DynamicMCPTokenAuth(server_name) logger.info( @@ -241,9 +258,20 @@ async def get_session( ll_session = LongLivedSession(client, server_name) await ll_session.start() - self._sessions[cache_key] = (ll_session, config_hash) + result = (ll_session, config_hash) + init_future.set_result(result) + async with self._dict_lock: + self._sessions[cache_key] = result return ll_session.session + except BaseException as exc: + if not init_future.done(): + init_future.set_exception(exc) + async with self._dict_lock: + if self._sessions.get(cache_key) is init_future: + self._sessions.pop(cache_key, None) + raise + async def ensure_prewarm( self, server_name: str, @@ -258,12 +286,20 @@ async def ensure_prewarm( async def shutdown(self): """关闭并回收连接池中的所有连接""" - async with self._lock: - for cache_key, (ll_session, _) in list(self._sessions.items()): - logger.info(f"Stopping MCP session for {cache_key} during shutdown") - await ll_session.stop() + async with self._dict_lock: + sessions_to_stop = [] + for cache_key, val in list(self._sessions.items()): + if isinstance(val, asyncio.Future): + val.cancel() + else: + ll_session, _ = val + sessions_to_stop.append((cache_key, ll_session)) self._sessions.clear() + for cache_key, ll_session in sessions_to_stop: + logger.info(f"Stopping MCP session for {cache_key} during shutdown") + await ll_session.stop() + # 全局单例连接池 mcp_client_pool = MCPClientPool() diff --git a/backend/package/yuxi/services/mcp/server_service.py b/backend/package/yuxi/services/mcp/server_service.py index cc64d604b..6f962286b 100644 --- a/backend/package/yuxi/services/mcp/server_service.py +++ b/backend/package/yuxi/services/mcp/server_service.py @@ -405,15 +405,17 @@ async def delete_mcp_server(db: AsyncSession, name: str) -> bool: if not server: return False - await db.delete(server) - await db.commit() - from yuxi.services.mcp.tool_registry_service import ( _clear_mcp_server_runtime_auth_cache, invalidate_mcp_server_tools_cache, ) + # NOTE: 必须在级联删除前执行 Redis 缓存清理,否则关联的 connection 行被删除后将无法提取 ID await _clear_mcp_server_runtime_auth_cache(db, name) + + await db.delete(server) + await db.commit() + await invalidate_mcp_server_tools_cache(name) logger.info(f"Deleted MCP server '{name}'") diff --git a/backend/package/yuxi/services/mcp/tool_registry_service.py b/backend/package/yuxi/services/mcp/tool_registry_service.py index ce6db2515..c9cf8f09b 100644 --- a/backend/package/yuxi/services/mcp/tool_registry_service.py +++ b/backend/package/yuxi/services/mcp/tool_registry_service.py @@ -10,6 +10,7 @@ from typing import Any, cast import httpx +from cachetools import LRUCache from sqlalchemy.ext.asyncio import AsyncSession from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_auth.orchestrator import AuthContext @@ -23,8 +24,8 @@ logger = logging.getLogger("yuxi.mcp.tool_registry_service") # 全局共享状态(直接在本模块维护,供外部和测试使用) -_mcp_tools_cache: dict[str, list[Callable[..., Any]]] = {} -_mcp_tools_stats: dict[str, dict[str, int]] = {} +_mcp_tools_cache: LRUCache = LRUCache(maxsize=128) +_mcp_tools_stats: LRUCache = LRUCache(maxsize=128) _mcp_tool_cache_store = RedisMcpToolCache() _mcp_lock = asyncio.Lock() diff --git a/backend/package/yuxi/services/mcp_auth/config_models.py b/backend/package/yuxi/services/mcp_auth/config_models.py index eccb991f8..a453882db 100644 --- a/backend/package/yuxi/services/mcp_auth/config_models.py +++ b/backend/package/yuxi/services/mcp_auth/config_models.py @@ -21,7 +21,7 @@ class RefreshPolicy(BaseModel): class MCPAuthConfig(BaseModel): - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra="ignore") version: int = 1 provider: Literal[ diff --git a/backend/package/yuxi/services/mcp_auth/crypto.py b/backend/package/yuxi/services/mcp_auth/crypto.py index f4b5ccac4..5edff8e65 100644 --- a/backend/package/yuxi/services/mcp_auth/crypto.py +++ b/backend/package/yuxi/services/mcp_auth/crypto.py @@ -6,10 +6,12 @@ import os from typing import Any +from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.kdf.hkdf import HKDF MASTER_KEY_ENV = "MCP_CREDENTIALS_MASTER_KEY" -ENVELOPE_VERSION = 1 +ENVELOPE_VERSION = 2 ENVELOPE_KEY_ID = "local" _AAD = b"yuxi:mcp_credentials:v1" @@ -21,10 +23,22 @@ def _get_master_key() -> str: return value -def _derive_aes_key(master_key: str) -> bytes: +def _derive_aes_key_v1(master_key: str) -> bytes: + # legacy v1 key derivation (raw sha256) return hashlib.sha256(master_key.encode("utf-8")).digest() +def _derive_aes_key_v2(master_key: str, salt: bytes) -> bytes: + # v2 key derivation using HKDF + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + info=b"mcp-credentials-v2", + ) + return hkdf.derive(master_key.encode("utf-8")) + + def _b64encode(value: bytes) -> str: return base64.urlsafe_b64encode(value).decode("ascii") @@ -43,7 +57,10 @@ def _parse_envelope(blob: str) -> dict[str, Any] | None: required_keys = {"v", "kid", "nonce", "ciphertext"} if not required_keys.issubset(payload.keys()): return None - if payload.get("v") != ENVELOPE_VERSION: + v = payload.get("v") + if v not in (1, 2): + return None + if v == 2 and "salt" not in payload: return None return payload @@ -61,13 +78,15 @@ def encrypt_credential_blob(plaintext: str) -> str: return plaintext master_key = _get_master_key() - aesgcm = AESGCM(_derive_aes_key(master_key)) + salt = os.urandom(16) + aesgcm = AESGCM(_derive_aes_key_v2(master_key, salt)) nonce = os.urandom(12) ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), _AAD) return json.dumps( { "v": ENVELOPE_VERSION, "kid": ENVELOPE_KEY_ID, + "salt": _b64encode(salt), "nonce": _b64encode(nonce), "ciphertext": _b64encode(ciphertext), }, @@ -85,7 +104,16 @@ def decrypt_credential_blob(blob: str | None) -> str | None: return blob master_key = _get_master_key() - aesgcm = AESGCM(_derive_aes_key(master_key)) + v = payload.get("v") + if v == 1: + key = _derive_aes_key_v1(master_key) + elif v == 2: + salt = _b64decode(payload["salt"]) + key = _derive_aes_key_v2(master_key, salt) + else: + return blob + + aesgcm = AESGCM(key) plaintext = aesgcm.decrypt( _b64decode(payload["nonce"]), _b64decode(payload["ciphertext"]), diff --git a/backend/package/yuxi/services/mcp_auth/fetchers/base.py b/backend/package/yuxi/services/mcp_auth/fetchers/base.py index f670597b7..b8c89cf87 100644 --- a/backend/package/yuxi/services/mcp_auth/fetchers/base.py +++ b/backend/package/yuxi/services/mcp_auth/fetchers/base.py @@ -45,7 +45,7 @@ async def fetch_custom_http_token( response_map = response_map or dict(_DEFAULT_TOKEN_RESPONSE_MAP) if http_client is None: - http_client = httpx.AsyncClient() + http_client = httpx.AsyncClient(timeout=httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0)) should_close = True else: should_close = False @@ -76,6 +76,8 @@ async def fetch_custom_http_token( else: request_kwargs["data"] = body + request_kwargs["timeout"] = httpx.Timeout(10.0, read=30.0) + response = await http_client.request(**request_kwargs) response.raise_for_status() payload = response.json() diff --git a/backend/package/yuxi/services/mcp_auth/proxy_service.py b/backend/package/yuxi/services/mcp_auth/proxy_service.py index aad2234c5..334764721 100644 --- a/backend/package/yuxi/services/mcp_auth/proxy_service.py +++ b/backend/package/yuxi/services/mcp_auth/proxy_service.py @@ -9,13 +9,29 @@ from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from starlette.background import BackgroundTask from yuxi.services.mcp_auth.config_models import MCPAuthConfig from yuxi.services.mcp_auth.orchestrator import AuthContext, resolve_runtime_mcp_config from yuxi.storage.postgres.models_business import MCPConnection, MCPServer from server.utils.auth_utils import AuthUtils +_proxy_http_client: httpx.AsyncClient | None = None + + +def get_shared_proxy_client() -> httpx.AsyncClient: + global _proxy_http_client + if _proxy_http_client is None: + _proxy_http_client = httpx.AsyncClient(timeout=httpx.Timeout(connect=30.0, pool=120.0, read=120.0, write=30.0)) + return _proxy_http_client + + +async def close_shared_proxy_client() -> None: + global _proxy_http_client + if _proxy_http_client is not None: + await _proxy_http_client.aclose() + _proxy_http_client = None + + INTERNAL_PROXY_TOKEN_HEADER = "X-Yuxi-MCP-Proxy-Token" INTERNAL_PROXY_DISABLE_TOOL_OBJECT_CACHE_KEY = "__yuxi_disable_tool_object_cache" _PROXY_TOKEN_TYPE = "mcp_proxy" @@ -213,15 +229,13 @@ async def _proxy_mcp_request_stream( connect_timeout = server.timeout or 60.0 read_timeout = server.sse_read_timeout or connect_timeout - http_client = _http_client or httpx.AsyncClient( - timeout=httpx.Timeout( - connect=connect_timeout, - read=read_timeout, - write=connect_timeout, - pool=connect_timeout, - ) + request_timeout = httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=connect_timeout, + pool=connect_timeout, ) - bg_task = BackgroundTask(http_client.aclose) + http_client = _http_client or get_shared_proxy_client() if _token_cache is not None: token_cache = _token_cache @@ -248,6 +262,7 @@ async def _proxy_mcp_request_stream( url=target_url, headers=upstream_headers, content=body, + timeout=request_timeout, ) # 使用 send(stream=True) 获取异步可迭代响应而不会阻塞 SSE 长链接 @@ -262,7 +277,7 @@ async def _proxy_mcp_request_stream( content='{"error": "insufficient_scope", "message": "当前授权范围不足"}', status_code=403, media_type="application/json", - background=bg_task, + background=None, ) if response.status_code != 401: @@ -283,7 +298,7 @@ async def proxy_stream_generator(): resp_headers[k] = v return StreamingResponse( - proxy_stream_generator(), status_code=response.status_code, headers=resp_headers, background=bg_task + proxy_stream_generator(), status_code=response.status_code, headers=resp_headers, background=None ) # 如果是 401,回收流连接并准备重试 @@ -300,5 +315,5 @@ async def proxy_stream_generator(): content='{"error": "reauth_required", "message": "连接失效,请重新连接"}', status_code=424, media_type="application/json", - background=bg_task, + background=None, ) diff --git a/backend/package/yuxi/services/mcp_auth/redis_token_cache.py b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py index 606789ea5..4e96a4016 100644 --- a/backend/package/yuxi/services/mcp_auth/redis_token_cache.py +++ b/backend/package/yuxi/services/mcp_auth/redis_token_cache.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import uuid from collections.abc import Awaitable, Callable from datetime import UTC, datetime from typing import Any @@ -15,27 +14,6 @@ DEFAULT_LOCK_TTL_SECONDS = 30 -_PYTEST_SESSION_TOKEN = uuid.uuid4().hex[:8] - - -def _access_token_key(connection_id: int) -> str: - key = f"{ACCESS_TOKEN_KEY_PREFIX}:{connection_id}" - import os - - if os.environ.get("PYTEST_CURRENT_TEST"): - return f"test:{_PYTEST_SESSION_TOKEN}:{key}" - return key - - -def _refresh_lock_key(connection_id: int) -> str: - key = f"{REFRESH_LOCK_KEY_PREFIX}:{connection_id}" - import os - - if os.environ.get("PYTEST_CURRENT_TEST"): - return f"test:{_PYTEST_SESSION_TOKEN}:{key}" - return key - - def _compute_token_ttl_seconds(token_payload: dict[str, Any]) -> int: expires_at = token_payload.get("expires_at") if isinstance(expires_at, str): @@ -54,15 +32,32 @@ def _compute_token_ttl_seconds(token_payload: dict[str, Any]) -> int: class RedisTokenCache: - def __init__(self, redis_client_factory: Callable[[], Awaitable[Any]] | None = None): + def __init__( + self, + redis_client_factory: Callable[[], Awaitable[Any]] | None = None, + key_prefix: str | None = None, + ): self._redis_client_factory = redis_client_factory or get_redis_client + self._key_prefix = key_prefix + + def _access_token_key(self, connection_id: int) -> str: + key = f"{ACCESS_TOKEN_KEY_PREFIX}:{connection_id}" + if self._key_prefix: + return f"{self._key_prefix}:{key}" + return key + + def _refresh_lock_key(self, connection_id: int) -> str: + key = f"{REFRESH_LOCK_KEY_PREFIX}:{connection_id}" + if self._key_prefix: + return f"{self._key_prefix}:{key}" + return key async def _get_redis(self): return await self._redis_client_factory() async def get_access_token(self, connection_id: int) -> dict[str, Any] | None: redis = await self._get_redis() - raw = await redis.get(_access_token_key(connection_id)) + raw = await redis.get(self._access_token_key(connection_id)) if not raw: return None if isinstance(raw, dict): @@ -73,20 +68,20 @@ async def set_access_token(self, connection_id: int, token_payload: dict[str, An redis = await self._get_redis() ttl_seconds = _compute_token_ttl_seconds(token_payload) await redis.set( - _access_token_key(connection_id), + self._access_token_key(connection_id), json.dumps(token_payload, ensure_ascii=False, separators=(",", ":")), ex=ttl_seconds, ) async def delete_access_token(self, connection_id: int) -> None: redis = await self._get_redis() - await redis.delete(_access_token_key(connection_id)) + await redis.delete(self._access_token_key(connection_id)) async def acquire_refresh_lock(self, connection_id: int, *, ttl_seconds: int = DEFAULT_LOCK_TTL_SECONDS) -> bool: redis = await self._get_redis() - acquired = await redis.set(_refresh_lock_key(connection_id), "1", ex=ttl_seconds, nx=True) + acquired = await redis.set(self._refresh_lock_key(connection_id), "1", ex=ttl_seconds, nx=True) return bool(acquired) async def release_refresh_lock(self, connection_id: int) -> None: redis = await self._get_redis() - await redis.delete(_refresh_lock_key(connection_id)) + await redis.delete(self._refresh_lock_key(connection_id)) diff --git a/backend/server/utils/lifespan.py b/backend/server/utils/lifespan.py index 9b6539471..ccc8d72d1 100644 --- a/backend/server/utils/lifespan.py +++ b/backend/server/utils/lifespan.py @@ -101,6 +101,14 @@ async def lifespan(app: FastAPI): """) logger.info("Yuxi backend startup complete") yield + + from yuxi.services.mcp.client_pool import mcp_client_pool + from yuxi.services.mcp_auth.proxy_service import close_shared_proxy_client + + logger.info("Shutting down MCP client pool and proxy clients...") + await mcp_client_pool.shutdown() + await close_shared_proxy_client() + await tasker.shutdown() shutdown_sandbox_provider() await close_queue_clients() diff --git a/backend/test/unit/services/test_mcp_auth_crypto.py b/backend/test/unit/services/test_mcp_auth_crypto.py index feee81e76..aa117588b 100644 --- a/backend/test/unit/services/test_mcp_auth_crypto.py +++ b/backend/test/unit/services/test_mcp_auth_crypto.py @@ -21,10 +21,35 @@ def test_encrypt_and_decrypt_credential_blob_round_trip(monkeypatch): decrypted = decrypt_credential_blob(encrypted) assert encrypted != plaintext - assert json.loads(encrypted)["v"] == 1 + payload = json.loads(encrypted) + assert payload["v"] == 2 + assert "salt" in payload assert decrypted == plaintext +def test_decrypt_legacy_v1_envelope(monkeypatch): + monkeypatch.setenv("MCP_CREDENTIALS_MASTER_KEY", "local-test-master-key") + plaintext = "super-secret-legacy" + + import hashlib + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + from yuxi.services.mcp_auth.crypto import _b64encode + + key = hashlib.sha256(b"local-test-master-key").digest() + aesgcm = AESGCM(key) + nonce = os.urandom(12) + ciphertext = aesgcm.encrypt(nonce, plaintext.encode("utf-8"), b"yuxi:mcp_credentials:v1") + + v1_blob = json.dumps({ + "v": 1, + "kid": "local", + "nonce": _b64encode(nonce), + "ciphertext": _b64encode(ciphertext), + }) + + assert decrypt_credential_blob(v1_blob) == plaintext + + def test_decrypt_credential_blob_keeps_legacy_plaintext_payload(monkeypatch): monkeypatch.delenv("MCP_CREDENTIALS_MASTER_KEY", raising=False) plaintext = '{"secrets":{"access_token":"legacy-token"}}' diff --git a/backend/test/unit/services/test_mcp_client_pool.py b/backend/test/unit/services/test_mcp_client_pool.py index 0cb43fd06..36bb5c2ad 100644 --- a/backend/test/unit/services/test_mcp_client_pool.py +++ b/backend/test/unit/services/test_mcp_client_pool.py @@ -144,8 +144,7 @@ async def test_dynamic_mcp_token_auth_cache(): token = mcp_auth_context_var.set(auth_ctx) try: with patch("yuxi.storage.postgres.manager.pg_manager.get_async_session_context") as mock_session_ctx, \ - patch("yuxi.services.mcp.server_service.get_runtime_mcp_server_config", return_value=mock_runtime_config) as mock_get_config, \ - patch("time.time", side_effect=[1000.0, 1005.0, 1010.0, 1030.0]): + patch("yuxi.services.mcp.server_service.get_runtime_mcp_server_config", return_value=mock_runtime_config) as mock_get_config: # 模拟 async with pg_manager.get_async_session_context() as session mock_session = MagicMock() @@ -178,22 +177,14 @@ async def test_dynamic_mcp_token_auth_cache(): results_3 = [r async for r in generator_3] assert len(results_3) == 1 assert mock_get_config.call_count == 2 - - # 5. 第四次请求(+20秒):经过了 10 秒(累计过了15秒的 TTL 限制),由于 1020 - 1010 >= 15,缓存过期,再次执行 DB 查询 - mock_req_4 = MagicMock(headers={}) - generator_4 = auth.async_auth_flow(mock_req_4) - results_4 = [r async for r in generator_4] - assert len(results_4) == 1 - assert mock_get_config.call_count == 3 - - # 6. 测试 clear_mcp_cache / clear_mcp_server_tools_cache 联动清除所有 resolved_headers 缓存 - from yuxi.services.mcp.tool_registry_service import clear_mcp_cache, clear_mcp_server_tools_cache + # 4. 测试 clear_mcp_cache / clear_mcp_server_tools_cache 联动清除所有 resolved_headers 缓存 + from yuxi.services.mcp.tool_registry_service import clear_mcp_cache, invalidate_mcp_server_tools_cache # 确保当前有缓存项 - _resolved_headers_cache[cache_key] = ({"Auth": "Bearer test"}, 2000.0) - clear_mcp_server_tools_cache("test_server") + _resolved_headers_cache[cache_key] = {"Auth": "Bearer test"} + await invalidate_mcp_server_tools_cache("test_server") assert len(_resolved_headers_cache) == 0 - _resolved_headers_cache[cache_key] = ({"Auth": "Bearer test"}, 2000.0) + _resolved_headers_cache[cache_key] = {"Auth": "Bearer test"} await clear_mcp_cache() assert len(_resolved_headers_cache) == 0 finally: diff --git a/backend/test/unit/services/test_model_provider_service.py b/backend/test/unit/services/test_model_provider_service.py index dbbb8148f..3303aca42 100644 --- a/backend/test/unit/services/test_model_provider_service.py +++ b/backend/test/unit/services/test_model_provider_service.py @@ -97,6 +97,12 @@ async def fake_fetch(client, provider, headers, endpoint, model_type): return [{"id": f"{model_type}-model", "type": model_type}] monkeypatch.setattr("yuxi.services.model_provider_service._fetch_models_from_endpoint", fake_fetch) + + from unittest.mock import AsyncMock, MagicMock + mock_client_instance = MagicMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + monkeypatch.setattr("yuxi.services.model_provider_service.httpx.AsyncClient", lambda **kwargs: mock_client_instance) class Provider: base_url = "https://example.com/v1" diff --git a/backend/uv.lock b/backend/uv.lock index 7fdfad461..41305dd5e 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -5650,6 +5650,7 @@ dependencies = [ { name = "argon2-cffi" }, { name = "asyncpg" }, { name = "beautifulsoup4" }, + { name = "cachetools" }, { name = "chardet" }, { name = "colorlog" }, { name = "dashscope" }, @@ -5729,6 +5730,7 @@ requires-dist = [ { name = "argon2-cffi", specifier = ">=25.1.0" }, { name = "asyncpg", specifier = ">=0.30.0" }, { name = "beautifulsoup4", specifier = ">=4.12.0" }, + { name = "cachetools", specifier = ">=5.3.0" }, { name = "chardet", specifier = ">=5.0.0" }, { name = "colorlog", specifier = ">=6.9.0" }, { name = "dashscope", specifier = ">=1.23.2" }, From 7002976868a68ed350ab7bca819d8efcaea7c8c5 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 09:09:02 +0800 Subject: [PATCH 18/36] =?UTF-8?q?fix(web):=20=E4=BF=AE=E5=A4=8D=20MCP=20?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E9=85=8D=E7=BD=AE=E4=B8=AD=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E4=B8=8B=E6=8B=89=E6=A1=86=E6=98=BE=E7=A4=BA=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=E5=8F=8A=E9=87=8D=E5=A4=8D=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/extensions/McpDetailView.vue | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/extensions/McpDetailView.vue b/web/src/components/extensions/McpDetailView.vue index 57b0273ad..6a235e146 100644 --- a/web/src/components/extensions/McpDetailView.vue +++ b/web/src/components/extensions/McpDetailView.vue @@ -663,7 +663,7 @@ :loading="isFetchingScopeOptions" placeholder="请选择用户" show-search - :options="userList.map(u => ({ label: `${u.username} (${u.user_id_login})`, value: u.id.toString() }))" + :options="userList.map(u => ({ label: u.username === u.user_id ? u.username : `${u.username} (${u.user_id})`, value: u.id.toString() }))" /> Date: Fri, 5 Jun 2026 09:10:49 +0800 Subject: [PATCH 19/36] =?UTF-8?q?feat(web):=20=E8=BF=9E=E6=8E=A5=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E9=9D=A2=E6=9D=BF=E6=94=AF=E6=8C=81=E6=98=BE=E7=A4=BA?= =?UTF-8?q?=E5=8F=AF=E8=AF=BB=E7=9A=84=E8=AE=A4=E8=AF=81=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/extensions/McpDetailView.vue | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/web/src/components/extensions/McpDetailView.vue b/web/src/components/extensions/McpDetailView.vue index 6a235e146..645b9785e 100644 --- a/web/src/components/extensions/McpDetailView.vue +++ b/web/src/components/extensions/McpDetailView.vue @@ -489,7 +489,7 @@
认证方式 - {{ server.auth_config?.provider || '未配置' }} + {{ providerLabelMap[server.auth_config?.provider] || server.auth_config?.provider || '未配置' }}
默认绑定 @@ -921,6 +921,14 @@ const statusLabelMap = { invalid: '无效' } +const providerLabelMap = { + none: '不启用', + bound_secret: '绑定长期密钥', + custom_http_token: '接口换 Token', + stdio_env: 'StdIO 环境变量', + client_credentials: 'OAuth2 客户端凭证' +} + const actionLabel = computed(() => { if (server.value?.enabled === false) return '恢复' return server.value?.created_by === 'system' ? '移除' : '退役' From 0e63bf4a13abb70690ef01a55aa99dab85bf9b6f Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 09:12:01 +0800 Subject: [PATCH 20/36] =?UTF-8?q?style(web):=20MCP=20=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E8=AF=A6=E6=83=85=E9=A1=B5=E5=8F=AA=E8=AF=BB=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=E4=B8=8B=E9=9A=90=E8=97=8F=E5=86=97=E9=95=BF=E7=9A=84=E8=AE=A4?= =?UTF-8?q?=E8=AF=81=E9=85=8D=E7=BD=AE=20JSON=EF=BC=8C=E6=94=B9=E7=94=A8?= =?UTF-8?q?=E7=AE=80=E7=95=A5=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/extensions/McpDetailView.vue | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/web/src/components/extensions/McpDetailView.vue b/web/src/components/extensions/McpDetailView.vue index 645b9785e..ec3d7d520 100644 --- a/web/src/components/extensions/McpDetailView.vue +++ b/web/src/components/extensions/McpDetailView.vue @@ -349,7 +349,10 @@ v-if="server.auth_config && Object.keys(server.auth_config).length > 0" > -
{{ JSON.stringify(server.auth_config, null, 2) }}
+ + {{ providerLabelMap[server.auth_config.provider] || server.auth_config.provider || '已配置' }} + (进入编辑模式查看详情) +
From 81f08e9a09388603442754bf0534622decd1742d Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 09:14:02 +0800 Subject: [PATCH 21/36] =?UTF-8?q?feat(web):=20=E6=94=AF=E6=8C=81=E5=8F=AA?= =?UTF-8?q?=E8=AF=BB=E6=A8=A1=E5=BC=8F=E6=B8=B2=E6=9F=93=E8=AE=A4=E8=AF=81?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=20(McpAuthConfigBuilder)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../extensions/McpAuthConfigBuilder.vue | 62 +++++++++++-------- .../components/extensions/McpDetailView.vue | 14 +++-- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/web/src/components/extensions/McpAuthConfigBuilder.vue b/web/src/components/extensions/McpAuthConfigBuilder.vue index f909f6a68..1775ca07c 100644 --- a/web/src/components/extensions/McpAuthConfigBuilder.vue +++ b/web/src/components/extensions/McpAuthConfigBuilder.vue @@ -23,7 +23,7 @@ v-for="option in providerOptions" :key="option.value" type="button" - class="auth-provider-card" + class="auth-provider-card" :disabled="readonly" :class="{ active: form.provider === option.value }" @click="switchProvider(option.value)" > @@ -52,7 +52,7 @@ v-for="scope in bindingScopeOptions" :key="scope.value" type="button" - class="binding-scope-card" + class="binding-scope-card" :disabled="readonly" :class="{ active: form.bindingScope === scope.value }" @click="form.bindingScope = scope.value" > @@ -71,12 +71,12 @@
-
+
- + 添加注入项 @@ -123,13 +124,13 @@
- + POST GET PUT @@ -137,7 +138,7 @@
- + JSON Form @@ -154,19 +155,20 @@ :key="`header-${index}`" class="row-editor-line" > - - + +
- + 添加一行 @@ -180,14 +182,14 @@ :key="`body-${index}`" class="row-editor-line" > - - + + @@ -195,7 +197,7 @@ 添加一行 @@ -212,14 +214,14 @@ :key="`response-${index}`" class="row-editor-line" > - - + + @@ -227,7 +229,7 @@ 添加一行 @@ -243,7 +245,7 @@
- + 按连接隔离 服务级共享 @@ -251,7 +253,7 @@
- + 收到 401 时清理缓存并自动重试一次。
@@ -312,10 +314,10 @@ -
+
格式化 导入到向导
@@ -354,6 +356,7 @@ import { } from '@/utils/mcpAuthConfigBuilder' const props = defineProps({ + readonly: { type: Boolean, default: false }, modelValue: { type: String, default: '' }, transport: { type: String, default: 'streamable_http' } }) @@ -737,6 +740,11 @@ watch(jsonDraft, (value) => { background: var(--main-10); color: var(--main-color); } + + &:disabled { + cursor: not-allowed; + opacity: 0.7; + } } .binding-scope-card { diff --git a/web/src/components/extensions/McpDetailView.vue b/web/src/components/extensions/McpDetailView.vue index ec3d7d520..bd7348467 100644 --- a/web/src/components/extensions/McpDetailView.vue +++ b/web/src/components/extensions/McpDetailView.vue @@ -345,14 +345,16 @@ {{ server.created_by }}
- - - {{ providerLabelMap[server.auth_config.provider] || server.auth_config.provider || '已配置' }} - (进入编辑模式查看详情) - + +
From 3ab2b32a364b02b4979cdb7e19bb40e55571b3ee Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 10:11:03 +0800 Subject: [PATCH 22/36] =?UTF-8?q?fix(web):=20=E4=BF=AE=E5=A4=8D=20McpAuthC?= =?UTF-8?q?onfigBuilder=20=E7=BB=84=E4=BB=B6=E4=B8=AD=E9=87=8D=E5=A4=8D?= =?UTF-8?q?=E7=9A=84=20v-if=20=E5=B1=9E=E6=80=A7=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E7=9A=84=E7=BC=96=E8=AF=91=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/extensions/McpAuthConfigBuilder.vue | 1 - 1 file changed, 1 deletion(-) diff --git a/web/src/components/extensions/McpAuthConfigBuilder.vue b/web/src/components/extensions/McpAuthConfigBuilder.vue index 1775ca07c..460a342c8 100644 --- a/web/src/components/extensions/McpAuthConfigBuilder.vue +++ b/web/src/components/extensions/McpAuthConfigBuilder.vue @@ -163,7 +163,6 @@ danger :disabled="form.tokenHeaders.length === 1" @click="removeKeyValueRow(form.tokenHeaders, index)" v-if="!readonly" - v-if="!readonly" >
From 5063a8b5634b7c2f20eea7a0f682660d4d3aa21d Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Fri, 5 Jun 2026 10:13:45 +0800 Subject: [PATCH 23/36] =?UTF-8?q?style(web):=20=E4=BC=98=E5=8C=96=E5=8F=AA?= =?UTF-8?q?=E8=AF=BB=E6=A8=A1=E5=BC=8F=E4=B8=8B=20McpAuthConfigBuilder=20?= =?UTF-8?q?=E7=9A=84=E6=A0=B7=E5=BC=8F=EF=BC=8C=E9=81=BF=E5=85=8D=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=A1=86=E5=92=8C=E6=8C=89=E9=92=AE=E7=9C=8B=E8=B5=B7?= =?UTF-8?q?=E6=9D=A5=E5=83=8F=E7=A6=81=E7=94=A8=E7=9A=84=E7=81=B0=E8=89=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../extensions/McpAuthConfigBuilder.vue | 64 ++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/web/src/components/extensions/McpAuthConfigBuilder.vue b/web/src/components/extensions/McpAuthConfigBuilder.vue index 460a342c8..0c4775abd 100644 --- a/web/src/components/extensions/McpAuthConfigBuilder.vue +++ b/web/src/components/extensions/McpAuthConfigBuilder.vue @@ -1,5 +1,5 @@ From 73fa65638fbd9b61cd8103bd92e004a4594c7f00 Mon Sep 17 00:00:00 2001 From: supreme0597 Date: Tue, 9 Jun 2026 19:33:47 +0800 Subject: [PATCH 36/36] =?UTF-8?q?fix(mcp):=20=E4=BC=98=E5=8C=96=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E6=80=81=E5=8A=A0=E8=BD=BD=E4=B8=8E=E7=A6=BB=E7=BA=BF?= =?UTF-8?q?=E9=99=8D=E5=99=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../yuxi/agents/buildin/chatbot/graph.py | 7 +- .../yuxi/agents/buildin/deep_agent/graph.py | 6 +- .../middlewares/runtime_config_middleware.py | 18 +++- .../agents/middlewares/skills_middleware.py | 30 +++++- .../package/yuxi/services/mcp/client_pool.py | 11 ++- .../services/mcp/tool_registry_service.py | 93 +++++++++++++++++-- backend/package/yuxi/utils/logging_config.py | 6 ++ .../test_runtime_config_middleware.py | 10 +- .../middlewares/test_skills_middleware.py | 33 ++++++- .../test_mcp_tool_registry_service.py | 80 ++++++++++++++-- docs/develop-guides/roadmap.md | 3 +- 11 files changed, 267 insertions(+), 30 deletions(-) diff --git a/backend/package/yuxi/agents/buildin/chatbot/graph.py b/backend/package/yuxi/agents/buildin/chatbot/graph.py index 3f0872256..db3d9b9ee 100644 --- a/backend/package/yuxi/agents/buildin/chatbot/graph.py +++ b/backend/package/yuxi/agents/buildin/chatbot/graph.py @@ -12,16 +12,19 @@ save_attachments_to_fs, ) from yuxi.agents.middlewares.knowledge_base_middleware import KnowledgeBaseMiddleware -from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware +from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware, collect_context_mcp_names_for_preload from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers from yuxi.services.subagent_service import get_subagents_from_names +from yuxi.utils.logging_config import logger from .prompt import TODO_MID_PROMPT, build_prompt_with_context async def _build_middlewares(context): """构建中间件列表""" - all_mcp_tools = await get_tools_from_all_servers() # 因为异步加载,无法放在 RuntimeConfigMiddleware 的 __init__ 中 + preload_mcp_names = await collect_context_mcp_names_for_preload(context) + logger.info(f"ChatbotAgent MCP preload candidates: {preload_mcp_names}") + all_mcp_tools = await get_tools_from_all_servers(preload_mcp_names) # summary middleware # 主 Agent 上下文优化:90k tokens 触发压缩(128k context window 的 70%) diff --git a/backend/package/yuxi/agents/buildin/deep_agent/graph.py b/backend/package/yuxi/agents/buildin/deep_agent/graph.py index e56155fe5..4f100f0dc 100644 --- a/backend/package/yuxi/agents/buildin/deep_agent/graph.py +++ b/backend/package/yuxi/agents/buildin/deep_agent/graph.py @@ -15,7 +15,7 @@ save_attachments_to_fs, ) from yuxi.agents.middlewares.knowledge_base_middleware import KnowledgeBaseMiddleware -from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware +from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware, collect_context_mcp_names_for_preload from yuxi.agents.toolkits.buildin.tools import _create_tavily_search from yuxi.services.mcp.tool_registry_service import get_tools_from_all_servers from yuxi.services.subagent_service import get_subagents_from_names @@ -57,7 +57,9 @@ async def get_graph(self, context=None, **kwargs): model = load_chat_model(context.model) sub_model = load_chat_model(context.subagents_model) search_tools = await self.get_tools() - all_mcp_tools = await get_tools_from_all_servers() + preload_mcp_names = await collect_context_mcp_names_for_preload(context) + logger.info(f"DeepAgent MCP preload candidates: {preload_mcp_names}") + all_mcp_tools = await get_tools_from_all_servers(preload_mcp_names) # 合并搜索工具和 MCP 工具 # 从数据库加载 subagent specs(工具名称已解析) diff --git a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py index ec0eca60e..8a4ba6c92 100644 --- a/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py +++ b/backend/package/yuxi/agents/middlewares/runtime_config_middleware.py @@ -182,10 +182,15 @@ async def get_tools_from_context(self, context) -> list: all_mcp_names.append(server_name) selected_mcp_servers: set[str] = set() + selected_mcp_names: list[str] = [] + loaded_mcp_tools: dict[str, int] = {} + unavailable_mcp_servers: list[str] = [] + failed_mcp_servers: list[str] = [] for server_name in all_mcp_names: if server_name in selected_mcp_servers: continue selected_mcp_servers.add(server_name) + selected_mcp_names.append(server_name) try: user_id = getattr(context, "user_id", None) work_id = getattr(context, "work_id", None) @@ -198,9 +203,20 @@ async def get_tools_from_context(self, context) -> list: ), ) if not mcp_tools: - logger.warning(f"RuntimeConfigMiddleware: mcp dependency unavailable, skip: {server_name}") + unavailable_mcp_servers.append(server_name) + logger.debug(f"RuntimeConfigMiddleware: mcp dependency unavailable, skip: {server_name}") + else: + loaded_mcp_tools[server_name] = len(mcp_tools) selected_tools.extend(mcp_tools) except Exception as e: + failed_mcp_servers.append(server_name) logger.warning(f"RuntimeConfigMiddleware: failed to load mcp dependency '{server_name}': {e}") + if selected_mcp_names: + logger.info( + "RuntimeConfigMiddleware MCP runtime selection: " + f"selected={selected_mcp_names}, loaded={loaded_mcp_tools}, " + f"unavailable={unavailable_mcp_servers}, failed={failed_mcp_servers}" + ) + return selected_tools diff --git a/backend/package/yuxi/agents/middlewares/skills_middleware.py b/backend/package/yuxi/agents/middlewares/skills_middleware.py index dcb336333..7d90407b4 100644 --- a/backend/package/yuxi/agents/middlewares/skills_middleware.py +++ b/backend/package/yuxi/agents/middlewares/skills_middleware.py @@ -80,6 +80,21 @@ async def get_dependency_map(db: AsyncSession | None = None) -> dict[str, SkillD return result +async def collect_context_mcp_names_for_preload(context, *, skills_context_name: str = "skills") -> list[str]: + """收集图构建阶段需要预注册的 MCP 名称。""" + names: list[str] = [] + names.extend(normalize_selected_skills(getattr(context, "mcps", None) or [])) + + dependency_map = await get_dependency_map() + configured_skills = normalize_selected_skills(getattr(context, skills_context_name, None) or []) + for slug in expand_skill_closure(configured_skills, dependency_map): + node = dependency_map.get(slug) + if node: + names.extend(node.get("mcps", [])) + + return normalize_selected_skills(names) + + def normalize_selected_skills(selected_skills: list[str] | None) -> list[str]: """规范化 skills 列表,去重并过滤无效值""" return _normalize_string_list(selected_skills) @@ -339,6 +354,9 @@ async def _get_mcp_tools_from_context( # 去重 unique_mcp_names = list(dict.fromkeys(all_mcp_names)) + loaded_mcp_tools: dict[str, int] = {} + unavailable_mcp_servers: list[str] = [] + failed_mcp_servers: list[str] = [] async def load_mcp_tools(server_name: str) -> list: """加载单个 MCP 服务器的工具""" @@ -354,14 +372,24 @@ async def load_mcp_tools(server_name: str) -> list: ), ) if not mcp_tools: - logger.warning(f"SkillsMiddleware: mcp dependency unavailable, skip: {server_name}") + unavailable_mcp_servers.append(server_name) + logger.debug(f"SkillsMiddleware: mcp dependency unavailable, skip: {server_name}") + else: + loaded_mcp_tools[server_name] = len(mcp_tools) return mcp_tools except Exception as e: + failed_mcp_servers.append(server_name) logger.warning(f"SkillsMiddleware: failed to load mcp dependency '{server_name}': {e}") return [] # 并行加载所有 MCP 工具 results = await asyncio.gather(*[load_mcp_tools(name) for name in unique_mcp_names]) + if unique_mcp_names: + logger.info( + "SkillsMiddleware MCP dependency selection: " + f"selected={unique_mcp_names}, loaded={loaded_mcp_tools}, " + f"unavailable={unavailable_mcp_servers}, failed={failed_mcp_servers}" + ) selected_tools = [] for tools in results: selected_tools.extend(tools) diff --git a/backend/package/yuxi/services/mcp/client_pool.py b/backend/package/yuxi/services/mcp/client_pool.py index 6e2bf8b6b..3b8f95dce 100644 --- a/backend/package/yuxi/services/mcp/client_pool.py +++ b/backend/package/yuxi/services/mcp/client_pool.py @@ -122,8 +122,12 @@ async def _run_loop(self): self._ready_event.set() # 挂起直到收到停止指令 await self._stop_event.wait() - except Exception: - logger.error(f"Error in long-lived MCP session loop for {self.server_name}", exc_info=True) + except Exception as exc: + if self.session is None: + logger.debug(f"Failed to start MCP session for {self.server_name}: {exc}", exc_info=True) + else: + logger.warning(f"MCP session loop stopped for {self.server_name}: {exc}") + logger.debug(f"Error in long-lived MCP session loop for {self.server_name}", exc_info=True) finally: self.session = None self._running = False @@ -270,9 +274,12 @@ async def get_session( except BaseException as exc: if not init_future.done(): init_future.set_exception(exc) + init_future.exception() async with self._dict_lock: if self._sessions.get(cache_key) is init_future: self._sessions.pop(cache_key, None) + raise + async def remove_session(self, server_name: str, partition_key: str): """移除指定 key 的连接,强制下一次请求重新创建""" cache_key = (server_name, partition_key) diff --git a/backend/package/yuxi/services/mcp/tool_registry_service.py b/backend/package/yuxi/services/mcp/tool_registry_service.py index 372d59570..0bb859c80 100644 --- a/backend/package/yuxi/services/mcp/tool_registry_service.py +++ b/backend/package/yuxi/services/mcp/tool_registry_service.py @@ -4,7 +4,8 @@ import hashlib import json import logging -import traceback +import os +import time from collections.abc import Callable from types import SimpleNamespace from typing import Any, cast @@ -26,8 +27,10 @@ # 全局共享状态(直接在本模块维护,供外部和测试使用) _mcp_tools_cache: LRUCache = LRUCache(maxsize=128) _mcp_tools_stats: LRUCache = LRUCache(maxsize=128) +_mcp_tools_failure_cache: LRUCache = LRUCache(maxsize=256) _mcp_tool_cache_store = RedisMcpToolCache() _mcp_lock = asyncio.Lock() +_MCP_TOOL_FAILURE_COOLDOWN_SECONDS = float(os.getenv("YUXI_MCP_TOOL_FAILURE_COOLDOWN_SECONDS", "30")) def to_camel_case(s: str) -> str: @@ -171,6 +174,37 @@ def _can_preload_mcp_server_tools_without_runtime_auth(server_config: dict[str, return auth_config.provider == "legacy_static" +def _get_cached_mcp_tool_failure(cache_key: str) -> dict[str, Any] | None: + entry = _mcp_tools_failure_cache.get(cache_key) + if not entry: + return None + retry_at = float(entry.get("retry_at") or 0) + if retry_at <= time.monotonic(): + _mcp_tools_failure_cache.pop(cache_key, None) + return None + return entry + + +def _record_mcp_tool_failure(cache_key: str, exc: BaseException) -> None: + if _MCP_TOOL_FAILURE_COOLDOWN_SECONDS <= 0: + return + _mcp_tools_failure_cache[cache_key] = { + "retry_at": time.monotonic() + _MCP_TOOL_FAILURE_COOLDOWN_SECONDS, + "message": str(exc) or exc.__class__.__name__, + } + + +def _clear_mcp_tool_failure(cache_key: str) -> None: + _mcp_tools_failure_cache.pop(cache_key, None) + + +def _clear_mcp_tool_failure_cache_for_server(server_name: str) -> None: + prefix = f"{server_name}:" + stale_keys = [key for key in _mcp_tools_failure_cache if key.startswith(prefix)] + for key in stale_keys: + _mcp_tools_failure_cache.pop(key, None) + + async def get_mcp_tools( server_name: str, additional_servers: dict[str, dict[str, Any]] | None = None, @@ -219,6 +253,16 @@ async def get_mcp_tools( all_processed_tools = _mcp_tools_cache[cache_key] if not all_processed_tools: + if not force_refresh: + failure_entry = _get_cached_mcp_tool_failure(cache_key) + if failure_entry is not None: + retry_in = max(0.0, float(failure_entry.get("retry_at") or 0) - time.monotonic()) + logger.debug( + f"Skip loading MCP tools for '{server_name}' during failure cooldown " + f"({retry_in:.1f}s left): {failure_entry.get('message')}" + ) + return [] + try: client_config = { k: v @@ -296,13 +340,22 @@ async def get_mcp_tools( f"{len(all_processed_tools)} tools loaded." ) + _clear_mcp_tool_failure(cache_key) + except Exception as e: - logger.error( - f"Failed to load tools from MCP server '{server_name}': {e}, traceback: {traceback.format_exc()}" + _record_mcp_tool_failure(cache_key, e) + logger.warning( + f"MCP server '{server_name}' temporarily unavailable; " + f"suppress retries for {_MCP_TOOL_FAILURE_COOLDOWN_SECONDS:.0f}s: {e}" ) + logger.debug(f"Failed to load tools from MCP server '{server_name}'", exc_info=True) try: - partition_key = f"{cache_partition}:s{cache_descriptor['server_revision']}:p{cache_descriptor['partition_revision']}" + partition_key = ( + f"{cache_partition}:s{cache_descriptor['server_revision']}:" + f"p{cache_descriptor['partition_revision']}" + ) from yuxi.services.mcp.client_pool import mcp_client_pool + await mcp_client_pool.remove_session(server_name, partition_key) except Exception as pool_err: logger.warning(f"Failed to remove stale session for {server_name}: {pool_err}") @@ -315,11 +368,26 @@ async def get_mcp_tools( return all_processed_tools -async def get_tools_from_all_servers() -> list[Callable[..., Any]]: - """批量载入所有可用服务的工具(用于系统初始化及预热)""" +async def get_tools_from_all_servers(server_names: list[str] | None = None) -> list[Callable[..., Any]]: + """批量载入指定或所有可用服务的工具(用于系统初始化及预热)""" from yuxi.services.mcp.server_service import _load_enabled_mcp_server_configs - server_configs = await _load_enabled_mcp_server_configs() + names: list[str] | None = None + if server_names is not None: + names = [] + seen: set[str] = set() + for value in server_names: + if not isinstance(value, str): + continue + name = value.strip() + if not name or name in seen: + continue + seen.add(name) + names.append(name) + if not names: + return [] + + server_configs = await _load_enabled_mcp_server_configs(names=names) all_tools = [] for server_name, server_config in server_configs.items(): if not _can_preload_mcp_server_tools_without_runtime_auth(server_config): @@ -332,8 +400,9 @@ async def get_tools_from_all_servers() -> list[Callable[..., Any]]: async def clear_mcp_cache() -> None: """清空本地内存工具缓存""" - global _mcp_tools_cache - _mcp_tools_cache = {} + global _mcp_tools_cache, _mcp_tools_failure_cache + _mcp_tools_cache = LRUCache(maxsize=128) + _mcp_tools_failure_cache = LRUCache(maxsize=256) try: from yuxi.services.mcp.client_pool import clear_resolved_headers_cache, mcp_client_pool @@ -351,6 +420,7 @@ def clear_mcp_server_tools_cache(server_name: str) -> None: stale_keys = [k for k in _mcp_tools_cache if k.startswith(prefix)] for key in stale_keys: _mcp_tools_cache.pop(key, None) + _clear_mcp_tool_failure_cache_for_server(server_name) try: from yuxi.services.mcp.client_pool import clear_server_resolved_headers_cache @@ -369,6 +439,11 @@ def clear_mcp_connection_tools_cache(server_name: str, connection_id: int | None stale_keys = [k for k in _mcp_tools_cache if suffix in k and k.startswith(f"{server_name}:")] for key in stale_keys: _mcp_tools_cache.pop(key, None) + stale_failure_keys = [ + key for key in _mcp_tools_failure_cache if suffix in key and key.startswith(f"{server_name}:") + ] + for key in stale_failure_keys: + _mcp_tools_failure_cache.pop(key, None) try: from yuxi.services.mcp.client_pool import clear_server_resolved_headers_cache diff --git a/backend/package/yuxi/utils/logging_config.py b/backend/package/yuxi/utils/logging_config.py index d76f9c996..1235e2208 100644 --- a/backend/package/yuxi/utils/logging_config.py +++ b/backend/package/yuxi/utils/logging_config.py @@ -44,6 +44,12 @@ def _setup_logging_bridge(): lightrag_logger.setLevel(logging.DEBUG) lightrag_logger.propagate = False # 避免重复 + # 桥接 MCP 服务层日志,便于在 agent 运行时直接观察 MCP 选择、加载和失败冷却。 + mcp_logger = logging.getLogger("yuxi.mcp") + mcp_logger.addHandler(loguru_handler) + mcp_logger.setLevel(logging.INFO) + mcp_logger.propagate = False + # 桥接其他常见第三方库(降低级别减少噪音) for lib in ["httpx", "openai", "neo4j", "urllib3"]: lib_logger = logging.getLogger(lib) diff --git a/backend/test/unit/middlewares/test_runtime_config_middleware.py b/backend/test/unit/middlewares/test_runtime_config_middleware.py index d4cee0f1e..3f6c792cb 100644 --- a/backend/test/unit/middlewares/test_runtime_config_middleware.py +++ b/backend/test/unit/middlewares/test_runtime_config_middleware.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from types import SimpleNamespace import pytest @@ -12,7 +13,10 @@ @pytest.mark.asyncio @pytest.mark.unit -async def test_get_tools_from_context_passes_auth_context_to_mcp_loader(monkeypatch: pytest.MonkeyPatch): +async def test_get_tools_from_context_passes_auth_context_to_mcp_loader( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +): captured: list[tuple[str, str | None, str | None]] = [] monkeypatch.setattr(runtime_config_middleware, "get_all_tool_instances", lambda: []) @@ -32,10 +36,12 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= department_id="dept-9", ) - tools = await middleware.get_tools_from_context(context) + with caplog.at_level(logging.WARNING, logger="Yuxi"): + tools = await middleware.get_tools_from_context(context) assert tools == [] assert captured == [("finance-gateway", "user-1", "dept-9")] + assert "mcp dependency unavailable" not in caplog.text @pytest.mark.asyncio diff --git a/backend/test/unit/middlewares/test_skills_middleware.py b/backend/test/unit/middlewares/test_skills_middleware.py index 7694d48e4..f3664e165 100644 --- a/backend/test/unit/middlewares/test_skills_middleware.py +++ b/backend/test/unit/middlewares/test_skills_middleware.py @@ -1,16 +1,20 @@ from __future__ import annotations +import logging from types import SimpleNamespace import pytest import yuxi.agents.middlewares.skills_middleware as skills_middleware -from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware +from yuxi.agents.middlewares.skills_middleware import SkillsMiddleware, collect_context_mcp_names_for_preload @pytest.mark.asyncio @pytest.mark.unit -async def test_get_mcp_tools_from_context_passes_auth_context_to_mcp_loader(monkeypatch: pytest.MonkeyPatch): +async def test_get_mcp_tools_from_context_passes_auth_context_to_mcp_loader( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +): captured: list[tuple[str, str | None, str | None]] = [] async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db=None, http_client=None): @@ -27,10 +31,12 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= department_id="dept-9", ) - tools = await middleware._get_mcp_tools_from_context(context) + with caplog.at_level(logging.WARNING, logger="Yuxi"): + tools = await middleware._get_mcp_tools_from_context(context) assert tools == [] assert captured == [("finance-gateway", "user-1", "dept-9")] + assert "mcp dependency unavailable" not in caplog.text @pytest.mark.asyncio @@ -57,3 +63,24 @@ async def fake_get_enabled_mcp_tools(server_name: str, *, auth_context=None, db= assert tools == [] assert captured == [("dts-mcp_server", "2", "dept-9")] + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_collect_context_mcp_names_for_preload_includes_configured_skill_dependencies( + monkeypatch: pytest.MonkeyPatch, +): + async def fake_get_dependency_map(db=None): + del db + return { + "reporter": {"tools": [], "mcps": ["charts"], "skills": ["common"]}, + "common": {"tools": [], "mcps": ["finance-gateway"], "skills": []}, + } + + monkeypatch.setattr(skills_middleware, "get_dependency_map", fake_get_dependency_map) + + context = SimpleNamespace(mcps=["direct", "charts"], skills=["reporter"]) + + names = await collect_context_mcp_names_for_preload(context) + + assert names == ["direct", "charts", "finance-gateway"] diff --git a/backend/test/unit/services/test_mcp_tool_registry_service.py b/backend/test/unit/services/test_mcp_tool_registry_service.py index 4fe07518c..dbd40a610 100644 --- a/backend/test/unit/services/test_mcp_tool_registry_service.py +++ b/backend/test/unit/services/test_mcp_tool_registry_service.py @@ -2,9 +2,8 @@ from types import SimpleNamespace -from yuxi.services.mcp import connection_service, server_service, tool_registry_service +from yuxi.services.mcp import server_service, tool_registry_service from yuxi.services.mcp.client_pool import mcp_client_pool -from yuxi.services.mcp_auth.redis_token_cache import RedisTokenCache from yuxi.services.mcp_tool_cache import RedisMcpToolCache from yuxi.services.mcp_auth.proxy_service import INTERNAL_PROXY_TOKEN_HEADER @@ -135,6 +134,38 @@ async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs ] +async def test_get_tools_from_all_servers_limits_preload_to_selected_names(monkeypatch): + server_configs = { + "alpha": {"transport": "stdio", "command": "cmd-a", "disabled_tools": []}, + "beta": {"transport": "stdio", "command": "cmd-b", "disabled_tools": []}, + } + loaded_names: list[list[str] | None] = [] + calls: list[str] = [] + + async def fake_load_enabled_mcp_server_configs(*, names=None, db=None): + del db + loaded_names.append(names) + if not names: + return server_configs + return {name: server_configs[name] for name in names if name in server_configs} + + async def fake_get_mcp_tools(server_name: str, additional_servers=None, **kwargs): + del additional_servers, kwargs + calls.append(server_name) + return [server_name] + + monkeypatch.setattr(server_service, "_load_enabled_mcp_server_configs", fake_load_enabled_mcp_server_configs) + monkeypatch.setattr(tool_registry_service, "get_mcp_tools", fake_get_mcp_tools) + + tools = await tool_registry_service.get_tools_from_all_servers(["alpha", "alpha", "missing"]) + empty_tools = await tool_registry_service.get_tools_from_all_servers([]) + + assert tools == ["alpha"] + assert empty_tools == [] + assert loaded_names == [["alpha", "missing"]] + assert calls == ["alpha"] + + async def test_get_mcp_tools_sets_handle_tool_error(monkeypatch): await tool_registry_service.clear_mcp_cache() @@ -158,6 +189,34 @@ async def fake_get_mcp_client(server_configs): await tool_registry_service.clear_mcp_cache() +async def test_get_mcp_tools_suppresses_retries_during_failure_cooldown(monkeypatch): + await tool_registry_service.clear_mcp_cache() + + config = {"transport": "stdio", "command": "offline-demo", "disabled_tools": []} + build_calls: list[dict] = [] + + async def fail_get_mcp_client(server_configs): + build_calls.append(server_configs) + raise ConnectionError("mcp service offline") + + monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fail_get_mcp_client) + + tools_first = await tool_registry_service.get_mcp_tools("offline", additional_servers={"offline": config}) + tools_second = await tool_registry_service.get_mcp_tools("offline", additional_servers={"offline": config}) + tools_forced = await tool_registry_service.get_mcp_tools( + "offline", + additional_servers={"offline": config}, + force_refresh=True, + ) + + assert tools_first == [] + assert tools_second == [] + assert tools_forced == [] + assert len(build_calls) == 2 + + await tool_registry_service.clear_mcp_cache() + + async def test_get_mcp_tools_keeps_connection_partitions_separate(monkeypatch): await tool_registry_service.clear_mcp_cache() @@ -227,21 +286,28 @@ async def test_get_mcp_tools_does_not_cache_internal_proxy_tool_objects(monkeypa }, ] build_calls: list[str] = [] + tool_load_count = 0 + + class RefreshingFakeClient: + async def get_tools(self): + nonlocal tool_load_count + tool_load_count += 1 + tool = SimpleNamespace(name=f"tool_for_load_{tool_load_count}", metadata={}) + return [tool] async def fake_get_mcp_client(server_configs): token = server_configs["demo"]["headers"][INTERNAL_PROXY_TOKEN_HEADER] build_calls.append(token) - tool = SimpleNamespace(name=f"tool_for_{token}", metadata={}) - return _FakeClient([tool]) + return RefreshingFakeClient() monkeypatch.setattr(mcp_client_pool, "_get_mcp_client", fake_get_mcp_client) tools_first = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[0]}) tools_second = await tool_registry_service.get_mcp_tools("demo", additional_servers={"demo": configs[1]}) - assert [tool.name for tool in tools_first] == ["tool_for_proxy-token-v1"] - assert [tool.name for tool in tools_second] == ["tool_for_proxy-token-v2"] - assert build_calls == ["proxy-token-v1", "proxy-token-v2"] + assert [tool.name for tool in tools_first] == ["tool_for_load_1"] + assert [tool.name for tool in tools_second] == ["tool_for_load_2"] + assert build_calls == ["proxy-token-v1"] await tool_registry_service.clear_mcp_cache() diff --git a/docs/develop-guides/roadmap.md b/docs/develop-guides/roadmap.md index 2e0b2655f..23e3fe7fd 100644 --- a/docs/develop-guides/roadmap.md +++ b/docs/develop-guides/roadmap.md @@ -40,7 +40,8 @@ - **MCP 多鉴权编排与内部代理链路**:为 MCP 接入新增 `auth_config_json` 与 `mcp_connections` 绑定模型,支持 `bound_secret`、`custom_http_token`、`client_credentials`、`stdio_env`、`authorization_code` 等鉴权编排基础;后端补齐基于 Redis 的 access token 缓存、预刷新与 401 后自动删缓存重试能力,并新增 `/api/internal/mcp-proxy/{server_name}` 内部代理路由,将动态 HTTP MCP 的鉴权、续期 and 重试逻辑统一收敛到服务端;补齐用户/部门绑定连接缺失时的内部代理拒绝逻辑,避免个人级 MCP 连接被其他用户通过代理入口串用;同时让管理端 `/api/system/mcp-servers/{name}/tools` 与 `/tools/refresh` 也按当前管理员的 `user_id/department_id` 解析绑定连接,避免跨部门管理员在未授权情况下探测到 MCP 工具列表;新增 Redis 版次 + manifest 分级缓存,让 API/Worker 多进程场景下的 MCP 工具清单按 `server` / `connection` 分区同步失效,并避免旧 graph 中预加载 of managed tool 覆盖本轮实时鉴权加载结果;修复动态 HTTP 内部代理短期 JWT 被工具对象缓存固化、停用 MCP 仍可通过内部代理访问、更新 `auth_config` 后 runtime token 未立即清理的问题;统一 Agent 运行态与连接管理页的个人 MCP scope 语义,避免运行态使用数据库主键查找 `mcp_connections.scope_id` 导致个人连接不可用;补齐运行时鉴权 MCP 工具的执行阶段映射,避免模型已绑定 `getTicket` 等动态工具但 ToolNode 静态注册表无法执行的问题;审计并修复该链路隐患:通过 DynamicMCPTokenAuth 引入 15 秒 TTL 在内存缓存(含联动清除机制)解决 httpx 请求对 DB 的高频重复查询问题;修复 `_normalize_token_payload` 处理 naive datetime 的时区偏差问题以消除 token 无限自动刷新的 Bug;改进 `_calculate_config_hash` 哈希计算逻辑,对 json.dumps 增加 default=str 降级保护防止无法序列化而崩溃的问题;优化免密钥连接测试,在 binding_scope 非 inline 且配置未引用 `\${secret.xxx}` 变量时免去 connection 的强检验,允许直接进行测试和工具加载;在 `client_pool` 中实现长连接失效的断线清理机制,防止 anyio.ClosedResourceError 报错固化在缓存中;修复 mock MCP demo server 在 FastAPI 路由下返回值的 ASGI 响应冲突,将其重构为原生 ASGI App 路由,并在 Docker 中容器化部署;前端增加 `${context.work_id}` 快注按键并补齐后端 context.work_id 工号识别支持;修复未配置认证时前端发送空字典 `{}` 导致 Pydantic 400 校验错误的问题。 - 本次补充:明确 `YUXI_INTERNAL_MCP_PROXY_BASE_URL` 是动态 HTTP MCP 的内部鉴权网关地址;统一 runtime config 与代理入口的 active connection 强制规则,允许未引用 `${secret.xxx}` 的动态 MCP 无绑定连接运行;连接测试补齐 user scope 的 `work_id`,连接池 hash 忽略短期 `X-Yuxi-MCP-Proxy-Token`;补充 MCP 动态鉴权使用说明和开发手册。 - 本次补充:将个人级 MCP 连接配置收敛到用户设置弹框,普通用户仅可查看脱敏 MCP 信息并维护自己的 `user` scope 连接;管理员仍在扩展管理中维护 MCP 服务、共享连接与工具开关。 - - 本次补充:优化 MCP 连接管理体验,管理页连接区支持健康筛选、绑定对象搜索和分页;连接卡片统一展示凭据状态、生效范围、绑定对象与单一问题主动作,设置页沿用同一卡片语言但隐藏生效范围。 + - 本次补充:优化 MCP 连接管理体验,管理页连接区支持健康筛选、绑定对象搜索和分页;连接卡片统一展示生效范围、绑定对象与单一问题主动作,设置页沿用同一卡片语言并在详情头部展示生效范围。 + - 本次补充:为 MCP 工具加载失败增加短期冷却与日志降噪,服务端离线时 Agent 运行态会跳过不可用 MCP,避免每轮运行重复建连并输出大量 error traceback;图构建阶段只预加载当前 agent 配置与已配置 skill 依赖的 MCP,手动刷新或配置变更仍会重新探测。 ---