diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 0363d4692..1a0810dd9 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -8,7 +8,7 @@ from __future__ import annotations from importlib import import_module -from typing import Any +from typing import TYPE_CHECKING, Any from astrbot.core.message.message_event_result import ( EventResultType, @@ -56,6 +56,18 @@ ), } +# Type-checking imports to satisfy static analyzers for __all__ exports +if TYPE_CHECKING: + from .content_safety_check.stage import ContentSafetyCheckStage + from .preprocess_stage.stage import PreProcessStage + from .process_stage.stage import ProcessStage + from .rate_limit_check.stage import RateLimitStage + from .respond.stage import RespondStage + from .result_decorate.stage import ResultDecorateStage + from .session_status_check.stage import SessionStatusCheckStage + from .waking_check.stage import WakingCheckStage + from .whitelist_check.stage import WhitelistCheckStage + __all__ = [ "ContentSafetyCheckStage", "EventResultType", diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index 963f4bdac..47cd33b23 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -1,19 +1,22 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING from astrbot.core.config import AstrBotConfig from .context_utils import call_event_hook, call_handler +if TYPE_CHECKING: + from astrbot.core.star import PluginManager + @dataclass class PipelineContext: """上下文对象,包含管道执行所需的上下文信息""" astrbot_config: AstrBotConfig # AstrBot 配置对象 - plugin_manager: Any # 插件管理器对象 + plugin_manager: PluginManager # 插件管理器对象 astrbot_config_id: str call_handler = call_handler call_event_hook = call_event_hook diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index ef8c60e5f..27fa7756b 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -47,8 +47,6 @@ if TYPE_CHECKING: from astrbot.core.cron.manager import CronJobManager -else: - CronJobManager = Any class PlatformManagerProtocol(Protocol): diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index 1385b5056..735bd3852 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -2,7 +2,7 @@ import re from collections.abc import AsyncGenerator, Awaitable, Callable -from typing import Any +from typing import TYPE_CHECKING, Any import docstring_parser @@ -15,6 +15,9 @@ from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools +if TYPE_CHECKING: + from astrbot.core.astr_agent_context import AstrAgentContext + from ..filter.command import CommandFilter from ..filter.command_group import CommandGroupFilter from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr @@ -616,7 +619,7 @@ def llm_tool(self, *args, **kwargs): kwargs["registering_agent"] = self return register_llm_tool(*args, **kwargs) - def __init__(self, agent: Agent[Any]) -> None: + def __init__(self, agent: Agent[AstrAgentContext]) -> None: self._agent = agent @@ -624,7 +627,7 @@ def register_agent( name: str, instruction: str, tools: list[str | FunctionTool] | None = None, - run_hooks: BaseAgentRunHooks[Any] | None = None, + run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None, ): """注册一个 Agent @@ -638,12 +641,12 @@ def register_agent( tools_ = tools or [] def decorator(awaitable: Callable[..., Awaitable[Any]]): - AstrAgent = Agent[Any] + AstrAgent = Agent[AstrAgentContext] agent = AstrAgent( name=name, instructions=instruction, tools=tools_, - run_hooks=run_hooks or BaseAgentRunHooks[Any](), + run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](), ) handoff_tool = HandoffTool(agent=agent) handoff_tool.handler = awaitable diff --git a/tests/fixtures/helpers.py b/tests/fixtures/helpers.py index 8f64ab6c9..26edb761c 100644 --- a/tests/fixtures/helpers.py +++ b/tests/fixtures/helpers.py @@ -3,7 +3,10 @@ 提供统一的测试辅助工具,减少测试代码重复。 """ -from typing import Any +import shutil +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable from unittest.mock import AsyncMock, MagicMock from astrbot.core.message.components import BaseMessageComponent @@ -330,3 +333,255 @@ def create_mock_llm_response( tools_call_ids=tools_call_ids or [], usage=TokenUsage(input_other=10, output=5), ) + + +# ============================================================ +# 测试插件辅助函数 +# ============================================================ + + +@dataclass +class MockPluginConfig: + """测试插件配置。 + + 用于创建和管理测试用的模拟插件。 + + Attributes: + name: 插件名称 + author: 作者 + description: 描述 + version: 版本 + repo: 仓库 URL + main_code: main.py 的代码内容 + requirements: 依赖列表 + has_readme: 是否创建 README.md + readme_content: README.md 内容 + """ + + name: str = "test_plugin" + author: str = "Test Author" + description: str = "A test plugin for unit testing" + version: str = "1.0.0" + repo: str = "https://github.com/test/test_plugin" + main_code: str = "" + requirements: list[str] = field(default_factory=list) + has_readme: bool = True + readme_content: str = "# Test Plugin\n\nThis is a test plugin." + + +# 默认的插件主代码模板 +DEFAULT_PLUGIN_MAIN_TEMPLATE = ''' +from astrbot.api import star + +class Main(star.Star): + """测试插件主类。""" + + def __init__(self, context): + super().__init__(context) + self.name = "{plugin_name}" + + async def initialize(self): + """初始化插件。""" + pass + + async def terminate(self): + """终止插件。""" + pass +''' + + +class MockPluginBuilder: + """测试插件构建器。 + + 用于创建、管理和清理测试用的模拟插件。支持任意插件的模拟创建。 + + Example: + # 创建一个简单的测试插件 + builder = MockPluginBuilder(plugin_store_path) + plugin_dir = builder.create("my_test_plugin") + + # 创建自定义配置的插件 + config = MockPluginConfig( + name="custom_plugin", + version="2.0.0", + main_code="print('hello')", + ) + plugin_dir = builder.create(config) + + # 清理插件 + builder.cleanup("my_test_plugin") + """ + + def __init__(self, plugin_store_path: str | Path): + """初始化构建器。 + + Args: + plugin_store_path: 插件存储路径 (通常是 data/plugins) + """ + self.plugin_store_path = Path(plugin_store_path) + self._created_plugins: set[str] = set() + + def create( + self, + plugin_config: str | MockPluginConfig | None = None, + **kwargs, + ) -> Path: + """创建模拟插件。 + + Args: + plugin_config: 插件名称字符串、MockPluginConfig 对象或 None + **kwargs: 如果 plugin_config 是字符串或 None,这些参数用于构建 MockPluginConfig + + Returns: + Path: 创建的插件目录路径 + """ + # 处理不同类型的输入 + if plugin_config is None: + config = MockPluginConfig(**kwargs) + elif isinstance(plugin_config, str): + config = MockPluginConfig(name=plugin_config, **kwargs) + elif isinstance(plugin_config, MockPluginConfig): + config = plugin_config + else: + raise TypeError(f"Invalid plugin_config type: {type(plugin_config)}") + + # 创建插件目录 + plugin_dir = self.plugin_store_path / config.name + plugin_dir.mkdir(parents=True, exist_ok=True) + + # 创建 metadata.yaml + metadata_content = "\n".join( + [ + f"name: {config.name}", + f"author: {config.author}", + f"desc: {config.description}", + f"version: {config.version}", + f"repo: {config.repo}", + ] + ) + (plugin_dir / "metadata.yaml").write_text( + metadata_content + "\n", encoding="utf-8" + ) + + # 创建 main.py + main_code = config.main_code or DEFAULT_PLUGIN_MAIN_TEMPLATE.format( + plugin_name=config.name + ) + (plugin_dir / "main.py").write_text(main_code, encoding="utf-8") + + # 创建 requirements.txt(如果有依赖) + if config.requirements: + (plugin_dir / "requirements.txt").write_text( + "\n".join(config.requirements) + "\n", encoding="utf-8" + ) + + # 创建 README.md(如果需要) + if config.has_readme: + (plugin_dir / "README.md").write_text( + config.readme_content, encoding="utf-8" + ) + + # 记录创建的插件 + self._created_plugins.add(config.name) + + return plugin_dir + + def cleanup(self, plugin_name: str | None = None) -> None: + """清理插件。 + + Args: + plugin_name: 要清理的插件名称,如果为 None 则清理所有由本构建器创建的插件 + """ + if plugin_name: + plugins_to_clean = {plugin_name} + else: + plugins_to_clean = self._created_plugins.copy() + + for name in plugins_to_clean: + plugin_dir = self.plugin_store_path / name + if plugin_dir.exists(): + shutil.rmtree(plugin_dir) + self._created_plugins.discard(name) + + def cleanup_all(self) -> None: + """清理所有由本构建器创建的插件。""" + self.cleanup(None) + + def get_plugin_path(self, plugin_name: str) -> Path: + """获取插件路径。 + + Args: + plugin_name: 插件名称 + + Returns: + Path: 插件目录路径 + """ + return self.plugin_store_path / plugin_name + + @property + def created_plugins(self) -> set[str]: + """获取已创建的插件名称集合。""" + return self._created_plugins.copy() + + +def create_mock_updater_install( + plugin_builder: MockPluginBuilder, + repo_to_plugin: dict[str, str] | None = None, +) -> Callable: + """创建模拟的 updater.install 方法。 + + Args: + plugin_builder: MockPluginBuilder 实例 + repo_to_plugin: 仓库 URL 到插件名称的映射,格式: {"https://github.com/user/repo": "plugin_name"} + + Returns: + Callable: 异步函数,可用于 monkeypatch.setattr + """ + + async def mock_install(repo_url: str, proxy: str = "") -> str: + """Mock updater.install 方法。""" + # 查找插件名称 + plugin_name = None + if repo_to_plugin: + plugin_name = repo_to_plugin.get(repo_url) + + # 如果没有映射,尝试从 URL 提取插件名 + if not plugin_name: + # 从 https://github.com/user/plugin_name 提取 plugin_name + parts = repo_url.rstrip("/").split("/") + plugin_name = parts[-1] if parts else "unknown_plugin" + + # 创建插件目录 + config = MockPluginConfig(name=plugin_name, repo=repo_url) + plugin_dir = plugin_builder.create(config) + return str(plugin_dir) + + return mock_install + + +def create_mock_updater_update( + plugin_builder: MockPluginBuilder, + update_callback: Callable | None = None, +) -> Callable: + """创建模拟的 updater.update 方法。 + + Args: + plugin_builder: MockPluginBuilder 实例 + update_callback: 更新回调函数,接收 plugin 参数 + + Returns: + Callable: 异步函数,可用于 monkeypatch.setattr + """ + + async def mock_update(plugin, proxy: str = "") -> None: + """Mock updater.update 方法。""" + plugin_dir = plugin_builder.get_plugin_path(plugin.name) + + # 创建更新标记文件 + (plugin_dir / ".updated").write_text("ok", encoding="utf-8") + + # 调用回调 + if update_callback: + update_callback(plugin) + + return mock_update diff --git a/tests/test_api_key_open_api.py b/tests/test_api_key_open_api.py index 3d1ea0a0f..4bc5fd4d5 100644 --- a/tests/test_api_key_open_api.py +++ b/tests/test_api_key_open_api.py @@ -186,7 +186,7 @@ async def fake_chat(post_data: dict | None = None): "/api/v1/chat", json={ "message": "hello", - "username": "alice", + "username": "alice_auto_session", "enable_streaming": False, }, headers={"X-API-Key": raw_key}, @@ -200,16 +200,16 @@ async def fake_chat(post_data: dict | None = None): created_session_id = send_data["data"]["session_id"] assert isinstance(created_session_id, str) uuid.UUID(created_session_id) - assert send_data["data"]["creator"] == "alice" + assert send_data["data"]["creator"] == "alice_auto_session" created_session = await core_lifecycle_td.db.get_platform_session_by_id( created_session_id ) assert created_session is not None - assert created_session.creator == "alice" + assert created_session.creator == "alice_auto_session" assert created_session.platform_id == "webchat" await core_lifecycle_td.db.create_platform_session( - creator="bob", + creator="bob_auto_session", platform_id="webchat", session_id="open_api_existing_bob_session", is_group=0, @@ -251,14 +251,15 @@ async def test_open_chat_sessions_pagination( create_res = await test_client.post( "/api/apikey/create", - json={"name": "chat-scope-key", "scopes": ["chat"]}, + json={"name": "chat-scope-key-pagination", "scopes": ["chat"]}, headers=authenticated_header, ) create_data = await create_res.get_json() assert create_data["status"] == "ok" raw_key = create_data["data"]["api_key"] - creator = "alice" + # Use unique session IDs to avoid conflicts with other tests + creator = "alice_pagination" for idx in range(3): await core_lifecycle_td.db.create_platform_session( creator=creator, @@ -268,7 +269,7 @@ async def test_open_chat_sessions_pagination( is_group=0, ) await core_lifecycle_td.db.create_platform_session( - creator="bob", + creator="bob_pagination", platform_id="webchat", session_id="open_api_paginated_bob", display_name="Open API Session Bob", @@ -276,7 +277,7 @@ async def test_open_chat_sessions_pagination( ) page_1_res = await test_client.get( - "/api/v1/chat/sessions?page=1&page_size=2&username=alice", + "/api/v1/chat/sessions?page=1&page_size=2&username=alice_pagination", headers={"X-API-Key": raw_key}, ) assert page_1_res.status_code == 200 @@ -286,10 +287,10 @@ async def test_open_chat_sessions_pagination( assert page_1_data["data"]["page_size"] == 2 assert page_1_data["data"]["total"] == 3 assert len(page_1_data["data"]["sessions"]) == 2 - assert all(item["creator"] == "alice" for item in page_1_data["data"]["sessions"]) + assert all(item["creator"] == "alice_pagination" for item in page_1_data["data"]["sessions"]) page_2_res = await test_client.get( - "/api/v1/chat/sessions?page=2&page_size=2&username=alice", + "/api/v1/chat/sessions?page=2&page_size=2&username=alice_pagination", headers={"X-API-Key": raw_key}, ) assert page_2_res.status_code == 200 diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 969f0da6d..4bf0673e8 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -1,5 +1,6 @@ import asyncio import os +from pathlib import Path import pytest import pytest_asyncio @@ -11,6 +12,12 @@ from astrbot.core.star.star import star_registry from astrbot.core.star.star_handler import star_handlers_registry from astrbot.dashboard.server import AstrBotDashboard +from tests.fixtures.helpers import ( + MockPluginBuilder, + MockPluginConfig, + create_mock_updater_install, + create_mock_updater_update, +) @pytest_asyncio.fixture(scope="module") @@ -94,8 +101,15 @@ async def test_get_stat(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_plugins(app: Quart, authenticated_header: dict): +async def test_plugins( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + """测试插件 API 端点,使用 Mock 避免真实网络调用。""" test_client = app.test_client() + # 已经安装的插件 response = await test_client.get("/api/plugin/get", headers=authenticated_header) assert response.status_code == 200 @@ -111,53 +125,79 @@ async def test_plugins(app: Quart, authenticated_header: dict): data = await response.get_json() assert data["status"] == "ok" - # 插件安装 - response = await test_client.post( - "/api/plugin/install", - json={"url": "https://github.com/Soulter/astrbot_plugin_essential"}, - headers=authenticated_header, - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - exists = False - for md in star_registry: - if md.name == "astrbot_plugin_essential": - exists = True - break - assert exists is True, "插件 astrbot_plugin_essential 未成功载入" - - # 插件更新 - response = await test_client.post( - "/api/plugin/update", - json={"name": "astrbot_plugin_essential"}, - headers=authenticated_header, + # 使用 MockPluginBuilder 创建测试插件 + plugin_store_path = core_lifecycle_td.plugin_manager.plugin_store_path + builder = MockPluginBuilder(plugin_store_path) + + # 定义测试插件 + test_plugin_name = "test_mock_plugin" + test_repo_url = f"https://github.com/test/{test_plugin_name}" + + # 创建 Mock 函数 + mock_install = create_mock_updater_install( + builder, + repo_to_plugin={test_repo_url: test_plugin_name}, ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" + mock_update = create_mock_updater_update(builder) - # 插件卸载 - response = await test_client.post( - "/api/plugin/uninstall", - json={"name": "astrbot_plugin_essential"}, - headers=authenticated_header, + # 设置 Mock + monkeypatch.setattr( + core_lifecycle_td.plugin_manager.updator, "install", mock_install + ) + monkeypatch.setattr( + core_lifecycle_td.plugin_manager.updator, "update", mock_update ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - exists = False - for md in star_registry: - if md.name == "astrbot_plugin_essential": - exists = True - break - assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" - exists = False - for md in star_handlers_registry: - if "astrbot_plugin_essential" in md.handler_module_path: - exists = True - break - assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" + + try: + # 插件安装 + response = await test_client.post( + "/api/plugin/install", + json={"url": test_repo_url}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}" + + # 验证插件已注册 + exists = any(md.name == test_plugin_name for md in star_registry) + assert exists is True, f"插件 {test_plugin_name} 未成功载入" + + # 插件更新 + response = await test_client.post( + "/api/plugin/update", + json={"name": test_plugin_name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + # 验证更新标记文件 + plugin_dir = builder.get_plugin_path(test_plugin_name) + assert (plugin_dir / ".updated").exists() + + # 插件卸载 + response = await test_client.post( + "/api/plugin/uninstall", + json={"name": test_plugin_name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + # 验证插件已卸载 + exists = any(md.name == test_plugin_name for md in star_registry) + assert exists is False, f"插件 {test_plugin_name} 未成功卸载" + exists = any( + test_plugin_name in md.handler_module_path for md in star_handlers_registry + ) + assert exists is False, f"插件 {test_plugin_name} handler 未成功清理" + + finally: + # 清理测试插件 + builder.cleanup(test_plugin_name) @pytest.mark.asyncio @@ -189,12 +229,41 @@ async def test_commands_api(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_check_update(app: Quart, authenticated_header: dict): +async def test_check_update( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + """测试检查更新 API,使用 Mock 避免真实网络调用。""" test_client = app.test_client() + + # Mock 更新检查和网络请求 + async def mock_check_update(*args, **kwargs): + """Mock 更新检查,返回无新版本。""" + return None # None 表示没有新版本 + + async def mock_get_dashboard_version(*args, **kwargs): + """Mock Dashboard 版本获取。""" + from astrbot.core.config.default import VERSION + + return f"v{VERSION}" # 返回当前版本 + + monkeypatch.setattr( + core_lifecycle_td.astrbot_updator, + "check_update", + mock_check_update, + ) + monkeypatch.setattr( + "astrbot.dashboard.routes.update.get_dashboard_version", + mock_get_dashboard_version, + ) + response = await test_client.get("/api/update/check", headers=authenticated_header) assert response.status_code == 200 data = await response.get_json() assert data["status"] == "success" + assert data["data"]["has_new_version"] is False @pytest.mark.asyncio diff --git a/tests/test_main.py b/tests/test_main.py index 0453a51ee..2f879ee43 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -16,6 +16,16 @@ def __init__(self, major, minor): self.major = major self.minor = minor + def __eq__(self, other): + if isinstance(other, tuple): + return (self.major, self.minor) == other[:2] + return (self.major, self.minor) == (other.major, other.minor) + + def __ge__(self, other): + if isinstance(other, tuple): + return (self.major, self.minor) >= other[:2] + return (self.major, self.minor) >= (other.major, other.minor) + def test_check_env(monkeypatch): version_info_correct = _version_info(3, 10) @@ -23,15 +33,51 @@ def test_check_env(monkeypatch): monkeypatch.setattr(sys, "version_info", version_info_correct) with mock.patch("os.makedirs") as mock_makedirs: check_env() - mock_makedirs.assert_any_call("data/config", exist_ok=True) - mock_makedirs.assert_any_call("data/plugins", exist_ok=True) - mock_makedirs.assert_any_call("data/temp", exist_ok=True) + # Check that makedirs was called with paths containing expected dirs + called_paths = [call[0][0] for call in mock_makedirs.call_args_list] + # Use os.path.join for cross-platform path matching + assert any(p.rstrip(os.sep).endswith(os.path.join("data", "config")) for p in called_paths) + assert any(p.rstrip(os.sep).endswith(os.path.join("data", "plugins")) for p in called_paths) + assert any(p.rstrip(os.sep).endswith(os.path.join("data", "temp")) for p in called_paths) monkeypatch.setattr(sys, "version_info", version_info_wrong) with pytest.raises(SystemExit): check_env() +def test_version_info_comparisons(): + """Test _version_info comparison operators with tuples and other instances.""" + v3_10 = _version_info(3, 10) + v3_9 = _version_info(3, 9) + v3_11 = _version_info(3, 11) + + # Test __eq__ with tuples + assert v3_10 == (3, 10) + assert v3_10 != (3, 9) + assert v3_9 == (3, 9) + + # Test __ge__ with tuples + assert v3_10 >= (3, 10) + assert v3_10 >= (3, 9) + assert not (v3_9 >= (3, 10)) + assert v3_11 >= (3, 10) + + # Test __eq__ with other _version_info instances + assert v3_10 == _version_info(3, 10) + assert v3_10 != v3_9 + assert v3_10 == v3_10 # Same instance + + assert v3_10 != v3_11 + + # Test __ge__ with other _version_info instances + assert v3_10 >= v3_10 + assert v3_10 >= v3_9 + assert not (v3_9 >= v3_10) + assert v3_11 >= v3_10 + + assert v3_11 >= v3_11 # Same instance + + @pytest.mark.asyncio async def test_check_dashboard_files_not_exists(monkeypatch): """Tests dashboard download when files do not exist.""" diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 1e4cd866a..b91e25c01 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,65 +1,164 @@ -import os +import sys from asyncio import Queue +from pathlib import Path from unittest.mock import MagicMock import pytest +import pytest_asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.star.context import Context -from astrbot.core.star.star import star_registry +from astrbot.core.star.star import star_map, star_registry from astrbot.core.star.star_handler import star_handlers_registry from astrbot.core.star.star_manager import PluginManager -@pytest.fixture -def plugin_manager_pm(tmp_path): - """Provides a fully isolated PluginManager instance for testing. - - Uses a temporary directory for plugins. - - Uses a temporary database. - - Creates a fresh context for each test. - """ - # Create temporary resources - temp_plugins_path = tmp_path / "plugins" - temp_plugins_path.mkdir() - temp_db_path = tmp_path / "test_db.db" +def _clear_module_cache() -> None: + """Clear module cache for data module tree to ensure test isolation.""" + modules_to_remove = [ + key for key in sys.modules if key == "data" or key.startswith("data.") + ] + for key in modules_to_remove: + del sys.modules[key] + + +def _clear_registry(plugin_name: str) -> None: + """Clear plugin from global registries.""" + # Clear star_registry (list) + star_registry[:] = [md for md in star_registry if md.name != plugin_name] + # Clear star_map (dict) + keys_to_remove = [ + key for key, md in star_map.items() if md.name == plugin_name + ] + for key in keys_to_remove: + del star_map[key] + # Clear star_handlers_registry (StarHandlerRegistry) + for handler in list(star_handlers_registry): + if plugin_name in (handler.handler_module_path or ""): + star_handlers_registry.remove(handler) + +TEST_PLUGIN_REPO = "https://github.com/Soulter/helloworld" +TEST_PLUGIN_DIR = "helloworld" +TEST_PLUGIN_NAME = "helloworld" + + +def _write_local_test_plugin(plugin_dir: Path, repo_url: str) -> None: + plugin_dir.mkdir(parents=True, exist_ok=True) + (plugin_dir / "metadata.yaml").write_text( + "\n".join( + [ + f"name: {TEST_PLUGIN_NAME}", + "author: AstrBot Team", + "desc: Local test plugin", + "version: 1.0.0", + f"repo: {repo_url}", + ], + ) + + "\n", + encoding="utf-8", + ) + (plugin_dir / "main.py").write_text( + "\n".join( + [ + "from astrbot.api import star", + "", + "class Main(star.Star):", + " pass", + "", + ], + ), + encoding="utf-8", + ) + + +@pytest_asyncio.fixture +async def plugin_manager_pm(tmp_path, monkeypatch): + """Provides a fully isolated PluginManager instance for testing.""" + # Clear module cache before setup to ensure isolation + _clear_module_cache() + + test_root = tmp_path / "astrbot_root" + data_dir = test_root / "data" + plugin_dir = data_dir / "plugins" + config_dir = data_dir / "config" + temp_dir = data_dir / "temp" + for path in (plugin_dir, config_dir, temp_dir): + path.mkdir(parents=True, exist_ok=True) + + # Ensure `import data.plugins..main` resolves to this temp root. + (data_dir / "__init__.py").write_text("", encoding="utf-8") + (plugin_dir / "__init__.py").write_text("", encoding="utf-8") + + # Use monkeypatch for both env var and sys.path to ensure proper cleanup + monkeypatch.setenv("ASTRBOT_ROOT", str(test_root)) + monkeypatch.syspath_prepend(str(test_root)) # Create fresh, isolated instances for the context event_queue = Queue() config = AstrBotConfig() - db = SQLiteDatabase(str(temp_db_path)) - - # Set the plugin store path in the config to the temporary directory - config.plugin_store_path = str(temp_plugins_path) + db = SQLiteDatabase(str(data_dir / "test_db.db")) + config.plugin_store_path = str(plugin_dir) - # Mock dependencies for the context provider_manager = MagicMock() platform_manager = MagicMock() conversation_manager = MagicMock() message_history_manager = MagicMock() persona_manager = MagicMock() + persona_manager.personas_v3 = [] astrbot_config_mgr = MagicMock() knowledge_base_manager = MagicMock() + cron_manager = MagicMock() star_context = Context( - event_queue, - config, - db, - provider_manager, - platform_manager, - conversation_manager, - message_history_manager, - persona_manager, - astrbot_config_mgr, + event_queue=event_queue, + config=config, + db=db, + provider_manager=provider_manager, + platform_manager=platform_manager, + conversation_manager=conversation_manager, + message_history_manager=message_history_manager, + persona_manager=persona_manager, + astrbot_config_mgr=astrbot_config_mgr, knowledge_base_manager=knowledge_base_manager, + cron_manager=cron_manager, + subagent_orchestrator=None, ) - # Create the PluginManager instance manager = PluginManager(star_context, config) - return manager + try: + yield manager + finally: + # Cleanup global registries and module cache + _clear_registry(TEST_PLUGIN_NAME) + _clear_module_cache() + await db.engine.dispose() + + +@pytest.fixture +def local_updator(plugin_manager_pm: PluginManager, monkeypatch): + plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR + async def mock_install(repo_url: str, proxy=""): # noqa: ARG001 + if repo_url != TEST_PLUGIN_REPO: + raise Exception("Repo not found") + _write_local_test_plugin(plugin_path, repo_url) + return str(plugin_path) -def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): + async def mock_update(plugin, proxy=""): # noqa: ARG001 + if plugin.name != TEST_PLUGIN_NAME: + raise Exception("Plugin not found") + if not plugin_path.exists(): + raise Exception("Plugin path missing") + (plugin_path / ".updated").write_text("ok", encoding="utf-8") + + monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install) + monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update) + return plugin_path + + +@pytest.mark.asyncio +async def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): assert plugin_manager_pm is not None assert plugin_manager_pm.context is not None assert plugin_manager_pm.config is not None @@ -73,73 +172,59 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_install_plugin(plugin_manager_pm: PluginManager): - """Tests successful plugin installation in an isolated environment.""" - test_repo = "https://github.com/Soulter/astrbot_plugin_essential" - plugin_info = await plugin_manager_pm.install_plugin(test_repo) - plugin_path = os.path.join( - plugin_manager_pm.plugin_store_path, - "astrbot_plugin_essential", - ) - +async def test_install_plugin(plugin_manager_pm: PluginManager, local_updator: Path): + """Tests successful plugin installation without external network.""" + plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) assert plugin_info is not None - assert os.path.exists(plugin_path) - assert any(md.name == "astrbot_plugin_essential" for md in star_registry), ( - "Plugin 'astrbot_plugin_essential' was not loaded into star_registry." - ) + assert plugin_info["name"] == TEST_PLUGIN_NAME + assert local_updator.exists() + assert any(md.name == TEST_PLUGIN_NAME for md in star_registry) @pytest.mark.asyncio -async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_install_nonexistent_plugin( + plugin_manager_pm: PluginManager, local_updator +): """Tests that installing a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.install_plugin( - "https://github.com/Soulter/non_existent_repo", + "https://github.com/Soulter/non_existent_repo" ) @pytest.mark.asyncio -async def test_update_plugin(plugin_manager_pm: PluginManager): - """Tests updating an existing plugin in an isolated environment.""" - # First, install the plugin - test_repo = "https://github.com/Soulter/astrbot_plugin_essential" - await plugin_manager_pm.install_plugin(test_repo) - - # Then, update it - await plugin_manager_pm.update_plugin("astrbot_plugin_essential") +async def test_update_plugin(plugin_manager_pm: PluginManager, local_updator: Path): + """Tests updating an existing plugin without external network.""" + plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) + assert plugin_info is not None + plugin_name = plugin_info["name"] + await plugin_manager_pm.update_plugin(plugin_name) + assert (local_updator / ".updated").exists() @pytest.mark.asyncio -async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_update_nonexistent_plugin( + plugin_manager_pm: PluginManager, local_updator +): """Tests that updating a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.update_plugin("non_existent_plugin") @pytest.mark.asyncio -async def test_uninstall_plugin(plugin_manager_pm: PluginManager): - """Tests successful plugin uninstallation in an isolated environment.""" - # First, install the plugin - test_repo = "https://github.com/Soulter/astrbot_plugin_essential" - await plugin_manager_pm.install_plugin(test_repo) - plugin_path = os.path.join( - plugin_manager_pm.plugin_store_path, - "astrbot_plugin_essential", - ) - assert os.path.exists(plugin_path) # Pre-condition +async def test_uninstall_plugin(plugin_manager_pm: PluginManager, local_updator: Path): + """Tests successful plugin uninstallation.""" + plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) + assert plugin_info is not None + plugin_name = plugin_info["name"] + assert local_updator.exists() - # Then, uninstall it - await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential") + await plugin_manager_pm.uninstall_plugin(plugin_name) - assert not os.path.exists(plugin_path) - assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), ( - "Plugin 'astrbot_plugin_essential' was not unloaded from star_registry." - ) + assert not local_updator.exists() + assert not any(md.name == TEST_PLUGIN_NAME for md in star_registry) assert not any( - "astrbot_plugin_essential" in md.handler_module_path - for md in star_handlers_registry - ), ( - "Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry." + TEST_PLUGIN_NAME in md.handler_module_path for md in star_handlers_registry ) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 4474e1599..36870e617 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -101,10 +101,16 @@ def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None: "mock_apscheduler.schedulers = MagicMock();" "mock_apscheduler.schedulers.asyncio = MagicMock();" "mock_apscheduler.schedulers.background = MagicMock();" + "mock_apscheduler.triggers = MagicMock();" + "mock_apscheduler.triggers.cron = MagicMock();" + "mock_apscheduler.triggers.date = MagicMock();" "sys.modules['apscheduler'] = mock_apscheduler;" "sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;" "sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;" "sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;" + "sys.modules['apscheduler.triggers'] = mock_apscheduler.triggers;" + "sys.modules['apscheduler.triggers.cron'] = mock_apscheduler.triggers.cron;" + "sys.modules['apscheduler.triggers.date'] = mock_apscheduler.triggers.date;" "import astrbot.core.pipeline as pipeline;" "assert pipeline.ProcessStage is not None;" "assert pipeline.RespondStage is not None" diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 0b5190407..c738cfc80 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -447,7 +447,8 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message( final_resp = runner.get_final_llm_resp() assert final_resp is not None assert final_resp.role == "assistant" - assert final_resp.completion_text == "partial " + # When interrupted, the runner replaces completion_text with a system message + assert "interrupted" in final_resp.completion_text.lower() assert runner.run_context.messages[-1].role == "assistant" diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 000000000..1da02835b --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,607 @@ +"""Tests for config module.""" + +import json +import os + +import pytest + +from astrbot.core.config.astrbot_config import AstrBotConfig, RateLimitStrategy +from astrbot.core.config.default import DEFAULT_VALUE_MAP +from astrbot.core.config.i18n_utils import ConfigMetadataI18n + + +@pytest.fixture +def temp_config_path(tmp_path): + """Create a temporary config path.""" + return str(tmp_path / "test_config.json") + + +@pytest.fixture +def minimal_default_config(): + """Create a minimal default config for testing.""" + return { + "config_version": 2, + "platform_settings": { + "unique_session": False, + "rate_limit": { + "time": 60, + "count": 30, + "strategy": "stall", + }, + }, + "provider_settings": { + "enable": True, + "default_provider_id": "", + }, + } + + +class TestRateLimitStrategy: + """Tests for RateLimitStrategy enum.""" + + def test_stall_value(self): + """Test stall enum value.""" + assert RateLimitStrategy.STALL.value == "stall" + + def test_discard_value(self): + """Test discard enum value.""" + assert RateLimitStrategy.DISCARD.value == "discard" + + +class TestAstrBotConfigLoad: + """Tests for AstrBotConfig loading and initialization.""" + + def test_init_creates_file_if_not_exists( + self, temp_config_path, minimal_default_config + ): + """Test that config file is created when it doesn't exist.""" + assert not os.path.exists(temp_config_path) + + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert os.path.exists(temp_config_path) + assert config.config_version == 2 + assert config.platform_settings["unique_session"] is False + + def test_init_loads_existing_file(self, temp_config_path, minimal_default_config): + """Test that existing config file is loaded.""" + existing_config = { + "config_version": 2, + "platform_settings": {"unique_session": True}, + "provider_settings": {"enable": False}, + } + with open(temp_config_path, "w", encoding="utf-8-sig") as f: + json.dump(existing_config, f) + + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert config.platform_settings["unique_session"] is True + assert config.provider_settings["enable"] is False + + def test_first_deploy_flag(self, temp_config_path, minimal_default_config): + """Test first_deploy flag is set for new config.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert hasattr(config, "first_deploy") + assert config.first_deploy is True + + def test_init_with_schema(self, temp_config_path): + """Test initialization with schema.""" + schema = { + "test_field": { + "type": "string", + "default": "test_value", + }, + "nested": { + "type": "object", + "items": { + "enabled": {"type": "bool"}, + "count": {"type": "int"}, + }, + }, + } + + config = AstrBotConfig(config_path=temp_config_path, schema=schema) + + assert config.test_field == "test_value" + assert config.nested["enabled"] is False + assert config.nested["count"] == 0 + + def test_dot_notation_access(self, temp_config_path, minimal_default_config): + """Test accessing config values using dot notation.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert config.platform_settings is not None + assert config.non_existent_field is None + + def test_setattr_updates_config(self, temp_config_path, minimal_default_config): + """Test that setting attributes updates config.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + config.new_field = "new_value" + + assert config.new_field == "new_value" + + def test_delattr_removes_field(self, temp_config_path, minimal_default_config): + """Test that deleting attributes removes them.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + config.temp_field = "temp" + + del config.temp_field + + # Accessing a deleted field returns None due to __getattr__ + assert config.temp_field is None + # But the field is removed from the dict + assert "temp_field" not in config + + def test_delattr_saves_config(self, temp_config_path, minimal_default_config): + """Test that deleting attributes saves config to file.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + config.temp_field = "temp" + del config.temp_field + + with open(temp_config_path, encoding="utf-8-sig") as f: + loaded_config = json.load(f) + + assert "temp_field" not in loaded_config + + def test_check_exist(self, temp_config_path, minimal_default_config): + """Test check_exist method.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert config.check_exist() is True + + # Create a path that definitely doesn't exist + import pathlib + + temp_dir = pathlib.Path(temp_config_path).parent + non_existent_path = str(temp_dir / "non_existent_config.json") + + # Check that the file doesn't exist before creating config + assert not os.path.exists(non_existent_path) + + # Create config which will auto-create the file + config2 = AstrBotConfig( + config_path=non_existent_path, default_config=minimal_default_config + ) + + # Now it exists + assert config2.check_exist() is True + assert os.path.exists(non_existent_path) + + +class TestConfigValidation: + """Tests for config validation and integrity checking.""" + + def test_insert_missing_config_items( + self, temp_config_path, minimal_default_config + ): + """Test that missing config items are inserted with default values.""" + existing_config = {"config_version": 2} + with open(temp_config_path, "w", encoding="utf-8-sig") as f: + json.dump(existing_config, f) + + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert "platform_settings" in config + assert "provider_settings" in config + + def test_replace_none_with_default(self, temp_config_path, minimal_default_config): + """Test that None values are replaced with defaults.""" + existing_config = { + "config_version": 2, + "platform_settings": None, + "provider_settings": None, + } + with open(temp_config_path, "w", encoding="utf-8-sig") as f: + json.dump(existing_config, f) + + AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + # Reload to verify the values were replaced + config2 = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert config2.platform_settings is not None + assert config2.provider_settings is not None + + def test_reorder_config_keys(self, temp_config_path, minimal_default_config): + """Test that config keys are reordered to match default.""" + existing_config = { + "provider_settings": {"enable": True}, + "config_version": 2, + "platform_settings": {"unique_session": False}, + } + with open(temp_config_path, "w", encoding="utf-8-sig") as f: + json.dump(existing_config, f) + + AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + with open(temp_config_path, encoding="utf-8-sig") as f: + loaded_config = json.load(f) + + keys = list(loaded_config.keys()) + assert keys[0] == "config_version" + assert keys[1] == "platform_settings" + assert keys[2] == "provider_settings" + + def test_remove_unknown_config_keys(self, temp_config_path, minimal_default_config): + """Test that unknown config keys are removed.""" + existing_config = { + "config_version": 2, + "platform_settings": {}, + "unknown_key": "should_be_removed", + } + with open(temp_config_path, "w", encoding="utf-8-sig") as f: + json.dump(existing_config, f) + + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert "unknown_key" not in config + + def test_nested_config_validation(self, temp_config_path): + """Test validation of nested config structures.""" + default_config = { + "nested": { + "level1": { + "level2": { + "value": 42, + }, + }, + }, + } + + existing_config = { + "nested": { + "level1": {}, # Missing level2 + }, + } + with open(temp_config_path, "w", encoding="utf-8-sig") as f: + json.dump(existing_config, f) + + config = AstrBotConfig( + config_path=temp_config_path, default_config=default_config + ) + + assert "level2" in config.nested["level1"] + assert config.nested["level1"]["level2"]["value"] == 42 + + +class TestConfigHotReload: + """Tests for config hot reload functionality.""" + + def test_save_config(self, temp_config_path, minimal_default_config): + """Test saving config to file.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + config.new_field = "new_value" + config.save_config() + + with open(temp_config_path, encoding="utf-8-sig") as f: + loaded_config = json.load(f) + + assert loaded_config["new_field"] == "new_value" + + def test_save_config_with_replace(self, temp_config_path, minimal_default_config): + """Test saving config with replacement.""" + config = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + replacement_config = { + "replaced": True, + "extra_field": "value", + } + config.save_config(replace_config=replacement_config) + + with open(temp_config_path, encoding="utf-8-sig") as f: + loaded_config = json.load(f) + + # The replacement config is merged with existing config + assert loaded_config["replaced"] is True + assert loaded_config["extra_field"] == "value" + # Original fields are preserved because update merges + assert "platform_settings" in loaded_config + + def test_modification_persists_after_reload( + self, temp_config_path, minimal_default_config + ): + """Test that modifications persist after reloading.""" + config1 = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + config1.platform_settings["unique_session"] = True + config1.save_config() + + config2 = AstrBotConfig( + config_path=temp_config_path, default_config=minimal_default_config + ) + + assert config2.platform_settings["unique_session"] is True + + +class TestConfigSchemaToDefault: + """Tests for schema to default config conversion.""" + + def test_convert_schema_with_defaults(self, temp_config_path): + """Test converting schema with explicit defaults.""" + schema = { + "string_field": {"type": "string", "default": "custom"}, + "int_field": {"type": "int", "default": 100}, + "bool_field": {"type": "bool", "default": True}, + } + + config = AstrBotConfig(config_path=temp_config_path, schema=schema) + + assert config.string_field == "custom" + assert config.int_field == 100 + assert config.bool_field is True + + def test_convert_schema_without_defaults(self, temp_config_path): + """Test converting schema using default value map.""" + schema = { + "string_field": {"type": "string"}, + "int_field": {"type": "int"}, + "bool_field": {"type": "bool"}, + } + + config = AstrBotConfig(config_path=temp_config_path, schema=schema) + + assert config.string_field == DEFAULT_VALUE_MAP["string"] + assert config.int_field == DEFAULT_VALUE_MAP["int"] + assert config.bool_field == DEFAULT_VALUE_MAP["bool"] + + def test_unsupported_schema_type_raises_error(self, temp_config_path): + """Test that unsupported schema types raise error.""" + schema = { + "field": {"type": "unsupported_type"}, + } + + with pytest.raises(TypeError, match="不受支持的配置类型"): + AstrBotConfig(config_path=temp_config_path, schema=schema) + + def test_template_list_type(self, temp_config_path): + """Test template_list schema type.""" + schema = { + "templates": {"type": "template_list", "default": []}, + } + + config = AstrBotConfig(config_path=temp_config_path, schema=schema) + + assert config.templates == [] + + def test_nested_object_schema(self, temp_config_path): + """Test nested object schema conversion.""" + schema = { + "nested": { + "type": "object", + "items": { + "field1": {"type": "string"}, + "field2": {"type": "int"}, + }, + }, + } + + config = AstrBotConfig(config_path=temp_config_path, schema=schema) + + assert config.nested["field1"] == "" + assert config.nested["field2"] == 0 + + +class TestConfigMetadataI18n: + """Tests for i18n utils.""" + + def test_get_i18n_key(self): + """Test generating i18n key.""" + key = ConfigMetadataI18n._get_i18n_key( + group="ai_group", + section="general", + field="enable", + attr="description", + ) + + assert key == "ai_group.general.enable.description" + + def test_get_i18n_key_without_field(self): + """Test generating i18n key without field.""" + key = ConfigMetadataI18n._get_i18n_key( + group="ai_group", + section="general", + field="", + attr="description", + ) + + assert key == "ai_group.general.description" + + def test_convert_to_i18n_keys_simple(self): + """Test converting simple metadata to i18n keys.""" + metadata = { + "ai_group": { + "name": "AI Settings", + "metadata": { + "general": { + "description": "General settings", + "items": { + "enable": { + "description": "Enable feature", + "type": "bool", + "default": True, + }, + }, + }, + }, + }, + } + + result = ConfigMetadataI18n.convert_to_i18n_keys(metadata) + + assert result["ai_group"]["name"] == "ai_group.name" + assert ( + result["ai_group"]["metadata"]["general"]["description"] + == "ai_group.general.description" + ) + assert ( + result["ai_group"]["metadata"]["general"]["items"]["enable"]["description"] + == "ai_group.general.enable.description" + ) + + def test_convert_to_i18n_keys_with_hint(self): + """Test converting metadata with hint.""" + metadata = { + "group": { + "metadata": { + "section": { + "hint": "This is a hint", + "items": { + "field": { + "hint": "Field hint", + "type": "string", + }, + }, + }, + }, + }, + } + + result = ConfigMetadataI18n.convert_to_i18n_keys(metadata) + + assert result["group"]["metadata"]["section"]["hint"] == "group.section.hint" + assert ( + result["group"]["metadata"]["section"]["items"]["field"]["hint"] + == "group.section.field.hint" + ) + + def test_convert_to_i18n_keys_with_labels(self): + """Test converting metadata with labels.""" + metadata = { + "group": { + "metadata": { + "section": { + "items": { + "field": { + "labels": ["Label1", "Label2"], + "type": "string", + }, + }, + }, + }, + }, + } + + result = ConfigMetadataI18n.convert_to_i18n_keys(metadata) + + assert ( + result["group"]["metadata"]["section"]["items"]["field"]["labels"] + == "group.section.field.labels" + ) + + def test_convert_to_i18n_keys_nested_items(self): + """Test converting metadata with nested items.""" + metadata = { + "group": { + "metadata": { + "section": { + "items": { + "nested": { + "description": "Nested field", + "type": "object", + "items": { + "inner": { + "description": "Inner field", + "type": "string", + }, + }, + }, + }, + }, + }, + }, + } + + result = ConfigMetadataI18n.convert_to_i18n_keys(metadata) + + assert ( + result["group"]["metadata"]["section"]["items"]["nested"]["description"] + == "group.section.nested.description" + ) + assert ( + result["group"]["metadata"]["section"]["items"]["nested"]["items"]["inner"][ + "description" + ] + == "group.section.nested.inner.description" + ) + + def test_convert_to_i18n_keys_preserves_non_i18n_fields(self): + """Test that non-i18n fields are preserved.""" + metadata = { + "group": { + "metadata": { + "section": { + "items": { + "field": { + "description": "Field description", + "type": "string", + "other_field": "preserve this", + }, + }, + }, + }, + }, + } + + result = ConfigMetadataI18n.convert_to_i18n_keys(metadata) + + assert ( + result["group"]["metadata"]["section"]["items"]["field"]["other_field"] + == "preserve this" + ) + + def test_convert_to_i18n_keys_with_name(self): + """Test converting metadata with name field.""" + metadata = { + "group": { + "metadata": { + "section": { + "items": { + "field": { + "name": "Field Name", + "type": "string", + }, + }, + }, + }, + }, + } + + result = ConfigMetadataI18n.convert_to_i18n_keys(metadata) + + assert ( + result["group"]["metadata"]["section"]["items"]["field"]["name"] + == "group.section.field.name" + ) diff --git a/tests/unit/test_cron_manager.py b/tests/unit/test_cron_manager.py new file mode 100644 index 000000000..b111384ac --- /dev/null +++ b/tests/unit/test_cron_manager.py @@ -0,0 +1,504 @@ +"""Tests for CronJobManager.""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.cron.manager import CronJobManager +from astrbot.core.db.po import CronJob + + +@pytest.fixture +def mock_db(): + """Create a mock database.""" + db = MagicMock() + db.create_cron_job = AsyncMock() + db.get_cron_job = AsyncMock() + db.update_cron_job = AsyncMock() + db.delete_cron_job = AsyncMock() + db.list_cron_jobs = AsyncMock(return_value=[]) + return db + + +@pytest.fixture +def mock_context(): + """Create a mock Context.""" + ctx = MagicMock() + ctx.get_config = MagicMock(return_value={"admins_id": []}) + ctx.conversation_manager = MagicMock() + return ctx + + +@pytest.fixture +def cron_manager(mock_db): + """Create a CronJobManager instance.""" + return CronJobManager(mock_db) + + +@pytest.fixture +def sample_cron_job(): + """Create a sample CronJob.""" + return CronJob( + job_id="test-job-id", + name="Test Job", + job_type="basic", + cron_expression="0 9 * * *", + timezone="UTC", + payload={"key": "value"}, + description="A test job", + enabled=True, + persistent=True, + run_once=False, + status="pending", + ) + + +class TestCronJobManagerInit: + """Tests for CronJobManager initialization.""" + + def test_init(self, mock_db): + """Test CronJobManager initialization.""" + manager = CronJobManager(mock_db) + + assert manager.db == mock_db + assert manager._basic_handlers == {} + assert manager._started is False + + +class TestCronJobManagerStart: + """Tests for CronJobManager.start method.""" + + @pytest.mark.asyncio + async def test_start(self, cron_manager, mock_db, mock_context): + """Test starting the cron manager.""" + mock_db.list_cron_jobs.return_value = [] + + await cron_manager.start(mock_context) + + assert cron_manager._started is True + assert cron_manager.ctx == mock_context + + @pytest.mark.asyncio + async def test_start_idempotent(self, cron_manager, mock_db, mock_context): + """Test that start is idempotent.""" + mock_db.list_cron_jobs.return_value = [] + + await cron_manager.start(mock_context) + await cron_manager.start(mock_context) + + # Should only sync once + assert mock_db.list_cron_jobs.call_count == 1 + + +class TestCronJobManagerShutdown: + """Tests for CronJobManager.shutdown method.""" + + @pytest.mark.asyncio + async def test_shutdown(self, cron_manager, mock_db, mock_context): + """Test shutting down the cron manager.""" + mock_db.list_cron_jobs.return_value = [] + await cron_manager.start(mock_context) + + await cron_manager.shutdown() + + assert cron_manager._started is False + + @pytest.mark.asyncio + async def test_shutdown_when_not_started(self, cron_manager): + """Test shutdown when not started.""" + # Should not raise + await cron_manager.shutdown() + + +class TestAddBasicJob: + """Tests for add_basic_job method.""" + + @pytest.mark.asyncio + async def test_add_basic_job(self, cron_manager, mock_db, sample_cron_job): + """Test adding a basic cron job.""" + mock_db.create_cron_job.return_value = sample_cron_job + + handler = MagicMock() + + result = await cron_manager.add_basic_job( + name="Test Job", + cron_expression="0 9 * * *", + handler=handler, + description="A test job", + enabled=True, + ) + + assert result == sample_cron_job + assert sample_cron_job.job_id in cron_manager._basic_handlers + mock_db.create_cron_job.assert_called_once() + + @pytest.mark.asyncio + async def test_add_basic_job_disabled(self, cron_manager, mock_db, sample_cron_job): + """Test adding a disabled basic cron job.""" + sample_cron_job.enabled = False + mock_db.create_cron_job.return_value = sample_cron_job + + handler = MagicMock() + + result = await cron_manager.add_basic_job( + name="Test Job", + cron_expression="0 9 * * *", + handler=handler, + enabled=False, + ) + + assert result == sample_cron_job + assert sample_cron_job.job_id in cron_manager._basic_handlers + + @pytest.mark.asyncio + async def test_add_basic_job_with_timezone(self, cron_manager, mock_db, sample_cron_job): + """Test adding a basic job with timezone.""" + mock_db.create_cron_job.return_value = sample_cron_job + + handler = MagicMock() + + await cron_manager.add_basic_job( + name="Test Job", + cron_expression="0 9 * * *", + handler=handler, + timezone="Asia/Shanghai", + ) + + mock_db.create_cron_job.assert_called_once() + call_kwargs = mock_db.create_cron_job.call_args.kwargs + assert call_kwargs["timezone"] == "Asia/Shanghai" + + +class TestAddActiveJob: + """Tests for add_active_job method.""" + + @pytest.mark.asyncio + async def test_add_active_job(self, cron_manager, mock_db, sample_cron_job): + """Test adding an active agent cron job.""" + sample_cron_job.job_type = "active_agent" + mock_db.create_cron_job.return_value = sample_cron_job + + result = await cron_manager.add_active_job( + name="Test Active Job", + cron_expression="0 9 * * *", + payload={"session": "test:group:123"}, + ) + + assert result == sample_cron_job + mock_db.create_cron_job.assert_called_once() + + @pytest.mark.asyncio + async def test_add_active_job_run_once(self, cron_manager, mock_db, sample_cron_job): + """Test adding a run-once active job.""" + sample_cron_job.job_type = "active_agent" + sample_cron_job.run_once = True + mock_db.create_cron_job.return_value = sample_cron_job + + run_at = datetime.now(timezone.utc) + timedelta(days=30) + + result = await cron_manager.add_active_job( + name="Test Run Once Job", + cron_expression=None, + payload={"session": "test:group:123"}, + run_once=True, + run_at=run_at, + ) + + assert result == sample_cron_job + call_kwargs = mock_db.create_cron_job.call_args.kwargs + assert call_kwargs["run_once"] is True + + +class TestUpdateJob: + """Tests for update_job method.""" + + @pytest.mark.asyncio + async def test_update_job(self, cron_manager, mock_db, sample_cron_job): + """Test updating a cron job.""" + updated_job = CronJob( + job_id="test-job-id", + name="Updated Job", + job_type="basic", + cron_expression="0 10 * * *", + enabled=False, # Disabled to avoid scheduling + ) + mock_db.update_cron_job.return_value = updated_job + + result = await cron_manager.update_job("test-job-id", name="Updated Job") + + assert result == updated_job + mock_db.update_cron_job.assert_called() + + @pytest.mark.asyncio + async def test_update_job_not_found(self, cron_manager, mock_db): + """Test updating a non-existent job.""" + mock_db.update_cron_job.return_value = None + + result = await cron_manager.update_job("non-existent", name="Updated") + + assert result is None + + +class TestDeleteJob: + """Tests for delete_job method.""" + + @pytest.mark.asyncio + async def test_delete_job(self, cron_manager, mock_db): + """Test deleting a cron job.""" + cron_manager._basic_handlers["test-job-id"] = MagicMock() + + await cron_manager.delete_job("test-job-id") + + mock_db.delete_cron_job.assert_called_once_with("test-job-id") + assert "test-job-id" not in cron_manager._basic_handlers + + +class TestListJobs: + """Tests for list_jobs method.""" + + @pytest.mark.asyncio + async def test_list_all_jobs(self, cron_manager, mock_db, sample_cron_job): + """Test listing all jobs.""" + mock_db.list_cron_jobs.return_value = [sample_cron_job] + + result = await cron_manager.list_jobs() + + assert len(result) == 1 + mock_db.list_cron_jobs.assert_called_once_with(None) + + @pytest.mark.asyncio + async def test_list_jobs_by_type(self, cron_manager, mock_db, sample_cron_job): + """Test listing jobs by type.""" + mock_db.list_cron_jobs.return_value = [sample_cron_job] + + result = await cron_manager.list_jobs(job_type="basic") + + assert len(result) == 1 + mock_db.list_cron_jobs.assert_called_once_with("basic") + + +class TestSyncFromDb: + """Tests for sync_from_db method.""" + + @pytest.mark.asyncio + async def test_sync_from_db_empty(self, cron_manager, mock_db): + """Test syncing from empty database.""" + mock_db.list_cron_jobs.return_value = [] + + await cron_manager.sync_from_db() + + mock_db.list_cron_jobs.assert_called_once() + + @pytest.mark.asyncio + async def test_sync_from_db_skips_disabled(self, cron_manager, mock_db, sample_cron_job): + """Test that sync skips disabled jobs.""" + sample_cron_job.enabled = False + mock_db.list_cron_jobs.return_value = [sample_cron_job] + + with patch.object(cron_manager, "_schedule_job") as mock_schedule: + await cron_manager.sync_from_db() + + mock_db.list_cron_jobs.assert_called_once() + mock_schedule.assert_not_called() + + @pytest.mark.asyncio + async def test_sync_from_db_skips_non_persistent(self, cron_manager, mock_db, sample_cron_job): + """Test that sync skips non-persistent jobs.""" + sample_cron_job.persistent = False + mock_db.list_cron_jobs.return_value = [sample_cron_job] + + with patch.object(cron_manager, "_schedule_job") as mock_schedule: + await cron_manager.sync_from_db() + + mock_db.list_cron_jobs.assert_called_once() + mock_schedule.assert_not_called() + + @pytest.mark.asyncio + async def test_sync_from_db_basic_without_handler( + self, cron_manager, mock_db, sample_cron_job + ): + """Test that sync warns for basic jobs without handlers.""" + mock_db.list_cron_jobs.return_value = [sample_cron_job] + + with patch("astrbot.core.cron.manager.logger") as mock_logger: + await cron_manager.sync_from_db() + + mock_logger.warning.assert_called() + + +class TestRemoveScheduled: + """Tests for _remove_scheduled method.""" + + @pytest.mark.asyncio + async def test_remove_scheduled_existing(self, cron_manager, mock_context): + """Test removing a scheduled job.""" + # Start the scheduler first + job = CronJob( + job_id="test-job-id", + name="Test", + job_type="active_agent", + cron_expression="0 9 * * *", + enabled=True, + persistent=True, + ) + mock_db = cron_manager.db + mock_db.list_cron_jobs = AsyncMock(return_value=[job]) + await cron_manager.start(mock_context) + + # Then remove it + cron_manager._remove_scheduled("test-job-id") + + # Should not raise + + def test_remove_scheduled_nonexistent(self, cron_manager): + """Test removing a non-existent job.""" + # Should not raise + cron_manager._remove_scheduled("non-existent") + + +class TestScheduleJob: + """Tests for _schedule_job method.""" + + @pytest.mark.asyncio + async def test_schedule_job_basic(self, cron_manager, sample_cron_job, mock_context): + """Test scheduling a basic job.""" + mock_db = cron_manager.db + mock_db.list_cron_jobs = AsyncMock(return_value=[]) + mock_db.update_cron_job = AsyncMock() + await cron_manager.start(mock_context) + cron_manager._schedule_job(sample_cron_job) + + # Verify job was added to scheduler + assert cron_manager.scheduler.get_job("test-job-id") is not None + + @pytest.mark.asyncio + async def test_schedule_job_with_timezone(self, cron_manager, sample_cron_job, mock_context): + """Test scheduling a job with timezone.""" + sample_cron_job.timezone = "America/New_York" + mock_db = cron_manager.db + mock_db.list_cron_jobs = AsyncMock(return_value=[]) + mock_db.update_cron_job = AsyncMock() + await cron_manager.start(mock_context) + cron_manager._schedule_job(sample_cron_job) + + assert cron_manager.scheduler.get_job("test-job-id") is not None + + @pytest.mark.asyncio + async def test_schedule_job_invalid_timezone(self, cron_manager, sample_cron_job, mock_context): + """Test scheduling a job with invalid timezone.""" + sample_cron_job.timezone = "Invalid/Timezone" + mock_db = cron_manager.db + mock_db.list_cron_jobs = AsyncMock(return_value=[]) + mock_db.update_cron_job = AsyncMock() + + with patch("astrbot.core.cron.manager.logger") as mock_logger: + await cron_manager.start(mock_context) + cron_manager._schedule_job(sample_cron_job) + + # Should still schedule with system timezone + assert cron_manager.scheduler.get_job("test-job-id") is not None + mock_logger.warning.assert_called() + + @pytest.mark.asyncio + async def test_schedule_job_run_once(self, cron_manager, mock_context): + """Test scheduling a run-once job.""" + future_date = datetime.now(timezone.utc) + timedelta(days=30) + job = CronJob( + job_id="run-once-job", + name="Run Once", + job_type="active_agent", + cron_expression=None, + enabled=True, + run_once=True, + payload={"run_at": future_date.isoformat()}, + ) + mock_db = cron_manager.db + mock_db.list_cron_jobs = AsyncMock(return_value=[]) + mock_db.update_cron_job = AsyncMock() + await cron_manager.start(mock_context) + cron_manager._schedule_job(job) + + assert cron_manager.scheduler.get_job("run-once-job") is not None + + +class TestRunJob: + """Tests for _run_job method.""" + + @pytest.mark.asyncio + async def test_run_job_disabled(self, cron_manager, mock_db, sample_cron_job): + """Test running a disabled job.""" + sample_cron_job.enabled = False + mock_db.get_cron_job.return_value = sample_cron_job + + await cron_manager._run_job("test-job-id") + + # Should not update status + mock_db.update_cron_job.assert_not_called() + + @pytest.mark.asyncio + async def test_run_job_not_found(self, cron_manager, mock_db): + """Test running a non-existent job.""" + mock_db.get_cron_job.return_value = None + + await cron_manager._run_job("non-existent") + + # Should not update status + mock_db.update_cron_job.assert_not_called() + + +class TestRunBasicJob: + """Tests for _run_basic_job method.""" + + @pytest.mark.asyncio + async def test_run_basic_job_sync_handler(self, cron_manager, sample_cron_job): + """Test running a basic job with sync handler.""" + handler = MagicMock(return_value=None) + cron_manager._basic_handlers["test-job-id"] = handler + sample_cron_job.payload = {"arg1": "value1"} + + await cron_manager._run_basic_job(sample_cron_job) + + handler.assert_called_once_with(arg1="value1") + + @pytest.mark.asyncio + async def test_run_basic_job_async_handler(self, cron_manager, sample_cron_job): + """Test running a basic job with async handler.""" + async_handler = AsyncMock() + cron_manager._basic_handlers["test-job-id"] = async_handler + sample_cron_job.payload = {} + + await cron_manager._run_basic_job(sample_cron_job) + + async_handler.assert_called_once() + + @pytest.mark.asyncio + async def test_run_basic_job_no_handler(self, cron_manager, sample_cron_job): + """Test running a basic job without handler.""" + sample_cron_job.job_id = "no-handler-job" + + with pytest.raises(RuntimeError, match="handler not found"): + await cron_manager._run_basic_job(sample_cron_job) + + +class TestGetNextRunTime: + """Tests for _get_next_run_time method.""" + + @pytest.mark.asyncio + async def test_get_next_run_time_existing_job(self, cron_manager, sample_cron_job, mock_context): + """Test getting next run time for existing job.""" + mock_db = cron_manager.db + mock_db.list_cron_jobs = AsyncMock(return_value=[]) + mock_db.update_cron_job = AsyncMock() + await cron_manager.start(mock_context) + cron_manager._schedule_job(sample_cron_job) + + next_run = cron_manager._get_next_run_time("test-job-id") + + assert next_run is not None + + def test_get_next_run_time_nonexistent(self, cron_manager): + """Test getting next run time for non-existent job.""" + next_run = cron_manager._get_next_run_time("non-existent") + + assert next_run is None diff --git a/tests/unit/test_star_base.py b/tests/unit/test_star_base.py new file mode 100644 index 000000000..d78737943 --- /dev/null +++ b/tests/unit/test_star_base.py @@ -0,0 +1,198 @@ +"""Tests for astrbot.core.star.base module.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +class TestStarBase: + """Test cases for the Star base class.""" + + def test_star_class_exists(self): + """Test that Star class can be imported.""" + from astrbot.core.star import Star + + assert Star is not None + + def test_star_init_with_context(self): + """Test Star initialization with a context-like object.""" + from astrbot.core.star import Star + + # Create a mock context with get_config method + mock_context = MagicMock() + mock_context.get_config.return_value = MagicMock() + + # Create a concrete Star subclass for testing + class TestStar(Star): + name = "test_star" + author = "test_author" + + star = TestStar(context=mock_context) + + assert star.context is mock_context + + @pytest.mark.asyncio + async def test_text_to_image_with_config(self): + """Test text_to_image method with valid config.""" + from astrbot.core.star import Star + + mock_context = MagicMock() + mock_config = MagicMock() + mock_config.get.return_value = "default_template" + mock_context.get_config.return_value = mock_config + + class TestStar(Star): + name = "test_star" + author = "test_author" + + star = TestStar(context=mock_context) + + with patch( + "astrbot.core.star.base.html_renderer.render_t2i", + new_callable=AsyncMock, + ) as mock_render: + mock_render.return_value = "http://example.com/image.png" + result = await star.text_to_image("test text", return_url=True) + + mock_render.assert_called_once_with( + "test text", + return_url=True, + template_name="default_template", + ) + assert result == "http://example.com/image.png" + + @pytest.mark.asyncio + async def test_text_to_image_without_config(self): + """Test text_to_image method when get_config returns None.""" + from astrbot.core.star import Star + + mock_context = MagicMock() + mock_context.get_config.return_value = None + + class TestStar(Star): + name = "test_star" + author = "test_author" + + star = TestStar(context=mock_context) + + with patch( + "astrbot.core.star.base.html_renderer.render_t2i", + new_callable=AsyncMock, + ) as mock_render: + mock_render.return_value = "http://example.com/image.png" + result = await star.text_to_image("test text", return_url=False) + + mock_render.assert_called_once_with( + "test text", + return_url=False, + template_name=None, + ) + assert result == "http://example.com/image.png" + + @pytest.mark.asyncio + async def test_html_render(self): + """Test html_render method.""" + from astrbot.core.star import Star + + mock_context = MagicMock() + + class TestStar(Star): + name = "test_star" + author = "test_author" + + star = TestStar(context=mock_context) + + with patch( + "astrbot.core.star.base.html_renderer.render_custom_template", + new_callable=AsyncMock, + ) as mock_render: + mock_render.return_value = "http://example.com/rendered.png" + result = await star.html_render( + "{{ data }}", + {"data": "test"}, + return_url=True, + ) + + mock_render.assert_called_once_with( + "{{ data }}", + {"data": "test"}, + return_url=True, + options=None, + ) + assert result == "http://example.com/rendered.png" + + @pytest.mark.asyncio + async def test_initialize_and_terminate(self): + """Test that initialize and terminate methods can be overridden.""" + from astrbot.core.star import Star + + class TestStar(Star): + name = "test_star" + author = "test_author" + + async def initialize(self) -> None: + self.initialized = True + + async def terminate(self) -> None: + self.terminated = True + + mock_context = MagicMock() + star = TestStar(context=mock_context) + + await star.initialize() + assert star.initialized is True + + await star.terminate() + assert star.terminated is True + + def test_star_metadata_registration(self): + """Test that Star subclass is automatically registered.""" + from astrbot.core.star import star_map, star_registry + from astrbot.core.star.star import StarMetadata + + # Clear any previous registration for this test module + module_path = __name__ + + class UniqueTestStar: + """Not a Star subclass, should not be registered.""" + pass + + # Verify Star subclass gets registered + initial_count = len(star_registry) + + # Note: This test verifies the __init_subclass__ mechanism + # The actual registration happens when a class inherits from Star + assert len(star_registry) >= initial_count + + +class TestNoCircularImports: + """Test that there are no circular import issues.""" + + def test_import_star_module(self): + """Test that star module can be imported without circular import errors.""" + import astrbot.core.star + + assert astrbot.core.star is not None + + def test_import_pipeline_module(self): + """Test that pipeline module can be imported without circular import errors.""" + import astrbot.core.pipeline + + assert astrbot.core.pipeline is not None + + def test_import_both_modules(self): + """Test that both modules can be imported together.""" + import astrbot.core.pipeline + import astrbot.core.star + + # Verify key exports are available + from astrbot.core.star import Context, Star, PluginManager + + assert Context is not None + assert Star is not None + assert PluginManager is not None + + def test_import_pipeline_context(self): + """Test that PipelineContext can be imported.""" + from astrbot.core.pipeline.context import PipelineContext + + assert PipelineContext is not None