diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index ae20eb8e1c..b5ee75ca24 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -1,15 +1,262 @@ +from __future__ import annotations + import asyncio -import re +import time +from collections.abc import Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING from astrbot import logger from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType +from astrbot.core.utils.error_redaction import safe_error + +if TYPE_CHECKING: + from astrbot.core.provider.provider import Provider + + +MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT = 30.0 +MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT = 4 +MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND = 16 +MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds" +MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency" +MODEL_CACHE_MAX_ENTRIES = 512 + + +@dataclass(frozen=True) +class _ModelLookupConfig: + umo: str | None + cache_ttl_seconds: float + max_concurrency: int + + +class _ModelCache: + def __init__(self) -> None: + self._store: dict[tuple[str, str | None], tuple[float, list[str]]] = {} + + def get(self, provider_id: str, umo: str | None, ttl: float) -> list[str] | None: + if ttl <= 0: + return None + entry = self._store.get((provider_id, umo)) + if not entry: + return None + timestamp, models = entry + if time.monotonic() - timestamp > ttl: + self._store.pop((provider_id, umo), None) + return None + return models + + def set( + self, provider_id: str, umo: str | None, models: list[str], ttl: float + ) -> None: + if ttl <= 0: + return + self._store[(provider_id, umo)] = (time.monotonic(), list(models)) + self._evict_if_needed() + + def _evict_if_needed(self) -> None: + if len(self._store) <= MODEL_CACHE_MAX_ENTRIES: + return + # Drop oldest entries first when cache grows too large. + overflow = len(self._store) - MODEL_CACHE_MAX_ENTRIES + for key, _ in sorted( + self._store.items(), + key=lambda item: item[1][0], + )[:overflow]: + self._store.pop(key, None) + + def invalidate( + self, provider_id: str | None = None, *, umo: str | None = None + ) -> None: + if provider_id is None: + self._store.clear() + return + if umo is not None: + self._store.pop((provider_id, umo), None) + return + stale_keys = [ + cache_key for cache_key in self._store if cache_key[0] == provider_id + ] + for cache_key in stale_keys: + self._store.pop(cache_key, None) class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context + self._model_cache = _ModelCache() + self._register_provider_change_hook() + + def _register_provider_change_hook(self) -> None: + set_change_callback = getattr( + self.context.provider_manager, + "set_provider_change_callback", + None, + ) + if callable(set_change_callback): + set_change_callback(self._on_provider_manager_changed) + return + register_change_hook = getattr( + self.context.provider_manager, + "register_provider_change_hook", + None, + ) + if callable(register_change_hook): + register_change_hook(self._on_provider_manager_changed) + + def invalidate_provider_models_cache( + self, provider_id: str | None = None, *, umo: str | None = None + ) -> None: + """Public hook for cache invalidation on external provider config changes.""" + self._model_cache.invalidate(provider_id, umo=umo) + + def _on_provider_manager_changed( + self, + provider_id: str, + provider_type: ProviderType, + umo: str | None, + ) -> None: + if provider_type == ProviderType.CHAT_COMPLETION: + self.invalidate_provider_models_cache(provider_id, umo=umo) + + def _get_provider_settings(self, umo: str | None) -> dict: + if not umo: + return {} + try: + return self.context.get_config(umo).get("provider_settings", {}) or {} + except Exception as e: + logger.debug( + "读取 provider_settings 失败,使用默认值: %s", + safe_error("", e), + ) + return {} + + def _get_model_cache_ttl(self, umo: str | None) -> float: + settings = self._get_provider_settings(umo) + raw = settings.get( + MODEL_LIST_CACHE_TTL_KEY, + MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, + ) + try: + return max(float(raw), 0.0) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %r: %s", + MODEL_LIST_CACHE_TTL_KEY, + MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, + safe_error("", e), + ) + return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT + + def _get_model_lookup_concurrency(self, umo: str | None) -> int: + settings = self._get_provider_settings(umo) + raw = settings.get( + MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, + ) + try: + value = int(raw) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %r: %s", + MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, + safe_error("", e), + ) + value = MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT + return min(max(value, 1), MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND) + + def _get_model_lookup_config(self, umo: str | None) -> _ModelLookupConfig: + return _ModelLookupConfig( + umo=umo, + cache_ttl_seconds=self._get_model_cache_ttl(umo), + max_concurrency=self._get_model_lookup_concurrency(umo), + ) + + def _resolve_model_name( + self, + model_name: str, + models: Sequence[str], + ) -> str | None: + """Resolve model name with precedence: + exact > case-insensitive > provider-qualified suffix. + """ + requested = model_name.strip() + if not requested: + return None + + requested_norm = requested.casefold() + + # exact / case-insensitive match + for candidate in models: + if candidate == requested or candidate.casefold() == requested_norm: + return candidate + + # provider-qualified suffix match: + # e.g. candidate `openai/gpt-4o` should match requested `gpt-4o`. + for candidate in models: + cand_norm = candidate.casefold() + if cand_norm.endswith(f"/{requested_norm}") or cand_norm.endswith( + f":{requested_norm}" + ): + return candidate + + return None + + def _apply_model( + self, prov: Provider, model_name: str, *, umo: str | None = None + ) -> str: + prov.set_model(model_name) + self.invalidate_provider_models_cache(prov.meta().id, umo=umo) + return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]" + + async def _get_provider_models( + self, + provider: Provider, + *, + config: _ModelLookupConfig, + use_cache: bool = True, + ) -> list[str]: + provider_id = provider.meta().id + ttl_seconds = config.cache_ttl_seconds + umo = config.umo + if use_cache: + cached = self._model_cache.get(provider_id, umo, ttl_seconds) + if cached is not None: + return cached + + models = list(await provider.get_models()) + if use_cache: + self._model_cache.set(provider_id, umo, models, ttl_seconds) + return models + + async def _get_models_or_reply_error( + self, + message: AstrMessageEvent, + prov: Provider, + config: _ModelLookupConfig, + *, + error_prefix: str, + disable_t2i: bool = False, + warning_log: str | None = None, + ) -> list[str] | None: + try: + return await self._get_provider_models(prov, config=config) + except asyncio.CancelledError: + raise + except Exception as e: + if warning_log is not None: + logger.warning( + warning_log, + prov.meta().id, + safe_error("", e), + ) + result = MessageEventResult().message(safe_error(error_prefix, e)) + if disable_t2i: + result = result.use_t2i(False) + message.set_result(result) + return None def _log_reachability_failure( self, @@ -38,12 +285,96 @@ async def _test_provider_capability(self, provider): return True, None, None except Exception as e: err_code = "TEST_FAILED" - err_reason = str(e) + err_reason = safe_error("", e) self._log_reachability_failure( provider, provider_capability_type, err_code, err_reason ) return False, err_code, err_reason + async def _find_provider_for_model( + self, + model_name: str, + *, + exclude_provider_id: str | None = None, + config: _ModelLookupConfig, + use_cache: bool = True, + ) -> tuple[Provider | None, str | None]: + all_providers = [] + for provider in self.context.get_all_providers(): + provider_meta = provider.meta() + if provider_meta.provider_type != ProviderType.CHAT_COMPLETION: + continue + if ( + exclude_provider_id is not None + and provider_meta.id == exclude_provider_id + ): + continue + all_providers.append(provider) + if not all_providers: + return None, None + + semaphore = asyncio.Semaphore(config.max_concurrency) + + async def fetch_models( + provider: Provider, + ) -> tuple[Provider, list[str] | None, str | None]: + async with semaphore: + try: + models = await self._get_provider_models( + provider, + config=config, + use_cache=use_cache, + ) + return provider, models, None + except asyncio.CancelledError: + raise + except Exception as e: + err = safe_error("", e) + logger.debug( + "跨提供商查找模型 %s 获取 %s 模型列表失败: %s", + model_name, + provider.meta().id, + err, + ) + return provider, None, err + + results = await asyncio.gather( + *(fetch_models(provider) for provider in all_providers) + ) + failed_provider_errors: list[tuple[str, str]] = [] + for provider, models, err in results: + if err is not None: + failed_provider_errors.append((provider.meta().id, err)) + continue + if models is None: + continue + + matched_model_name = self._resolve_model_name(model_name, models) + if matched_model_name is not None: + return provider, matched_model_name + + if failed_provider_errors and len(failed_provider_errors) == len(all_providers): + failed_ids = ",".join( + provider_id for provider_id, _ in failed_provider_errors + ) + logger.error( + "跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络", + model_name, + len(all_providers), + failed_ids, + ) + elif failed_provider_errors: + logger.debug( + "跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s", + model_name, + len(failed_provider_errors), + ",".join( + f"{provider_id}({error})" + for provider_id, error in failed_provider_errors + ), + ) + return None, None + async def provider( self, event: AstrMessageEvent, @@ -92,13 +423,15 @@ async def provider( id_ = meta.id error_code = None + if isinstance(reachable, asyncio.CancelledError): + raise reachable if isinstance(reachable, Exception): # 异常情况下兜底处理,避免单个 provider 导致列表失败 self._log_reachability_failure( p, None, reachable.__class__.__name__, - str(reachable), + safe_error("", reachable), ) reachable_flag = False error_code = reachable.__class__.__name__ @@ -224,6 +557,73 @@ async def provider( else: event.set_result(MessageEventResult().message("无效的参数。")) + async def _switch_model_by_name( + self, message: AstrMessageEvent, model_name: str, prov: Provider + ) -> None: + model_name = model_name.strip() + if not model_name: + message.set_result(MessageEventResult().message("模型名不能为空。")) + return + + umo = message.unified_msg_origin + config = self._get_model_lookup_config(umo) + curr_provider_id = prov.meta().id + + models = await self._get_models_or_reply_error( + message, + prov, + config, + error_prefix="获取当前提供商模型列表失败: ", + warning_log="获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", + ) + if models is None: + return + + matched_model_name = self._resolve_model_name(model_name, models) + if matched_model_name is not None: + message.set_result( + MessageEventResult().message( + self._apply_model(prov, matched_model_name, umo=umo) + ), + ) + return + + target_prov, matched_target_model_name = await self._find_provider_for_model( + model_name, + exclude_provider_id=curr_provider_id, + config=config, + ) + + if target_prov is None or matched_target_model_name is None: + message.set_result( + MessageEventResult().message( + f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。", + ), + ) + return + + target_id = target_prov.meta().id + try: + await self.context.provider_manager.set_provider( + provider_id=target_id, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + self._apply_model(target_prov, matched_target_model_name, umo=umo) + message.set_result( + MessageEventResult().message( + f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", + ), + ) + except asyncio.CancelledError: + raise + except Exception as e: + message.set_result( + MessageEventResult().message( + safe_error("跨提供商切换并设置模型失败: ", e) + ), + ) + async def model_ls( self, message: AstrMessageEvent, @@ -236,20 +636,17 @@ async def model_ls( MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) return - # 定义正则表达式匹配 API 密钥 - api_key_pattern = re.compile(r"key=[^&'\" ]+") + config = self._get_model_lookup_config(message.unified_msg_origin) if idx_or_name is None: - models = [] - try: - models = await prov.get_models() - except BaseException as e: - err_msg = api_key_pattern.sub("key=***", str(e)) - message.set_result( - MessageEventResult() - .message("获取模型列表失败: " + err_msg) - .use_t2i(False), - ) + models = await self._get_models_or_reply_error( + message, + prov, + config, + error_prefix="获取模型列表失败: ", + disable_t2i=True, + ) + if models is None: return parts = ["下面列出了此模型提供商可用模型:"] for i, model in enumerate(models, 1): @@ -258,40 +655,43 @@ async def model_ls( curr_model = prov.get_model() or "无" parts.append(f"\n当前模型: [{curr_model}]") parts.append( - "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + "\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换。" ) ret = "".join(parts) message.set_result(MessageEventResult().message(ret).use_t2i(False)) elif isinstance(idx_or_name, int): - models = [] - try: - models = await prov.get_models() - except BaseException as e: - message.set_result( - MessageEventResult().message("获取模型列表失败: " + str(e)), - ) + models = await self._get_models_or_reply_error( + message, + prov, + config, + error_prefix="获取模型列表失败: ", + ) + if models is None: return if idx_or_name > len(models) or idx_or_name < 1: message.set_result(MessageEventResult().message("模型序号错误。")) else: try: new_model = models[idx_or_name - 1] - prov.set_model(new_model) - except BaseException as e: message.set_result( - MessageEventResult().message("切换模型未知错误: " + str(e)), + MessageEventResult().message( + self._apply_model( + prov, + new_model, + umo=message.unified_msg_origin, + ) + ), ) - message.set_result( - MessageEventResult().message( - f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", - ), - ) + except Exception as e: + message.set_result( + MessageEventResult().message( + safe_error("切换模型未知错误: ", e) + ), + ) + return else: - prov.set_model(idx_or_name) - message.set_result( - MessageEventResult().message(f"切换模型到 {prov.get_model()}。"), - ) + await self._switch_model_by_name(message, idx_or_name, prov) async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: prov = self.context.get_using_provider(message.unified_msg_origin) @@ -322,8 +722,15 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None try: new_key = keys_data[index - 1] prov.set_key(new_key) - except BaseException as e: + self.invalidate_provider_models_cache( + prov.meta().id, + umo=message.unified_msg_origin, + ) + message.set_result(MessageEventResult().message("切换 Key 成功。")) + except Exception as e: message.set_result( - MessageEventResult().message(f"切换 Key 未知错误: {e!s}"), + MessageEventResult().message( + safe_error("切换 Key 未知错误: ", e) + ), ) - message.set_result(MessageEventResult().message("切换 Key 成功。")) + return diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index a331c97e9b..2359a81371 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -2,11 +2,13 @@ import copy import os import traceback +from collections.abc import Callable from typing import Protocol, runtime_checkable from astrbot.core import astrbot_config, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase +from astrbot.core.utils.error_redaction import safe_error from ..persona_mgr import PersonaManager from .entities import ProviderType @@ -71,6 +73,56 @@ def __init__( self.curr_tts_provider_inst: TTSProvider | None = None """默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.db_helper = db_helper + self._provider_change_callback: ( + Callable[[str, ProviderType, str | None], None] | None + ) = None + self._provider_change_hooks: list[ + Callable[[str, ProviderType, str | None], None] + ] = [] + + def set_provider_change_callback( + self, + cb: Callable[[str, ProviderType, str | None], None] | None, + ) -> None: + # Backward-compatible single-callback setter. + # This callback coexists with register_provider_change_hook subscriptions. + self._provider_change_callback = cb + + def register_provider_change_hook( + self, + hook: Callable[[str, ProviderType, str | None], None], + ) -> None: + if hook not in self._provider_change_hooks: + self._provider_change_hooks.append(hook) + + def _notify_provider_changed( + self, + provider_id: str, + provider_type: ProviderType, + umo: str | None, + ) -> None: + if self._provider_change_callback is not None: + try: + self._provider_change_callback(provider_id, provider_type, umo) + except Exception as e: + logger.warning( + "调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s", + provider_id, + provider_type, + safe_error("", e), + ) + for hook in list(self._provider_change_hooks): + if hook is self._provider_change_callback: + continue + try: + hook(provider_id, provider_type, umo) + except Exception as e: + logger.warning( + "调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s", + provider_id, + provider_type, + safe_error("", e), + ) @property def persona_configs(self) -> list: @@ -111,6 +163,7 @@ async def set_provider( f"provider_perf_{provider_type.value}", provider_id, ) + self._notify_provider_changed(provider_id, provider_type, umo) return # 不启用提供商会话隔离模式的情况 @@ -126,6 +179,7 @@ async def set_provider( scope="global", scope_id="global", ) + self._notify_provider_changed(provider_id, provider_type, umo) elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance( prov, STTProvider, @@ -137,6 +191,7 @@ async def set_provider( scope="global", scope_id="global", ) + self._notify_provider_changed(provider_id, provider_type, umo) elif provider_type == ProviderType.CHAT_COMPLETION and isinstance( prov, Provider, @@ -148,6 +203,7 @@ async def set_provider( scope="global", scope_id="global", ) + self._notify_provider_changed(provider_id, provider_type, umo) async def get_provider_by_id(self, provider_id: str) -> Providers | None: """根据提供商 ID 获取提供商实例""" diff --git a/astrbot/core/utils/error_redaction.py b/astrbot/core/utils/error_redaction.py new file mode 100644 index 0000000000..dcab07ac58 --- /dev/null +++ b/astrbot/core/utils/error_redaction.py @@ -0,0 +1,82 @@ +import re + +_SECRET_KEYS = ( + r"(?:api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)" +) + +_JSON_FIELD_PATTERN = re.compile( + rf"(?i)(?P(?P['\"]){_SECRET_KEYS}(?P=kq)\s*:\s*)(?P['\"])(?P[^'\"]+)(?P=vq)" +) +_AUTH_JSON_FIELD_PATTERN = re.compile( + r"(?i)(?P(?P['\"])authorization(?P=kq)\s*:\s*)(?P['\"])bearer\s+[^'\"]+(?P=vq)" +) +_QUERY_FIELD_PATTERN = re.compile( + rf"(?i)(?P{_SECRET_KEYS}\s*=\s*)(?P[^&'\" ]+)" +) +_QUERY_PARAM_PATTERN = re.compile( + r"(?i)(?P[?&](?:api_?key|key|access_?token|auth_?token)=)(?P[^&'\" ]+)" +) +_AUTH_HEADER_PATTERN = re.compile( + r"(?i)(?P\bauthorization\s*:\s*bearer\s+)(?P[A-Za-z0-9._\-]+)" +) +_BEARER_PATTERN = re.compile(r"(?i)(?P\bbearer\s+)(?P[A-Za-z0-9._\-]+)") +_SK_PATTERN = re.compile(r"\bsk-[A-Za-z0-9]{16,}\b") + + +def _redact_json_field(match: re.Match[str]) -> str: + quote = match.group("vq") + return f"{match.group('prefix')}{quote}[REDACTED]{quote}" + + +def _redact_auth_json_field(match: re.Match[str]) -> str: + quote = match.group("vq") + return f"{match.group('prefix')}{quote}Bearer [REDACTED]{quote}" + + +def _redact_prefixed_value(match: re.Match[str]) -> str: + return f"{match.group('prefix')}[REDACTED]" + + +def _redact_bearer_token(match: re.Match[str]) -> str: + return f"{match.group('prefix')}[REDACTED]" + + +def _redact_json_like(text: str) -> str: + text = _JSON_FIELD_PATTERN.sub(_redact_json_field, text) + return _AUTH_JSON_FIELD_PATTERN.sub(_redact_auth_json_field, text) + + +def _redact_query_like(text: str) -> str: + text = _QUERY_FIELD_PATTERN.sub(_redact_prefixed_value, text) + return _QUERY_PARAM_PATTERN.sub(_redact_prefixed_value, text) + + +def _redact_tokens(text: str) -> str: + text = _AUTH_HEADER_PATTERN.sub(_redact_bearer_token, text) + text = _BEARER_PATTERN.sub(_redact_bearer_token, text) + return _SK_PATTERN.sub("[REDACTED]", text) + + +def redact_sensitive_text(text: str) -> str: + text = _redact_json_like(text) + text = _redact_query_like(text) + text = _redact_tokens(text) + return text + + +def safe_error( + prefix: str, + error: Exception | BaseException | str, + *, + redact: bool = True, +) -> str: + try: + text = str(error) + except Exception: + try: + text = repr(error) + except Exception: + text = "" + if redact: + text = redact_sensitive_text(text) + return prefix + text