diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index a942c56e4a..cd5e4765a4 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -9,6 +9,7 @@ from typing import Literal, cast from urllib.parse import urlparse +import httpx from google import genai from google.genai import types from google.genai.errors import APIError @@ -82,13 +83,21 @@ def __init__( def _init_client(self) -> None: """初始化Gemini客户端""" proxy = self.provider_config.get("proxy", "") + client_kwargs = { + "timeout": self.timeout, + "trust_env": True, + } + if proxy: + client_kwargs["proxy"] = proxy http_options = types.HttpOptions( base_url=self.api_base, timeout=self.timeout * 1000, # 毫秒 ) + # issue #7564: Force google-genai to use httpx; its aiohttp error path can mask API errors. + self._httpx_async_client = httpx.AsyncClient(**client_kwargs) + http_options.httpx_async_client = self._httpx_async_client if proxy: - http_options.async_client_args = {"proxy": proxy} - logger.info(f"[Gemini] 使用代理: {proxy}") + logger.info("[Gemini] 使用代理") self.client = genai.Client( api_key=self.chosen_api_key, http_options=http_options, @@ -117,12 +126,12 @@ async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool: if len(keys) > 0: self.set_key(random.choice(keys)) logger.info( - f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}...", + f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试...", ) await asyncio.sleep(1) return True logger.error( - f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...", + f"检测到 Key 异常({e.message}),且已没有可用的 Key。", ) raise Exception("达到了 Gemini 速率限制, 请稍后再试...") @@ -1070,3 +1079,5 @@ async def encode_image_bs64(self, image_url: str) -> str: async def terminate(self) -> None: if self.client: await self.client.aclose() + if self._httpx_async_client: + await self._httpx_async_client.aclose() diff --git a/tests/test_gemini_source.py b/tests/test_gemini_source.py index 4db8e92bfe..d13fab3fb5 100644 --- a/tests/test_gemini_source.py +++ b/tests/test_gemini_source.py @@ -1,10 +1,114 @@ +from types import SimpleNamespace + import pytest +import astrbot.core.provider.sources.gemini_source as gemini_source_module from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI +def _make_provider_config(overrides: dict | None = None) -> dict: + config = { + "id": "test-gemini", + "type": "googlegenai_chat_completion", + "model": "gemini-2.5-pro", + "key": ["test-key"], + "timeout": 180, + "gm_safety_settings": {}, + } + if overrides: + config.update(overrides) + return config + + +class _FakeGeminiClient: + def __init__(self): + self.closed = False + + async def aclose(self): + self.closed = True + + +def test_gemini_client_forces_httpx_client_and_keeps_env_proxy(monkeypatch): + captured: dict[str, object] = {} + httpx_client = _FakeGeminiClient() + + def fake_httpx_client(**kwargs): + captured["httpx_client_kwargs"] = kwargs + return httpx_client + + def fake_client(api_key, http_options): + captured["api_key"] = api_key + captured["http_options"] = http_options + return SimpleNamespace(aio=SimpleNamespace()) + + monkeypatch.setenv("HTTPS_PROXY", "http://global-proxy.example:8080") + monkeypatch.setattr(gemini_source_module.httpx, "AsyncClient", fake_httpx_client) + monkeypatch.setattr(gemini_source_module.genai, "Client", fake_client) + + ProviderGoogleGenAI(_make_provider_config(), {}) + + http_options = captured["http_options"] + assert captured["api_key"] == "test-key" + assert captured["httpx_client_kwargs"] == {"timeout": 180, "trust_env": True} + assert http_options.httpx_async_client is httpx_client + + +def test_gemini_client_passes_proxy_to_httpx_client_without_logging_it(monkeypatch): + captured: dict[str, object] = {} + httpx_client = _FakeGeminiClient() + proxy = "socks5://user:secret@127.0.0.1:1080" + + def fake_httpx_client(**kwargs): + captured["httpx_client_kwargs"] = kwargs + return httpx_client + + def fake_client(api_key, http_options): + captured["http_options"] = http_options + return SimpleNamespace(aio=SimpleNamespace()) + + def fake_log(message): + captured["log_message"] = message + + monkeypatch.setattr(gemini_source_module.httpx, "AsyncClient", fake_httpx_client) + monkeypatch.setattr(gemini_source_module.genai, "Client", fake_client) + monkeypatch.setattr(gemini_source_module.logger, "info", fake_log) + + ProviderGoogleGenAI(_make_provider_config({"proxy": proxy}), {}) + + http_options = captured["http_options"] + assert captured["httpx_client_kwargs"] == { + "timeout": 180, + "trust_env": True, + "proxy": proxy, + } + assert http_options.httpx_async_client is httpx_client + assert "secret" not in captured["log_message"] + assert proxy not in captured["log_message"] + + +@pytest.mark.asyncio +async def test_gemini_api_key_error_log_does_not_include_key(monkeypatch): + captured: dict[str, str] = {} + api_key = "sensitive-api-key-value" + + def fake_log(message): + captured["message"] = message + + monkeypatch.setattr(gemini_source_module.logger, "error", fake_log) + + provider = ProviderGoogleGenAI.__new__(ProviderGoogleGenAI) + provider.chosen_api_key = api_key + error = SimpleNamespace(code=429, message="quota exceeded") + + with pytest.raises(Exception, match="Gemini"): + await provider._handle_api_error(error, [api_key]) + + assert api_key not in captured["message"] + assert api_key[:12] not in captured["message"] + + def test_gemini_empty_output_raises_empty_model_output_error(): llm_response = LLMResponse(role="assistant")