Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions astrbot/builtin_stars/builtin_commands/commands/provider.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
import re
from typing import TYPE_CHECKING

from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType

if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider


class ProviderCommands:
def __init__(self, context: star.Context) -> None:
Expand Down Expand Up @@ -44,6 +48,36 @@ async def _test_provider_capability(self, provider):
)
return False, err_code, err_reason

async def _find_provider_for_model(
self, model_name: str, exclude_provider_id: str | None = None
) -> tuple["Provider" | None, str | None]:
"""在所有 LLM 提供商中查找包含指定模型的提供商。返回 (provider, provider_id) 或 (None, None)。"""
all_providers = self.context.get_all_providers()
results = await asyncio.gather(
*[p.get_models() for p in all_providers],
return_exceptions=True,
)
for provider, result in zip(all_providers, results):
if isinstance(result, BaseException):
logger.warning(
"跨提供商查找模型时,提供商 %s 的 get_models() 失败: %s",
provider.meta().id,
result,
)
continue
provider_id = provider.meta().id
if exclude_provider_id and provider_id == exclude_provider_id:
continue
if model_name in result:
return provider, provider_id
if results and all(isinstance(r, BaseException) for r in results):
logger.error(
"跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败,请检查配置或网络",
model_name,
len(all_providers),
)
return None, None

async def provider(
self,
event: AstrMessageEvent,
Expand Down Expand Up @@ -258,7 +292,7 @@ async def model_ls(
curr_model = prov.get_model() or "无"
parts.append(f"\n当前模型: [{curr_model}]")
parts.append(
"\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
"\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换。"
)

ret = "".join(parts)
Expand All @@ -268,8 +302,9 @@ async def model_ls(
try:
models = await prov.get_models()
except BaseException as e:
err_msg = api_key_pattern.sub("key=***", str(e))
message.set_result(
MessageEventResult().message("获取模型列表失败: " + str(e)),
MessageEventResult().message("获取模型列表失败: " + err_msg),
)
return
if idx_or_name > len(models) or idx_or_name < 1:
Expand All @@ -278,20 +313,60 @@ async def model_ls(
try:
new_model = models[idx_or_name - 1]
prov.set_model(new_model)
message.set_result(
MessageEventResult().message(
f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]",
),
)
except BaseException as e:
err_msg = api_key_pattern.sub("key=***", str(e))
message.set_result(
MessageEventResult().message("切换模型未知错误: " + str(e)),
MessageEventResult().message("切换模型未知错误: " + err_msg),
)
return
else:
# 字符串:模型名,需智能解析是否跨提供商
model_name = idx_or_name
umo = message.unified_msg_origin
curr_provider_id = prov.meta().id

# 1. 检查当前提供商
models = []
try:
models = await prov.get_models()
except BaseException:
models = []
if model_name in models:
prov.set_model(model_name)
message.set_result(
MessageEventResult().message(
f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]",
f"切换模型成功。当前提供商: [{curr_provider_id}] 当前模型: [{model_name}]",
),
)
else:
prov.set_model(idx_or_name)
message.set_result(
MessageEventResult().message(f"切换模型到 {prov.get_model()}。"),
return

# 2. 在其他提供商中查找
target_prov, target_id = await self._find_provider_for_model(
model_name, exclude_provider_id=curr_provider_id
)
if target_prov and target_id:
await self.context.provider_manager.set_provider(
provider_id=target_id,
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
target_prov.set_model(model_name)
message.set_result(
MessageEventResult().message(
f"检测到模型 [{model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
),
)
else:
message.set_result(
MessageEventResult().message(
f"模型 [{model_name}] 未在任何已配置的提供商中找到。请使用 /provider 切换到目标提供商,或确认模型名正确。",
),
)

async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
prov = self.context.get_using_provider(message.unified_msg_origin)
Expand Down