Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Literal, cast
from urllib.parse import urlparse

import httpx
from google import genai
from google.genai import types
from google.genai.errors import APIError
Expand Down Expand Up @@ -82,13 +83,21 @@ def __init__(
def _init_client(self) -> None:
"""初始化Gemini客户端"""
proxy = self.provider_config.get("proxy", "")
client_kwargs = {
"timeout": self.timeout,
"trust_env": True,
}
if proxy:
client_kwargs["proxy"] = proxy
http_options = types.HttpOptions(
base_url=self.api_base,
timeout=self.timeout * 1000, # 毫秒
)
# issue #7564: Force google-genai to use httpx; its aiohttp error path can mask API errors.
self._httpx_async_client = httpx.AsyncClient(**client_kwargs)
http_options.httpx_async_client = self._httpx_async_client
if proxy:
http_options.async_client_args = {"proxy": proxy}
logger.info(f"[Gemini] 使用代理: {proxy}")
logger.info("[Gemini] 使用代理")
self.client = genai.Client(
api_key=self.chosen_api_key,
http_options=http_options,
Expand Down Expand Up @@ -117,12 +126,12 @@ async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
if len(keys) > 0:
self.set_key(random.choice(keys))
logger.info(
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}...",
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试...",
)
await asyncio.sleep(1)
return True
logger.error(
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...",
f"检测到 Key 异常({e.message}),且已没有可用的 Key。",
)
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")

Expand Down Expand Up @@ -1070,3 +1079,5 @@ async def encode_image_bs64(self, image_url: str) -> str:
async def terminate(self) -> None:
if self.client:
await self.client.aclose()
if self._httpx_async_client:
await self._httpx_async_client.aclose()
104 changes: 104 additions & 0 deletions tests/test_gemini_source.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,114 @@
from types import SimpleNamespace

import pytest

import astrbot.core.provider.sources.gemini_source as gemini_source_module
from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI


def _make_provider_config(overrides: dict | None = None) -> dict:
config = {
"id": "test-gemini",
"type": "googlegenai_chat_completion",
"model": "gemini-2.5-pro",
"key": ["test-key"],
"timeout": 180,
"gm_safety_settings": {},
}
if overrides:
config.update(overrides)
return config


class _FakeGeminiClient:
def __init__(self):
self.closed = False

async def aclose(self):
self.closed = True


def test_gemini_client_forces_httpx_client_and_keeps_env_proxy(monkeypatch):
captured: dict[str, object] = {}
httpx_client = _FakeGeminiClient()

def fake_httpx_client(**kwargs):
captured["httpx_client_kwargs"] = kwargs
return httpx_client

def fake_client(api_key, http_options):
captured["api_key"] = api_key
captured["http_options"] = http_options
return SimpleNamespace(aio=SimpleNamespace())

monkeypatch.setenv("HTTPS_PROXY", "http://global-proxy.example:8080")
monkeypatch.setattr(gemini_source_module.httpx, "AsyncClient", fake_httpx_client)
monkeypatch.setattr(gemini_source_module.genai, "Client", fake_client)

ProviderGoogleGenAI(_make_provider_config(), {})

http_options = captured["http_options"]
assert captured["api_key"] == "test-key"
assert captured["httpx_client_kwargs"] == {"timeout": 180, "trust_env": True}
assert http_options.httpx_async_client is httpx_client


def test_gemini_client_passes_proxy_to_httpx_client_without_logging_it(monkeypatch):
captured: dict[str, object] = {}
httpx_client = _FakeGeminiClient()
proxy = "socks5://user:secret@127.0.0.1:1080"

def fake_httpx_client(**kwargs):
captured["httpx_client_kwargs"] = kwargs
return httpx_client

def fake_client(api_key, http_options):
captured["http_options"] = http_options
return SimpleNamespace(aio=SimpleNamespace())

def fake_log(message):
captured["log_message"] = message

monkeypatch.setattr(gemini_source_module.httpx, "AsyncClient", fake_httpx_client)
monkeypatch.setattr(gemini_source_module.genai, "Client", fake_client)
monkeypatch.setattr(gemini_source_module.logger, "info", fake_log)

ProviderGoogleGenAI(_make_provider_config({"proxy": proxy}), {})

http_options = captured["http_options"]
assert captured["httpx_client_kwargs"] == {
"timeout": 180,
"trust_env": True,
"proxy": proxy,
}
assert http_options.httpx_async_client is httpx_client
assert "secret" not in captured["log_message"]
assert proxy not in captured["log_message"]


@pytest.mark.asyncio
async def test_gemini_api_key_error_log_does_not_include_key(monkeypatch):
captured: dict[str, str] = {}
api_key = "sensitive-api-key-value"

def fake_log(message):
captured["message"] = message

monkeypatch.setattr(gemini_source_module.logger, "error", fake_log)

provider = ProviderGoogleGenAI.__new__(ProviderGoogleGenAI)
provider.chosen_api_key = api_key
error = SimpleNamespace(code=429, message="quota exceeded")

with pytest.raises(Exception, match="Gemini"):
await provider._handle_api_error(error, [api_key])

assert api_key not in captured["message"]
assert api_key[:12] not in captured["message"]


def test_gemini_empty_output_raises_empty_model_output_error():
llm_response = LLMResponse(role="assistant")

Expand Down
Loading