Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion astrbot/core/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

from importlib import import_module
from typing import Any
from typing import TYPE_CHECKING, Any

from astrbot.core.message.message_event_result import (
EventResultType,
Expand Down Expand Up @@ -56,6 +56,18 @@
),
}

# Type-checking imports to satisfy static analyzers for __all__ exports
if TYPE_CHECKING:
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage

__all__ = [
"ContentSafetyCheckStage",
"EventResultType",
Expand Down
7 changes: 5 additions & 2 deletions astrbot/core/pipeline/context.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING

from astrbot.core.config import AstrBotConfig

from .context_utils import call_event_hook, call_handler

if TYPE_CHECKING:
from astrbot.core.star import PluginManager


@dataclass
class PipelineContext:
"""上下文对象,包含管道执行所需的上下文信息"""

astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: Any # 插件管理器对象
plugin_manager: PluginManager # 插件管理器对象
astrbot_config_id: str
call_handler = call_handler
call_event_hook = call_event_hook
2 changes: 0 additions & 2 deletions astrbot/core/star/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@

if TYPE_CHECKING:
from astrbot.core.cron.manager import CronJobManager
else:
CronJobManager = Any


class PlatformManagerProtocol(Protocol):
Expand Down
13 changes: 8 additions & 5 deletions astrbot/core/star/register/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
from typing import TYPE_CHECKING, Any

import docstring_parser

Expand All @@ -15,6 +15,9 @@
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools

if TYPE_CHECKING:
from astrbot.core.astr_agent_context import AstrAgentContext

from ..filter.command import CommandFilter
from ..filter.command_group import CommandGroupFilter
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
Expand Down Expand Up @@ -616,15 +619,15 @@ def llm_tool(self, *args, **kwargs):
kwargs["registering_agent"] = self
return register_llm_tool(*args, **kwargs)

def __init__(self, agent: Agent[Any]) -> None:
def __init__(self, agent: Agent[AstrAgentContext]) -> None:
self._agent = agent


def register_agent(
name: str,
instruction: str,
tools: list[str | FunctionTool] | None = None,
run_hooks: BaseAgentRunHooks[Any] | None = None,
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
):
"""注册一个 Agent

Expand All @@ -638,12 +641,12 @@ def register_agent(
tools_ = tools or []

def decorator(awaitable: Callable[..., Awaitable[Any]]):
AstrAgent = Agent[Any]
AstrAgent = Agent[AstrAgentContext]
agent = AstrAgent(
name=name,
instructions=instruction,
tools=tools_,
run_hooks=run_hooks or BaseAgentRunHooks[Any](),
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
)
handoff_tool = HandoffTool(agent=agent)
handoff_tool.handler = awaitable
Expand Down
257 changes: 256 additions & 1 deletion tests/fixtures/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
提供统一的测试辅助工具,减少测试代码重复。
"""

from typing import Any
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
from unittest.mock import AsyncMock, MagicMock

from astrbot.core.message.components import BaseMessageComponent
Expand Down Expand Up @@ -330,3 +333,255 @@ def create_mock_llm_response(
tools_call_ids=tools_call_ids or [],
usage=TokenUsage(input_other=10, output=5),
)


# ============================================================
# 测试插件辅助函数
# ============================================================


@dataclass
class MockPluginConfig:
"""测试插件配置。
用于创建和管理测试用的模拟插件。
Attributes:
name: 插件名称
author: 作者
description: 描述
version: 版本
repo: 仓库 URL
main_code: main.py 的代码内容
requirements: 依赖列表
has_readme: 是否创建 README.md
readme_content: README.md 内容
"""

name: str = "test_plugin"
author: str = "Test Author"
description: str = "A test plugin for unit testing"
version: str = "1.0.0"
repo: str = "https://github.com/test/test_plugin"
main_code: str = ""
requirements: list[str] = field(default_factory=list)
has_readme: bool = True
readme_content: str = "# Test Plugin\n\nThis is a test plugin."


# 默认的插件主代码模板
DEFAULT_PLUGIN_MAIN_TEMPLATE = '''
from astrbot.api import star
class Main(star.Star):
"""测试插件主类。"""
def __init__(self, context):
super().__init__(context)
self.name = "{plugin_name}"
async def initialize(self):
"""初始化插件。"""
pass
async def terminate(self):
"""终止插件。"""
pass
'''


class MockPluginBuilder:
"""测试插件构建器。
用于创建、管理和清理测试用的模拟插件。支持任意插件的模拟创建。
Example:
# 创建一个简单的测试插件
builder = MockPluginBuilder(plugin_store_path)
plugin_dir = builder.create("my_test_plugin")
# 创建自定义配置的插件
config = MockPluginConfig(
name="custom_plugin",
version="2.0.0",
main_code="print('hello')",
)
plugin_dir = builder.create(config)
# 清理插件
builder.cleanup("my_test_plugin")
"""

def __init__(self, plugin_store_path: str | Path):
"""初始化构建器。
Args:
plugin_store_path: 插件存储路径 (通常是 data/plugins)
"""
self.plugin_store_path = Path(plugin_store_path)
self._created_plugins: set[str] = set()

def create(
self,
plugin_config: str | MockPluginConfig | None = None,
**kwargs,
) -> Path:
"""创建模拟插件。
Args:
plugin_config: 插件名称字符串、MockPluginConfig 对象或 None
**kwargs: 如果 plugin_config 是字符串或 None,这些参数用于构建 MockPluginConfig
Returns:
Path: 创建的插件目录路径
"""
# 处理不同类型的输入
if plugin_config is None:
config = MockPluginConfig(**kwargs)
elif isinstance(plugin_config, str):
config = MockPluginConfig(name=plugin_config, **kwargs)
elif isinstance(plugin_config, MockPluginConfig):
config = plugin_config
else:
raise TypeError(f"Invalid plugin_config type: {type(plugin_config)}")

# 创建插件目录
plugin_dir = self.plugin_store_path / config.name
plugin_dir.mkdir(parents=True, exist_ok=True)

# 创建 metadata.yaml
metadata_content = "\n".join(
[
f"name: {config.name}",
f"author: {config.author}",
f"desc: {config.description}",
f"version: {config.version}",
f"repo: {config.repo}",
]
)
(plugin_dir / "metadata.yaml").write_text(
metadata_content + "\n", encoding="utf-8"
)

# 创建 main.py
main_code = config.main_code or DEFAULT_PLUGIN_MAIN_TEMPLATE.format(
plugin_name=config.name
)
(plugin_dir / "main.py").write_text(main_code, encoding="utf-8")

# 创建 requirements.txt(如果有依赖)
if config.requirements:
(plugin_dir / "requirements.txt").write_text(
"\n".join(config.requirements) + "\n", encoding="utf-8"
)

# 创建 README.md(如果需要)
if config.has_readme:
(plugin_dir / "README.md").write_text(
config.readme_content, encoding="utf-8"
)

# 记录创建的插件
self._created_plugins.add(config.name)

return plugin_dir

def cleanup(self, plugin_name: str | None = None) -> None:
"""清理插件。
Args:
plugin_name: 要清理的插件名称,如果为 None 则清理所有由本构建器创建的插件
"""
if plugin_name:
plugins_to_clean = {plugin_name}
else:
plugins_to_clean = self._created_plugins.copy()

for name in plugins_to_clean:
plugin_dir = self.plugin_store_path / name
if plugin_dir.exists():
shutil.rmtree(plugin_dir)
self._created_plugins.discard(name)

def cleanup_all(self) -> None:
"""清理所有由本构建器创建的插件。"""
self.cleanup(None)

def get_plugin_path(self, plugin_name: str) -> Path:
"""获取插件路径。
Args:
plugin_name: 插件名称
Returns:
Path: 插件目录路径
"""
return self.plugin_store_path / plugin_name

@property
def created_plugins(self) -> set[str]:
"""获取已创建的插件名称集合。"""
return self._created_plugins.copy()


def create_mock_updater_install(
plugin_builder: MockPluginBuilder,
repo_to_plugin: dict[str, str] | None = None,
) -> Callable:
"""创建模拟的 updater.install 方法。
Args:
plugin_builder: MockPluginBuilder 实例
repo_to_plugin: 仓库 URL 到插件名称的映射,格式: {"https://github.com/user/repo": "plugin_name"}
Returns:
Callable: 异步函数,可用于 monkeypatch.setattr
"""

async def mock_install(repo_url: str, proxy: str = "") -> str:
"""Mock updater.install 方法。"""
# 查找插件名称
plugin_name = None
if repo_to_plugin:
plugin_name = repo_to_plugin.get(repo_url)

# 如果没有映射,尝试从 URL 提取插件名
if not plugin_name:
# 从 https://github.com/user/plugin_name 提取 plugin_name
parts = repo_url.rstrip("/").split("/")
plugin_name = parts[-1] if parts else "unknown_plugin"

# 创建插件目录
config = MockPluginConfig(name=plugin_name, repo=repo_url)
plugin_dir = plugin_builder.create(config)
return str(plugin_dir)

return mock_install


def create_mock_updater_update(
plugin_builder: MockPluginBuilder,
update_callback: Callable | None = None,
) -> Callable:
"""创建模拟的 updater.update 方法。
Args:
plugin_builder: MockPluginBuilder 实例
update_callback: 更新回调函数,接收 plugin 参数
Returns:
Callable: 异步函数,可用于 monkeypatch.setattr
"""

async def mock_update(plugin, proxy: str = "") -> None:
"""Mock updater.update 方法。"""
plugin_dir = plugin_builder.get_plugin_path(plugin.name)

# 创建更新标记文件
(plugin_dir / ".updated").write_text("ok", encoding="utf-8")

# 调用回调
if update_callback:
update_callback(plugin)

return mock_update
Loading