From e5704694cfb702f028238d3f96118ba99ca6cc2a Mon Sep 17 00:00:00 2001 From: pandyzhou Date: Sun, 1 Mar 2026 00:16:54 +0800 Subject: [PATCH 01/28] fix: /model command now auto-switches provider when model exists elsewhere Made-with: Cursor --- .../builtin_commands/commands/provider.py | 73 +++++++++++++++++-- 1 file changed, 67 insertions(+), 6 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index ae20eb8e1c..67f9796a9b 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -1,11 +1,15 @@ import asyncio import re +from typing import TYPE_CHECKING from astrbot import logger from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType +if TYPE_CHECKING: + from astrbot.core.provider.provider import Provider + class ProviderCommands: def __init__(self, context: star.Context) -> None: @@ -44,6 +48,25 @@ async def _test_provider_capability(self, provider): ) return False, err_code, err_reason + async def _find_provider_for_model( + self, model_name: str, exclude_provider_id: str | None = None + ) -> tuple["Provider" | None, str | None]: + """在所有 LLM 提供商中查找包含指定模型的提供商。返回 (provider, provider_id) 或 (None, None)。""" + all_providers = self.context.get_all_providers() + results = await asyncio.gather( + *[p.get_models() for p in all_providers], + return_exceptions=True, + ) + for provider, result in zip(all_providers, results): + if isinstance(result, BaseException): + continue + provider_id = provider.meta().id + if exclude_provider_id and provider_id == exclude_provider_id: + continue + if model_name in result: + return provider, provider_id + return None, None + async def provider( self, event: AstrMessageEvent, @@ -258,7 +281,7 @@ async def model_ls( curr_model = prov.get_model() or "无" parts.append(f"\n当前模型: [{curr_model}]") parts.append( - "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + "\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换。" ) ret = "".join(parts) @@ -278,20 +301,58 @@ async def model_ls( try: new_model = models[idx_or_name - 1] prov.set_model(new_model) + message.set_result( + MessageEventResult().message( + f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", + ), + ) except BaseException as e: message.set_result( MessageEventResult().message("切换模型未知错误: " + str(e)), ) + else: + # 字符串:模型名,需智能解析是否跨提供商 + model_name = idx_or_name + umo = message.unified_msg_origin + curr_provider_id = prov.meta().id + + # 1. 检查当前提供商 + models = [] + try: + models = await prov.get_models() + except BaseException: + models = [] + if model_name in models: + prov.set_model(model_name) message.set_result( MessageEventResult().message( - f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", + f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]", ), ) - else: - prov.set_model(idx_or_name) - message.set_result( - MessageEventResult().message(f"切换模型到 {prov.get_model()}。"), + return + + # 2. 在其他提供商中查找 + target_prov, target_id = await self._find_provider_for_model( + model_name, exclude_provider_id=curr_provider_id ) + if target_prov and target_id: + await self.context.provider_manager.set_provider( + provider_id=target_id, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + target_prov.set_model(model_name) + message.set_result( + MessageEventResult().message( + f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", + ), + ) + else: + message.set_result( + MessageEventResult().message( + f"模型 [{model_name}] 未在任何已配置的提供商中找到。请使用 /provider 切换到目标提供商,或确认模型名正确。", + ), + ) async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: prov = self.context.get_using_provider(message.unified_msg_origin) From ec7995ebd2a27ab54a9b81f9924de98e789f59fd Mon Sep 17 00:00:00 2001 From: pandyzhou Date: Sun, 1 Mar 2026 00:39:40 +0800 Subject: [PATCH 02/28] fix: address Sourcery review - log get_models() failures in cross-provider lookup Made-with: Cursor --- .../builtin_commands/commands/provider.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 67f9796a9b..222996a79f 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -59,12 +59,23 @@ async def _find_provider_for_model( ) for provider, result in zip(all_providers, results): if isinstance(result, BaseException): + logger.warning( + "跨提供商查找模型时,提供商 %s 的 get_models() 失败: %s", + provider.meta().id, + result, + ) continue provider_id = provider.meta().id if exclude_provider_id and provider_id == exclude_provider_id: continue if model_name in result: return provider, provider_id + if results and all(isinstance(r, BaseException) for r in results): + logger.error( + "跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败,请检查配置或网络", + model_name, + len(all_providers), + ) return None, None async def provider( @@ -320,7 +331,7 @@ async def model_ls( models = [] try: models = await prov.get_models() - except BaseException: + except BaseException as e: models = [] if model_name in models: prov.set_model(model_name) From f9cd8422ecc994df5e9ddb23493265a4f45a0ad3 Mon Sep 17 00:00:00 2001 From: pandyzhou Date: Sun, 1 Mar 2026 00:42:00 +0800 Subject: [PATCH 03/28] fix: integer branch exception handling and API key masking in model command Made-with: Cursor --- .../builtin_stars/builtin_commands/commands/provider.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 222996a79f..b21477cbe7 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -302,8 +302,9 @@ async def model_ls( try: models = await prov.get_models() except BaseException as e: + err_msg = api_key_pattern.sub("key=***", str(e)) message.set_result( - MessageEventResult().message("获取模型列表失败: " + str(e)), + MessageEventResult().message("获取模型列表失败: " + err_msg), ) return if idx_or_name > len(models) or idx_or_name < 1: @@ -318,9 +319,11 @@ async def model_ls( ), ) except BaseException as e: + err_msg = api_key_pattern.sub("key=***", str(e)) message.set_result( - MessageEventResult().message("切换模型未知错误: " + str(e)), + MessageEventResult().message("切换模型未知错误: " + err_msg), ) + return else: # 字符串:模型名,需智能解析是否跨提供商 model_name = idx_or_name @@ -331,7 +334,7 @@ async def model_ls( models = [] try: models = await prov.get_models() - except BaseException as e: + except BaseException: models = [] if model_name in models: prov.set_model(model_name) From 83fb0e8cf0073660576e3c2033982510744b242d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:06:43 +0900 Subject: [PATCH 04/28] fix: harden cross-provider model resolution --- .../builtin_commands/commands/provider.py | 68 +++++++++++++++---- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index b21477cbe7..c8c233427b 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -1,5 +1,6 @@ import asyncio import re +import time from typing import TYPE_CHECKING from astrbot import logger @@ -10,10 +11,33 @@ if TYPE_CHECKING: from astrbot.core.provider.provider import Provider +_API_KEY_PATTERN = re.compile(r"key=[^&'\" ]+") + class ProviderCommands: + _MODEL_LIST_CACHE_TTL_SECONDS = 30.0 + def __init__(self, context: star.Context) -> None: self.context = context + self._provider_models_cache: dict[str, tuple[float, list[str]]] = {} + + @staticmethod + def _mask_sensitive_text(value: str) -> str: + return _API_KEY_PATTERN.sub("key=***", value) + + async def _get_provider_models( + self, provider: "Provider", *, use_cache: bool = True + ) -> list[str]: + provider_id = provider.meta().id + now = time.monotonic() + if use_cache: + cached = self._provider_models_cache.get(provider_id) + if cached and now - cached[0] <= self._MODEL_LIST_CACHE_TTL_SECONDS: + return list(cached[1]) + + models = list(await provider.get_models()) + self._provider_models_cache[provider_id] = (now, models) + return list(models) def _log_reachability_failure( self, @@ -52,22 +76,27 @@ async def _find_provider_for_model( self, model_name: str, exclude_provider_id: str | None = None ) -> tuple["Provider" | None, str | None]: """在所有 LLM 提供商中查找包含指定模型的提供商。返回 (provider, provider_id) 或 (None, None)。""" - all_providers = self.context.get_all_providers() + all_providers = [ + p + for p in self.context.get_all_providers() + if not exclude_provider_id or p.meta().id != exclude_provider_id + ] + if not all_providers: + return None, None results = await asyncio.gather( - *[p.get_models() for p in all_providers], + *[self._get_provider_models(p) for p in all_providers], return_exceptions=True, ) for provider, result in zip(all_providers, results): if isinstance(result, BaseException): + masked_error = self._mask_sensitive_text(str(result)) logger.warning( "跨提供商查找模型时,提供商 %s 的 get_models() 失败: %s", provider.meta().id, - result, + masked_error, ) continue provider_id = provider.meta().id - if exclude_provider_id and provider_id == exclude_provider_id: - continue if model_name in result: return provider, provider_id if results and all(isinstance(r, BaseException) for r in results): @@ -270,15 +299,13 @@ async def model_ls( MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) return - # 定义正则表达式匹配 API 密钥 - api_key_pattern = re.compile(r"key=[^&'\" ]+") if idx_or_name is None: models = [] try: - models = await prov.get_models() + models = await self._get_provider_models(prov) except BaseException as e: - err_msg = api_key_pattern.sub("key=***", str(e)) + err_msg = self._mask_sensitive_text(str(e)) message.set_result( MessageEventResult() .message("获取模型列表失败: " + err_msg) @@ -300,9 +327,9 @@ async def model_ls( elif isinstance(idx_or_name, int): models = [] try: - models = await prov.get_models() + models = await self._get_provider_models(prov) except BaseException as e: - err_msg = api_key_pattern.sub("key=***", str(e)) + err_msg = self._mask_sensitive_text(str(e)) message.set_result( MessageEventResult().message("获取模型列表失败: " + err_msg), ) @@ -319,7 +346,7 @@ async def model_ls( ), ) except BaseException as e: - err_msg = api_key_pattern.sub("key=***", str(e)) + err_msg = self._mask_sensitive_text(str(e)) message.set_result( MessageEventResult().message("切换模型未知错误: " + err_msg), ) @@ -333,9 +360,20 @@ async def model_ls( # 1. 检查当前提供商 models = [] try: - models = await prov.get_models() - except BaseException: - models = [] + models = await self._get_provider_models(prov) + except BaseException as e: + err_msg = self._mask_sensitive_text(str(e)) + logger.warning( + "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", + curr_provider_id, + err_msg, + ) + message.set_result( + MessageEventResult().message( + "获取当前提供商模型列表失败: " + err_msg + ), + ) + return if model_name in models: prov.set_model(model_name) message.set_result( From 4d5c8aeb4528977df4e9de142f5e3923e577c8f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:11:16 +0900 Subject: [PATCH 05/28] fix: improve model lookup resilience and cache hygiene --- .../builtin_commands/commands/provider.py | 70 ++++++++++++++----- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index c8c233427b..32989ec5d6 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -21,6 +21,12 @@ def __init__(self, context: star.Context) -> None: self.context = context self._provider_models_cache: dict[str, tuple[float, list[str]]] = {} + def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: + if provider_id is None: + self._provider_models_cache.clear() + return + self._provider_models_cache.pop(provider_id, None) + @staticmethod def _mask_sensitive_text(value: str) -> str: return _API_KEY_PATTERN.sub("key=***", value) @@ -87,23 +93,34 @@ async def _find_provider_for_model( *[self._get_provider_models(p) for p in all_providers], return_exceptions=True, ) + failed_provider_errors: list[tuple[str, str]] = [] for provider, result in zip(all_providers, results): if isinstance(result, BaseException): masked_error = self._mask_sensitive_text(str(result)) - logger.warning( - "跨提供商查找模型时,提供商 %s 的 get_models() 失败: %s", - provider.meta().id, - masked_error, - ) + failed_provider_errors.append((provider.meta().id, masked_error)) continue provider_id = provider.meta().id if model_name in result: return provider, provider_id - if results and all(isinstance(r, BaseException) for r in results): + if failed_provider_errors and len(failed_provider_errors) == len(all_providers): + failed_ids = ",".join( + provider_id for provider_id, _ in failed_provider_errors + ) logger.error( - "跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败,请检查配置或网络", + "跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络", model_name, len(all_providers), + failed_ids, + ) + elif failed_provider_errors: + logger.debug( + "跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s", + model_name, + len(failed_provider_errors), + ",".join( + f"{provider_id}({error})" + for provider_id, error in failed_provider_errors + ), ) return None, None @@ -340,6 +357,7 @@ async def model_ls( try: new_model = models[idx_or_name - 1] prov.set_model(new_model) + self._invalidate_provider_models_cache(prov.meta().id) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", @@ -353,9 +371,12 @@ async def model_ls( return else: # 字符串:模型名,需智能解析是否跨提供商 - model_name = idx_or_name + model_name = idx_or_name.strip() umo = message.unified_msg_origin curr_provider_id = prov.meta().id + if not model_name: + message.set_result(MessageEventResult().message("模型名不能为空。")) + return # 1. 检查当前提供商 models = [] @@ -376,6 +397,7 @@ async def model_ls( return if model_name in models: prov.set_model(model_name) + self._invalidate_provider_models_cache(curr_provider_id) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]", @@ -388,17 +410,26 @@ async def model_ls( model_name, exclude_provider_id=curr_provider_id ) if target_prov and target_id: - await self.context.provider_manager.set_provider( - provider_id=target_id, - provider_type=ProviderType.CHAT_COMPLETION, - umo=umo, - ) - target_prov.set_model(model_name) - message.set_result( - MessageEventResult().message( - f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", - ), - ) + try: + await self.context.provider_manager.set_provider( + provider_id=target_id, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + target_prov.set_model(model_name) + self._invalidate_provider_models_cache(target_id) + message.set_result( + MessageEventResult().message( + f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", + ), + ) + except BaseException as e: + err_msg = self._mask_sensitive_text(str(e)) + message.set_result( + MessageEventResult().message( + "跨提供商切换并设置模型失败: " + err_msg + ), + ) else: message.set_result( MessageEventResult().message( @@ -435,6 +466,7 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None try: new_key = keys_data[index - 1] prov.set_key(new_key) + self._invalidate_provider_models_cache(prov.meta().id) except BaseException as e: message.set_result( MessageEventResult().message(f"切换 Key 未知错误: {e!s}"), From 5371267283aff2c4bbbb89689db89f3d6b2c4899 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:15:08 +0900 Subject: [PATCH 06/28] refactor: simplify model switch lookup flow --- .../builtin_commands/commands/provider.py | 179 ++++++++++-------- 1 file changed, 99 insertions(+), 80 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 32989ec5d6..4612fdb1ea 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -11,7 +11,11 @@ if TYPE_CHECKING: from astrbot.core.provider.provider import Provider -_API_KEY_PATTERN = re.compile(r"key=[^&'\" ]+") +_API_KEY_PATTERN = re.compile(r"(?i)(api_?key|key)=[^&'\" ]+") + + +class _AllProvidersModelFetchFailedError(RuntimeError): + pass class ProviderCommands: @@ -19,7 +23,7 @@ class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context - self._provider_models_cache: dict[str, tuple[float, list[str]]] = {} + self._provider_models_cache: dict[str, tuple[float, tuple[str, ...]]] = {} def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: if provider_id is None: @@ -31,6 +35,9 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N def _mask_sensitive_text(value: str) -> str: return _API_KEY_PATTERN.sub("key=***", value) + def _safe_err(self, e: BaseException) -> str: + return self._mask_sensitive_text(str(e)) + async def _get_provider_models( self, provider: "Provider", *, use_cache: bool = True ) -> list[str]: @@ -42,8 +49,8 @@ async def _get_provider_models( return list(cached[1]) models = list(await provider.get_models()) - self._provider_models_cache[provider_id] = (now, models) - return list(models) + self._provider_models_cache[provider_id] = (now, tuple(models)) + return models def _log_reachability_failure( self, @@ -80,15 +87,15 @@ async def _test_provider_capability(self, provider): async def _find_provider_for_model( self, model_name: str, exclude_provider_id: str | None = None - ) -> tuple["Provider" | None, str | None]: - """在所有 LLM 提供商中查找包含指定模型的提供商。返回 (provider, provider_id) 或 (None, None)。""" + ) -> "Provider | None": + """在所有 LLM 提供商中查找包含指定模型的提供商。""" all_providers = [ p for p in self.context.get_all_providers() if not exclude_provider_id or p.meta().id != exclude_provider_id ] if not all_providers: - return None, None + return None results = await asyncio.gather( *[self._get_provider_models(p) for p in all_providers], return_exceptions=True, @@ -96,12 +103,11 @@ async def _find_provider_for_model( failed_provider_errors: list[tuple[str, str]] = [] for provider, result in zip(all_providers, results): if isinstance(result, BaseException): - masked_error = self._mask_sensitive_text(str(result)) + masked_error = self._safe_err(result) failed_provider_errors.append((provider.meta().id, masked_error)) continue - provider_id = provider.meta().id if model_name in result: - return provider, provider_id + return provider if failed_provider_errors and len(failed_provider_errors) == len(all_providers): failed_ids = ",".join( provider_id for provider_id, _ in failed_provider_errors @@ -112,6 +118,9 @@ async def _find_provider_for_model( len(all_providers), failed_ids, ) + raise _AllProvidersModelFetchFailedError( + f"all providers failed to fetch models: {failed_ids}" + ) elif failed_provider_errors: logger.debug( "跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s", @@ -122,7 +131,7 @@ async def _find_provider_for_model( for provider_id, error in failed_provider_errors ), ) - return None, None + return None async def provider( self, @@ -304,6 +313,81 @@ async def provider( else: event.set_result(MessageEventResult().message("无效的参数。")) + async def _switch_model_by_name( + self, message: AstrMessageEvent, model_name: str, prov: "Provider" + ) -> None: + model_name = model_name.strip() + if not model_name: + message.set_result(MessageEventResult().message("模型名不能为空。")) + return + + umo = message.unified_msg_origin + curr_provider_id = prov.meta().id + + try: + models = await self._get_provider_models(prov) + except BaseException as e: + err_msg = self._safe_err(e) + logger.warning( + "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", + curr_provider_id, + err_msg, + ) + message.set_result( + MessageEventResult().message("获取当前提供商模型列表失败: " + err_msg), + ) + return + + if model_name in models: + prov.set_model(model_name) + self._invalidate_provider_models_cache(curr_provider_id) + message.set_result( + MessageEventResult().message( + f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]", + ), + ) + return + + try: + target_prov = await self._find_provider_for_model( + model_name, exclude_provider_id=curr_provider_id + ) + except _AllProvidersModelFetchFailedError: + message.set_result( + MessageEventResult().message( + "跨提供商查询模型失败:所有提供商的模型列表均获取失败,请检查提供商配置或网络后重试。", + ), + ) + return + + if not target_prov: + message.set_result( + MessageEventResult().message( + f"模型 [{model_name}] 未在任何已配置的提供商中找到。请使用 /provider 切换到目标提供商,或确认模型名正确。", + ), + ) + return + + target_id = target_prov.meta().id + try: + await self.context.provider_manager.set_provider( + provider_id=target_id, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + target_prov.set_model(model_name) + self._invalidate_provider_models_cache(target_id) + message.set_result( + MessageEventResult().message( + f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", + ), + ) + except BaseException as e: + err_msg = self._safe_err(e) + message.set_result( + MessageEventResult().message("跨提供商切换并设置模型失败: " + err_msg), + ) + async def model_ls( self, message: AstrMessageEvent, @@ -322,7 +406,7 @@ async def model_ls( try: models = await self._get_provider_models(prov) except BaseException as e: - err_msg = self._mask_sensitive_text(str(e)) + err_msg = self._safe_err(e) message.set_result( MessageEventResult() .message("获取模型列表失败: " + err_msg) @@ -346,7 +430,7 @@ async def model_ls( try: models = await self._get_provider_models(prov) except BaseException as e: - err_msg = self._mask_sensitive_text(str(e)) + err_msg = self._safe_err(e) message.set_result( MessageEventResult().message("获取模型列表失败: " + err_msg), ) @@ -364,78 +448,13 @@ async def model_ls( ), ) except BaseException as e: - err_msg = self._mask_sensitive_text(str(e)) + err_msg = self._safe_err(e) message.set_result( MessageEventResult().message("切换模型未知错误: " + err_msg), ) return else: - # 字符串:模型名,需智能解析是否跨提供商 - model_name = idx_or_name.strip() - umo = message.unified_msg_origin - curr_provider_id = prov.meta().id - if not model_name: - message.set_result(MessageEventResult().message("模型名不能为空。")) - return - - # 1. 检查当前提供商 - models = [] - try: - models = await self._get_provider_models(prov) - except BaseException as e: - err_msg = self._mask_sensitive_text(str(e)) - logger.warning( - "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", - curr_provider_id, - err_msg, - ) - message.set_result( - MessageEventResult().message( - "获取当前提供商模型列表失败: " + err_msg - ), - ) - return - if model_name in models: - prov.set_model(model_name) - self._invalidate_provider_models_cache(curr_provider_id) - message.set_result( - MessageEventResult().message( - f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]", - ), - ) - return - - # 2. 在其他提供商中查找 - target_prov, target_id = await self._find_provider_for_model( - model_name, exclude_provider_id=curr_provider_id - ) - if target_prov and target_id: - try: - await self.context.provider_manager.set_provider( - provider_id=target_id, - provider_type=ProviderType.CHAT_COMPLETION, - umo=umo, - ) - target_prov.set_model(model_name) - self._invalidate_provider_models_cache(target_id) - message.set_result( - MessageEventResult().message( - f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", - ), - ) - except BaseException as e: - err_msg = self._mask_sensitive_text(str(e)) - message.set_result( - MessageEventResult().message( - "跨提供商切换并设置模型失败: " + err_msg - ), - ) - else: - message.set_result( - MessageEventResult().message( - f"模型 [{model_name}] 未在任何已配置的提供商中找到。请使用 /provider 切换到目标提供商,或确认模型名正确。", - ), - ) + await self._switch_model_by_name(message, idx_or_name, prov) async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: prov = self.context.get_using_provider(message.unified_msg_origin) From 281c44fb18c3a75707780399d72020bc0f37c633 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:19:18 +0900 Subject: [PATCH 07/28] refactor: streamline provider model cache updates --- .../builtin_commands/commands/provider.py | 90 +++++++++++++------ 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 4612fdb1ea..b5028fbc49 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -31,6 +31,17 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N return self._provider_models_cache.pop(provider_id, None) + def _invalidate_cache_for(self, provider: "Provider") -> None: + self._invalidate_provider_models_cache(provider.meta().id) + + def _set_model_and_invalidate(self, provider: "Provider", model_name: str) -> None: + provider.set_model(model_name) + self._invalidate_cache_for(provider) + + def _set_key_and_invalidate(self, provider: "Provider", key: str) -> None: + provider.set_key(key) + self._invalidate_cache_for(provider) + @staticmethod def _mask_sensitive_text(value: str) -> str: return _API_KEY_PATTERN.sub("key=***", value) @@ -38,6 +49,9 @@ def _mask_sensitive_text(value: str) -> str: def _safe_err(self, e: BaseException) -> str: return self._mask_sensitive_text(str(e)) + def _format_err(self, prefix: str, e: BaseException) -> str: + return f"{prefix}{self._safe_err(e)}" + async def _get_provider_models( self, provider: "Provider", *, use_cache: bool = True ) -> list[str]: @@ -96,18 +110,38 @@ async def _find_provider_for_model( ] if not all_providers: return None - results = await asyncio.gather( - *[self._get_provider_models(p) for p in all_providers], - return_exceptions=True, - ) + + async def _fetch_models( + provider: "Provider", + ) -> tuple["Provider", list[str] | None, BaseException | None]: + try: + return provider, await self._get_provider_models(provider), None + except BaseException as e: + return provider, None, e + + tasks = [ + asyncio.create_task(_fetch_models(provider)) for provider in all_providers + ] failed_provider_errors: list[tuple[str, str]] = [] - for provider, result in zip(all_providers, results): - if isinstance(result, BaseException): - masked_error = self._safe_err(result) + matched_provider: Provider | None = None + for task in asyncio.as_completed(tasks): + provider, models, error = await task + if error is not None: + masked_error = self._safe_err(error) failed_provider_errors.append((provider.meta().id, masked_error)) continue - if model_name in result: - return provider + + if models is not None and model_name in models: + matched_provider = provider + break + + if matched_provider is not None: + for task in tasks: + if not task.done(): + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + return matched_provider + if failed_provider_errors and len(failed_provider_errors) == len(all_providers): failed_ids = ",".join( provider_id for provider_id, _ in failed_provider_errors @@ -334,13 +368,14 @@ async def _switch_model_by_name( err_msg, ) message.set_result( - MessageEventResult().message("获取当前提供商模型列表失败: " + err_msg), + MessageEventResult().message( + self._format_err("获取当前提供商模型列表失败: ", e) + ) ) return if model_name in models: - prov.set_model(model_name) - self._invalidate_provider_models_cache(curr_provider_id) + self._set_model_and_invalidate(prov, model_name) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]", @@ -375,17 +410,17 @@ async def _switch_model_by_name( provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) - target_prov.set_model(model_name) - self._invalidate_provider_models_cache(target_id) + self._set_model_and_invalidate(target_prov, model_name) message.set_result( MessageEventResult().message( f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", ), ) except BaseException as e: - err_msg = self._safe_err(e) message.set_result( - MessageEventResult().message("跨提供商切换并设置模型失败: " + err_msg), + MessageEventResult().message( + self._format_err("跨提供商切换并设置模型失败: ", e) + ), ) async def model_ls( @@ -406,10 +441,9 @@ async def model_ls( try: models = await self._get_provider_models(prov) except BaseException as e: - err_msg = self._safe_err(e) message.set_result( MessageEventResult() - .message("获取模型列表失败: " + err_msg) + .message(self._format_err("获取模型列表失败: ", e)) .use_t2i(False), ) return @@ -430,9 +464,10 @@ async def model_ls( try: models = await self._get_provider_models(prov) except BaseException as e: - err_msg = self._safe_err(e) message.set_result( - MessageEventResult().message("获取模型列表失败: " + err_msg), + MessageEventResult().message( + self._format_err("获取模型列表失败: ", e) + ), ) return if idx_or_name > len(models) or idx_or_name < 1: @@ -440,17 +475,17 @@ async def model_ls( else: try: new_model = models[idx_or_name - 1] - prov.set_model(new_model) - self._invalidate_provider_models_cache(prov.meta().id) + self._set_model_and_invalidate(prov, new_model) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", ), ) except BaseException as e: - err_msg = self._safe_err(e) message.set_result( - MessageEventResult().message("切换模型未知错误: " + err_msg), + MessageEventResult().message( + self._format_err("切换模型未知错误: ", e) + ), ) return else: @@ -484,10 +519,11 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None else: try: new_key = keys_data[index - 1] - prov.set_key(new_key) - self._invalidate_provider_models_cache(prov.meta().id) + self._set_key_and_invalidate(prov, new_key) except BaseException as e: message.set_result( - MessageEventResult().message(f"切换 Key 未知错误: {e!s}"), + MessageEventResult().message( + self._format_err("切换 Key 未知错误: ", e) + ), ) message.set_result(MessageEventResult().message("切换 Key 成功。")) From ba1b1fff5b62e7d43518e59379d06c6b8026f02f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:24:56 +0900 Subject: [PATCH 08/28] fix: align provider annotations and key error flow --- .../builtin_commands/commands/provider.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index b5028fbc49..9c9cfbcde1 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import re import time @@ -31,14 +33,14 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N return self._provider_models_cache.pop(provider_id, None) - def _invalidate_cache_for(self, provider: "Provider") -> None: + def _invalidate_cache_for(self, provider: Provider) -> None: self._invalidate_provider_models_cache(provider.meta().id) - def _set_model_and_invalidate(self, provider: "Provider", model_name: str) -> None: + def _set_model_and_invalidate(self, provider: Provider, model_name: str) -> None: provider.set_model(model_name) self._invalidate_cache_for(provider) - def _set_key_and_invalidate(self, provider: "Provider", key: str) -> None: + def _set_key_and_invalidate(self, provider: Provider, key: str) -> None: provider.set_key(key) self._invalidate_cache_for(provider) @@ -53,7 +55,7 @@ def _format_err(self, prefix: str, e: BaseException) -> str: return f"{prefix}{self._safe_err(e)}" async def _get_provider_models( - self, provider: "Provider", *, use_cache: bool = True + self, provider: Provider, *, use_cache: bool = True ) -> list[str]: provider_id = provider.meta().id now = time.monotonic() @@ -101,7 +103,7 @@ async def _test_provider_capability(self, provider): async def _find_provider_for_model( self, model_name: str, exclude_provider_id: str | None = None - ) -> "Provider | None": + ) -> Provider | None: """在所有 LLM 提供商中查找包含指定模型的提供商。""" all_providers = [ p @@ -112,8 +114,8 @@ async def _find_provider_for_model( return None async def _fetch_models( - provider: "Provider", - ) -> tuple["Provider", list[str] | None, BaseException | None]: + provider: Provider, + ) -> tuple[Provider, list[str] | None, BaseException | None]: try: return provider, await self._get_provider_models(provider), None except BaseException as e: @@ -348,7 +350,7 @@ async def provider( event.set_result(MessageEventResult().message("无效的参数。")) async def _switch_model_by_name( - self, message: AstrMessageEvent, model_name: str, prov: "Provider" + self, message: AstrMessageEvent, model_name: str, prov: Provider ) -> None: model_name = model_name.strip() if not model_name: @@ -520,10 +522,11 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None try: new_key = keys_data[index - 1] self._set_key_and_invalidate(prov, new_key) + message.set_result(MessageEventResult().message("切换 Key 成功。")) except BaseException as e: message.set_result( MessageEventResult().message( self._format_err("切换 Key 未知错误: ", e) ), ) - message.set_result(MessageEventResult().message("切换 Key 成功。")) + return From d062abff1d0dba9f48d7cb46dfb42dbcf360e406 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:28:37 +0900 Subject: [PATCH 09/28] fix: narrow provider command exception handling --- .../builtin_commands/commands/provider.py | 75 ++++++++++++++----- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 9c9cfbcde1..079bceccb8 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -48,24 +48,48 @@ def _set_key_and_invalidate(self, provider: Provider, key: str) -> None: def _mask_sensitive_text(value: str) -> str: return _API_KEY_PATTERN.sub("key=***", value) - def _safe_err(self, e: BaseException) -> str: + def _safe_err(self, e: Exception) -> str: return self._mask_sensitive_text(str(e)) - def _format_err(self, prefix: str, e: BaseException) -> str: + def _format_err(self, prefix: str, e: Exception) -> str: return f"{prefix}{self._safe_err(e)}" + def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float: + ttl = self._MODEL_LIST_CACHE_TTL_SECONDS + if not umo: + return ttl + try: + cfg = self.context.get_config(umo).get("provider_settings", {}) + configured_ttl = cfg.get("model_list_cache_ttl_seconds") + if configured_ttl is not None: + ttl = float(configured_ttl) + except Exception as e: + logger.debug( + "读取 model_list_cache_ttl_seconds 失败,回退默认值 %.1f: %s", + self._MODEL_LIST_CACHE_TTL_SECONDS, + e, + ) + ttl = self._MODEL_LIST_CACHE_TTL_SECONDS + return max(ttl, 0.0) + async def _get_provider_models( - self, provider: Provider, *, use_cache: bool = True + self, + provider: Provider, + *, + use_cache: bool = True, + umo: str | None = None, ) -> list[str]: provider_id = provider.meta().id now = time.monotonic() - if use_cache: + ttl_seconds = self._get_model_cache_ttl_seconds(umo) + if use_cache and ttl_seconds > 0: cached = self._provider_models_cache.get(provider_id) - if cached and now - cached[0] <= self._MODEL_LIST_CACHE_TTL_SECONDS: + if cached and now - cached[0] <= ttl_seconds: return list(cached[1]) models = list(await provider.get_models()) - self._provider_models_cache[provider_id] = (now, tuple(models)) + if use_cache and ttl_seconds > 0: + self._provider_models_cache[provider_id] = (now, tuple(models)) return models def _log_reachability_failure( @@ -102,7 +126,10 @@ async def _test_provider_capability(self, provider): return False, err_code, err_reason async def _find_provider_for_model( - self, model_name: str, exclude_provider_id: str | None = None + self, + model_name: str, + exclude_provider_id: str | None = None, + umo: str | None = None, ) -> Provider | None: """在所有 LLM 提供商中查找包含指定模型的提供商。""" all_providers = [ @@ -115,10 +142,14 @@ async def _find_provider_for_model( async def _fetch_models( provider: Provider, - ) -> tuple[Provider, list[str] | None, BaseException | None]: + ) -> tuple[Provider, list[str] | None, Exception | None]: try: - return provider, await self._get_provider_models(provider), None - except BaseException as e: + return ( + provider, + await self._get_provider_models(provider, umo=umo), + None, + ) + except Exception as e: return provider, None, e tasks = [ @@ -361,8 +392,8 @@ async def _switch_model_by_name( curr_provider_id = prov.meta().id try: - models = await self._get_provider_models(prov) - except BaseException as e: + models = await self._get_provider_models(prov, umo=umo) + except Exception as e: err_msg = self._safe_err(e) logger.warning( "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", @@ -387,7 +418,7 @@ async def _switch_model_by_name( try: target_prov = await self._find_provider_for_model( - model_name, exclude_provider_id=curr_provider_id + model_name, exclude_provider_id=curr_provider_id, umo=umo ) except _AllProvidersModelFetchFailedError: message.set_result( @@ -418,7 +449,7 @@ async def _switch_model_by_name( f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", ), ) - except BaseException as e: + except Exception as e: message.set_result( MessageEventResult().message( self._format_err("跨提供商切换并设置模型失败: ", e) @@ -441,8 +472,10 @@ async def model_ls( if idx_or_name is None: models = [] try: - models = await self._get_provider_models(prov) - except BaseException as e: + models = await self._get_provider_models( + prov, umo=message.unified_msg_origin + ) + except Exception as e: message.set_result( MessageEventResult() .message(self._format_err("获取模型列表失败: ", e)) @@ -464,8 +497,10 @@ async def model_ls( elif isinstance(idx_or_name, int): models = [] try: - models = await self._get_provider_models(prov) - except BaseException as e: + models = await self._get_provider_models( + prov, umo=message.unified_msg_origin + ) + except Exception as e: message.set_result( MessageEventResult().message( self._format_err("获取模型列表失败: ", e) @@ -483,7 +518,7 @@ async def model_ls( f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", ), ) - except BaseException as e: + except Exception as e: message.set_result( MessageEventResult().message( self._format_err("切换模型未知错误: ", e) @@ -523,7 +558,7 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None new_key = keys_data[index - 1] self._set_key_and_invalidate(prov, new_key) message.set_result(MessageEventResult().message("切换 Key 成功。")) - except BaseException as e: + except Exception as e: message.set_result( MessageEventResult().message( self._format_err("切换 Key 未知错误: ", e) From 1d611d421717ce2e690f86c0fb9be5abdb875769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:35:54 +0900 Subject: [PATCH 10/28] refactor: harden provider command error redaction and flow --- .../builtin_commands/commands/provider.py | 122 ++++++++---------- 1 file changed, 53 insertions(+), 69 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 079bceccb8..446262a861 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -3,6 +3,7 @@ import asyncio import re import time +from dataclasses import dataclass from typing import TYPE_CHECKING from astrbot import logger @@ -13,11 +14,27 @@ if TYPE_CHECKING: from astrbot.core.provider.provider import Provider -_API_KEY_PATTERN = re.compile(r"(?i)(api_?key|key)=[^&'\" ]+") +_SECRET_PATTERNS = [ + re.compile( + r"(?i)\b(api_?key|access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+" + ), + re.compile(r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+"), + re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+"), + re.compile(r"\bsk-[A-Za-z0-9]{16,}\b"), +] -class _AllProvidersModelFetchFailedError(RuntimeError): - pass +def redact_secrets(text: str) -> str: + redacted = text + for pattern in _SECRET_PATTERNS: + redacted = pattern.sub("[REDACTED]", redacted) + return redacted + + +@dataclass +class _ModelCacheEntry: + timestamp: float + models: tuple[str, ...] class ProviderCommands: @@ -25,7 +42,7 @@ class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context - self._provider_models_cache: dict[str, tuple[float, tuple[str, ...]]] = {} + self._provider_models_cache: dict[str, _ModelCacheEntry] = {} def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: if provider_id is None: @@ -46,13 +63,10 @@ def _set_key_and_invalidate(self, provider: Provider, key: str) -> None: @staticmethod def _mask_sensitive_text(value: str) -> str: - return _API_KEY_PATTERN.sub("key=***", value) - - def _safe_err(self, e: Exception) -> str: - return self._mask_sensitive_text(str(e)) + return redact_secrets(value) - def _format_err(self, prefix: str, e: Exception) -> str: - return f"{prefix}{self._safe_err(e)}" + def _format_safe_err(self, prefix: str, e: Exception) -> str: + return f"{prefix}{self._mask_sensitive_text(str(e))}" def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float: ttl = self._MODEL_LIST_CACHE_TTL_SECONDS @@ -67,7 +81,7 @@ def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float: logger.debug( "读取 model_list_cache_ttl_seconds 失败,回退默认值 %.1f: %s", self._MODEL_LIST_CACHE_TTL_SECONDS, - e, + self._mask_sensitive_text(str(e)), ) ttl = self._MODEL_LIST_CACHE_TTL_SECONDS return max(ttl, 0.0) @@ -84,12 +98,15 @@ async def _get_provider_models( ttl_seconds = self._get_model_cache_ttl_seconds(umo) if use_cache and ttl_seconds > 0: cached = self._provider_models_cache.get(provider_id) - if cached and now - cached[0] <= ttl_seconds: - return list(cached[1]) + if cached and now - cached.timestamp <= ttl_seconds: + return list(cached.models) models = list(await provider.get_models()) if use_cache and ttl_seconds > 0: - self._provider_models_cache[provider_id] = (now, tuple(models)) + self._provider_models_cache[provider_id] = _ModelCacheEntry( + timestamp=now, + models=tuple(models), + ) return models def _log_reachability_failure( @@ -119,7 +136,7 @@ async def _test_provider_capability(self, provider): return True, None, None except Exception as e: err_code = "TEST_FAILED" - err_reason = str(e) + err_reason = self._mask_sensitive_text(str(e)) self._log_reachability_failure( provider, provider_capability_type, err_code, err_reason ) @@ -139,41 +156,19 @@ async def _find_provider_for_model( ] if not all_providers: return None - - async def _fetch_models( - provider: Provider, - ) -> tuple[Provider, list[str] | None, Exception | None]: + failed_provider_errors: list[tuple[str, str]] = [] + for provider in all_providers: + provider_id = provider.meta().id try: - return ( - provider, - await self._get_provider_models(provider, umo=umo), - None, - ) + models = await self._get_provider_models(provider, umo=umo) except Exception as e: - return provider, None, e - - tasks = [ - asyncio.create_task(_fetch_models(provider)) for provider in all_providers - ] - failed_provider_errors: list[tuple[str, str]] = [] - matched_provider: Provider | None = None - for task in asyncio.as_completed(tasks): - provider, models, error = await task - if error is not None: - masked_error = self._safe_err(error) - failed_provider_errors.append((provider.meta().id, masked_error)) + failed_provider_errors.append( + (provider_id, self._format_safe_err("", e)) + ) continue - if models is not None and model_name in models: - matched_provider = provider - break - - if matched_provider is not None: - for task in tasks: - if not task.done(): - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - return matched_provider + if model_name in models: + return provider if failed_provider_errors and len(failed_provider_errors) == len(all_providers): failed_ids = ",".join( @@ -185,9 +180,6 @@ async def _fetch_models( len(all_providers), failed_ids, ) - raise _AllProvidersModelFetchFailedError( - f"all providers failed to fetch models: {failed_ids}" - ) elif failed_provider_errors: logger.debug( "跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s", @@ -254,7 +246,7 @@ async def provider( p, None, reachable.__class__.__name__, - str(reachable), + self._mask_sensitive_text(str(reachable)), ) reachable_flag = False error_code = reachable.__class__.__name__ @@ -394,7 +386,7 @@ async def _switch_model_by_name( try: models = await self._get_provider_models(prov, umo=umo) except Exception as e: - err_msg = self._safe_err(e) + err_msg = self._format_safe_err("", e) logger.warning( "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", curr_provider_id, @@ -402,7 +394,7 @@ async def _switch_model_by_name( ) message.set_result( MessageEventResult().message( - self._format_err("获取当前提供商模型列表失败: ", e) + self._format_safe_err("获取当前提供商模型列表失败: ", e) ) ) return @@ -416,22 +408,14 @@ async def _switch_model_by_name( ) return - try: - target_prov = await self._find_provider_for_model( - model_name, exclude_provider_id=curr_provider_id, umo=umo - ) - except _AllProvidersModelFetchFailedError: - message.set_result( - MessageEventResult().message( - "跨提供商查询模型失败:所有提供商的模型列表均获取失败,请检查提供商配置或网络后重试。", - ), - ) - return + target_prov = await self._find_provider_for_model( + model_name, exclude_provider_id=curr_provider_id, umo=umo + ) if not target_prov: message.set_result( MessageEventResult().message( - f"模型 [{model_name}] 未在任何已配置的提供商中找到。请使用 /provider 切换到目标提供商,或确认模型名正确。", + f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。", ), ) return @@ -452,7 +436,7 @@ async def _switch_model_by_name( except Exception as e: message.set_result( MessageEventResult().message( - self._format_err("跨提供商切换并设置模型失败: ", e) + self._format_safe_err("跨提供商切换并设置模型失败: ", e) ), ) @@ -478,7 +462,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult() - .message(self._format_err("获取模型列表失败: ", e)) + .message(self._format_safe_err("获取模型列表失败: ", e)) .use_t2i(False), ) return @@ -503,7 +487,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._format_err("获取模型列表失败: ", e) + self._format_safe_err("获取模型列表失败: ", e) ), ) return @@ -521,7 +505,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._format_err("切换模型未知错误: ", e) + self._format_safe_err("切换模型未知错误: ", e) ), ) return @@ -561,7 +545,7 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None except Exception as e: message.set_result( MessageEventResult().message( - self._format_err("切换 Key 未知错误: ", e) + self._format_safe_err("切换 Key 未知错误: ", e) ), ) return From 5293aef6a35d432224d88b389ef5dbfbc4e7e846 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:41:08 +0900 Subject: [PATCH 11/28] fix: improve provider model lookup and secret redaction --- .../builtin_commands/commands/provider.py | 135 ++++++++++++------ 1 file changed, 90 insertions(+), 45 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 446262a861..5ee7adf194 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -16,7 +16,7 @@ _SECRET_PATTERNS = [ re.compile( - r"(?i)\b(api_?key|access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+" + r"(?i)\b(api_?key|key|access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+" ), re.compile(r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+"), re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+"), @@ -39,6 +39,7 @@ class _ModelCacheEntry: class ProviderCommands: _MODEL_LIST_CACHE_TTL_SECONDS = 30.0 + _MODEL_LOOKUP_MAX_CONCURRENCY = 4 def __init__(self, context: star.Context) -> None: self.context = context @@ -50,23 +51,47 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N return self._provider_models_cache.pop(provider_id, None) - def _invalidate_cache_for(self, provider: Provider) -> None: + def _update_provider_and_invalidate( + self, + provider: Provider, + *, + model_name: str | None = None, + key: str | None = None, + ) -> None: + if model_name is not None: + provider.set_model(model_name) + if key is not None: + provider.set_key(key) self._invalidate_provider_models_cache(provider.meta().id) - def _set_model_and_invalidate(self, provider: Provider, model_name: str) -> None: - provider.set_model(model_name) - self._invalidate_cache_for(provider) - - def _set_key_and_invalidate(self, provider: Provider, key: str) -> None: - provider.set_key(key) - self._invalidate_cache_for(provider) + @staticmethod + def _safe_err(prefix: str, e: Exception) -> str: + return prefix + redact_secrets(str(e)) @staticmethod - def _mask_sensitive_text(value: str) -> str: - return redact_secrets(value) + def _normalize_model_name(model_name: str) -> str: + return model_name.strip().casefold() - def _format_safe_err(self, prefix: str, e: Exception) -> str: - return f"{prefix}{self._mask_sensitive_text(str(e))}" + def _resolve_model_name(self, model_name: str, models: list[str]) -> str | None: + normalized_model_name = self._normalize_model_name(model_name) + if not normalized_model_name: + return None + if model_name in models: + return model_name + + for candidate in models: + normalized_candidate = self._normalize_model_name(candidate) + if normalized_candidate == normalized_model_name: + return candidate + if normalized_candidate.endswith( + f"/{normalized_model_name}" + ) or normalized_candidate.endswith(f":{normalized_model_name}"): + return candidate + if normalized_model_name.endswith( + f"/{normalized_candidate}" + ) or normalized_model_name.endswith(f":{normalized_candidate}"): + return candidate + return None def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float: ttl = self._MODEL_LIST_CACHE_TTL_SECONDS @@ -81,7 +106,7 @@ def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float: logger.debug( "读取 model_list_cache_ttl_seconds 失败,回退默认值 %.1f: %s", self._MODEL_LIST_CACHE_TTL_SECONDS, - self._mask_sensitive_text(str(e)), + redact_secrets(str(e)), ) ttl = self._MODEL_LIST_CACHE_TTL_SECONDS return max(ttl, 0.0) @@ -136,7 +161,7 @@ async def _test_provider_capability(self, provider): return True, None, None except Exception as e: err_code = "TEST_FAILED" - err_reason = self._mask_sensitive_text(str(e)) + err_reason = redact_secrets(str(e)) self._log_reachability_failure( provider, provider_capability_type, err_code, err_reason ) @@ -147,7 +172,7 @@ async def _find_provider_for_model( model_name: str, exclude_provider_id: str | None = None, umo: str | None = None, - ) -> Provider | None: + ) -> tuple[Provider | None, str | None]: """在所有 LLM 提供商中查找包含指定模型的提供商。""" all_providers = [ p @@ -155,20 +180,37 @@ async def _find_provider_for_model( if not exclude_provider_id or p.meta().id != exclude_provider_id ] if not all_providers: - return None + return None, None + + semaphore = asyncio.Semaphore(self._MODEL_LOOKUP_MAX_CONCURRENCY) + + async def _fetch_models( + provider: Provider, + ) -> tuple[Provider, list[str] | None, Exception | None]: + async with semaphore: + try: + return ( + provider, + await self._get_provider_models(provider, umo=umo), + None, + ) + except Exception as e: + return provider, None, e + + results = await asyncio.gather( + *[_fetch_models(provider) for provider in all_providers] + ) failed_provider_errors: list[tuple[str, str]] = [] - for provider in all_providers: + for provider, models, error in results: provider_id = provider.meta().id - try: - models = await self._get_provider_models(provider, umo=umo) - except Exception as e: - failed_provider_errors.append( - (provider_id, self._format_safe_err("", e)) - ) + if error is not None: + failed_provider_errors.append((provider_id, self._safe_err("", error))) continue - - if model_name in models: - return provider + if models is None: + continue + matched_model_name = self._resolve_model_name(model_name, models) + if matched_model_name is not None: + return provider, matched_model_name if failed_provider_errors and len(failed_provider_errors) == len(all_providers): failed_ids = ",".join( @@ -190,7 +232,7 @@ async def _find_provider_for_model( for provider_id, error in failed_provider_errors ), ) - return None + return None, None async def provider( self, @@ -246,7 +288,7 @@ async def provider( p, None, reachable.__class__.__name__, - self._mask_sensitive_text(str(reachable)), + redact_secrets(str(reachable)), ) reachable_flag = False error_code = reachable.__class__.__name__ @@ -386,7 +428,7 @@ async def _switch_model_by_name( try: models = await self._get_provider_models(prov, umo=umo) except Exception as e: - err_msg = self._format_safe_err("", e) + err_msg = self._safe_err("", e) logger.warning( "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", curr_provider_id, @@ -394,25 +436,26 @@ async def _switch_model_by_name( ) message.set_result( MessageEventResult().message( - self._format_safe_err("获取当前提供商模型列表失败: ", e) + self._safe_err("获取当前提供商模型列表失败: ", e) ) ) return - if model_name in models: - self._set_model_and_invalidate(prov, model_name) + matched_model_name = self._resolve_model_name(model_name, models) + if matched_model_name is not None: + self._update_provider_and_invalidate(prov, model_name=matched_model_name) message.set_result( MessageEventResult().message( - f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]", + f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{matched_model_name}]", ), ) return - target_prov = await self._find_provider_for_model( + target_prov, matched_target_model_name = await self._find_provider_for_model( model_name, exclude_provider_id=curr_provider_id, umo=umo ) - if not target_prov: + if target_prov is None or matched_target_model_name is None: message.set_result( MessageEventResult().message( f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。", @@ -427,16 +470,18 @@ async def _switch_model_by_name( provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) - self._set_model_and_invalidate(target_prov, model_name) + self._update_provider_and_invalidate( + target_prov, model_name=matched_target_model_name + ) message.set_result( MessageEventResult().message( - f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", + f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", ), ) except Exception as e: message.set_result( MessageEventResult().message( - self._format_safe_err("跨提供商切换并设置模型失败: ", e) + self._safe_err("跨提供商切换并设置模型失败: ", e) ), ) @@ -462,7 +507,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult() - .message(self._format_safe_err("获取模型列表失败: ", e)) + .message(self._safe_err("获取模型列表失败: ", e)) .use_t2i(False), ) return @@ -487,7 +532,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._format_safe_err("获取模型列表失败: ", e) + self._safe_err("获取模型列表失败: ", e) ), ) return @@ -496,7 +541,7 @@ async def model_ls( else: try: new_model = models[idx_or_name - 1] - self._set_model_and_invalidate(prov, new_model) + self._update_provider_and_invalidate(prov, model_name=new_model) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", @@ -505,7 +550,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._format_safe_err("切换模型未知错误: ", e) + self._safe_err("切换模型未知错误: ", e) ), ) return @@ -540,12 +585,12 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None else: try: new_key = keys_data[index - 1] - self._set_key_and_invalidate(prov, new_key) + self._update_provider_and_invalidate(prov, key=new_key) message.set_result(MessageEventResult().message("切换 Key 成功。")) except Exception as e: message.set_result( MessageEventResult().message( - self._format_safe_err("切换 Key 未知错误: ", e) + self._safe_err("切换 Key 未知错误: ", e) ), ) return From 418a405d9bef294af9e6dcf5955ad0cb7dd9fc1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:46:06 +0900 Subject: [PATCH 12/28] refactor: cache normalized model names in provider lookup --- .../builtin_commands/commands/provider.py | 128 ++++++++++-------- 1 file changed, 75 insertions(+), 53 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 5ee7adf194..cb94d579fc 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -14,6 +14,8 @@ if TYPE_CHECKING: from astrbot.core.provider.provider import Provider +_MODEL_LIST_CACHE_TTL_CONFIG_KEY = "model_list_cache_ttl_seconds" + _SECRET_PATTERNS = [ re.compile( r"(?i)\b(api_?key|key|access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+" @@ -35,6 +37,7 @@ def redact_secrets(text: str) -> str: class _ModelCacheEntry: timestamp: float models: tuple[str, ...] + normalized_models: tuple[tuple[str, str], ...] class ProviderCommands: @@ -51,36 +54,31 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N return self._provider_models_cache.pop(provider_id, None) - def _update_provider_and_invalidate( - self, - provider: Provider, - *, - model_name: str | None = None, - key: str | None = None, - ) -> None: - if model_name is not None: - provider.set_model(model_name) - if key is not None: - provider.set_key(key) - self._invalidate_provider_models_cache(provider.meta().id) - @staticmethod - def _safe_err(prefix: str, e: Exception) -> str: + def _format_safe_error(prefix: str, e: Exception) -> str: return prefix + redact_secrets(str(e)) @staticmethod def _normalize_model_name(model_name: str) -> str: return model_name.strip().casefold() - def _resolve_model_name(self, model_name: str, models: list[str]) -> str | None: + def _build_normalized_model_index( + self, models: tuple[str, ...] + ) -> tuple[tuple[str, str], ...]: + return tuple((model, self._normalize_model_name(model)) for model in models) + + def _resolve_model_name( + self, + model_name: str, + normalized_models: tuple[tuple[str, str], ...], + ) -> str | None: normalized_model_name = self._normalize_model_name(model_name) if not normalized_model_name: return None - if model_name in models: - return model_name - for candidate in models: - normalized_candidate = self._normalize_model_name(candidate) + for candidate, normalized_candidate in normalized_models: + if candidate == model_name: + return candidate if normalized_candidate == normalized_model_name: return candidate if normalized_candidate.endswith( @@ -99,40 +97,56 @@ def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float: return ttl try: cfg = self.context.get_config(umo).get("provider_settings", {}) - configured_ttl = cfg.get("model_list_cache_ttl_seconds") + configured_ttl = cfg.get(_MODEL_LIST_CACHE_TTL_CONFIG_KEY) if configured_ttl is not None: ttl = float(configured_ttl) except Exception as e: logger.debug( "读取 model_list_cache_ttl_seconds 失败,回退默认值 %.1f: %s", self._MODEL_LIST_CACHE_TTL_SECONDS, - redact_secrets(str(e)), + self._format_safe_error("", e), ) ttl = self._MODEL_LIST_CACHE_TTL_SECONDS return max(ttl, 0.0) - async def _get_provider_models( + async def _get_provider_model_entry( self, provider: Provider, *, use_cache: bool = True, umo: str | None = None, - ) -> list[str]: + ) -> _ModelCacheEntry: provider_id = provider.meta().id now = time.monotonic() ttl_seconds = self._get_model_cache_ttl_seconds(umo) if use_cache and ttl_seconds > 0: cached = self._provider_models_cache.get(provider_id) if cached and now - cached.timestamp <= ttl_seconds: - return list(cached.models) + return cached - models = list(await provider.get_models()) + models = tuple(await provider.get_models()) + entry = _ModelCacheEntry( + timestamp=now, + models=models, + normalized_models=self._build_normalized_model_index(models), + ) if use_cache and ttl_seconds > 0: - self._provider_models_cache[provider_id] = _ModelCacheEntry( - timestamp=now, - models=tuple(models), - ) - return models + self._provider_models_cache[provider_id] = entry + return entry + + async def _get_provider_models( + self, + provider: Provider, + *, + use_cache: bool = True, + umo: str | None = None, + ) -> list[str]: + entry = await self._get_provider_model_entry( + provider, + use_cache=use_cache, + umo=umo, + ) + return list(entry.models) def _log_reachability_failure( self, @@ -161,7 +175,7 @@ async def _test_provider_capability(self, provider): return True, None, None except Exception as e: err_code = "TEST_FAILED" - err_reason = redact_secrets(str(e)) + err_reason = self._format_safe_error("", e) self._log_reachability_failure( provider, provider_capability_type, err_code, err_reason ) @@ -186,12 +200,12 @@ async def _find_provider_for_model( async def _fetch_models( provider: Provider, - ) -> tuple[Provider, list[str] | None, Exception | None]: + ) -> tuple[Provider, _ModelCacheEntry | None, Exception | None]: async with semaphore: try: return ( provider, - await self._get_provider_models(provider, umo=umo), + await self._get_provider_model_entry(provider, umo=umo), None, ) except Exception as e: @@ -201,14 +215,18 @@ async def _fetch_models( *[_fetch_models(provider) for provider in all_providers] ) failed_provider_errors: list[tuple[str, str]] = [] - for provider, models, error in results: + for provider, model_entry, error in results: provider_id = provider.meta().id if error is not None: - failed_provider_errors.append((provider_id, self._safe_err("", error))) + failed_provider_errors.append( + (provider_id, self._format_safe_error("", error)) + ) continue - if models is None: + if model_entry is None: continue - matched_model_name = self._resolve_model_name(model_name, models) + matched_model_name = self._resolve_model_name( + model_name, model_entry.normalized_models + ) if matched_model_name is not None: return provider, matched_model_name @@ -288,7 +306,7 @@ async def provider( p, None, reachable.__class__.__name__, - redact_secrets(str(reachable)), + self._format_safe_error("", reachable), ) reachable_flag = False error_code = reachable.__class__.__name__ @@ -426,9 +444,9 @@ async def _switch_model_by_name( curr_provider_id = prov.meta().id try: - models = await self._get_provider_models(prov, umo=umo) + model_entry = await self._get_provider_model_entry(prov, umo=umo) except Exception as e: - err_msg = self._safe_err("", e) + err_msg = self._format_safe_error("", e) logger.warning( "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", curr_provider_id, @@ -436,14 +454,17 @@ async def _switch_model_by_name( ) message.set_result( MessageEventResult().message( - self._safe_err("获取当前提供商模型列表失败: ", e) + self._format_safe_error("获取当前提供商模型列表失败: ", e) ) ) return - matched_model_name = self._resolve_model_name(model_name, models) + matched_model_name = self._resolve_model_name( + model_name, model_entry.normalized_models + ) if matched_model_name is not None: - self._update_provider_and_invalidate(prov, model_name=matched_model_name) + prov.set_model(matched_model_name) + self._invalidate_provider_models_cache(curr_provider_id) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{matched_model_name}]", @@ -470,9 +491,8 @@ async def _switch_model_by_name( provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) - self._update_provider_and_invalidate( - target_prov, model_name=matched_target_model_name - ) + target_prov.set_model(matched_target_model_name) + self._invalidate_provider_models_cache(target_id) message.set_result( MessageEventResult().message( f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", @@ -481,7 +501,7 @@ async def _switch_model_by_name( except Exception as e: message.set_result( MessageEventResult().message( - self._safe_err("跨提供商切换并设置模型失败: ", e) + self._format_safe_error("跨提供商切换并设置模型失败: ", e) ), ) @@ -507,7 +527,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult() - .message(self._safe_err("获取模型列表失败: ", e)) + .message(self._format_safe_error("获取模型列表失败: ", e)) .use_t2i(False), ) return @@ -532,7 +552,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._safe_err("获取模型列表失败: ", e) + self._format_safe_error("获取模型列表失败: ", e) ), ) return @@ -541,7 +561,8 @@ async def model_ls( else: try: new_model = models[idx_or_name - 1] - self._update_provider_and_invalidate(prov, model_name=new_model) + prov.set_model(new_model) + self._invalidate_provider_models_cache(prov.meta().id) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", @@ -550,7 +571,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._safe_err("切换模型未知错误: ", e) + self._format_safe_error("切换模型未知错误: ", e) ), ) return @@ -585,12 +606,13 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None else: try: new_key = keys_data[index - 1] - self._update_provider_and_invalidate(prov, key=new_key) + prov.set_key(new_key) + self._invalidate_provider_models_cache(prov.meta().id) message.set_result(MessageEventResult().message("切换 Key 成功。")) except Exception as e: message.set_result( MessageEventResult().message( - self._safe_err("切换 Key 未知错误: ", e) + self._format_safe_error("切换 Key 未知错误: ", e) ), ) return From abe31a30d1d8a2a2a9ff407aedf08b39c4a8e69d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:50:54 +0900 Subject: [PATCH 13/28] refactor: simplify provider model lookup helpers --- .../builtin_commands/commands/provider.py | 157 ++++++++---------- 1 file changed, 70 insertions(+), 87 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index cb94d579fc..f5f826b136 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -3,6 +3,7 @@ import asyncio import re import time +from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING @@ -15,29 +16,23 @@ from astrbot.core.provider.provider import Provider _MODEL_LIST_CACHE_TTL_CONFIG_KEY = "model_list_cache_ttl_seconds" +_MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY = "model_lookup_max_concurrency" _SECRET_PATTERNS = [ re.compile( - r"(?i)\b(api_?key|key|access_?token|token|secret|auth_?token|session_?id|password)\s*=\s*[^&'\" ]+" + r"(?i)\b(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)\s*=\s*[^&'\" ]+" ), + re.compile(r"(?i)([?&](?:api_?key|key|access_?token|auth_?token))=[^&'\" ]+"), re.compile(r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+"), re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+"), re.compile(r"\bsk-[A-Za-z0-9]{16,}\b"), ] -def redact_secrets(text: str) -> str: - redacted = text - for pattern in _SECRET_PATTERNS: - redacted = pattern.sub("[REDACTED]", redacted) - return redacted - - @dataclass class _ModelCacheEntry: timestamp: float models: tuple[str, ...] - normalized_models: tuple[tuple[str, str], ...] class ProviderCommands: @@ -55,28 +50,27 @@ def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> N self._provider_models_cache.pop(provider_id, None) @staticmethod - def _format_safe_error(prefix: str, e: Exception) -> str: - return prefix + redact_secrets(str(e)) + def _safe_error(prefix: str, e: Exception) -> str: + text = str(e) + for pattern in _SECRET_PATTERNS: + text = pattern.sub("[REDACTED]", text) + return prefix + text @staticmethod def _normalize_model_name(model_name: str) -> str: return model_name.strip().casefold() - def _build_normalized_model_index( - self, models: tuple[str, ...] - ) -> tuple[tuple[str, str], ...]: - return tuple((model, self._normalize_model_name(model)) for model in models) - def _resolve_model_name( self, model_name: str, - normalized_models: tuple[tuple[str, str], ...], + models: Sequence[str], ) -> str | None: normalized_model_name = self._normalize_model_name(model_name) if not normalized_model_name: return None - for candidate, normalized_candidate in normalized_models: + for candidate in models: + normalized_candidate = self._normalize_model_name(candidate) if candidate == model_name: return candidate if normalized_candidate == normalized_model_name: @@ -104,49 +98,52 @@ def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float: logger.debug( "读取 model_list_cache_ttl_seconds 失败,回退默认值 %.1f: %s", self._MODEL_LIST_CACHE_TTL_SECONDS, - self._format_safe_error("", e), + self._safe_error("", e), ) ttl = self._MODEL_LIST_CACHE_TTL_SECONDS return max(ttl, 0.0) - async def _get_provider_model_entry( + def _get_model_lookup_max_concurrency(self, umo: str | None = None) -> int: + concurrency = self._MODEL_LOOKUP_MAX_CONCURRENCY + if not umo: + return max(concurrency, 1) + try: + cfg = self.context.get_config(umo).get("provider_settings", {}) + configured_concurrency = cfg.get(_MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY) + if configured_concurrency is not None: + concurrency = int(configured_concurrency) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %d: %s", + _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, + self._MODEL_LOOKUP_MAX_CONCURRENCY, + self._safe_error("", e), + ) + concurrency = self._MODEL_LOOKUP_MAX_CONCURRENCY + return max(concurrency, 1) + + async def _get_provider_models( self, provider: Provider, *, use_cache: bool = True, umo: str | None = None, - ) -> _ModelCacheEntry: + ) -> list[str]: provider_id = provider.meta().id now = time.monotonic() ttl_seconds = self._get_model_cache_ttl_seconds(umo) if use_cache and ttl_seconds > 0: cached = self._provider_models_cache.get(provider_id) if cached and now - cached.timestamp <= ttl_seconds: - return cached + return list(cached.models) models = tuple(await provider.get_models()) - entry = _ModelCacheEntry( - timestamp=now, - models=models, - normalized_models=self._build_normalized_model_index(models), - ) if use_cache and ttl_seconds > 0: - self._provider_models_cache[provider_id] = entry - return entry - - async def _get_provider_models( - self, - provider: Provider, - *, - use_cache: bool = True, - umo: str | None = None, - ) -> list[str]: - entry = await self._get_provider_model_entry( - provider, - use_cache=use_cache, - umo=umo, - ) - return list(entry.models) + self._provider_models_cache[provider_id] = _ModelCacheEntry( + timestamp=now, + models=models, + ) + return list(models) def _log_reachability_failure( self, @@ -175,7 +172,7 @@ async def _test_provider_capability(self, provider): return True, None, None except Exception as e: err_code = "TEST_FAILED" - err_reason = self._format_safe_error("", e) + err_reason = self._safe_error("", e) self._log_reachability_failure( provider, provider_capability_type, err_code, err_reason ) @@ -196,39 +193,27 @@ async def _find_provider_for_model( if not all_providers: return None, None - semaphore = asyncio.Semaphore(self._MODEL_LOOKUP_MAX_CONCURRENCY) - - async def _fetch_models( - provider: Provider, - ) -> tuple[Provider, _ModelCacheEntry | None, Exception | None]: - async with semaphore: - try: - return ( - provider, - await self._get_provider_model_entry(provider, umo=umo), - None, - ) - except Exception as e: - return provider, None, e - - results = await asyncio.gather( - *[_fetch_models(provider) for provider in all_providers] - ) failed_provider_errors: list[tuple[str, str]] = [] - for provider, model_entry, error in results: - provider_id = provider.meta().id - if error is not None: - failed_provider_errors.append( - (provider_id, self._format_safe_error("", error)) - ) - continue - if model_entry is None: - continue - matched_model_name = self._resolve_model_name( - model_name, model_entry.normalized_models + max_concurrency = self._get_model_lookup_max_concurrency(umo) + for start in range(0, len(all_providers), max_concurrency): + batch_providers = all_providers[start : start + max_concurrency] + batch_results = await asyncio.gather( + *[ + self._get_provider_models(provider, umo=umo) + for provider in batch_providers + ], + return_exceptions=True, ) - if matched_model_name is not None: - return provider, matched_model_name + for provider, result in zip(batch_providers, batch_results, strict=False): + provider_id = provider.meta().id + if isinstance(result, Exception): + failed_provider_errors.append( + (provider_id, self._safe_error("", result)) + ) + continue + matched_model_name = self._resolve_model_name(model_name, result) + if matched_model_name is not None: + return provider, matched_model_name if failed_provider_errors and len(failed_provider_errors) == len(all_providers): failed_ids = ",".join( @@ -306,7 +291,7 @@ async def provider( p, None, reachable.__class__.__name__, - self._format_safe_error("", reachable), + self._safe_error("", reachable), ) reachable_flag = False error_code = reachable.__class__.__name__ @@ -444,9 +429,9 @@ async def _switch_model_by_name( curr_provider_id = prov.meta().id try: - model_entry = await self._get_provider_model_entry(prov, umo=umo) + models = await self._get_provider_models(prov, umo=umo) except Exception as e: - err_msg = self._format_safe_error("", e) + err_msg = self._safe_error("", e) logger.warning( "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", curr_provider_id, @@ -454,14 +439,12 @@ async def _switch_model_by_name( ) message.set_result( MessageEventResult().message( - self._format_safe_error("获取当前提供商模型列表失败: ", e) + self._safe_error("获取当前提供商模型列表失败: ", e) ) ) return - matched_model_name = self._resolve_model_name( - model_name, model_entry.normalized_models - ) + matched_model_name = self._resolve_model_name(model_name, models) if matched_model_name is not None: prov.set_model(matched_model_name) self._invalidate_provider_models_cache(curr_provider_id) @@ -501,7 +484,7 @@ async def _switch_model_by_name( except Exception as e: message.set_result( MessageEventResult().message( - self._format_safe_error("跨提供商切换并设置模型失败: ", e) + self._safe_error("跨提供商切换并设置模型失败: ", e) ), ) @@ -527,7 +510,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult() - .message(self._format_safe_error("获取模型列表失败: ", e)) + .message(self._safe_error("获取模型列表失败: ", e)) .use_t2i(False), ) return @@ -552,7 +535,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._format_safe_error("获取模型列表失败: ", e) + self._safe_error("获取模型列表失败: ", e) ), ) return @@ -571,7 +554,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._format_safe_error("切换模型未知错误: ", e) + self._safe_error("切换模型未知错误: ", e) ), ) return @@ -612,7 +595,7 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None except Exception as e: message.set_result( MessageEventResult().message( - self._format_safe_error("切换 Key 未知错误: ", e) + self._safe_error("切换 Key 未知错误: ", e) ), ) return From 34235c6d3bc051b3d1922af4d106700ce8194346 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 17:56:16 +0900 Subject: [PATCH 14/28] refactor: extract provider model lookup helpers --- .../builtin_commands/commands/provider.py | 305 ++++++++++-------- astrbot/core/utils/error_redaction.py | 22 ++ 2 files changed, 189 insertions(+), 138 deletions(-) create mode 100644 astrbot/core/utils/error_redaction.py diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index f5f826b136..a465f66a0c 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -1,9 +1,8 @@ from __future__ import annotations import asyncio -import re import time -from collections.abc import Sequence +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING @@ -11,6 +10,7 @@ from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType +from astrbot.core.utils.error_redaction import safe_error if TYPE_CHECKING: from astrbot.core.provider.provider import Provider @@ -18,16 +18,6 @@ _MODEL_LIST_CACHE_TTL_CONFIG_KEY = "model_list_cache_ttl_seconds" _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY = "model_lookup_max_concurrency" -_SECRET_PATTERNS = [ - re.compile( - r"(?i)\b(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)\s*=\s*[^&'\" ]+" - ), - re.compile(r"(?i)([?&](?:api_?key|key|access_?token|auth_?token))=[^&'\" ]+"), - re.compile(r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+"), - re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+"), - re.compile(r"\bsk-[A-Za-z0-9]{16,}\b"), -] - @dataclass class _ModelCacheEntry: @@ -35,26 +25,157 @@ class _ModelCacheEntry: models: tuple[str, ...] +class _ModelListCache: + def __init__(self, context: star.Context) -> None: + self._context = context + self._cache: dict[str, _ModelCacheEntry] = {} + + def invalidate(self, provider_id: str | None = None) -> None: + if provider_id is None: + self._cache.clear() + return + self._cache.pop(provider_id, None) + + def _get_ttl_seconds(self, umo: str | None, default_ttl: float) -> float: + ttl = default_ttl + if not umo: + return max(ttl, 0.0) + try: + cfg = self._context.get_config(umo).get("provider_settings", {}) + configured_ttl = cfg.get(_MODEL_LIST_CACHE_TTL_CONFIG_KEY) + if configured_ttl is not None: + ttl = float(configured_ttl) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %.1f: %s", + _MODEL_LIST_CACHE_TTL_CONFIG_KEY, + default_ttl, + safe_error("", e), + ) + ttl = default_ttl + return max(ttl, 0.0) + + async def get_models( + self, + provider: Provider, + *, + use_cache: bool = True, + umo: str | None = None, + default_ttl: float = 30.0, + ) -> list[str]: + provider_id = provider.meta().id + now = time.monotonic() + ttl_seconds = self._get_ttl_seconds(umo, default_ttl) + if use_cache and ttl_seconds > 0: + cached = self._cache.get(provider_id) + if cached and now - cached.timestamp <= ttl_seconds: + return list(cached.models) + + models = tuple(await provider.get_models()) + if use_cache and ttl_seconds > 0: + self._cache[provider_id] = _ModelCacheEntry( + timestamp=now, + models=models, + ) + return list(models) + + +class _ProviderModelLookup: + def __init__(self, context: star.Context) -> None: + self._context = context + + def _get_max_concurrency(self, umo: str | None, default: int) -> int: + concurrency = default + if not umo: + return max(concurrency, 1) + try: + cfg = self._context.get_config(umo).get("provider_settings", {}) + configured_concurrency = cfg.get(_MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY) + if configured_concurrency is not None: + concurrency = int(configured_concurrency) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %d: %s", + _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, + default, + safe_error("", e), + ) + concurrency = default + return max(concurrency, 1) + + async def find_provider_for_model( + self, + model_name: str, + *, + exclude_provider_id: str | None, + umo: str | None, + default_max_concurrency: int, + get_models: Callable[..., Awaitable[list[str]]], + resolve_name: Callable[[str, Sequence[str]], str | None], + ) -> tuple[Provider | None, str | None]: + all_providers = [ + p + for p in self._context.get_all_providers() + if not exclude_provider_id or p.meta().id != exclude_provider_id + ] + if not all_providers: + return None, None + + failed_provider_errors: list[tuple[str, str]] = [] + max_concurrency = self._get_max_concurrency(umo, default_max_concurrency) + for start in range(0, len(all_providers), max_concurrency): + batch_providers = all_providers[start : start + max_concurrency] + batch_results = await asyncio.gather( + *[get_models(provider, umo=umo) for provider in batch_providers], + return_exceptions=True, + ) + for provider, result in zip(batch_providers, batch_results, strict=False): + provider_id = provider.meta().id + if isinstance(result, Exception): + failed_provider_errors.append((provider_id, safe_error("", result))) + continue + matched_model_name = resolve_name(model_name, result) + if matched_model_name is not None: + return provider, matched_model_name + + if failed_provider_errors and len(failed_provider_errors) == len(all_providers): + failed_ids = ",".join( + provider_id for provider_id, _ in failed_provider_errors + ) + logger.error( + "跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络", + model_name, + len(all_providers), + failed_ids, + ) + elif failed_provider_errors: + logger.debug( + "跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s", + model_name, + len(failed_provider_errors), + ",".join( + f"{provider_id}({error})" + for provider_id, error in failed_provider_errors + ), + ) + return None, None + + class ProviderCommands: _MODEL_LIST_CACHE_TTL_SECONDS = 30.0 _MODEL_LOOKUP_MAX_CONCURRENCY = 4 def __init__(self, context: star.Context) -> None: self.context = context - self._provider_models_cache: dict[str, _ModelCacheEntry] = {} + self._model_cache = _ModelListCache(context) + self._model_lookup = _ProviderModelLookup(context) def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: - if provider_id is None: - self._provider_models_cache.clear() - return - self._provider_models_cache.pop(provider_id, None) + self._model_cache.invalidate(provider_id) - @staticmethod - def _safe_error(prefix: str, e: Exception) -> str: - text = str(e) - for pattern in _SECRET_PATTERNS: - text = pattern.sub("[REDACTED]", text) - return prefix + text + def invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: + """Public hook for cache invalidation on external provider config changes.""" + self._invalidate_provider_models_cache(provider_id) @staticmethod def _normalize_model_name(model_name: str) -> str: @@ -85,43 +206,6 @@ def _resolve_model_name( return candidate return None - def _get_model_cache_ttl_seconds(self, umo: str | None = None) -> float: - ttl = self._MODEL_LIST_CACHE_TTL_SECONDS - if not umo: - return ttl - try: - cfg = self.context.get_config(umo).get("provider_settings", {}) - configured_ttl = cfg.get(_MODEL_LIST_CACHE_TTL_CONFIG_KEY) - if configured_ttl is not None: - ttl = float(configured_ttl) - except Exception as e: - logger.debug( - "读取 model_list_cache_ttl_seconds 失败,回退默认值 %.1f: %s", - self._MODEL_LIST_CACHE_TTL_SECONDS, - self._safe_error("", e), - ) - ttl = self._MODEL_LIST_CACHE_TTL_SECONDS - return max(ttl, 0.0) - - def _get_model_lookup_max_concurrency(self, umo: str | None = None) -> int: - concurrency = self._MODEL_LOOKUP_MAX_CONCURRENCY - if not umo: - return max(concurrency, 1) - try: - cfg = self.context.get_config(umo).get("provider_settings", {}) - configured_concurrency = cfg.get(_MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY) - if configured_concurrency is not None: - concurrency = int(configured_concurrency) - except Exception as e: - logger.debug( - "读取 %s 失败,回退默认值 %d: %s", - _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, - self._MODEL_LOOKUP_MAX_CONCURRENCY, - self._safe_error("", e), - ) - concurrency = self._MODEL_LOOKUP_MAX_CONCURRENCY - return max(concurrency, 1) - async def _get_provider_models( self, provider: Provider, @@ -129,21 +213,12 @@ async def _get_provider_models( use_cache: bool = True, umo: str | None = None, ) -> list[str]: - provider_id = provider.meta().id - now = time.monotonic() - ttl_seconds = self._get_model_cache_ttl_seconds(umo) - if use_cache and ttl_seconds > 0: - cached = self._provider_models_cache.get(provider_id) - if cached and now - cached.timestamp <= ttl_seconds: - return list(cached.models) - - models = tuple(await provider.get_models()) - if use_cache and ttl_seconds > 0: - self._provider_models_cache[provider_id] = _ModelCacheEntry( - timestamp=now, - models=models, - ) - return list(models) + return await self._model_cache.get_models( + provider, + use_cache=use_cache, + umo=umo, + default_ttl=self._MODEL_LIST_CACHE_TTL_SECONDS, + ) def _log_reachability_failure( self, @@ -172,7 +247,7 @@ async def _test_provider_capability(self, provider): return True, None, None except Exception as e: err_code = "TEST_FAILED" - err_reason = self._safe_error("", e) + err_reason = safe_error("", e) self._log_reachability_failure( provider, provider_capability_type, err_code, err_reason ) @@ -184,58 +259,14 @@ async def _find_provider_for_model( exclude_provider_id: str | None = None, umo: str | None = None, ) -> tuple[Provider | None, str | None]: - """在所有 LLM 提供商中查找包含指定模型的提供商。""" - all_providers = [ - p - for p in self.context.get_all_providers() - if not exclude_provider_id or p.meta().id != exclude_provider_id - ] - if not all_providers: - return None, None - - failed_provider_errors: list[tuple[str, str]] = [] - max_concurrency = self._get_model_lookup_max_concurrency(umo) - for start in range(0, len(all_providers), max_concurrency): - batch_providers = all_providers[start : start + max_concurrency] - batch_results = await asyncio.gather( - *[ - self._get_provider_models(provider, umo=umo) - for provider in batch_providers - ], - return_exceptions=True, - ) - for provider, result in zip(batch_providers, batch_results, strict=False): - provider_id = provider.meta().id - if isinstance(result, Exception): - failed_provider_errors.append( - (provider_id, self._safe_error("", result)) - ) - continue - matched_model_name = self._resolve_model_name(model_name, result) - if matched_model_name is not None: - return provider, matched_model_name - - if failed_provider_errors and len(failed_provider_errors) == len(all_providers): - failed_ids = ",".join( - provider_id for provider_id, _ in failed_provider_errors - ) - logger.error( - "跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络", - model_name, - len(all_providers), - failed_ids, - ) - elif failed_provider_errors: - logger.debug( - "跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s", - model_name, - len(failed_provider_errors), - ",".join( - f"{provider_id}({error})" - for provider_id, error in failed_provider_errors - ), - ) - return None, None + return await self._model_lookup.find_provider_for_model( + model_name, + exclude_provider_id=exclude_provider_id, + umo=umo, + default_max_concurrency=self._MODEL_LOOKUP_MAX_CONCURRENCY, + get_models=self._get_provider_models, + resolve_name=self._resolve_model_name, + ) async def provider( self, @@ -291,7 +322,7 @@ async def provider( p, None, reachable.__class__.__name__, - self._safe_error("", reachable), + safe_error("", reachable), ) reachable_flag = False error_code = reachable.__class__.__name__ @@ -431,7 +462,7 @@ async def _switch_model_by_name( try: models = await self._get_provider_models(prov, umo=umo) except Exception as e: - err_msg = self._safe_error("", e) + err_msg = safe_error("", e) logger.warning( "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", curr_provider_id, @@ -439,7 +470,7 @@ async def _switch_model_by_name( ) message.set_result( MessageEventResult().message( - self._safe_error("获取当前提供商模型列表失败: ", e) + safe_error("获取当前提供商模型列表失败: ", e) ) ) return @@ -484,7 +515,7 @@ async def _switch_model_by_name( except Exception as e: message.set_result( MessageEventResult().message( - self._safe_error("跨提供商切换并设置模型失败: ", e) + safe_error("跨提供商切换并设置模型失败: ", e) ), ) @@ -510,7 +541,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult() - .message(self._safe_error("获取模型列表失败: ", e)) + .message(safe_error("获取模型列表失败: ", e)) .use_t2i(False), ) return @@ -534,9 +565,7 @@ async def model_ls( ) except Exception as e: message.set_result( - MessageEventResult().message( - self._safe_error("获取模型列表失败: ", e) - ), + MessageEventResult().message(safe_error("获取模型列表失败: ", e)), ) return if idx_or_name > len(models) or idx_or_name < 1: @@ -554,7 +583,7 @@ async def model_ls( except Exception as e: message.set_result( MessageEventResult().message( - self._safe_error("切换模型未知错误: ", e) + safe_error("切换模型未知错误: ", e) ), ) return @@ -595,7 +624,7 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None except Exception as e: message.set_result( MessageEventResult().message( - self._safe_error("切换 Key 未知错误: ", e) + safe_error("切换 Key 未知错误: ", e) ), ) return diff --git a/astrbot/core/utils/error_redaction.py b/astrbot/core/utils/error_redaction.py new file mode 100644 index 0000000000..f324c478bd --- /dev/null +++ b/astrbot/core/utils/error_redaction.py @@ -0,0 +1,22 @@ +import re + +_SECRET_PATTERNS = [ + re.compile( + r"(?i)\b(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)\s*=\s*[^&'\" ]+" + ), + re.compile(r"(?i)([?&](?:api_?key|key|access_?token|auth_?token))=[^&'\" ]+"), + re.compile(r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+"), + re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+"), + re.compile(r"\bsk-[A-Za-z0-9]{16,}\b"), +] + + +def redact_sensitive_text(text: str) -> str: + redacted = text + for pattern in _SECRET_PATTERNS: + redacted = pattern.sub("[REDACTED]", redacted) + return redacted + + +def safe_error(prefix: str, error: Exception | BaseException | str) -> str: + return prefix + redact_sensitive_text(str(error)) From 40b7fd3c4edd9edd4e54910ade0b4971411b8900 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:01:41 +0900 Subject: [PATCH 15/28] fix: harden provider lookup cancellation and redaction --- .../builtin_commands/commands/provider.py | 210 +++++++++--------- astrbot/core/utils/error_redaction.py | 8 + 2 files changed, 108 insertions(+), 110 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index a465f66a0c..2c2d8e8180 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -2,7 +2,7 @@ import asyncio import time -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING @@ -17,6 +17,7 @@ _MODEL_LIST_CACHE_TTL_CONFIG_KEY = "model_list_cache_ttl_seconds" _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY = "model_lookup_max_concurrency" +_MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND = 16 @dataclass @@ -80,87 +81,6 @@ async def get_models( return list(models) -class _ProviderModelLookup: - def __init__(self, context: star.Context) -> None: - self._context = context - - def _get_max_concurrency(self, umo: str | None, default: int) -> int: - concurrency = default - if not umo: - return max(concurrency, 1) - try: - cfg = self._context.get_config(umo).get("provider_settings", {}) - configured_concurrency = cfg.get(_MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY) - if configured_concurrency is not None: - concurrency = int(configured_concurrency) - except Exception as e: - logger.debug( - "读取 %s 失败,回退默认值 %d: %s", - _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, - default, - safe_error("", e), - ) - concurrency = default - return max(concurrency, 1) - - async def find_provider_for_model( - self, - model_name: str, - *, - exclude_provider_id: str | None, - umo: str | None, - default_max_concurrency: int, - get_models: Callable[..., Awaitable[list[str]]], - resolve_name: Callable[[str, Sequence[str]], str | None], - ) -> tuple[Provider | None, str | None]: - all_providers = [ - p - for p in self._context.get_all_providers() - if not exclude_provider_id or p.meta().id != exclude_provider_id - ] - if not all_providers: - return None, None - - failed_provider_errors: list[tuple[str, str]] = [] - max_concurrency = self._get_max_concurrency(umo, default_max_concurrency) - for start in range(0, len(all_providers), max_concurrency): - batch_providers = all_providers[start : start + max_concurrency] - batch_results = await asyncio.gather( - *[get_models(provider, umo=umo) for provider in batch_providers], - return_exceptions=True, - ) - for provider, result in zip(batch_providers, batch_results, strict=False): - provider_id = provider.meta().id - if isinstance(result, Exception): - failed_provider_errors.append((provider_id, safe_error("", result))) - continue - matched_model_name = resolve_name(model_name, result) - if matched_model_name is not None: - return provider, matched_model_name - - if failed_provider_errors and len(failed_provider_errors) == len(all_providers): - failed_ids = ",".join( - provider_id for provider_id, _ in failed_provider_errors - ) - logger.error( - "跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络", - model_name, - len(all_providers), - failed_ids, - ) - elif failed_provider_errors: - logger.debug( - "跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s", - model_name, - len(failed_provider_errors), - ",".join( - f"{provider_id}({error})" - for provider_id, error in failed_provider_errors - ), - ) - return None, None - - class ProviderCommands: _MODEL_LIST_CACHE_TTL_SECONDS = 30.0 _MODEL_LOOKUP_MAX_CONCURRENCY = 4 @@ -168,14 +88,10 @@ class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context self._model_cache = _ModelListCache(context) - self._model_lookup = _ProviderModelLookup(context) - - def _invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: - self._model_cache.invalidate(provider_id) def invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: """Public hook for cache invalidation on external provider config changes.""" - self._invalidate_provider_models_cache(provider_id) + self._model_cache.invalidate(provider_id) @staticmethod def _normalize_model_name(model_name: str) -> str: @@ -186,26 +102,57 @@ def _resolve_model_name( model_name: str, models: Sequence[str], ) -> str | None: - normalized_model_name = self._normalize_model_name(model_name) - if not normalized_model_name: + norm_name = self._normalize_model_name(model_name) + if not norm_name: return None + def is_exact(candidate: str) -> bool: + return candidate == model_name + + def is_case_insensitive(norm_candidate: str) -> bool: + return norm_candidate == norm_name + + def is_suffix_match(norm_candidate: str) -> bool: + return norm_candidate.endswith(f"/{norm_name}") or norm_candidate.endswith( + f":{norm_name}" + ) + + def is_reverse_suffix_match(norm_candidate: str) -> bool: + return norm_name.endswith(f"/{norm_candidate}") or norm_name.endswith( + f":{norm_candidate}" + ) + for candidate in models: - normalized_candidate = self._normalize_model_name(candidate) - if candidate == model_name: + norm_candidate = self._normalize_model_name(candidate) + if is_exact(candidate): return candidate - if normalized_candidate == normalized_model_name: + if is_case_insensitive(norm_candidate): return candidate - if normalized_candidate.endswith( - f"/{normalized_model_name}" - ) or normalized_candidate.endswith(f":{normalized_model_name}"): + if is_suffix_match(norm_candidate): return candidate - if normalized_model_name.endswith( - f"/{normalized_candidate}" - ) or normalized_model_name.endswith(f":{normalized_candidate}"): + if is_reverse_suffix_match(norm_candidate): return candidate return None + def _get_lookup_max_concurrency(self, umo: str | None) -> int: + concurrency = self._MODEL_LOOKUP_MAX_CONCURRENCY + if not umo: + return min(max(concurrency, 1), _MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND) + try: + cfg = self.context.get_config(umo).get("provider_settings", {}) + configured_concurrency = cfg.get(_MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY) + if configured_concurrency is not None: + concurrency = int(configured_concurrency) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %d: %s", + _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, + self._MODEL_LOOKUP_MAX_CONCURRENCY, + safe_error("", e), + ) + concurrency = self._MODEL_LOOKUP_MAX_CONCURRENCY + return min(max(concurrency, 1), _MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND) + async def _get_provider_models( self, provider: Provider, @@ -259,14 +206,57 @@ async def _find_provider_for_model( exclude_provider_id: str | None = None, umo: str | None = None, ) -> tuple[Provider | None, str | None]: - return await self._model_lookup.find_provider_for_model( - model_name, - exclude_provider_id=exclude_provider_id, - umo=umo, - default_max_concurrency=self._MODEL_LOOKUP_MAX_CONCURRENCY, - get_models=self._get_provider_models, - resolve_name=self._resolve_model_name, - ) + all_providers = [ + p + for p in self.context.get_all_providers() + if not exclude_provider_id or p.meta().id != exclude_provider_id + ] + if not all_providers: + return None, None + + failed_provider_errors: list[tuple[str, str]] = [] + max_concurrency = self._get_lookup_max_concurrency(umo) + for start in range(0, len(all_providers), max_concurrency): + batch_providers = all_providers[start : start + max_concurrency] + batch_results = await asyncio.gather( + *[ + self._get_provider_models(provider, umo=umo) + for provider in batch_providers + ], + return_exceptions=True, + ) + for provider, result in zip(batch_providers, batch_results, strict=False): + if isinstance(result, asyncio.CancelledError): + raise result + provider_id = provider.meta().id + if isinstance(result, Exception): + failed_provider_errors.append((provider_id, safe_error("", result))) + continue + matched_model_name = self._resolve_model_name(model_name, result) + if matched_model_name is not None: + return provider, matched_model_name + + if failed_provider_errors and len(failed_provider_errors) == len(all_providers): + failed_ids = ",".join( + provider_id for provider_id, _ in failed_provider_errors + ) + logger.error( + "跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络", + model_name, + len(all_providers), + failed_ids, + ) + elif failed_provider_errors: + logger.debug( + "跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s", + model_name, + len(failed_provider_errors), + ",".join( + f"{provider_id}({error})" + for provider_id, error in failed_provider_errors + ), + ) + return None, None async def provider( self, @@ -478,7 +468,7 @@ async def _switch_model_by_name( matched_model_name = self._resolve_model_name(model_name, models) if matched_model_name is not None: prov.set_model(matched_model_name) - self._invalidate_provider_models_cache(curr_provider_id) + self.invalidate_provider_models_cache(curr_provider_id) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{matched_model_name}]", @@ -506,7 +496,7 @@ async def _switch_model_by_name( umo=umo, ) target_prov.set_model(matched_target_model_name) - self._invalidate_provider_models_cache(target_id) + self.invalidate_provider_models_cache(target_id) message.set_result( MessageEventResult().message( f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", @@ -574,7 +564,7 @@ async def model_ls( try: new_model = models[idx_or_name - 1] prov.set_model(new_model) - self._invalidate_provider_models_cache(prov.meta().id) + self.invalidate_provider_models_cache(prov.meta().id) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", @@ -619,7 +609,7 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None try: new_key = keys_data[index - 1] prov.set_key(new_key) - self._invalidate_provider_models_cache(prov.meta().id) + self.invalidate_provider_models_cache(prov.meta().id) message.set_result(MessageEventResult().message("切换 Key 成功。")) except Exception as e: message.set_result( diff --git a/astrbot/core/utils/error_redaction.py b/astrbot/core/utils/error_redaction.py index f324c478bd..28830970e3 100644 --- a/astrbot/core/utils/error_redaction.py +++ b/astrbot/core/utils/error_redaction.py @@ -1,6 +1,14 @@ import re _SECRET_PATTERNS = [ + re.compile( + r"(?i)\"(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)\"\s*:\s*\"[^\"]+\"" + ), + re.compile(r"(?i)\"authorization\"\s*:\s*\"bearer\s+[^\"]+\""), + re.compile( + r"(?i)'(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)'\s*:\s*'[^']+'" + ), + re.compile(r"(?i)'authorization'\s*:\s*'bearer\s+[^']+'"), re.compile( r"(?i)\b(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)\s*=\s*[^&'\" ]+" ), From cf7da2fcd7376f714991a710ea10bde12b0e0555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:07:23 +0900 Subject: [PATCH 16/28] refactor: streamline provider cache and lookup settings --- .../builtin_commands/commands/provider.py | 191 +++++++++--------- astrbot/core/provider/manager.py | 32 +++ 2 files changed, 133 insertions(+), 90 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 2c2d8e8180..3e08b298dc 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -2,9 +2,9 @@ import asyncio import time -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from astrbot import logger from astrbot.api import star @@ -23,12 +23,11 @@ @dataclass class _ModelCacheEntry: timestamp: float - models: tuple[str, ...] + models: list[str] class _ModelListCache: - def __init__(self, context: star.Context) -> None: - self._context = context + def __init__(self) -> None: self._cache: dict[str, _ModelCacheEntry] = {} def invalidate(self, provider_id: str | None = None) -> None: @@ -37,48 +36,21 @@ def invalidate(self, provider_id: str | None = None) -> None: return self._cache.pop(provider_id, None) - def _get_ttl_seconds(self, umo: str | None, default_ttl: float) -> float: - ttl = default_ttl - if not umo: - return max(ttl, 0.0) - try: - cfg = self._context.get_config(umo).get("provider_settings", {}) - configured_ttl = cfg.get(_MODEL_LIST_CACHE_TTL_CONFIG_KEY) - if configured_ttl is not None: - ttl = float(configured_ttl) - except Exception as e: - logger.debug( - "读取 %s 失败,回退默认值 %.1f: %s", - _MODEL_LIST_CACHE_TTL_CONFIG_KEY, - default_ttl, - safe_error("", e), - ) - ttl = default_ttl - return max(ttl, 0.0) - - async def get_models( - self, - provider: Provider, - *, - use_cache: bool = True, - umo: str | None = None, - default_ttl: float = 30.0, - ) -> list[str]: - provider_id = provider.meta().id - now = time.monotonic() - ttl_seconds = self._get_ttl_seconds(umo, default_ttl) - if use_cache and ttl_seconds > 0: - cached = self._cache.get(provider_id) - if cached and now - cached.timestamp <= ttl_seconds: - return list(cached.models) + def get_models(self, provider_id: str, *, ttl_seconds: float) -> list[str] | None: + if ttl_seconds <= 0: + return None + cached = self._cache.get(provider_id) + if not cached: + return None + if time.monotonic() - cached.timestamp > ttl_seconds: + return None + return list(cached.models) - models = tuple(await provider.get_models()) - if use_cache and ttl_seconds > 0: - self._cache[provider_id] = _ModelCacheEntry( - timestamp=now, - models=models, - ) - return list(models) + def set_models(self, provider_id: str, models: list[str]) -> None: + self._cache[provider_id] = _ModelCacheEntry( + timestamp=time.monotonic(), + models=list(models), + ) class ProviderCommands: @@ -87,12 +59,52 @@ class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context - self._model_cache = _ModelListCache(context) + self._model_cache = _ModelListCache() + register_change_hook = getattr( + self.context.provider_manager, + "register_provider_change_hook", + None, + ) + if callable(register_change_hook): + register_change_hook(self._on_provider_manager_changed) def invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: """Public hook for cache invalidation on external provider config changes.""" self._model_cache.invalidate(provider_id) + def _on_provider_manager_changed( + self, + provider_id: str, + provider_type: ProviderType, + umo: str | None, + ) -> None: + if provider_type == ProviderType.CHAT_COMPLETION: + self.invalidate_provider_models_cache(provider_id) + + def _get_provider_setting( + self, + umo: str | None, + key: str, + cast: Callable[[Any], Any], + default: Any, + ) -> Any: + if not umo: + return default + try: + cfg = self.context.get_config(umo).get("provider_settings", {}) + raw = cfg.get(key) + if raw is None: + return default + return cast(raw) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %r: %s", + key, + default, + safe_error("", e), + ) + return default + @staticmethod def _normalize_model_name(model_name: str) -> str: return model_name.strip().casefold() @@ -102,57 +114,51 @@ def _resolve_model_name( model_name: str, models: Sequence[str], ) -> str | None: + """Resolve model name with precedence: + exact > case-insensitive > suffix > reverse suffix. + """ norm_name = self._normalize_model_name(model_name) if not norm_name: return None - def is_exact(candidate: str) -> bool: - return candidate == model_name - - def is_case_insensitive(norm_candidate: str) -> bool: - return norm_candidate == norm_name - - def is_suffix_match(norm_candidate: str) -> bool: - return norm_candidate.endswith(f"/{norm_name}") or norm_candidate.endswith( - f":{norm_name}" - ) - - def is_reverse_suffix_match(norm_candidate: str) -> bool: - return norm_name.endswith(f"/{norm_candidate}") or norm_name.endswith( - f":{norm_candidate}" - ) - for candidate in models: norm_candidate = self._normalize_model_name(candidate) - if is_exact(candidate): + # exact match + if candidate == model_name: return candidate - if is_case_insensitive(norm_candidate): + # case-insensitive match + if norm_candidate == norm_name: return candidate - if is_suffix_match(norm_candidate): + # suffix match: provider model is longer + if norm_candidate.endswith(f"/{norm_name}") or norm_candidate.endswith( + f":{norm_name}" + ): return candidate - if is_reverse_suffix_match(norm_candidate): + # reverse suffix match: requested model is longer + if norm_name.endswith(f"/{norm_candidate}") or norm_name.endswith( + f":{norm_candidate}" + ): return candidate return None def _get_lookup_max_concurrency(self, umo: str | None) -> int: - concurrency = self._MODEL_LOOKUP_MAX_CONCURRENCY - if not umo: - return min(max(concurrency, 1), _MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND) - try: - cfg = self.context.get_config(umo).get("provider_settings", {}) - configured_concurrency = cfg.get(_MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY) - if configured_concurrency is not None: - concurrency = int(configured_concurrency) - except Exception as e: - logger.debug( - "读取 %s 失败,回退默认值 %d: %s", - _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, - self._MODEL_LOOKUP_MAX_CONCURRENCY, - safe_error("", e), - ) - concurrency = self._MODEL_LOOKUP_MAX_CONCURRENCY + concurrency = self._get_provider_setting( + umo, + _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, + int, + self._MODEL_LOOKUP_MAX_CONCURRENCY, + ) return min(max(concurrency, 1), _MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND) + def _get_model_cache_ttl_seconds(self, umo: str | None) -> float: + ttl = self._get_provider_setting( + umo, + _MODEL_LIST_CACHE_TTL_CONFIG_KEY, + float, + self._MODEL_LIST_CACHE_TTL_SECONDS, + ) + return max(float(ttl), 0.0) + async def _get_provider_models( self, provider: Provider, @@ -160,12 +166,17 @@ async def _get_provider_models( use_cache: bool = True, umo: str | None = None, ) -> list[str]: - return await self._model_cache.get_models( - provider, - use_cache=use_cache, - umo=umo, - default_ttl=self._MODEL_LIST_CACHE_TTL_SECONDS, - ) + provider_id = provider.meta().id + ttl_seconds = self._get_model_cache_ttl_seconds(umo) + if use_cache: + cached = self._model_cache.get_models(provider_id, ttl_seconds=ttl_seconds) + if cached is not None: + return cached + + models = list(await provider.get_models()) + if use_cache and ttl_seconds > 0: + self._model_cache.set_models(provider_id, models) + return models def _log_reachability_failure( self, @@ -209,7 +220,7 @@ async def _find_provider_for_model( all_providers = [ p for p in self.context.get_all_providers() - if not exclude_provider_id or p.meta().id != exclude_provider_id + if exclude_provider_id is None or p.meta().id != exclude_provider_id ] if not all_providers: return None, None diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index a331c97e9b..d347cf7ebb 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -2,6 +2,7 @@ import copy import os import traceback +from collections.abc import Callable from typing import Protocol, runtime_checkable from astrbot.core import astrbot_config, logger, sp @@ -71,6 +72,33 @@ def __init__( self.curr_tts_provider_inst: TTSProvider | None = None """默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.db_helper = db_helper + self._provider_change_hooks: list[ + Callable[[str, ProviderType, str | None], None] + ] = [] + + def register_provider_change_hook( + self, + hook: Callable[[str, ProviderType, str | None], None], + ) -> None: + if hook not in self._provider_change_hooks: + self._provider_change_hooks.append(hook) + + def _notify_provider_change_hooks( + self, + provider_id: str, + provider_type: ProviderType, + umo: str | None, + ) -> None: + for hook in list(self._provider_change_hooks): + try: + hook(provider_id, provider_type, umo) + except Exception as e: + logger.warning( + "调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s", + provider_id, + provider_type, + e, + ) @property def persona_configs(self) -> list: @@ -111,6 +139,7 @@ async def set_provider( f"provider_perf_{provider_type.value}", provider_id, ) + self._notify_provider_change_hooks(provider_id, provider_type, umo) return # 不启用提供商会话隔离模式的情况 @@ -126,6 +155,7 @@ async def set_provider( scope="global", scope_id="global", ) + self._notify_provider_change_hooks(provider_id, provider_type, umo) elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance( prov, STTProvider, @@ -137,6 +167,7 @@ async def set_provider( scope="global", scope_id="global", ) + self._notify_provider_change_hooks(provider_id, provider_type, umo) elif provider_type == ProviderType.CHAT_COMPLETION and isinstance( prov, Provider, @@ -148,6 +179,7 @@ async def set_provider( scope="global", scope_id="global", ) + self._notify_provider_change_hooks(provider_id, provider_type, umo) async def get_provider_by_id(self, provider_id: str) -> Providers | None: """根据提供商 ID 获取提供商实例""" From 5aeccbb99822194958777c81a9ab293a4702ad0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:11:04 +0900 Subject: [PATCH 17/28] refactor: simplify provider command setting and update helpers --- .../builtin_commands/commands/provider.py | 76 ++++++++++++------- astrbot/core/provider/manager.py | 3 +- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 3e08b298dc..283cf6094e 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -105,10 +105,44 @@ def _get_provider_setting( ) return default + def _get_int_provider_setting(self, umo: str | None, key: str, default: int) -> int: + return int(self._get_provider_setting(umo, key, int, default)) + + def _get_float_provider_setting( + self, umo: str | None, key: str, default: float + ) -> float: + return float(self._get_provider_setting(umo, key, float, default)) + + def _set_model_and_invalidate(self, provider: Provider, model_name: str) -> None: + provider.set_model(model_name) + self.invalidate_provider_models_cache(provider.meta().id) + + def _set_key_and_invalidate(self, provider: Provider, key: str) -> None: + provider.set_key(key) + self.invalidate_provider_models_cache(provider.meta().id) + @staticmethod def _normalize_model_name(model_name: str) -> str: return model_name.strip().casefold() + def _match_exact_or_ci(self, requested: str, candidate: str) -> bool: + if candidate == requested: + return True + return self._normalize_model_name(candidate) == self._normalize_model_name( + requested + ) + + def _match_suffix_either_side(self, requested: str, candidate: str) -> bool: + req = self._normalize_model_name(requested) + cand = self._normalize_model_name(candidate) + if not req or not cand: + return False + if cand.endswith(f"/{req}") or cand.endswith(f":{req}"): + return True + if req.endswith(f"/{cand}") or req.endswith(f":{cand}"): + return True + return False + def _resolve_model_name( self, model_name: str, @@ -121,43 +155,33 @@ def _resolve_model_name( if not norm_name: return None + # exact / case-insensitive match for candidate in models: - norm_candidate = self._normalize_model_name(candidate) - # exact match - if candidate == model_name: - return candidate - # case-insensitive match - if norm_candidate == norm_name: - return candidate - # suffix match: provider model is longer - if norm_candidate.endswith(f"/{norm_name}") or norm_candidate.endswith( - f":{norm_name}" - ): + if self._match_exact_or_ci(model_name, candidate): return candidate - # reverse suffix match: requested model is longer - if norm_name.endswith(f"/{norm_candidate}") or norm_name.endswith( - f":{norm_candidate}" - ): + + # suffix / reverse suffix match + for candidate in models: + if self._match_suffix_either_side(model_name, candidate): return candidate + return None def _get_lookup_max_concurrency(self, umo: str | None) -> int: - concurrency = self._get_provider_setting( + concurrency = self._get_int_provider_setting( umo, _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, - int, self._MODEL_LOOKUP_MAX_CONCURRENCY, ) return min(max(concurrency, 1), _MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND) def _get_model_cache_ttl_seconds(self, umo: str | None) -> float: - ttl = self._get_provider_setting( + ttl = self._get_float_provider_setting( umo, _MODEL_LIST_CACHE_TTL_CONFIG_KEY, - float, self._MODEL_LIST_CACHE_TTL_SECONDS, ) - return max(float(ttl), 0.0) + return max(ttl, 0.0) async def _get_provider_models( self, @@ -478,8 +502,7 @@ async def _switch_model_by_name( matched_model_name = self._resolve_model_name(model_name, models) if matched_model_name is not None: - prov.set_model(matched_model_name) - self.invalidate_provider_models_cache(curr_provider_id) + self._set_model_and_invalidate(prov, matched_model_name) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{matched_model_name}]", @@ -506,8 +529,7 @@ async def _switch_model_by_name( provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) - target_prov.set_model(matched_target_model_name) - self.invalidate_provider_models_cache(target_id) + self._set_model_and_invalidate(target_prov, matched_target_model_name) message.set_result( MessageEventResult().message( f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", @@ -574,8 +596,7 @@ async def model_ls( else: try: new_model = models[idx_or_name - 1] - prov.set_model(new_model) - self.invalidate_provider_models_cache(prov.meta().id) + self._set_model_and_invalidate(prov, new_model) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", @@ -619,8 +640,7 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None else: try: new_key = keys_data[index - 1] - prov.set_key(new_key) - self.invalidate_provider_models_cache(prov.meta().id) + self._set_key_and_invalidate(prov, new_key) message.set_result(MessageEventResult().message("切换 Key 成功。")) except Exception as e: message.set_result( diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index d347cf7ebb..e1e003cbda 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -8,6 +8,7 @@ from astrbot.core import astrbot_config, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase +from astrbot.core.utils.error_redaction import safe_error from ..persona_mgr import PersonaManager from .entities import ProviderType @@ -97,7 +98,7 @@ def _notify_provider_change_hooks( "调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s", provider_id, provider_type, - e, + safe_error("", e), ) @property From 747bede0395d4a21b028eb3b122ce1c1ef64051a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:16:12 +0900 Subject: [PATCH 18/28] refactor: streamline provider model lookup config usage --- .../builtin_commands/commands/provider.py | 102 +++++++++--------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 283cf6094e..287086b0a1 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -2,9 +2,9 @@ import asyncio import time -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from astrbot import logger from astrbot.api import star @@ -15,9 +15,17 @@ if TYPE_CHECKING: from astrbot.core.provider.provider import Provider -_MODEL_LIST_CACHE_TTL_CONFIG_KEY = "model_list_cache_ttl_seconds" -_MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY = "model_lookup_max_concurrency" -_MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND = 16 + +@dataclass(frozen=True) +class _ModelLookupConfig: + list_cache_ttl_seconds: float = 30.0 + max_concurrency_default: int = 4 + max_concurrency_upper_bound: int = 16 + list_cache_ttl_key: str = "model_list_cache_ttl_seconds" + max_concurrency_key: str = "model_lookup_max_concurrency" + + +_MODEL_LOOKUP_CONFIG = _ModelLookupConfig() @dataclass @@ -54,12 +62,10 @@ def set_models(self, provider_id: str, models: list[str]) -> None: class ProviderCommands: - _MODEL_LIST_CACHE_TTL_SECONDS = 30.0 - _MODEL_LOOKUP_MAX_CONCURRENCY = 4 - def __init__(self, context: star.Context) -> None: self.context = context self._model_cache = _ModelListCache() + self._cfg = _MODEL_LOOKUP_CONFIG register_change_hook = getattr( self.context.provider_manager, "register_provider_change_hook", @@ -81,13 +87,7 @@ def _on_provider_manager_changed( if provider_type == ProviderType.CHAT_COMPLETION: self.invalidate_provider_models_cache(provider_id) - def _get_provider_setting( - self, - umo: str | None, - key: str, - cast: Callable[[Any], Any], - default: Any, - ) -> Any: + def _get_int_provider_setting(self, umo: str | None, key: str, default: int) -> int: if not umo: return default try: @@ -95,7 +95,7 @@ def _get_provider_setting( raw = cfg.get(key) if raw is None: return default - return cast(raw) + return int(raw) except Exception as e: logger.debug( "读取 %s 失败,回退默认值 %r: %s", @@ -105,13 +105,25 @@ def _get_provider_setting( ) return default - def _get_int_provider_setting(self, umo: str | None, key: str, default: int) -> int: - return int(self._get_provider_setting(umo, key, int, default)) - def _get_float_provider_setting( self, umo: str | None, key: str, default: float ) -> float: - return float(self._get_provider_setting(umo, key, float, default)) + if not umo: + return default + try: + cfg = self.context.get_config(umo).get("provider_settings", {}) + raw = cfg.get(key) + if raw is None: + return default + return float(raw) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %r: %s", + key, + default, + safe_error("", e), + ) + return default def _set_model_and_invalidate(self, provider: Provider, model_name: str) -> None: provider.set_model(model_name) @@ -121,28 +133,6 @@ def _set_key_and_invalidate(self, provider: Provider, key: str) -> None: provider.set_key(key) self.invalidate_provider_models_cache(provider.meta().id) - @staticmethod - def _normalize_model_name(model_name: str) -> str: - return model_name.strip().casefold() - - def _match_exact_or_ci(self, requested: str, candidate: str) -> bool: - if candidate == requested: - return True - return self._normalize_model_name(candidate) == self._normalize_model_name( - requested - ) - - def _match_suffix_either_side(self, requested: str, candidate: str) -> bool: - req = self._normalize_model_name(requested) - cand = self._normalize_model_name(candidate) - if not req or not cand: - return False - if cand.endswith(f"/{req}") or cand.endswith(f":{req}"): - return True - if req.endswith(f"/{cand}") or req.endswith(f":{cand}"): - return True - return False - def _resolve_model_name( self, model_name: str, @@ -151,18 +141,28 @@ def _resolve_model_name( """Resolve model name with precedence: exact > case-insensitive > suffix > reverse suffix. """ - norm_name = self._normalize_model_name(model_name) - if not norm_name: + requested = model_name.strip() + if not requested: return None + requested_norm = requested.casefold() + # exact / case-insensitive match for candidate in models: - if self._match_exact_or_ci(model_name, candidate): + if candidate == requested or candidate.casefold() == requested_norm: return candidate # suffix / reverse suffix match + def _match_suffix(req: str, cand: str) -> bool: + return ( + cand.endswith(f"/{req}") + or cand.endswith(f":{req}") + or req.endswith(f"/{cand}") + or req.endswith(f":{cand}") + ) + for candidate in models: - if self._match_suffix_either_side(model_name, candidate): + if _match_suffix(requested_norm, candidate.casefold()): return candidate return None @@ -170,16 +170,16 @@ def _resolve_model_name( def _get_lookup_max_concurrency(self, umo: str | None) -> int: concurrency = self._get_int_provider_setting( umo, - _MODEL_LOOKUP_MAX_CONCURRENCY_CONFIG_KEY, - self._MODEL_LOOKUP_MAX_CONCURRENCY, + self._cfg.max_concurrency_key, + self._cfg.max_concurrency_default, ) - return min(max(concurrency, 1), _MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND) + return min(max(concurrency, 1), self._cfg.max_concurrency_upper_bound) def _get_model_cache_ttl_seconds(self, umo: str | None) -> float: ttl = self._get_float_provider_setting( umo, - _MODEL_LIST_CACHE_TTL_CONFIG_KEY, - self._MODEL_LIST_CACHE_TTL_SECONDS, + self._cfg.list_cache_ttl_key, + self._cfg.list_cache_ttl_seconds, ) return max(ttl, 0.0) From 9059c1bf8ec68bc870c8e9840ab9046a27262b8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:24:14 +0900 Subject: [PATCH 19/28] refactor: flatten provider lookup settings and filter model lookup providers --- .../builtin_commands/commands/provider.py | 91 +++++++++---------- astrbot/core/utils/error_redaction.py | 12 ++- 2 files changed, 54 insertions(+), 49 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 287086b0a1..9708b55dca 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -16,16 +16,11 @@ from astrbot.core.provider.provider import Provider -@dataclass(frozen=True) -class _ModelLookupConfig: - list_cache_ttl_seconds: float = 30.0 - max_concurrency_default: int = 4 - max_concurrency_upper_bound: int = 16 - list_cache_ttl_key: str = "model_list_cache_ttl_seconds" - max_concurrency_key: str = "model_lookup_max_concurrency" - - -_MODEL_LOOKUP_CONFIG = _ModelLookupConfig() +MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT = 30.0 +MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT = 4 +MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND = 16 +MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds" +MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency" @dataclass @@ -65,7 +60,6 @@ class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context self._model_cache = _ModelListCache() - self._cfg = _MODEL_LOOKUP_CONFIG register_change_hook = getattr( self.context.provider_manager, "register_provider_change_hook", @@ -87,27 +81,13 @@ def _on_provider_manager_changed( if provider_type == ProviderType.CHAT_COMPLETION: self.invalidate_provider_models_cache(provider_id) - def _get_int_provider_setting(self, umo: str | None, key: str, default: int) -> int: - if not umo: - return default - try: - cfg = self.context.get_config(umo).get("provider_settings", {}) - raw = cfg.get(key) - if raw is None: - return default - return int(raw) - except Exception as e: - logger.debug( - "读取 %s 失败,回退默认值 %r: %s", - key, - default, - safe_error("", e), - ) - return default - - def _get_float_provider_setting( - self, umo: str | None, key: str, default: float - ) -> float: + def _get_numeric_provider_setting( + self, + umo: str | None, + key: str, + default: int | float, + cast: type[int] | type[float], + ) -> int | float: if not umo: return default try: @@ -115,7 +95,7 @@ def _get_float_provider_setting( raw = cfg.get(key) if raw is None: return default - return float(raw) + return cast(raw) except Exception as e: logger.debug( "读取 %s 失败,回退默认值 %r: %s", @@ -168,18 +148,27 @@ def _match_suffix(req: str, cand: str) -> bool: return None def _get_lookup_max_concurrency(self, umo: str | None) -> int: - concurrency = self._get_int_provider_setting( - umo, - self._cfg.max_concurrency_key, - self._cfg.max_concurrency_default, + concurrency = int( + self._get_numeric_provider_setting( + umo, + MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, + int, + ) + ) + return min( + max(concurrency, 1), + MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND, ) - return min(max(concurrency, 1), self._cfg.max_concurrency_upper_bound) def _get_model_cache_ttl_seconds(self, umo: str | None) -> float: - ttl = self._get_float_provider_setting( - umo, - self._cfg.list_cache_ttl_key, - self._cfg.list_cache_ttl_seconds, + ttl = float( + self._get_numeric_provider_setting( + umo, + MODEL_LIST_CACHE_TTL_KEY, + MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, + float, + ) ) return max(ttl, 0.0) @@ -241,11 +230,17 @@ async def _find_provider_for_model( exclude_provider_id: str | None = None, umo: str | None = None, ) -> tuple[Provider | None, str | None]: - all_providers = [ - p - for p in self.context.get_all_providers() - if exclude_provider_id is None or p.meta().id != exclude_provider_id - ] + all_providers = [] + for provider in self.context.get_all_providers(): + provider_meta = provider.meta() + if provider_meta.provider_type != ProviderType.CHAT_COMPLETION: + continue + if ( + exclude_provider_id is not None + and provider_meta.id == exclude_provider_id + ): + continue + all_providers.append(provider) if not all_providers: return None, None @@ -341,6 +336,8 @@ async def provider( id_ = meta.id error_code = None + if isinstance(reachable, asyncio.CancelledError): + raise reachable if isinstance(reachable, Exception): # 异常情况下兜底处理,避免单个 provider 导致列表失败 self._log_reachability_failure( diff --git a/astrbot/core/utils/error_redaction.py b/astrbot/core/utils/error_redaction.py index 28830970e3..56a248db2a 100644 --- a/astrbot/core/utils/error_redaction.py +++ b/astrbot/core/utils/error_redaction.py @@ -26,5 +26,13 @@ def redact_sensitive_text(text: str) -> str: return redacted -def safe_error(prefix: str, error: Exception | BaseException | str) -> str: - return prefix + redact_sensitive_text(str(error)) +def safe_error( + prefix: str, + error: Exception | BaseException | str, + *, + redact: bool = True, +) -> str: + text = str(error) + if redact: + text = redact_sensitive_text(text) + return prefix + text From b76ba0c371417acc983ea7ba5db02c96fe346dfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:30:39 +0900 Subject: [PATCH 20/28] refactor: simplify provider cache and callback flow --- .../builtin_commands/commands/provider.py | 97 +++++++++---------- astrbot/core/provider/manager.py | 47 +++++---- astrbot/core/utils/error_redaction.py | 60 ++++++++---- 3 files changed, 114 insertions(+), 90 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 9708b55dca..7391a6826e 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -46,6 +46,7 @@ def get_models(self, provider_id: str, *, ttl_seconds: float) -> list[str] | Non if not cached: return None if time.monotonic() - cached.timestamp > ttl_seconds: + self._cache.pop(provider_id, None) return None return list(cached.models) @@ -60,6 +61,14 @@ class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context self._model_cache = _ModelListCache() + set_change_callback = getattr( + self.context.provider_manager, + "set_provider_change_callback", + None, + ) + if callable(set_change_callback): + set_change_callback(self._on_provider_manager_changed) + return register_change_hook = getattr( self.context.provider_manager, "register_provider_change_hook", @@ -105,21 +114,13 @@ def _get_numeric_provider_setting( ) return default - def _set_model_and_invalidate(self, provider: Provider, model_name: str) -> None: - provider.set_model(model_name) - self.invalidate_provider_models_cache(provider.meta().id) - - def _set_key_and_invalidate(self, provider: Provider, key: str) -> None: - provider.set_key(key) - self.invalidate_provider_models_cache(provider.meta().id) - def _resolve_model_name( self, model_name: str, models: Sequence[str], ) -> str | None: """Resolve model name with precedence: - exact > case-insensitive > suffix > reverse suffix. + exact > case-insensitive > provider-qualified suffix. """ requested = model_name.strip() if not requested: @@ -132,46 +133,17 @@ def _resolve_model_name( if candidate == requested or candidate.casefold() == requested_norm: return candidate - # suffix / reverse suffix match - def _match_suffix(req: str, cand: str) -> bool: - return ( - cand.endswith(f"/{req}") - or cand.endswith(f":{req}") - or req.endswith(f"/{cand}") - or req.endswith(f":{cand}") - ) + # provider-qualified suffix match: + # e.g. candidate `openai/gpt-4o` should match requested `gpt-4o`. + def _match_qualified_suffix(req: str, cand: str) -> bool: + return cand.endswith(f"/{req}") or cand.endswith(f":{req}") for candidate in models: - if _match_suffix(requested_norm, candidate.casefold()): + if _match_qualified_suffix(requested_norm, candidate.casefold()): return candidate return None - def _get_lookup_max_concurrency(self, umo: str | None) -> int: - concurrency = int( - self._get_numeric_provider_setting( - umo, - MODEL_LOOKUP_MAX_CONCURRENCY_KEY, - MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, - int, - ) - ) - return min( - max(concurrency, 1), - MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND, - ) - - def _get_model_cache_ttl_seconds(self, umo: str | None) -> float: - ttl = float( - self._get_numeric_provider_setting( - umo, - MODEL_LIST_CACHE_TTL_KEY, - MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, - float, - ) - ) - return max(ttl, 0.0) - async def _get_provider_models( self, provider: Provider, @@ -180,7 +152,13 @@ async def _get_provider_models( umo: str | None = None, ) -> list[str]: provider_id = provider.meta().id - ttl_seconds = self._get_model_cache_ttl_seconds(umo) + raw_ttl = self._get_numeric_provider_setting( + umo, + MODEL_LIST_CACHE_TTL_KEY, + MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, + float, + ) + ttl_seconds = max(float(raw_ttl), 0.0) if use_cache: cached = self._model_cache.get_models(provider_id, ttl_seconds=ttl_seconds) if cached is not None: @@ -245,7 +223,16 @@ async def _find_provider_for_model( return None, None failed_provider_errors: list[tuple[str, str]] = [] - max_concurrency = self._get_lookup_max_concurrency(umo) + raw_concurrency = self._get_numeric_provider_setting( + umo, + MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, + int, + ) + max_concurrency = min( + max(int(raw_concurrency), 1), + MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND, + ) for start in range(0, len(all_providers), max_concurrency): batch_providers = all_providers[start : start + max_concurrency] batch_results = await asyncio.gather( @@ -483,6 +470,8 @@ async def _switch_model_by_name( try: models = await self._get_provider_models(prov, umo=umo) + except asyncio.CancelledError: + raise except Exception as e: err_msg = safe_error("", e) logger.warning( @@ -499,7 +488,8 @@ async def _switch_model_by_name( matched_model_name = self._resolve_model_name(model_name, models) if matched_model_name is not None: - self._set_model_and_invalidate(prov, matched_model_name) + prov.set_model(matched_model_name) + self.invalidate_provider_models_cache(prov.meta().id) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{matched_model_name}]", @@ -526,12 +516,15 @@ async def _switch_model_by_name( provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) - self._set_model_and_invalidate(target_prov, matched_target_model_name) + target_prov.set_model(matched_target_model_name) + self.invalidate_provider_models_cache(target_prov.meta().id) message.set_result( MessageEventResult().message( f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", ), ) + except asyncio.CancelledError: + raise except Exception as e: message.set_result( MessageEventResult().message( @@ -558,6 +551,8 @@ async def model_ls( models = await self._get_provider_models( prov, umo=message.unified_msg_origin ) + except asyncio.CancelledError: + raise except Exception as e: message.set_result( MessageEventResult() @@ -583,6 +578,8 @@ async def model_ls( models = await self._get_provider_models( prov, umo=message.unified_msg_origin ) + except asyncio.CancelledError: + raise except Exception as e: message.set_result( MessageEventResult().message(safe_error("获取模型列表失败: ", e)), @@ -593,7 +590,8 @@ async def model_ls( else: try: new_model = models[idx_or_name - 1] - self._set_model_and_invalidate(prov, new_model) + prov.set_model(new_model) + self.invalidate_provider_models_cache(prov.meta().id) message.set_result( MessageEventResult().message( f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", @@ -637,7 +635,8 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None else: try: new_key = keys_data[index - 1] - self._set_key_and_invalidate(prov, new_key) + prov.set_key(new_key) + self.invalidate_provider_models_cache(prov.meta().id) message.set_result(MessageEventResult().message("切换 Key 成功。")) except Exception as e: message.set_result( diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index e1e003cbda..69d6b8489b 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -73,33 +73,40 @@ def __init__( self.curr_tts_provider_inst: TTSProvider | None = None """默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.db_helper = db_helper - self._provider_change_hooks: list[ - Callable[[str, ProviderType, str | None], None] - ] = [] + self._on_provider_changed: ( + Callable[[str, ProviderType, str | None], None] | None + ) = None + + def set_provider_change_callback( + self, + cb: Callable[[str, ProviderType, str | None], None] | None, + ) -> None: + self._on_provider_changed = cb def register_provider_change_hook( self, hook: Callable[[str, ProviderType, str | None], None], ) -> None: - if hook not in self._provider_change_hooks: - self._provider_change_hooks.append(hook) + # Backward-compatible wrapper for older call sites. + self.set_provider_change_callback(hook) - def _notify_provider_change_hooks( + def _notify_provider_changed( self, provider_id: str, provider_type: ProviderType, umo: str | None, ) -> None: - for hook in list(self._provider_change_hooks): - try: - hook(provider_id, provider_type, umo) - except Exception as e: - logger.warning( - "调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s", - provider_id, - provider_type, - safe_error("", e), - ) + if not self._on_provider_changed: + return + try: + self._on_provider_changed(provider_id, provider_type, umo) + except Exception as e: + logger.warning( + "调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s", + provider_id, + provider_type, + safe_error("", e), + ) @property def persona_configs(self) -> list: @@ -140,7 +147,7 @@ async def set_provider( f"provider_perf_{provider_type.value}", provider_id, ) - self._notify_provider_change_hooks(provider_id, provider_type, umo) + self._notify_provider_changed(provider_id, provider_type, umo) return # 不启用提供商会话隔离模式的情况 @@ -156,7 +163,7 @@ async def set_provider( scope="global", scope_id="global", ) - self._notify_provider_change_hooks(provider_id, provider_type, umo) + self._notify_provider_changed(provider_id, provider_type, umo) elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance( prov, STTProvider, @@ -168,7 +175,7 @@ async def set_provider( scope="global", scope_id="global", ) - self._notify_provider_change_hooks(provider_id, provider_type, umo) + self._notify_provider_changed(provider_id, provider_type, umo) elif provider_type == ProviderType.CHAT_COMPLETION and isinstance( prov, Provider, @@ -180,7 +187,7 @@ async def set_provider( scope="global", scope_id="global", ) - self._notify_provider_change_hooks(provider_id, provider_type, umo) + self._notify_provider_changed(provider_id, provider_type, umo) async def get_provider_by_id(self, provider_id: str) -> Providers | None: """根据提供商 ID 获取提供商实例""" diff --git a/astrbot/core/utils/error_redaction.py b/astrbot/core/utils/error_redaction.py index 56a248db2a..9787ccc215 100644 --- a/astrbot/core/utils/error_redaction.py +++ b/astrbot/core/utils/error_redaction.py @@ -1,29 +1,47 @@ import re -_SECRET_PATTERNS = [ - re.compile( - r"(?i)\"(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)\"\s*:\s*\"[^\"]+\"" - ), - re.compile(r"(?i)\"authorization\"\s*:\s*\"bearer\s+[^\"]+\""), - re.compile( - r"(?i)'(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)'\s*:\s*'[^']+'" - ), - re.compile(r"(?i)'authorization'\s*:\s*'bearer\s+[^']+'"), - re.compile( - r"(?i)\b(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)\s*=\s*[^&'\" ]+" - ), - re.compile(r"(?i)([?&](?:api_?key|key|access_?token|auth_?token))=[^&'\" ]+"), - re.compile(r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+"), - re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+"), - re.compile(r"\bsk-[A-Za-z0-9]{16,}\b"), -] +_SECRET_KEYS = ( + r"(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)" +) + +_JSON_FIELD_PATTERN = re.compile( + rf"(?i)(['\"])({_SECRET_KEYS})\1\s*:\s*(['\"])[^'\"]+\3" +) +_AUTH_JSON_FIELD_PATTERN = re.compile( + r"(?i)(['\"])authorization\1\s*:\s*(['\"])bearer\s+[^'\"]+\2" +) +_QUERY_FIELD_PATTERN = re.compile(rf"(?i)\b{_SECRET_KEYS}\s*=\s*[^&'\" ]+") +_QUERY_PARAM_PATTERN = re.compile( + r"(?i)([?&](?:api_?key|key|access_?token|auth_?token))=[^&'\" ]+" +) +_AUTH_HEADER_PATTERN = re.compile( + r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+" +) +_BEARER_PATTERN = re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+") +_SK_PATTERN = re.compile(r"\bsk-[A-Za-z0-9]{16,}\b") + + +def _redact_json_like(text: str) -> str: + text = _JSON_FIELD_PATTERN.sub("[REDACTED]", text) + return _AUTH_JSON_FIELD_PATTERN.sub("[REDACTED]", text) + + +def _redact_query_like(text: str) -> str: + text = _QUERY_FIELD_PATTERN.sub("[REDACTED]", text) + return _QUERY_PARAM_PATTERN.sub("[REDACTED]", text) + + +def _redact_tokens(text: str) -> str: + text = _AUTH_HEADER_PATTERN.sub("[REDACTED]", text) + text = _BEARER_PATTERN.sub("[REDACTED]", text) + return _SK_PATTERN.sub("[REDACTED]", text) def redact_sensitive_text(text: str) -> str: - redacted = text - for pattern in _SECRET_PATTERNS: - redacted = pattern.sub("[REDACTED]", redacted) - return redacted + text = _redact_json_like(text) + text = _redact_query_like(text) + text = _redact_tokens(text) + return text def safe_error( From f139be7e5775bf74da63d60be52f0260e250d970 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:35:07 +0900 Subject: [PATCH 21/28] refactor: simplify provider command model cache flow --- .../builtin_commands/commands/provider.py | 134 ++++++++---------- 1 file changed, 62 insertions(+), 72 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 7391a6826e..a898217de2 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -3,7 +3,6 @@ import asyncio import time from collections.abc import Sequence -from dataclasses import dataclass from typing import TYPE_CHECKING from astrbot import logger @@ -23,44 +22,13 @@ MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency" -@dataclass -class _ModelCacheEntry: - timestamp: float - models: list[str] - - -class _ModelListCache: - def __init__(self) -> None: - self._cache: dict[str, _ModelCacheEntry] = {} - - def invalidate(self, provider_id: str | None = None) -> None: - if provider_id is None: - self._cache.clear() - return - self._cache.pop(provider_id, None) - - def get_models(self, provider_id: str, *, ttl_seconds: float) -> list[str] | None: - if ttl_seconds <= 0: - return None - cached = self._cache.get(provider_id) - if not cached: - return None - if time.monotonic() - cached.timestamp > ttl_seconds: - self._cache.pop(provider_id, None) - return None - return list(cached.models) - - def set_models(self, provider_id: str, models: list[str]) -> None: - self._cache[provider_id] = _ModelCacheEntry( - timestamp=time.monotonic(), - models=list(models), - ) - - class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context - self._model_cache = _ModelListCache() + self._model_cache: dict[str, tuple[float, list[str]]] = {} + self._register_provider_change_hook() + + def _register_provider_change_hook(self) -> None: set_change_callback = getattr( self.context.provider_manager, "set_provider_change_callback", @@ -79,7 +47,10 @@ def __init__(self, context: star.Context) -> None: def invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: """Public hook for cache invalidation on external provider config changes.""" - self._model_cache.invalidate(provider_id) + if provider_id is None: + self._model_cache.clear() + return + self._model_cache.pop(provider_id, None) def _on_provider_manager_changed( self, @@ -90,29 +61,58 @@ def _on_provider_manager_changed( if provider_type == ProviderType.CHAT_COMPLETION: self.invalidate_provider_models_cache(provider_id) - def _get_numeric_provider_setting( - self, - umo: str | None, - key: str, - default: int | float, - cast: type[int] | type[float], - ) -> int | float: + def _get_cached_models( + self, provider_id: str, *, ttl_seconds: float + ) -> list[str] | None: + if ttl_seconds <= 0: + return None + entry = self._model_cache.get(provider_id) + if not entry: + return None + timestamp, models = entry + if time.monotonic() - timestamp > ttl_seconds: + self._model_cache.pop(provider_id, None) + return None + return list(models) + + def _set_cached_models(self, provider_id: str, models: list[str]) -> None: + self._model_cache[provider_id] = (time.monotonic(), list(models)) + + def _get_ttl_setting(self, umo: str | None) -> float: if not umo: - return default + return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT try: cfg = self.context.get_config(umo).get("provider_settings", {}) - raw = cfg.get(key) + raw = cfg.get(MODEL_LIST_CACHE_TTL_KEY) if raw is None: - return default - return cast(raw) + return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT + return float(raw) except Exception as e: logger.debug( "读取 %s 失败,回退默认值 %r: %s", - key, - default, + MODEL_LIST_CACHE_TTL_KEY, + MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, safe_error("", e), ) - return default + return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT + + def _get_lookup_concurrency(self, umo: str | None) -> int: + if not umo: + return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT + try: + cfg = self.context.get_config(umo).get("provider_settings", {}) + raw = cfg.get(MODEL_LOOKUP_MAX_CONCURRENCY_KEY) + if raw is None: + return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT + return int(raw) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %r: %s", + MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, + safe_error("", e), + ) + return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT def _resolve_model_name( self, @@ -144,6 +144,11 @@ def _match_qualified_suffix(req: str, cand: str) -> bool: return None + def _apply_model(self, prov: Provider, model_name: str) -> str: + prov.set_model(model_name) + self.invalidate_provider_models_cache(prov.meta().id) + return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]" + async def _get_provider_models( self, provider: Provider, @@ -152,21 +157,15 @@ async def _get_provider_models( umo: str | None = None, ) -> list[str]: provider_id = provider.meta().id - raw_ttl = self._get_numeric_provider_setting( - umo, - MODEL_LIST_CACHE_TTL_KEY, - MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, - float, - ) - ttl_seconds = max(float(raw_ttl), 0.0) + ttl_seconds = max(float(self._get_ttl_setting(umo)), 0.0) if use_cache: - cached = self._model_cache.get_models(provider_id, ttl_seconds=ttl_seconds) + cached = self._get_cached_models(provider_id, ttl_seconds=ttl_seconds) if cached is not None: return cached models = list(await provider.get_models()) if use_cache and ttl_seconds > 0: - self._model_cache.set_models(provider_id, models) + self._set_cached_models(provider_id, models) return models def _log_reachability_failure( @@ -223,12 +222,7 @@ async def _find_provider_for_model( return None, None failed_provider_errors: list[tuple[str, str]] = [] - raw_concurrency = self._get_numeric_provider_setting( - umo, - MODEL_LOOKUP_MAX_CONCURRENCY_KEY, - MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, - int, - ) + raw_concurrency = self._get_lookup_concurrency(umo) max_concurrency = min( max(int(raw_concurrency), 1), MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND, @@ -488,11 +482,9 @@ async def _switch_model_by_name( matched_model_name = self._resolve_model_name(model_name, models) if matched_model_name is not None: - prov.set_model(matched_model_name) - self.invalidate_provider_models_cache(prov.meta().id) message.set_result( MessageEventResult().message( - f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{matched_model_name}]", + self._apply_model(prov, matched_model_name) ), ) return @@ -590,11 +582,9 @@ async def model_ls( else: try: new_model = models[idx_or_name - 1] - prov.set_model(new_model) - self.invalidate_provider_models_cache(prov.meta().id) message.set_result( MessageEventResult().message( - f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]", + self._apply_model(prov, new_model) ), ) except Exception as e: From 086a11ed78520991096aa0f00811dc0e55d48320 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:40:10 +0900 Subject: [PATCH 22/28] refactor: scope provider model cache by session --- .../builtin_commands/commands/provider.py | 134 +++++++++++------- 1 file changed, 86 insertions(+), 48 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index a898217de2..00058b47aa 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -2,8 +2,8 @@ import asyncio import time -from collections.abc import Sequence -from typing import TYPE_CHECKING +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, TypeVar from astrbot import logger from astrbot.api import star @@ -21,11 +21,15 @@ MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds" MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency" +T = TypeVar("T") + class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context - self._model_cache: dict[str, tuple[float, list[str]]] = {} + self._model_cache: dict[ + tuple[str, str | None], tuple[float, tuple[str, ...]] + ] = {} self._register_provider_change_hook() def _register_provider_change_hook(self) -> None: @@ -45,12 +49,21 @@ def _register_provider_change_hook(self) -> None: if callable(register_change_hook): register_change_hook(self._on_provider_manager_changed) - def invalidate_provider_models_cache(self, provider_id: str | None = None) -> None: + def invalidate_provider_models_cache( + self, provider_id: str | None = None, *, umo: str | None = None + ) -> None: """Public hook for cache invalidation on external provider config changes.""" if provider_id is None: self._model_cache.clear() return - self._model_cache.pop(provider_id, None) + if umo is not None: + self._model_cache.pop((provider_id, umo), None) + return + stale_keys = [ + cache_key for cache_key in self._model_cache if cache_key[0] == provider_id + ] + for cache_key in stale_keys: + self._model_cache.pop(cache_key, None) def _on_provider_manager_changed( self, @@ -59,60 +72,73 @@ def _on_provider_manager_changed( umo: str | None, ) -> None: if provider_type == ProviderType.CHAT_COMPLETION: - self.invalidate_provider_models_cache(provider_id) + self.invalidate_provider_models_cache(provider_id, umo=umo) + + @staticmethod + def _cache_key(provider_id: str, umo: str | None) -> tuple[str, str | None]: + return provider_id, umo def _get_cached_models( - self, provider_id: str, *, ttl_seconds: float + self, provider_id: str, *, ttl_seconds: float, umo: str | None ) -> list[str] | None: if ttl_seconds <= 0: return None - entry = self._model_cache.get(provider_id) + entry = self._model_cache.get(self._cache_key(provider_id, umo)) if not entry: return None timestamp, models = entry if time.monotonic() - timestamp > ttl_seconds: - self._model_cache.pop(provider_id, None) + self._model_cache.pop(self._cache_key(provider_id, umo), None) return None return list(models) - def _set_cached_models(self, provider_id: str, models: list[str]) -> None: - self._model_cache[provider_id] = (time.monotonic(), list(models)) + def _set_cached_models( + self, provider_id: str, models: list[str], *, umo: str | None + ) -> None: + self._model_cache[self._cache_key(provider_id, umo)] = ( + time.monotonic(), + tuple(models), + ) - def _get_ttl_setting(self, umo: str | None) -> float: + def _get_provider_setting( + self, + umo: str | None, + key: str, + default: T, + cast: Callable[[object], T], + ) -> T: if not umo: - return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT + return default try: cfg = self.context.get_config(umo).get("provider_settings", {}) - raw = cfg.get(MODEL_LIST_CACHE_TTL_KEY) + raw = cfg.get(key) if raw is None: - return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT - return float(raw) + return default + return cast(raw) except Exception as e: logger.debug( "读取 %s 失败,回退默认值 %r: %s", - MODEL_LIST_CACHE_TTL_KEY, - MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, + key, + default, safe_error("", e), ) - return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT + return default + + def _get_ttl_setting(self, umo: str | None) -> float: + return self._get_provider_setting( + umo, + MODEL_LIST_CACHE_TTL_KEY, + MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, + float, + ) def _get_lookup_concurrency(self, umo: str | None) -> int: - if not umo: - return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT - try: - cfg = self.context.get_config(umo).get("provider_settings", {}) - raw = cfg.get(MODEL_LOOKUP_MAX_CONCURRENCY_KEY) - if raw is None: - return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT - return int(raw) - except Exception as e: - logger.debug( - "读取 %s 失败,回退默认值 %r: %s", - MODEL_LOOKUP_MAX_CONCURRENCY_KEY, - MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, - safe_error("", e), - ) - return MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT + return self._get_provider_setting( + umo, + MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, + int, + ) def _resolve_model_name( self, @@ -135,18 +161,20 @@ def _resolve_model_name( # provider-qualified suffix match: # e.g. candidate `openai/gpt-4o` should match requested `gpt-4o`. - def _match_qualified_suffix(req: str, cand: str) -> bool: - return cand.endswith(f"/{req}") or cand.endswith(f":{req}") - for candidate in models: - if _match_qualified_suffix(requested_norm, candidate.casefold()): + cand_norm = candidate.casefold() + if cand_norm.endswith(f"/{requested_norm}") or cand_norm.endswith( + f":{requested_norm}" + ): return candidate return None - def _apply_model(self, prov: Provider, model_name: str) -> str: + def _apply_model( + self, prov: Provider, model_name: str, *, umo: str | None = None + ) -> str: prov.set_model(model_name) - self.invalidate_provider_models_cache(prov.meta().id) + self.invalidate_provider_models_cache(prov.meta().id, umo=umo) return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]" async def _get_provider_models( @@ -159,13 +187,17 @@ async def _get_provider_models( provider_id = provider.meta().id ttl_seconds = max(float(self._get_ttl_setting(umo)), 0.0) if use_cache: - cached = self._get_cached_models(provider_id, ttl_seconds=ttl_seconds) + cached = self._get_cached_models( + provider_id, + ttl_seconds=ttl_seconds, + umo=umo, + ) if cached is not None: return cached models = list(await provider.get_models()) if use_cache and ttl_seconds > 0: - self._set_cached_models(provider_id, models) + self._set_cached_models(provider_id, models, umo=umo) return models def _log_reachability_failure( @@ -484,7 +516,7 @@ async def _switch_model_by_name( if matched_model_name is not None: message.set_result( MessageEventResult().message( - self._apply_model(prov, matched_model_name) + self._apply_model(prov, matched_model_name, umo=umo) ), ) return @@ -508,8 +540,7 @@ async def _switch_model_by_name( provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) - target_prov.set_model(matched_target_model_name) - self.invalidate_provider_models_cache(target_prov.meta().id) + self._apply_model(target_prov, matched_target_model_name, umo=umo) message.set_result( MessageEventResult().message( f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。", @@ -584,7 +615,11 @@ async def model_ls( new_model = models[idx_or_name - 1] message.set_result( MessageEventResult().message( - self._apply_model(prov, new_model) + self._apply_model( + prov, + new_model, + umo=message.unified_msg_origin, + ) ), ) except Exception as e: @@ -626,7 +661,10 @@ async def key(self, message: AstrMessageEvent, index: int | None = None) -> None try: new_key = keys_data[index - 1] prov.set_key(new_key) - self.invalidate_provider_models_cache(prov.meta().id) + self.invalidate_provider_models_cache( + prov.meta().id, + umo=message.unified_msg_origin, + ) message.set_result(MessageEventResult().message("切换 Key 成功。")) except Exception as e: message.set_result( From 597690e3a849aaecf2120d058bee40db8073a6e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:44:45 +0900 Subject: [PATCH 23/28] fix: preserve redaction context and restore provider hooks --- .../builtin_commands/commands/provider.py | 6 ++- astrbot/core/provider/manager.py | 36 ++++++++------- astrbot/core/utils/error_redaction.py | 46 +++++++++++++------ 3 files changed, 58 insertions(+), 30 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 00058b47aa..f5801c4eb1 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -263,7 +263,11 @@ async def _find_provider_for_model( batch_providers = all_providers[start : start + max_concurrency] batch_results = await asyncio.gather( *[ - self._get_provider_models(provider, umo=umo) + self._get_provider_models( + provider, + use_cache=False, + umo=umo, + ) for provider in batch_providers ], return_exceptions=True, diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 69d6b8489b..73b95dd168 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -73,22 +73,25 @@ def __init__( self.curr_tts_provider_inst: TTSProvider | None = None """默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.db_helper = db_helper - self._on_provider_changed: ( - Callable[[str, ProviderType, str | None], None] | None - ) = None + self._provider_change_hooks: list[ + Callable[[str, ProviderType, str | None], None] + ] = [] def set_provider_change_callback( self, cb: Callable[[str, ProviderType, str | None], None] | None, ) -> None: - self._on_provider_changed = cb + # Backward-compatible single-callback setter. + self._provider_change_hooks.clear() + if cb is not None: + self._provider_change_hooks.append(cb) def register_provider_change_hook( self, hook: Callable[[str, ProviderType, str | None], None], ) -> None: - # Backward-compatible wrapper for older call sites. - self.set_provider_change_callback(hook) + if hook not in self._provider_change_hooks: + self._provider_change_hooks.append(hook) def _notify_provider_changed( self, @@ -96,17 +99,18 @@ def _notify_provider_changed( provider_type: ProviderType, umo: str | None, ) -> None: - if not self._on_provider_changed: + if not self._provider_change_hooks: return - try: - self._on_provider_changed(provider_id, provider_type, umo) - except Exception as e: - logger.warning( - "调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s", - provider_id, - provider_type, - safe_error("", e), - ) + for hook in list(self._provider_change_hooks): + try: + hook(provider_id, provider_type, umo) + except Exception as e: + logger.warning( + "调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s", + provider_id, + provider_type, + safe_error("", e), + ) @property def persona_configs(self) -> list: diff --git a/astrbot/core/utils/error_redaction.py b/astrbot/core/utils/error_redaction.py index 9787ccc215..b8d75094d7 100644 --- a/astrbot/core/utils/error_redaction.py +++ b/astrbot/core/utils/error_redaction.py @@ -1,39 +1,59 @@ import re _SECRET_KEYS = ( - r"(api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)" + r"(?:api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)" ) _JSON_FIELD_PATTERN = re.compile( - rf"(?i)(['\"])({_SECRET_KEYS})\1\s*:\s*(['\"])[^'\"]+\3" + rf"(?i)(?P(?P['\"]){_SECRET_KEYS}(?P=kq)\s*:\s*)(?P['\"])(?P[^'\"]+)(?P=vq)" ) _AUTH_JSON_FIELD_PATTERN = re.compile( - r"(?i)(['\"])authorization\1\s*:\s*(['\"])bearer\s+[^'\"]+\2" + r"(?i)(?P(?P['\"])authorization(?P=kq)\s*:\s*)(?P['\"])bearer\s+[^'\"]+(?P=vq)" +) +_QUERY_FIELD_PATTERN = re.compile( + rf"(?i)(?P{_SECRET_KEYS}\s*=\s*)(?P[^&'\" ]+)" ) -_QUERY_FIELD_PATTERN = re.compile(rf"(?i)\b{_SECRET_KEYS}\s*=\s*[^&'\" ]+") _QUERY_PARAM_PATTERN = re.compile( - r"(?i)([?&](?:api_?key|key|access_?token|auth_?token))=[^&'\" ]+" + r"(?i)(?P[?&](?:api_?key|key|access_?token|auth_?token)=)(?P[^&'\" ]+)" ) _AUTH_HEADER_PATTERN = re.compile( - r"(?i)\bauthorization\s*:\s*bearer\s+[A-Za-z0-9._\-]+" + r"(?i)(?P\bauthorization\s*:\s*bearer\s+)(?P[A-Za-z0-9._\-]+)" ) -_BEARER_PATTERN = re.compile(r"(?i)\bbearer\s+[A-Za-z0-9._\-]+") +_BEARER_PATTERN = re.compile(r"(?i)(?P\bbearer\s+)(?P[A-Za-z0-9._\-]+)") _SK_PATTERN = re.compile(r"\bsk-[A-Za-z0-9]{16,}\b") +def _redact_json_field(match: re.Match[str]) -> str: + quote = match.group("vq") + return f"{match.group('prefix')}{quote}[REDACTED]{quote}" + + +def _redact_auth_json_field(match: re.Match[str]) -> str: + quote = match.group("vq") + return f"{match.group('prefix')}{quote}Bearer [REDACTED]{quote}" + + +def _redact_prefixed_value(match: re.Match[str]) -> str: + return f"{match.group('prefix')}[REDACTED]" + + +def _redact_bearer_token(match: re.Match[str]) -> str: + return f"{match.group('prefix')}[REDACTED]" + + def _redact_json_like(text: str) -> str: - text = _JSON_FIELD_PATTERN.sub("[REDACTED]", text) - return _AUTH_JSON_FIELD_PATTERN.sub("[REDACTED]", text) + text = _JSON_FIELD_PATTERN.sub(_redact_json_field, text) + return _AUTH_JSON_FIELD_PATTERN.sub(_redact_auth_json_field, text) def _redact_query_like(text: str) -> str: - text = _QUERY_FIELD_PATTERN.sub("[REDACTED]", text) - return _QUERY_PARAM_PATTERN.sub("[REDACTED]", text) + text = _QUERY_FIELD_PATTERN.sub(_redact_prefixed_value, text) + return _QUERY_PARAM_PATTERN.sub(_redact_prefixed_value, text) def _redact_tokens(text: str) -> str: - text = _AUTH_HEADER_PATTERN.sub("[REDACTED]", text) - text = _BEARER_PATTERN.sub("[REDACTED]", text) + text = _AUTH_HEADER_PATTERN.sub(_redact_bearer_token, text) + text = _BEARER_PATTERN.sub(_redact_bearer_token, text) return _SK_PATTERN.sub("[REDACTED]", text) From 9c5f31f4538606213ea5ed911288b05cb936aaa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:49:28 +0900 Subject: [PATCH 24/28] refactor: unify provider model lookup config flow --- .../builtin_commands/commands/provider.py | 128 +++++++++++------- 1 file changed, 79 insertions(+), 49 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index f5801c4eb1..80bef06611 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -3,6 +3,7 @@ import asyncio import time from collections.abc import Callable, Sequence +from dataclasses import dataclass from typing import TYPE_CHECKING, TypeVar from astrbot import logger @@ -24,6 +25,13 @@ T = TypeVar("T") +@dataclass(frozen=True) +class _ModelLookupConfig: + umo: str | None + cache_ttl_seconds: float + max_concurrency: int + + class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context @@ -124,20 +132,26 @@ def _get_provider_setting( ) return default - def _get_ttl_setting(self, umo: str | None) -> float: - return self._get_provider_setting( - umo, - MODEL_LIST_CACHE_TTL_KEY, - MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, - float, + def _get_model_lookup_config(self, umo: str | None) -> _ModelLookupConfig: + ttl = self._get_provider_setting( + umo=umo, + key=MODEL_LIST_CACHE_TTL_KEY, + default=MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, + cast=float, ) - - def _get_lookup_concurrency(self, umo: str | None) -> int: - return self._get_provider_setting( - umo, - MODEL_LOOKUP_MAX_CONCURRENCY_KEY, - MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, - int, + raw_concurrency = self._get_provider_setting( + umo=umo, + key=MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + default=MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, + cast=int, + ) + return _ModelLookupConfig( + umo=umo, + cache_ttl_seconds=max(float(ttl), 0.0), + max_concurrency=min( + max(int(raw_concurrency), 1), + MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND, + ), ) def _resolve_model_name( @@ -181,11 +195,12 @@ async def _get_provider_models( self, provider: Provider, *, + config: _ModelLookupConfig, use_cache: bool = True, - umo: str | None = None, ) -> list[str]: provider_id = provider.meta().id - ttl_seconds = max(float(self._get_ttl_setting(umo)), 0.0) + ttl_seconds = config.cache_ttl_seconds + umo = config.umo if use_cache: cached = self._get_cached_models( provider_id, @@ -236,8 +251,9 @@ async def _test_provider_capability(self, provider): async def _find_provider_for_model( self, model_name: str, + *, exclude_provider_id: str | None = None, - umo: str | None = None, + config: _ModelLookupConfig, ) -> tuple[Provider | None, str | None]: all_providers = [] for provider in self.context.get_all_providers(): @@ -254,34 +270,48 @@ async def _find_provider_for_model( return None, None failed_provider_errors: list[tuple[str, str]] = [] - raw_concurrency = self._get_lookup_concurrency(umo) - max_concurrency = min( - max(int(raw_concurrency), 1), - MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND, - ) - for start in range(0, len(all_providers), max_concurrency): - batch_providers = all_providers[start : start + max_concurrency] - batch_results = await asyncio.gather( - *[ - self._get_provider_models( + semaphore = asyncio.Semaphore(config.max_concurrency) + + async def fetch_and_match( + provider: Provider, + ) -> tuple[Provider | None, str | None]: + async with semaphore: + try: + models = await self._get_provider_models( provider, + config=config, use_cache=False, - umo=umo, ) - for provider in batch_providers - ], - return_exceptions=True, - ) - for provider, result in zip(batch_providers, batch_results, strict=False): - if isinstance(result, asyncio.CancelledError): - raise result - provider_id = provider.meta().id - if isinstance(result, Exception): - failed_provider_errors.append((provider_id, safe_error("", result))) - continue - matched_model_name = self._resolve_model_name(model_name, result) - if matched_model_name is not None: - return provider, matched_model_name + except asyncio.CancelledError: + raise + except Exception as e: + failed_provider_errors.append( + (provider.meta().id, safe_error("", e)) + ) + return None, None + + matched_model_name = self._resolve_model_name(model_name, models) + if matched_model_name is None: + return None, None + return provider, matched_model_name + + results = await asyncio.gather( + *(fetch_and_match(provider) for provider in all_providers), + return_exceptions=True, + ) + for result in results: + if isinstance(result, asyncio.CancelledError): + raise result + if isinstance(result, Exception): + logger.debug( + "跨提供商查找模型 %s 发生异常: %s", + model_name, + safe_error("", result), + ) + continue + provider, matched_model_name = result + if provider is not None and matched_model_name is not None: + return provider, matched_model_name if failed_provider_errors and len(failed_provider_errors) == len(all_providers): failed_ids = ",".join( @@ -496,10 +526,11 @@ async def _switch_model_by_name( return umo = message.unified_msg_origin + config = self._get_model_lookup_config(umo) curr_provider_id = prov.meta().id try: - models = await self._get_provider_models(prov, umo=umo) + models = await self._get_provider_models(prov, config=config) except asyncio.CancelledError: raise except Exception as e: @@ -526,7 +557,9 @@ async def _switch_model_by_name( return target_prov, matched_target_model_name = await self._find_provider_for_model( - model_name, exclude_provider_id=curr_provider_id, umo=umo + model_name, + exclude_provider_id=curr_provider_id, + config=config, ) if target_prov is None or matched_target_model_name is None: @@ -571,13 +604,12 @@ async def model_ls( MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) return + config = self._get_model_lookup_config(message.unified_msg_origin) if idx_or_name is None: models = [] try: - models = await self._get_provider_models( - prov, umo=message.unified_msg_origin - ) + models = await self._get_provider_models(prov, config=config) except asyncio.CancelledError: raise except Exception as e: @@ -602,9 +634,7 @@ async def model_ls( elif isinstance(idx_or_name, int): models = [] try: - models = await self._get_provider_models( - prov, umo=message.unified_msg_origin - ) + models = await self._get_provider_models(prov, config=config) except asyncio.CancelledError: raise except Exception as e: From 161763e9610143ec681b607842de618977acbd31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:54:06 +0900 Subject: [PATCH 25/28] refactor: inline provider model cache access flow --- .../builtin_commands/commands/provider.py | 78 ++++++------------- 1 file changed, 23 insertions(+), 55 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 80bef06611..ac11b8bcc5 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -82,32 +82,6 @@ def _on_provider_manager_changed( if provider_type == ProviderType.CHAT_COMPLETION: self.invalidate_provider_models_cache(provider_id, umo=umo) - @staticmethod - def _cache_key(provider_id: str, umo: str | None) -> tuple[str, str | None]: - return provider_id, umo - - def _get_cached_models( - self, provider_id: str, *, ttl_seconds: float, umo: str | None - ) -> list[str] | None: - if ttl_seconds <= 0: - return None - entry = self._model_cache.get(self._cache_key(provider_id, umo)) - if not entry: - return None - timestamp, models = entry - if time.monotonic() - timestamp > ttl_seconds: - self._model_cache.pop(self._cache_key(provider_id, umo), None) - return None - return list(models) - - def _set_cached_models( - self, provider_id: str, models: list[str], *, umo: str | None - ) -> None: - self._model_cache[self._cache_key(provider_id, umo)] = ( - time.monotonic(), - tuple(models), - ) - def _get_provider_setting( self, umo: str | None, @@ -201,18 +175,18 @@ async def _get_provider_models( provider_id = provider.meta().id ttl_seconds = config.cache_ttl_seconds umo = config.umo - if use_cache: - cached = self._get_cached_models( - provider_id, - ttl_seconds=ttl_seconds, - umo=umo, - ) - if cached is not None: - return cached + cache_key = (provider_id, umo) + if use_cache and ttl_seconds > 0: + entry = self._model_cache.get(cache_key) + if entry: + timestamp, models = entry + if time.monotonic() - timestamp <= ttl_seconds: + return list(models) + self._model_cache.pop(cache_key, None) models = list(await provider.get_models()) if use_cache and ttl_seconds > 0: - self._set_cached_models(provider_id, models, umo=umo) + self._model_cache[cache_key] = (time.monotonic(), tuple(models)) return models def _log_reachability_failure( @@ -269,15 +243,12 @@ async def _find_provider_for_model( if not all_providers: return None, None - failed_provider_errors: list[tuple[str, str]] = [] semaphore = asyncio.Semaphore(config.max_concurrency) - async def fetch_and_match( - provider: Provider, - ) -> tuple[Provider | None, str | None]: + async def fetch_models(provider: Provider) -> list[str] | Exception: async with semaphore: try: - models = await self._get_provider_models( + return await self._get_provider_models( provider, config=config, use_cache=False, @@ -285,32 +256,29 @@ async def fetch_and_match( except asyncio.CancelledError: raise except Exception as e: - failed_provider_errors.append( - (provider.meta().id, safe_error("", e)) - ) - return None, None - - matched_model_name = self._resolve_model_name(model_name, models) - if matched_model_name is None: - return None, None - return provider, matched_model_name + return e results = await asyncio.gather( - *(fetch_and_match(provider) for provider in all_providers), + *(fetch_models(provider) for provider in all_providers), return_exceptions=True, ) - for result in results: + failed_provider_errors: list[tuple[str, str]] = [] + for provider, result in zip(all_providers, results, strict=False): if isinstance(result, asyncio.CancelledError): raise result if isinstance(result, Exception): + err = safe_error("", result) + failed_provider_errors.append((provider.meta().id, err)) logger.debug( - "跨提供商查找模型 %s 发生异常: %s", + "跨提供商查找模型 %s 获取 %s 模型列表失败: %s", model_name, - safe_error("", result), + provider.meta().id, + err, ) continue - provider, matched_model_name = result - if provider is not None and matched_model_name is not None: + + matched_model_name = self._resolve_model_name(model_name, result) + if matched_model_name is not None: return provider, matched_model_name if failed_provider_errors and len(failed_provider_errors) == len(all_providers): From 7d1f642ccc85eff2203f3595c3380ae628474bbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 18:58:09 +0900 Subject: [PATCH 26/28] fix: align provider lookup cache and callback semantics --- .../builtin_commands/commands/provider.py | 3 ++- astrbot/core/provider/manager.py | 22 ++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index ac11b8bcc5..d713f9df91 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -228,6 +228,7 @@ async def _find_provider_for_model( *, exclude_provider_id: str | None = None, config: _ModelLookupConfig, + use_cache: bool = True, ) -> tuple[Provider | None, str | None]: all_providers = [] for provider in self.context.get_all_providers(): @@ -251,7 +252,7 @@ async def fetch_models(provider: Provider) -> list[str] | Exception: return await self._get_provider_models( provider, config=config, - use_cache=False, + use_cache=use_cache, ) except asyncio.CancelledError: raise diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 73b95dd168..2359a81371 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -73,6 +73,9 @@ def __init__( self.curr_tts_provider_inst: TTSProvider | None = None """默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.db_helper = db_helper + self._provider_change_callback: ( + Callable[[str, ProviderType, str | None], None] | None + ) = None self._provider_change_hooks: list[ Callable[[str, ProviderType, str | None], None] ] = [] @@ -82,9 +85,8 @@ def set_provider_change_callback( cb: Callable[[str, ProviderType, str | None], None] | None, ) -> None: # Backward-compatible single-callback setter. - self._provider_change_hooks.clear() - if cb is not None: - self._provider_change_hooks.append(cb) + # This callback coexists with register_provider_change_hook subscriptions. + self._provider_change_callback = cb def register_provider_change_hook( self, @@ -99,9 +101,19 @@ def _notify_provider_changed( provider_type: ProviderType, umo: str | None, ) -> None: - if not self._provider_change_hooks: - return + if self._provider_change_callback is not None: + try: + self._provider_change_callback(provider_id, provider_type, umo) + except Exception as e: + logger.warning( + "调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s", + provider_id, + provider_type, + safe_error("", e), + ) for hook in list(self._provider_change_hooks): + if hook is self._provider_change_callback: + continue try: hook(provider_id, provider_type, umo) except Exception as e: From a9f1558a0e5aad1e294bfbeb2668f4afd37b17e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:03:35 +0900 Subject: [PATCH 27/28] refactor: centralize provider model fetch error handling --- .../builtin_commands/commands/provider.py | 157 +++++++++++------- astrbot/core/utils/error_redaction.py | 8 +- 2 files changed, 104 insertions(+), 61 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index d713f9df91..e4503c09e2 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -32,12 +32,49 @@ class _ModelLookupConfig: max_concurrency: int +class _ModelCache: + def __init__(self) -> None: + self._store: dict[tuple[str, str | None], tuple[float, tuple[str, ...]]] = {} + + def get(self, provider_id: str, umo: str | None, ttl: float) -> list[str] | None: + if ttl <= 0: + return None + entry = self._store.get((provider_id, umo)) + if not entry: + return None + timestamp, models = entry + if time.monotonic() - timestamp > ttl: + self._store.pop((provider_id, umo), None) + return None + return list(models) + + def set( + self, provider_id: str, umo: str | None, models: list[str], ttl: float + ) -> None: + if ttl <= 0: + return + self._store[(provider_id, umo)] = (time.monotonic(), tuple(models)) + + def invalidate( + self, provider_id: str | None = None, *, umo: str | None = None + ) -> None: + if provider_id is None: + self._store.clear() + return + if umo is not None: + self._store.pop((provider_id, umo), None) + return + stale_keys = [ + cache_key for cache_key in self._store if cache_key[0] == provider_id + ] + for cache_key in stale_keys: + self._store.pop(cache_key, None) + + class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context - self._model_cache: dict[ - tuple[str, str | None], tuple[float, tuple[str, ...]] - ] = {} + self._model_cache = _ModelCache() self._register_provider_change_hook() def _register_provider_change_hook(self) -> None: @@ -61,17 +98,7 @@ def invalidate_provider_models_cache( self, provider_id: str | None = None, *, umo: str | None = None ) -> None: """Public hook for cache invalidation on external provider config changes.""" - if provider_id is None: - self._model_cache.clear() - return - if umo is not None: - self._model_cache.pop((provider_id, umo), None) - return - stale_keys = [ - cache_key for cache_key in self._model_cache if cache_key[0] == provider_id - ] - for cache_key in stale_keys: - self._model_cache.pop(cache_key, None) + self._model_cache.invalidate(provider_id, umo=umo) def _on_provider_manager_changed( self, @@ -175,20 +202,43 @@ async def _get_provider_models( provider_id = provider.meta().id ttl_seconds = config.cache_ttl_seconds umo = config.umo - cache_key = (provider_id, umo) - if use_cache and ttl_seconds > 0: - entry = self._model_cache.get(cache_key) - if entry: - timestamp, models = entry - if time.monotonic() - timestamp <= ttl_seconds: - return list(models) - self._model_cache.pop(cache_key, None) + if use_cache: + cached = self._model_cache.get(provider_id, umo, ttl_seconds) + if cached is not None: + return cached models = list(await provider.get_models()) - if use_cache and ttl_seconds > 0: - self._model_cache[cache_key] = (time.monotonic(), tuple(models)) + if use_cache: + self._model_cache.set(provider_id, umo, models, ttl_seconds) return models + async def _get_models_or_reply_error( + self, + message: AstrMessageEvent, + prov: Provider, + config: _ModelLookupConfig, + *, + error_prefix: str, + disable_t2i: bool = False, + warning_log: str | None = None, + ) -> list[str] | None: + try: + return await self._get_provider_models(prov, config=config) + except asyncio.CancelledError: + raise + except Exception as e: + if warning_log is not None: + logger.warning( + warning_log, + prov.meta().id, + safe_error("", e), + ) + result = MessageEventResult().message(safe_error(error_prefix, e)) + if disable_t2i: + result = result.use_t2i(False) + message.set_result(result) + return None + def _log_reachability_failure( self, provider, @@ -498,22 +548,14 @@ async def _switch_model_by_name( config = self._get_model_lookup_config(umo) curr_provider_id = prov.meta().id - try: - models = await self._get_provider_models(prov, config=config) - except asyncio.CancelledError: - raise - except Exception as e: - err_msg = safe_error("", e) - logger.warning( - "获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", - curr_provider_id, - err_msg, - ) - message.set_result( - MessageEventResult().message( - safe_error("获取当前提供商模型列表失败: ", e) - ) - ) + models = await self._get_models_or_reply_error( + message, + prov, + config, + error_prefix="获取当前提供商模型列表失败: ", + warning_log="获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s", + ) + if models is None: return matched_model_name = self._resolve_model_name(model_name, models) @@ -576,17 +618,14 @@ async def model_ls( config = self._get_model_lookup_config(message.unified_msg_origin) if idx_or_name is None: - models = [] - try: - models = await self._get_provider_models(prov, config=config) - except asyncio.CancelledError: - raise - except Exception as e: - message.set_result( - MessageEventResult() - .message(safe_error("获取模型列表失败: ", e)) - .use_t2i(False), - ) + models = await self._get_models_or_reply_error( + message, + prov, + config, + error_prefix="获取模型列表失败: ", + disable_t2i=True, + ) + if models is None: return parts = ["下面列出了此模型提供商可用模型:"] for i, model in enumerate(models, 1): @@ -601,15 +640,13 @@ async def model_ls( ret = "".join(parts) message.set_result(MessageEventResult().message(ret).use_t2i(False)) elif isinstance(idx_or_name, int): - models = [] - try: - models = await self._get_provider_models(prov, config=config) - except asyncio.CancelledError: - raise - except Exception as e: - message.set_result( - MessageEventResult().message(safe_error("获取模型列表失败: ", e)), - ) + models = await self._get_models_or_reply_error( + message, + prov, + config, + error_prefix="获取模型列表失败: ", + ) + if models is None: return if idx_or_name > len(models) or idx_or_name < 1: message.set_result(MessageEventResult().message("模型序号错误。")) diff --git a/astrbot/core/utils/error_redaction.py b/astrbot/core/utils/error_redaction.py index b8d75094d7..dcab07ac58 100644 --- a/astrbot/core/utils/error_redaction.py +++ b/astrbot/core/utils/error_redaction.py @@ -70,7 +70,13 @@ def safe_error( *, redact: bool = True, ) -> str: - text = str(error) + try: + text = str(error) + except Exception: + try: + text = repr(error) + except Exception: + text = "" if redact: text = redact_sensitive_text(text) return prefix + text From 343d277b207256e13154cc09973a3834c08eecf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Sun, 1 Mar 2026 19:08:24 +0900 Subject: [PATCH 28/28] refactor: simplify provider model cache and lookup flow --- .../builtin_commands/commands/provider.py | 135 ++++++++++-------- 1 file changed, 78 insertions(+), 57 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index e4503c09e2..b5ee75ca24 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -2,9 +2,9 @@ import asyncio import time -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING from astrbot import logger from astrbot.api import star @@ -21,8 +21,7 @@ MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND = 16 MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds" MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency" - -T = TypeVar("T") +MODEL_CACHE_MAX_ENTRIES = 512 @dataclass(frozen=True) @@ -34,7 +33,7 @@ class _ModelLookupConfig: class _ModelCache: def __init__(self) -> None: - self._store: dict[tuple[str, str | None], tuple[float, tuple[str, ...]]] = {} + self._store: dict[tuple[str, str | None], tuple[float, list[str]]] = {} def get(self, provider_id: str, umo: str | None, ttl: float) -> list[str] | None: if ttl <= 0: @@ -46,14 +45,26 @@ def get(self, provider_id: str, umo: str | None, ttl: float) -> list[str] | None if time.monotonic() - timestamp > ttl: self._store.pop((provider_id, umo), None) return None - return list(models) + return models def set( self, provider_id: str, umo: str | None, models: list[str], ttl: float ) -> None: if ttl <= 0: return - self._store[(provider_id, umo)] = (time.monotonic(), tuple(models)) + self._store[(provider_id, umo)] = (time.monotonic(), list(models)) + self._evict_if_needed() + + def _evict_if_needed(self) -> None: + if len(self._store) <= MODEL_CACHE_MAX_ENTRIES: + return + # Drop oldest entries first when cache grows too large. + overflow = len(self._store) - MODEL_CACHE_MAX_ENTRIES + for key, _ in sorted( + self._store.items(), + key=lambda item: item[1][0], + )[:overflow]: + self._store.pop(key, None) def invalidate( self, provider_id: str | None = None, *, umo: str | None = None @@ -109,50 +120,58 @@ def _on_provider_manager_changed( if provider_type == ProviderType.CHAT_COMPLETION: self.invalidate_provider_models_cache(provider_id, umo=umo) - def _get_provider_setting( - self, - umo: str | None, - key: str, - default: T, - cast: Callable[[object], T], - ) -> T: + def _get_provider_settings(self, umo: str | None) -> dict: if not umo: - return default + return {} try: - cfg = self.context.get_config(umo).get("provider_settings", {}) - raw = cfg.get(key) - if raw is None: - return default - return cast(raw) + return self.context.get_config(umo).get("provider_settings", {}) or {} except Exception as e: logger.debug( - "读取 %s 失败,回退默认值 %r: %s", - key, - default, + "读取 provider_settings 失败,使用默认值: %s", safe_error("", e), ) - return default + return {} - def _get_model_lookup_config(self, umo: str | None) -> _ModelLookupConfig: - ttl = self._get_provider_setting( - umo=umo, - key=MODEL_LIST_CACHE_TTL_KEY, - default=MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, - cast=float, + def _get_model_cache_ttl(self, umo: str | None) -> float: + settings = self._get_provider_settings(umo) + raw = settings.get( + MODEL_LIST_CACHE_TTL_KEY, + MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, ) - raw_concurrency = self._get_provider_setting( - umo=umo, - key=MODEL_LOOKUP_MAX_CONCURRENCY_KEY, - default=MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, - cast=int, + try: + return max(float(raw), 0.0) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %r: %s", + MODEL_LIST_CACHE_TTL_KEY, + MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT, + safe_error("", e), + ) + return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT + + def _get_model_lookup_concurrency(self, umo: str | None) -> int: + settings = self._get_provider_settings(umo) + raw = settings.get( + MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, ) + try: + value = int(raw) + except Exception as e: + logger.debug( + "读取 %s 失败,回退默认值 %r: %s", + MODEL_LOOKUP_MAX_CONCURRENCY_KEY, + MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT, + safe_error("", e), + ) + value = MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT + return min(max(value, 1), MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND) + + def _get_model_lookup_config(self, umo: str | None) -> _ModelLookupConfig: return _ModelLookupConfig( umo=umo, - cache_ttl_seconds=max(float(ttl), 0.0), - max_concurrency=min( - max(int(raw_concurrency), 1), - MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND, - ), + cache_ttl_seconds=self._get_model_cache_ttl(umo), + max_concurrency=self._get_model_lookup_concurrency(umo), ) def _resolve_model_name( @@ -296,39 +315,41 @@ async def _find_provider_for_model( semaphore = asyncio.Semaphore(config.max_concurrency) - async def fetch_models(provider: Provider) -> list[str] | Exception: + async def fetch_models( + provider: Provider, + ) -> tuple[Provider, list[str] | None, str | None]: async with semaphore: try: - return await self._get_provider_models( + models = await self._get_provider_models( provider, config=config, use_cache=use_cache, ) + return provider, models, None except asyncio.CancelledError: raise except Exception as e: - return e + err = safe_error("", e) + logger.debug( + "跨提供商查找模型 %s 获取 %s 模型列表失败: %s", + model_name, + provider.meta().id, + err, + ) + return provider, None, err results = await asyncio.gather( - *(fetch_models(provider) for provider in all_providers), - return_exceptions=True, + *(fetch_models(provider) for provider in all_providers) ) failed_provider_errors: list[tuple[str, str]] = [] - for provider, result in zip(all_providers, results, strict=False): - if isinstance(result, asyncio.CancelledError): - raise result - if isinstance(result, Exception): - err = safe_error("", result) + for provider, models, err in results: + if err is not None: failed_provider_errors.append((provider.meta().id, err)) - logger.debug( - "跨提供商查找模型 %s 获取 %s 模型列表失败: %s", - model_name, - provider.meta().id, - err, - ) + continue + if models is None: continue - matched_model_name = self._resolve_model_name(model_name, result) + matched_model_name = self._resolve_model_name(model_name, models) if matched_model_name is not None: return provider, matched_model_name