diff --git a/.github/workflows/dashboard_ci.yml b/.github/workflows/dashboard_ci.yml index 5be935ebce..97a035e7cf 100644 --- a/.github/workflows/dashboard_ci.yml +++ b/.github/workflows/dashboard_ci.yml @@ -44,7 +44,7 @@ jobs: !dist/**/*.md - name: Create GitHub Release - if: github.event_name == 'push' + if: github.event_name == 'push' && github.repository == 'AstrBotDevs/AstrBot' uses: ncipollo/release-action@v1 with: tag: release-${{ github.sha }} diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 18c8d49269..d5865a1ab3 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -23,7 +23,7 @@ jobs: uses: actions/checkout@v6 with: fetch-depth: 1 - fetch-tag: true + fetch-tags: true - name: Check for new commits today if: github.event_name == 'schedule' @@ -121,7 +121,7 @@ jobs: uses: actions/checkout@v6 with: fetch-depth: 1 - fetch-tag: true + fetch-tags: true - name: Get latest tag (only on manual trigger) id: get-latest-tag diff --git a/astrbot/__init__.py b/astrbot/__init__.py index 73d64f303f..a13ca37a20 100644 --- a/astrbot/__init__.py +++ b/astrbot/__init__.py @@ -1,3 +1,9 @@ -from .core.log import LogManager +def __getattr__(name: str): + """延迟初始化 logger,避免 CLI 客户端导入时触发 core 全量初始化""" + if name == "logger": + from .core.log import LogManager -logger = LogManager.GetLogger(log_name="astrbot") + _logger = LogManager.GetLogger(log_name="astrbot") + globals()["logger"] = _logger + return _logger + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/astrbot/cli/__main__.py b/astrbot/cli/__main__.py index 40c46de79d..e773b994f4 100644 --- a/astrbot/cli/__main__.py +++ b/astrbot/cli/__main__.py @@ -5,7 +5,7 @@ import click from . import __version__ -from .commands import conf, init, plug, run +from .commands import conf, init, plug, restart, run, stop logo_tmpl = r""" ___ _______.___________..______ .______ ______ .___________. @@ -51,6 +51,8 @@ def help(command_name: str | None) -> None: cli.add_command(init) cli.add_command(run) +cli.add_command(restart) +cli.add_command(stop) cli.add_command(help) cli.add_command(plug) cli.add_command(conf) diff --git a/astrbot/cli/client/__init__.py b/astrbot/cli/client/__init__.py new file mode 100644 index 0000000000..9ef8b3c4e6 --- /dev/null +++ b/astrbot/cli/client/__init__.py @@ -0,0 +1,4 @@ +"""AstrBot CLI Client - Socket客户端工具 + +用于通过Unix Socket或TCP Socket与AstrBot CLIPlatformAdapter通信 +""" diff --git a/astrbot/cli/client/__main__.py b/astrbot/cli/client/__main__.py new file mode 100644 index 0000000000..2a2a60b51a --- /dev/null +++ b/astrbot/cli/client/__main__.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +AstrBot CLI Client - 跨平台Socket客户端 + +支持Unix Socket和TCP Socket连接到CLIPlatformAdapter + +用法: + astr "你好" + astr "/help" + echo "你好" | astr +""" + +import io +import sys + +import click + +# Windows UTF-8 输出支持(仅在非测试环境下替换,避免与 pytest capture 冲突) +if sys.platform == "win32" and "pytest" not in sys.modules: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") + sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace") + + +EPILOG = """ +命令总览 (所有命令均支持 -j/--json 输出原始 JSON): +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + + [发送消息] + astr 直接发送消息(隐式调用 send) + astr send 显式发送消息给 AstrBot + astr send -t 60 设置超时(秒),默认 30 + echo "msg" | astr 从管道读取消息 + + [会话管理] astr conv <子命令> + astr conv ls [page] 列出所有对话(可翻页) + astr conv new 创建新对话并切换到该对话 + astr conv switch 按序号切换对话(序号见 conv ls) + astr conv del 删除当前对话 + astr conv rename 重命名当前对话 + astr conv reset 清除当前对话的 LLM 上下文 + astr conv history [page] 查看当前对话的聊天记录 + + [插件管理] astr plugin <子命令> + astr plugin ls 列出已安装插件及状态 + astr plugin on 启用指定插件 + astr plugin off 禁用指定插件 + astr plugin help [name] 查看插件帮助(省略 name 则查看全部) + + [LLM 配置] + astr provider [index] 查看 Provider 列表 / 按序号切换 + astr model [index|name] 查看模型列表 / 按序号或名称切换 + astr key [index] 查看 API Key 列表 / 按序号切换 + + [快捷命令] + astr help 查看 AstrBot 服务端内置指令帮助 + astr sid 查看当前会话 ID 和管理员 ID + astr t2i 开关文字转图片(会话级别) + astr tts 开关文字转语音(会话级别) + + [日志查看] + astr log 读取最近 100 行日志(直接读文件) + astr log --lines 50 指定行数 + astr log --level ERROR 按级别过滤 (DEBUG/INFO/WARNING/ERROR) + astr log --pattern "关键词" 按关键词过滤(--regex 启用正则) + astr log --socket 通过 Socket 从服务端获取日志 + + [调试工具] + astr ping [-c N] 测试连通性和延迟(-c 指定次数) + astr status 查看连接配置、Token、服务状态 + astr test echo 发送消息并查看完整回环响应 + astr test plugin 测试插件命令(发送 / ) + 例: astr test plugin probe cpu + → 实际发送 /probe cpu + + [函数工具] astr tool <子命令> + astr tool ls 列出所有注册的函数工具 + astr tool ls -o plugin 按来源过滤(plugin/mcp/builtin) + astr tool info 查看工具详细信息和参数 + astr tool call [json] 调用工具,例: astr tool call my_func '{"k":"v"}' + + [跨会话浏览] astr session <子命令> + astr session ls 列出所有会话(跨平台:QQ/TG/微信/CLI…) + astr session ls -P qq 按平台过滤(-q 搜索关键词) + astr session convs 查看该会话下的对话列表 + astr session history 查看聊天记录(-c 指定对话,默认当前) + + [批量执行] + astr batch 从文件逐行读取并执行命令 + (# 开头为注释,空行跳过) + +兼容旧用法: astr --log = astr log | astr -j "msg" = astr send -j "msg" + +连接: 自动读取 data/.cli_connection 和 data/.cli_token + 需在 AstrBot 根目录运行,或设置 ASTRBOT_ROOT 环境变量 +""" + + +class RawEpilogGroup(click.Group): + """保留 epilog 原始格式的 Group,同时支持默认子命令路由""" + + def format_epilog(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: + if self.epilog: + formatter.write("\n") + for line in self.epilog.split("\n"): + formatter.write(line + "\n") + + # send 子命令的 option 前缀,用于识别 astr -j "你好" 等旧用法 + _send_opts = {"-j", "--json", "-t", "--timeout", "-s", "--socket"} + # --log 旧用法映射到 log 子命令 + _log_flag = {"--log"} + + def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]: + if args: + first = args[0] + if first in self._log_flag: + # astr --log ... → astr log ... + args = ["log"] + args[1:] + elif first not in self.commands: + if not first.startswith("-") or first in self._send_opts: + # astr 你好 / astr -j "你好" → astr send ... + args = ["send"] + args + return super().parse_args(ctx, args) + + +@click.group( + cls=RawEpilogGroup, + invoke_without_command=True, + epilog=EPILOG, +) +@click.pass_context +def main(ctx: click.Context) -> None: + """AstrBot CLI Client - 与 AstrBot 交互的命令行工具""" + if ctx.invoked_subcommand is None: + # 无子命令时,检查 stdin 是否有管道输入 + if not sys.stdin.isatty(): + message = sys.stdin.read().strip() + if message: + from .commands.send import do_send + + do_send(message, None, 30.0, False) + return + click.echo(ctx.get_help()) + + +# 注册所有子命令 +from .commands import register_commands # noqa: E402 + +register_commands(main) + + +if __name__ == "__main__": + main() diff --git a/astrbot/cli/client/commands/__init__.py b/astrbot/cli/client/commands/__init__.py new file mode 100644 index 0000000000..a59fecbd80 --- /dev/null +++ b/astrbot/cli/client/commands/__init__.py @@ -0,0 +1,97 @@ +"""命令注册模块 - 将所有子命令注册到主 CLI group""" + +import click + +from .conv import conv +from .debug import ping, status, test +from .log import log +from .plugin import plugin +from .provider import key, model, provider +from .send import send +from .session import session +from .tool import tool + + +def register_commands(group): + """将所有子命令注册到 CLI group + + Args: + group: click.Group 实例 + """ + # 核心命令 + group.add_command(send) + group.add_command(log) + + # 会话管理 + group.add_command(conv) + + # 跨会话浏览 + group.add_command(session) + + # 插件管理 + group.add_command(plugin) + + # Provider/Model/Key + group.add_command(provider) + group.add_command(model) + group.add_command(key) + + # 调试工具 + group.add_command(ping) + group.add_command(status) + group.add_command(test) + + # 函数工具管理 + group.add_command(tool) + + # 快捷别名(独立命令,映射到 send /cmd) + _register_aliases(group) + + +def _register_aliases(group): + """注册快捷别名命令""" + from .. import connection, output + + @group.command(name="help", help="查看 AstrBot 内置命令帮助") + @click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") + def help_cmd(use_json): + response = connection.send_message("/help") + output.output_response(response, use_json) + + @group.command(name="sid", help="查看当前会话 ID 和管理员 ID") + @click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") + def sid_cmd(use_json): + response = connection.send_message("/sid") + output.output_response(response, use_json) + + @group.command(name="t2i", help="开关文字转图片(会话级别)") + @click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") + def t2i_cmd(use_json): + response = connection.send_message("/t2i") + output.output_response(response, use_json) + + @group.command(name="tts", help="开关文字转语音(会话级别)") + @click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") + def tts_cmd(use_json): + response = connection.send_message("/tts") + output.output_response(response, use_json) + + @group.command(name="batch", help="从文件批量执行命令") + @click.argument("file", type=click.Path(exists=True)) + @click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") + def batch_cmd(file, use_json): + """从文件逐行读取并执行命令 + + \b + 示例: + astr batch commands.txt 批量执行文件中的命令 + """ + with open(file, encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line or line.startswith("#"): + continue + click.echo(f"[{line_num}] > {line}") + response = connection.send_message(line) + output.output_response(response, use_json) + click.echo("") diff --git a/astrbot/cli/client/commands/conv.py b/astrbot/cli/client/commands/conv.py new file mode 100644 index 0000000000..98f44c02c0 --- /dev/null +++ b/astrbot/cli/client/commands/conv.py @@ -0,0 +1,73 @@ +"""会话管理命令组 - astr conv""" + +import click + +from ..connection import send_message +from ..output import output_response + + +@click.group(help="会话管理 (子命令: ls/new/switch/del/rename/reset/history)") +def conv() -> None: + """会话管理命令组""" + + +@conv.command(name="ls", help="列出当前会话的所有对话") +@click.argument("page", default="", required=False) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def conv_ls(page: str, use_json: bool) -> None: + """列出对话""" + cmd = "/ls" if not page else f"/ls {page}" + response = send_message(cmd) + output_response(response, use_json) + + +@conv.command(name="new", help="创建新对话") +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def conv_new(use_json: bool) -> None: + """创建新对话""" + response = send_message("/new") + output_response(response, use_json) + + +@conv.command(name="switch", help="按序号切换对话") +@click.argument("index", type=int) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def conv_switch(index: int, use_json: bool) -> None: + """按序号切换对话""" + response = send_message(f"/switch {index}") + output_response(response, use_json) + + +@conv.command(name="del", help="删除当前对话") +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def conv_del(use_json: bool) -> None: + """删除当前对话""" + response = send_message("/del") + output_response(response, use_json) + + +@conv.command(name="rename", help="重命名当前对话") +@click.argument("name") +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def conv_rename(name: str, use_json: bool) -> None: + """重命名当前对话""" + response = send_message(f"/rename {name}") + output_response(response, use_json) + + +@conv.command(name="reset", help="重置当前 LLM 会话") +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def conv_reset(use_json: bool) -> None: + """重置当前 LLM 会话""" + response = send_message("/reset") + output_response(response, use_json) + + +@conv.command(name="history", help="查看对话记录") +@click.argument("page", default="", required=False) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def conv_history(page: str, use_json: bool) -> None: + """查看对话记录""" + cmd = "/history" if not page else f"/history {page}" + response = send_message(cmd) + output_response(response, use_json) diff --git a/astrbot/cli/client/commands/debug.py b/astrbot/cli/client/commands/debug.py new file mode 100644 index 0000000000..5f279716db --- /dev/null +++ b/astrbot/cli/client/commands/debug.py @@ -0,0 +1,135 @@ +"""调试工具命令 - astr ping / astr status / astr test""" + +import json +import os +import time + +import click + +from ..connection import ( + get_data_path, + load_auth_token, + load_connection_info, + send_message, +) +from ..output import output_response + + +@click.command(help="测试与 AstrBot 的连通性和延迟") +@click.option("-c", "--count", default=1, type=int, help="测试次数(默认 1)") +def ping(count: int) -> None: + """测试连通性和延迟 + + \b + 示例: + astr ping 单次测试 + astr ping -c 3 测试 3 次 + """ + for i in range(count): + start = time.time() + response = send_message("/help") + elapsed = (time.time() - start) * 1000 + + if response.get("status") == "success": + click.echo(f"pong: {elapsed:.0f}ms") + else: + error = response.get("error", "Unknown error") + click.echo(f"failed: {error}", err=True) + raise SystemExit(1) + + +@click.command(help="查看 AstrBot 连接状态") +def status() -> None: + """查看 AstrBot 连接状态 + + 检查连接文件、token、以及服务可达性 + """ + data_dir = get_data_path() + + # 检查连接文件 + connection_info = load_connection_info(data_dir) + if connection_info is not None: + conn_type = connection_info.get("type", "unknown") + if conn_type == "unix": + path = connection_info.get("path", "N/A") + click.echo("连接类型: Unix Socket") + click.echo(f"路径: {path}") + click.echo(f"文件存在: {os.path.exists(path)}") + elif conn_type == "tcp": + host = connection_info.get("host", "N/A") + port = connection_info.get("port", "N/A") + click.echo("连接类型: TCP Socket") + click.echo(f"地址: {host}:{port}") + else: + click.echo(f"连接类型: {conn_type} (未知)") + else: + click.echo("连接文件: 未找到 (.cli_connection)") + + # 检查 token + token = load_auth_token() + if token: + click.echo(f"Token: 已配置 ({token[:8]}...)") + else: + click.echo("Token: 未配置") + + # 测试连通性 + click.echo("---") + start = time.time() + response = send_message("/help") + elapsed = (time.time() - start) * 1000 + + if response.get("status") == "success": + click.echo(f"服务状态: 在线 ({elapsed:.0f}ms)") + else: + error = response.get("error", "Unknown error") + click.echo(f"服务状态: 离线 ({error})") + + +@click.group(help="测试工具 (子命令: echo/plugin)") +def test() -> None: + """测试工具命令组""" + + +@test.command(name="echo", help="发送消息并验证回环") +@click.argument("message", nargs=-1, required=True) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def test_echo(message: tuple[str, ...], use_json: bool) -> None: + """发送消息验证回环 + + \b + 示例: + astr test echo Hello 发送 Hello 并查看响应 + """ + msg = " ".join(message) + response = send_message(msg) + + if use_json: + click.echo(json.dumps(response, ensure_ascii=False, indent=2)) + else: + if response.get("status") == "success": + click.echo(f"发送: {msg}") + click.echo(f"响应: {response.get('response', '')}") + else: + error = response.get("error", "Unknown error") + click.echo(f"发送: {msg}") + click.echo(f"错误: {error}", err=True) + raise SystemExit(1) + + +@test.command(name="plugin", help="测试插件命令(发送 / )") +@click.argument("name") +@click.argument("input_text", nargs=-1, required=True) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def test_plugin(name: str, input_text: tuple[str, ...], use_json: bool) -> None: + """测试插件命令 + + name 是插件注册的命令名(非插件名),会拼接为 / 发送。 + + \b + 示例: + astr test plugin probe cpu → 发送 /probe cpu + astr test plugin help → 发送 /help + """ + msg = f"/{name} {' '.join(input_text)}" + response = send_message(msg) + output_response(response, use_json) diff --git a/astrbot/cli/client/commands/log.py b/astrbot/cli/client/commands/log.py new file mode 100644 index 0000000000..3ead7d58cc --- /dev/null +++ b/astrbot/cli/client/commands/log.py @@ -0,0 +1,131 @@ +"""log 命令 - 获取 AstrBot 日志""" + +import os +import re + +import click + +from ..connection import get_data_path, get_logs + + +@click.command(help="获取 AstrBot 日志") +@click.option( + "--lines", default=100, type=int, help="返回的日志行数(默认 100,最大 1000)" +) +@click.option( + "--level", default="", help="按级别过滤 (DEBUG/INFO/WARNING/ERROR/CRITICAL)" +) +@click.option("--pattern", default="", help="按模式过滤(子串匹配)") +@click.option("--regex", is_flag=True, help="使用正则表达式匹配 pattern") +@click.option( + "--socket", + "use_socket", + is_flag=True, + help="通过 Socket 连接 AstrBot 获取日志(需要 AstrBot 运行)", +) +@click.option( + "-t", "--timeout", default=30.0, type=float, help="超时时间(仅 Socket 模式)" +) +def log( + lines: int, + level: str, + pattern: str, + regex: bool, + use_socket: bool, + timeout: float, +) -> None: + """获取 AstrBot 日志 + + \b + 示例: + astr log # 直接读取日志文件(默认) + astr log --lines 50 # 获取最近 50 行 + astr log --level ERROR # 只显示 ERROR 级别 + astr log --pattern "plugin" # 匹配包含 "plugin" 的日志 + astr log --pattern "ERRO|WARN" --regex # 使用正则表达式 + astr log --socket # 通过 Socket 连接 AstrBot 获取 + """ + if use_socket: + response = get_logs(None, timeout, lines, level, pattern, regex) + if response.get("status") == "success": + formatted = response.get("response", "") + click.echo(formatted) + else: + error = response.get("error", "Unknown error") + click.echo(f"Error: {error}", err=True) + raise SystemExit(1) + else: + _read_log_from_file(lines, level, pattern, regex) + + +def _read_log_from_file(lines: int, level: str, pattern: str, use_regex: bool) -> None: + """直接从日志文件读取 + + Args: + lines: 返回的日志行数 + level: 日志级别过滤 + pattern: 模式过滤 + use_regex: 是否使用正则表达式 + """ + LEVEL_MAP = { + "DEBUG": "DEBUG", + "INFO": "INFO", + "WARNING": "WARN", + "WARN": "WARN", + "ERROR": "ERRO", + "CRITICAL": "CRIT", + } + + level_filter = LEVEL_MAP.get(level.upper(), level.upper()) + + log_path = os.path.join(get_data_path(), "logs", "astrbot.log") + + if not os.path.exists(log_path): + click.echo( + f"Error: 日志文件未找到: {log_path}", + err=True, + ) + click.echo( + "提示: 请在配置中启用 log_file_enable 来记录日志到文件,或使用不带 --file 的方式连接 AstrBot", + err=True, + ) + raise SystemExit(1) + + try: + with open(log_path, encoding="utf-8", errors="ignore") as f: + all_lines = f.readlines() + + logs = [] + for line in reversed(all_lines): + if not line.strip(): + continue + + if level_filter: + if not re.search(rf"\[{level_filter}\]", line): + continue + + if pattern: + if use_regex: + try: + if not re.search(pattern, line): + continue + except re.error: + if pattern not in line: + continue + else: + if pattern not in line: + continue + + logs.append(line.rstrip()) + + if len(logs) >= lines: + break + + logs.reverse() + + for log_line in logs: + click.echo(log_line) + + except OSError as e: + click.echo(f"Error: 读取日志文件失败: {e}", err=True) + raise SystemExit(1) diff --git a/astrbot/cli/client/commands/plugin.py b/astrbot/cli/client/commands/plugin.py new file mode 100644 index 0000000000..09a9b981c0 --- /dev/null +++ b/astrbot/cli/client/commands/plugin.py @@ -0,0 +1,47 @@ +"""插件管理命令组 - astr plugin""" + +import click + +from ..connection import send_message +from ..output import output_response + + +@click.group(help="插件管理 (子命令: ls/on/off/help)") +def plugin() -> None: + """插件管理命令组""" + + +@plugin.command(name="ls", help="列出已安装插件") +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def plugin_ls(use_json: bool) -> None: + """列出已安装插件""" + response = send_message("/plugin ls") + output_response(response, use_json) + + +@plugin.command(name="on", help="启用插件") +@click.argument("name") +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def plugin_on(name: str, use_json: bool) -> None: + """启用插件""" + response = send_message(f"/plugin on {name}") + output_response(response, use_json) + + +@plugin.command(name="off", help="禁用插件") +@click.argument("name") +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def plugin_off(name: str, use_json: bool) -> None: + """禁用插件""" + response = send_message(f"/plugin off {name}") + output_response(response, use_json) + + +@plugin.command(name="help", help="获取插件帮助") +@click.argument("name", default="", required=False) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def plugin_help(name: str, use_json: bool) -> None: + """获取插件帮助""" + cmd = "/plugin help" if not name else f"/plugin help {name}" + response = send_message(cmd) + output_response(response, use_json) diff --git a/astrbot/cli/client/commands/provider.py b/astrbot/cli/client/commands/provider.py new file mode 100644 index 0000000000..acd6d9a694 --- /dev/null +++ b/astrbot/cli/client/commands/provider.py @@ -0,0 +1,55 @@ +"""Provider/Model/Key 管理命令 - astr provider / astr model / astr key""" + +import click + +from ..connection import send_message +from ..output import output_response + + +@click.command(help="查看/切换 Provider") +@click.argument("index", default="", required=False) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def provider(index: str, use_json: bool) -> None: + """查看/切换 Provider + + \b + 示例: + astr provider 查看当前 Provider 列表 + astr provider 1 切换到 Provider 1 + """ + cmd = "/provider" if not index else f"/provider {index}" + response = send_message(cmd) + output_response(response, use_json) + + +@click.command(help="查看/切换模型") +@click.argument("index_or_name", default="", required=False) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def model(index_or_name: str, use_json: bool) -> None: + """查看/切换模型 + + \b + 示例: + astr model 查看当前模型列表 + astr model 1 切换到模型 1 + astr model gpt-4 按名称切换模型 + """ + cmd = "/model" if not index_or_name else f"/model {index_or_name}" + response = send_message(cmd) + output_response(response, use_json) + + +@click.command(help="查看/切换 Key") +@click.argument("index", default="", required=False) +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON") +def key(index: str, use_json: bool) -> None: + """查看/切换 Key + + \b + 示例: + astr key 查看当前 Key 列表 + astr key 1 切换到 Key 1 + """ + cmd = "/key" if not index else f"/key {index}" + response = send_message(cmd) + output_response(response, use_json) diff --git a/astrbot/cli/client/commands/send.py b/astrbot/cli/client/commands/send.py new file mode 100644 index 0000000000..e39310f375 --- /dev/null +++ b/astrbot/cli/client/commands/send.py @@ -0,0 +1,47 @@ +"""send 命令 - 发送消息给 AstrBot""" + +import sys + +import click + +from ..connection import send_message +from ..output import fix_git_bash_path, output_response + + +@click.command(help="发送消息给 AstrBot") +@click.argument("message", nargs=-1) +@click.option("-s", "--socket", "socket_path", default=None, help="Unix socket 路径") +@click.option("-t", "--timeout", default=30.0, type=float, help="超时时间(秒)") +@click.option("-j", "--json", "use_json", is_flag=True, help="输出原始 JSON 响应") +def send( + message: tuple[str, ...], socket_path: str | None, timeout: float, use_json: bool +) -> None: + """发送消息给 AstrBot + + \b + 示例: + astr send 你好 + astr send /help + astr send plugin ls + echo "你好" | astr send + """ + if message: + msg = " ".join(message) + msg = fix_git_bash_path(msg) + elif not sys.stdin.isatty(): + msg = sys.stdin.read().strip() + else: + click.echo("Error: 请提供消息内容", err=True) + raise SystemExit(1) + + if not msg: + click.echo("Error: 消息内容为空", err=True) + raise SystemExit(1) + + do_send(msg, socket_path, timeout, use_json) + + +def do_send(msg: str, socket_path: str | None, timeout: float, use_json: bool) -> None: + """执行消息发送并输出结果""" + response = send_message(msg, socket_path, timeout) + output_response(response, use_json) diff --git a/astrbot/cli/client/commands/session.py b/astrbot/cli/client/commands/session.py new file mode 100644 index 0000000000..3042a2fb79 --- /dev/null +++ b/astrbot/cli/client/commands/session.py @@ -0,0 +1,171 @@ +"""跨会话浏览命令组 - astr session""" + +import json +import sys + +import click + +from ..connection import ( + get_session_history, + list_session_conversations, + list_sessions, +) + + +@click.group(help="跨会话浏览 — 查看任意平台的会话和聊天记录 (ls/convs/history)") +def session() -> None: + """跨会话浏览命令组""" + + +@session.command(name="ls", help="列出所有会话(跨平台:QQ/TG/微信/CLI…)") +@click.option("--page", "-p", type=int, default=1, help="页码 (默认 1)") +@click.option("--size", "-s", type=int, default=20, help="每页数量 (默认 20)") +@click.option("--platform", "-P", type=str, default=None, help="按平台过滤") +@click.option("--search", "-q", type=str, default=None, help="搜索关键词") +@click.option("-j", "--json-output", "use_json", is_flag=True, help="输出原始 JSON") +def session_ls( + page: int, size: int, platform: str | None, search: str | None, use_json: bool +) -> None: + """列出所有会话""" + resp = list_sessions( + page=page, + page_size=size, + platform=platform, + search_query=search, + ) + + if resp.get("status") != "success": + click.echo(f"[ERROR] {resp.get('error', 'Unknown error')}", err=True) + sys.exit(1) + + if use_json: + click.echo(json.dumps(resp, ensure_ascii=False, indent=2)) + return + + sessions = resp.get("sessions", []) + if not sessions: + click.echo("没有找到会话。") + return + + click.echo(f"{'#':<4} {'会话 ID':<45} {'当前对话标题':<20} {'人设'}") + click.echo("-" * 90) + for i, s in enumerate(sessions, start=(page - 1) * size + 1): + sid = s.get("session_id", "?") + title = s.get("title") or "(无标题)" + persona = s.get("persona_name") or "-" + # 截断过长的字段 + if len(sid) > 43: + sid = sid[:40] + "..." + if len(title) > 18: + title = title[:15] + "..." + click.echo(f"{i:<4} {sid:<45} {title:<20} {persona}") + + total = resp.get("total", 0) + total_pages = resp.get("total_pages", 0) + click.echo(f"\n共 {total} 个会话,第 {page}/{total_pages} 页") + + +@session.command(name="convs", help="查看指定会话下的对话列表") +@click.argument("session_id") +@click.option("--page", "-p", type=int, default=1, help="页码 (默认 1)") +@click.option("--size", "-s", type=int, default=20, help="每页数量 (默认 20)") +@click.option("-j", "--json-output", "use_json", is_flag=True, help="输出原始 JSON") +def session_convs(session_id: str, page: int, size: int, use_json: bool) -> None: + """查看指定会话的对话列表""" + resp = list_session_conversations( + session_id=session_id, + page=page, + page_size=size, + ) + + if resp.get("status") != "success": + click.echo(f"[ERROR] {resp.get('error', 'Unknown error')}", err=True) + sys.exit(1) + + if use_json: + click.echo(json.dumps(resp, ensure_ascii=False, indent=2)) + return + + convs = resp.get("conversations", []) + if not convs: + click.echo(f"会话 {session_id} 没有对话。") + return + + click.echo(f"会话: {session_id}") + click.echo(f"当前对话: {resp.get('current_cid', '无')}") + click.echo("") + click.echo(f"{'#':<4} {'对话 ID':<38} {'标题':<20} {'Token':<8} {'当前'}") + click.echo("-" * 80) + for i, c in enumerate(convs, start=(page - 1) * size + 1): + cid = c.get("cid", "?") + title = c.get("title") or "(无标题)" + token = c.get("token_usage", 0) + is_curr = "*" if c.get("is_current") else "" + if len(title) > 18: + title = title[:15] + "..." + click.echo(f"{i:<4} {cid:<38} {title:<20} {token:<8} {is_curr}") + + total = resp.get("total", 0) + total_pages = resp.get("total_pages", 0) + click.echo(f"\n共 {total} 个对话,第 {page}/{total_pages} 页") + + +@session.command(name="history", help="查看聊天记录(用户/AI 交替显示)") +@click.argument("session_id") +@click.option( + "-c", "--conversation-id", type=str, default=None, help="对话 ID (默认当前对话)" +) +@click.option("--page", "-p", type=int, default=1, help="页码 (默认 1)") +@click.option("--size", "-s", type=int, default=10, help="每页数量 (默认 10)") +@click.option("-j", "--json-output", "use_json", is_flag=True, help="输出原始 JSON") +def session_history( + session_id: str, + conversation_id: str | None, + page: int, + size: int, + use_json: bool, +) -> None: + """查看指定会话的聊天记录""" + resp = get_session_history( + session_id=session_id, + conversation_id=conversation_id, + page=page, + page_size=size, + ) + + if resp.get("status") != "success": + click.echo(f"[ERROR] {resp.get('error', 'Unknown error')}", err=True) + sys.exit(1) + + if use_json: + click.echo(json.dumps(resp, ensure_ascii=False, indent=2)) + return + + _render_history(resp, session_id, page) + + +def _render_history(resp: dict, session_id: str, page: int) -> None: + """简洁地渲染聊天记录:用户/AI 交替显示。""" + history = resp.get("history", []) + total_pages = resp.get("total_pages", 0) + cid = resp.get("conversation_id") + + click.echo(f"会话: {session_id}") + click.echo(f"对话: {cid or '(无)'} 页码: {page}/{total_pages}") + click.echo("-" * 60) + + if not history: + click.echo("(无聊天记录)") + return + + for msg in history: + # 新格式:msg 是 {"role": "user"|"assistant", "text": "..."} + if isinstance(msg, dict): + role = msg.get("role", "?") + text = msg.get("text", "") + label = "You" if role == "user" else "AI" + click.echo(f"{label}: {text}") + else: + # 兼容旧格式(纯字符串) + click.echo(msg) + click.echo() diff --git a/astrbot/cli/client/commands/tool.py b/astrbot/cli/client/commands/tool.py new file mode 100644 index 0000000000..ebd01075a9 --- /dev/null +++ b/astrbot/cli/client/commands/tool.py @@ -0,0 +1,133 @@ +"""函数工具管理命令组 - astr tool""" + +import json +import sys + +import click + +from ..connection import call_tool, list_tools + + +@click.group(help="函数工具管理 (子命令: ls/info/call)") +def tool() -> None: + """函数工具管理命令组""" + + +@tool.command(name="ls", help="列出所有注册的函数工具") +@click.option( + "--origin", "-o", type=str, default="", help="按来源过滤: plugin/mcp/builtin" +) +@click.option("-j", "--json-output", "use_json", is_flag=True, help="输出原始 JSON") +def tool_ls(origin: str, use_json: bool) -> None: + """列出所有注册的函数工具""" + resp = list_tools() + + if resp.get("status") != "success": + click.echo(f"[ERROR] {resp.get('error', 'Unknown error')}", err=True) + sys.exit(1) + + tools = resp.get("tools", []) + if not tools: + raw = resp.get("response", "") + if raw: + try: + tools = json.loads(raw) + except (json.JSONDecodeError, TypeError): + pass + + if origin: + tools = [ + t + for t in tools + if t.get("origin") == origin or t.get("origin_name") == origin + ] + + if use_json: + click.echo(json.dumps(tools, ensure_ascii=False, indent=2)) + return + + if not tools: + click.echo("没有注册的函数工具。") + return + + click.echo(f"{'名称':<25} {'来源':<10} {'来源名':<18} {'状态':<6} {'描述'}") + click.echo("-" * 90) + for t in tools: + name = t.get("name", "?") + ori = t.get("origin", "?") + ori_name = t.get("origin_name", "?") + active = "启用" if t.get("active", True) else "停用" + desc = (t.get("description") or "")[:40] + click.echo(f"{name:<25} {ori:<10} {ori_name:<18} {active:<6} {desc}") + + click.echo(f"\n共 {len(tools)} 个工具") + + +@tool.command(name="info", help="查看工具详细信息") +@click.argument("name") +def tool_info(name: str) -> None: + """查看工具详细信息""" + resp = list_tools() + + if resp.get("status") != "success": + click.echo(f"[ERROR] {resp.get('error', 'Unknown error')}", err=True) + sys.exit(1) + + tools = resp.get("tools", []) + if not tools: + raw = resp.get("response", "") + if raw: + try: + tools = json.loads(raw) + except (json.JSONDecodeError, TypeError): + pass + + matched = [t for t in tools if t.get("name") == name] + if not matched: + click.echo(f"未找到工具: {name}") + sys.exit(1) + + t = matched[0] + click.echo(f"名称: {t.get('name')}") + click.echo(f"描述: {t.get('description', '无')}") + click.echo(f"来源: {t.get('origin', '?')} ({t.get('origin_name', '?')})") + click.echo(f"状态: {'启用' if t.get('active', True) else '停用'}") + + params = t.get("parameters") + if params: + click.echo("参数:") + props = params.get("properties", {}) + required = params.get("required", []) + for pname, pinfo in props.items(): + req_mark = "*" if pname in required else " " + ptype = pinfo.get("type", "any") + pdesc = pinfo.get("description", "") + click.echo(f" {req_mark} {pname} ({ptype}): {pdesc}") + + +@tool.command(name="call", help="调用指定的函数工具") +@click.argument("name") +@click.argument("args_json", required=False, default="{}") +@click.option("-t", "--timeout", type=float, default=60.0, help="超时时间(秒)") +def tool_call(name: str, args_json: str, timeout: float) -> None: + """调用指定的函数工具 + + ARGS_JSON: JSON 格式的参数,例如 '{"key": "value"}' + """ + try: + tool_args = json.loads(args_json) + except json.JSONDecodeError as e: + click.echo(f"[ERROR] 参数 JSON 格式错误: {e}", err=True) + sys.exit(1) + + if not isinstance(tool_args, dict): + click.echo("[ERROR] 参数必须是 JSON 对象", err=True) + sys.exit(1) + + resp = call_tool(name, tool_args, timeout=timeout) + + if resp.get("status") != "success": + click.echo(f"[ERROR] {resp.get('error', 'Unknown error')}", err=True) + sys.exit(1) + + click.echo(resp.get("response", "(无返回值)")) diff --git a/astrbot/cli/client/connection.py b/astrbot/cli/client/connection.py new file mode 100644 index 0000000000..1703985957 --- /dev/null +++ b/astrbot/cli/client/connection.py @@ -0,0 +1,424 @@ +"""连接管理模块 - 路径/token/socket/发送 + +从 __main__.py 提取的连接相关功能,不导入 astrbot 框架。 +""" + +import json +import os +import socket +import sys +import uuid + + +def get_data_path() -> str: + """获取数据目录路径 + + 优先级: + 1. 环境变量 ASTRBOT_ROOT + 2. 源码安装目录(通过 __file__ 获取) + 3. 当前工作目录 + """ + if root := os.environ.get("ASTRBOT_ROOT"): + return os.path.join(root, "data") + + source_root = os.path.realpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../..") + ) + data_dir = os.path.join(source_root, "data") + + if os.path.exists(data_dir): + return data_dir + + return os.path.join(os.path.realpath(os.getcwd()), "data") + + +def get_temp_path() -> str: + """获取临时目录路径,兼容容器和非容器环境""" + if root := os.environ.get("ASTRBOT_ROOT"): + return os.path.join(root, "data", "temp") + + source_root = os.path.realpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../..") + ) + temp_dir = os.path.join(source_root, "data", "temp") + if os.path.isdir(os.path.join(source_root, "data")): + return temp_dir + + import tempfile + + return tempfile.gettempdir() + + +def load_auth_token() -> str: + """从密钥文件加载认证token + + Returns: + token字符串,如果文件不存在则返回空字符串 + """ + token_file = os.path.join(get_data_path(), ".cli_token") + try: + with open(token_file, encoding="utf-8") as f: + return f.read().strip() + except FileNotFoundError: + return "" + except Exception: + return "" + + +def load_connection_info(data_dir: str) -> dict | None: + """加载连接信息 + + 从.cli_connection文件读取Socket连接信息 + + Args: + data_dir: 数据目录路径 + + Returns: + 连接信息字典,如果文件不存在则返回None + """ + connection_file = os.path.join(data_dir, ".cli_connection") + try: + with open(connection_file, encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + return None + except json.JSONDecodeError as e: + print( + f"[ERROR] Invalid JSON in connection file: {connection_file}", + file=sys.stderr, + ) + print(f"[ERROR] {e}", file=sys.stderr) + return None + except Exception as e: + print( + f"[ERROR] Failed to load connection info: {e}", + file=sys.stderr, + ) + return None + + +def connect_to_server(connection_info: dict, timeout: float = 30.0) -> socket.socket: + """连接到服务器 + + 根据连接信息类型选择Unix Socket或TCP Socket连接 + + Args: + connection_info: 连接信息字典 + timeout: 超时时间(秒) + + Returns: + socket连接对象 + + Raises: + ValueError: 无效的连接类型 + ConnectionError: 连接失败 + """ + socket_type = connection_info.get("type") + + if socket_type == "unix": + socket_path = connection_info.get("path") + if not socket_path: + raise ValueError("Unix socket path is missing in connection info") + + try: + client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + client_socket.settimeout(timeout) + client_socket.connect(socket_path) + return client_socket + except FileNotFoundError: + raise ConnectionError( + f"Socket file not found: {socket_path}. Is AstrBot running?" + ) + except ConnectionRefusedError: + raise ConnectionError( + "Connection refused. Is AstrBot running in socket mode?" + ) + except Exception as e: + raise ConnectionError(f"Unix socket connection error: {e}") + + elif socket_type == "tcp": + host = connection_info.get("host") + port = connection_info.get("port") + if not host or not port: + raise ValueError("TCP host or port is missing in connection info") + + try: + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.settimeout(timeout) + client_socket.connect((host, port)) + return client_socket + except ConnectionRefusedError: + raise ConnectionError( + f"Connection refused to {host}:{port}. Is AstrBot running?" + ) + except TimeoutError: + raise ConnectionError(f"Connection timeout to {host}:{port}") + except Exception as e: + raise ConnectionError(f"TCP socket connection error: {e}") + + else: + raise ValueError( + f"Invalid socket type: {socket_type}. Expected 'unix' or 'tcp'" + ) + + +def _receive_json_response(client_socket: socket.socket) -> dict: + """从 socket 接收并解析 JSON 响应 + + Args: + client_socket: socket连接对象 + + Returns: + 解析后的响应字典 + """ + response_data = b"" + while True: + chunk = client_socket.recv(4096) + if not chunk: + break + response_data += chunk + try: + return json.loads(response_data.decode("utf-8", errors="replace")) + except json.JSONDecodeError: + continue + + return json.loads(response_data.decode("utf-8", errors="replace")) + + +def _get_connected_socket( + socket_path: str | None = None, timeout: float = 30.0 +) -> socket.socket: + """获取已连接的 socket + + Args: + socket_path: Unix socket路径(向后兼容) + timeout: 超时时间(秒) + + Returns: + 已连接的 socket 对象 + + Raises: + ValueError, ConnectionError: 连接失败时 + """ + data_dir = get_data_path() + connection_info = load_connection_info(data_dir) + + if connection_info is not None: + return connect_to_server(connection_info, timeout) + + if socket_path is None: + socket_path = os.path.join(get_temp_path(), "astrbot.sock") + + fallback_info = {"type": "unix", "path": socket_path} + return connect_to_server(fallback_info, timeout) + + +def send_message( + message: str, socket_path: str | None = None, timeout: float = 30.0 +) -> dict: + """发送消息到AstrBot并获取响应 + + Args: + message: 要发送的消息 + socket_path: Unix socket路径(仅用于向后兼容) + timeout: 超时时间(秒) + + Returns: + 响应字典 + """ + auth_token = load_auth_token() + + request = {"message": message, "request_id": str(uuid.uuid4())} + if auth_token: + request["auth_token"] = auth_token + + try: + client_socket = _get_connected_socket(socket_path, timeout) + except (ValueError, ConnectionError) as e: + return {"status": "error", "error": str(e)} + except Exception as e: + return {"status": "error", "error": f"Connection error: {e}"} + + try: + request_data = json.dumps(request, ensure_ascii=False).encode("utf-8") + client_socket.sendall(request_data) + return _receive_json_response(client_socket) + except TimeoutError: + return {"status": "error", "error": "Request timeout"} + except Exception as e: + return {"status": "error", "error": f"Communication error: {e}"} + finally: + client_socket.close() + + +def get_logs( + socket_path: str | None = None, + timeout: float = 30.0, + lines: int = 100, + level: str = "", + pattern: str = "", + use_regex: bool = False, +) -> dict: + """获取AstrBot日志 + + Args: + socket_path: Socket路径 + timeout: 超时时间 + lines: 返回的日志行数 + level: 日志级别过滤 + pattern: 模式过滤 + use_regex: 是否使用正则表达式 + + Returns: + 响应字典 + """ + auth_token = load_auth_token() + + request = { + "action": "get_logs", + "request_id": str(uuid.uuid4()), + "lines": lines, + "level": level, + "pattern": pattern, + "regex": use_regex, + } + if auth_token: + request["auth_token"] = auth_token + + try: + client_socket = _get_connected_socket(socket_path, timeout) + except (ValueError, ConnectionError) as e: + return {"status": "error", "error": str(e)} + except Exception as e: + return {"status": "error", "error": f"Connection error: {e}"} + + try: + request_data = json.dumps(request, ensure_ascii=False).encode("utf-8") + client_socket.sendall(request_data) + return _receive_json_response(client_socket) + except TimeoutError: + return {"status": "error", "error": "Request timeout"} + except Exception as e: + return {"status": "error", "error": f"Communication error: {e}"} + finally: + client_socket.close() + + +def _send_action_request( + action: str, + extra_fields: dict | None = None, + socket_path: str | None = None, + timeout: float = 30.0, +) -> dict: + """发送 action 请求的通用方法""" + auth_token = load_auth_token() + + request: dict = {"action": action, "request_id": str(uuid.uuid4())} + if auth_token: + request["auth_token"] = auth_token + if extra_fields: + request.update(extra_fields) + + try: + client_socket = _get_connected_socket(socket_path, timeout) + except (ValueError, ConnectionError) as e: + return {"status": "error", "error": str(e)} + except Exception as e: + return {"status": "error", "error": f"Connection error: {e}"} + + try: + request_data = json.dumps(request, ensure_ascii=False).encode("utf-8") + client_socket.sendall(request_data) + return _receive_json_response(client_socket) + except TimeoutError: + return {"status": "error", "error": "Request timeout"} + except Exception as e: + return {"status": "error", "error": f"Communication error: {e}"} + finally: + client_socket.close() + + +def list_tools(socket_path: str | None = None, timeout: float = 30.0) -> dict: + """列出所有注册的函数工具""" + return _send_action_request("list_tools", socket_path=socket_path, timeout=timeout) + + +def call_tool( + tool_name: str, + tool_args: dict | None = None, + socket_path: str | None = None, + timeout: float = 60.0, +) -> dict: + """调用指定的函数工具""" + return _send_action_request( + "call_tool", + extra_fields={"tool_name": tool_name, "tool_args": tool_args or {}}, + socket_path=socket_path, + timeout=timeout, + ) + + +def list_sessions( + page: int = 1, + page_size: int = 20, + platform: str | None = None, + search_query: str | None = None, + socket_path: str | None = None, + timeout: float = 30.0, +) -> dict: + """列出所有会话""" + fields: dict = {"page": page, "page_size": page_size} + if platform: + fields["platform"] = platform + if search_query: + fields["search_query"] = search_query + return _send_action_request( + "list_sessions", + extra_fields=fields, + socket_path=socket_path, + timeout=timeout, + ) + + +def list_session_conversations( + session_id: str, + page: int = 1, + page_size: int = 20, + socket_path: str | None = None, + timeout: float = 30.0, +) -> dict: + """列出指定会话的所有对话""" + return _send_action_request( + "list_session_conversations", + extra_fields={ + "session_id": session_id, + "page": page, + "page_size": page_size, + }, + socket_path=socket_path, + timeout=timeout, + ) + + +def get_session_history( + session_id: str, + conversation_id: str | None = None, + page: int = 1, + page_size: int = 10, + socket_path: str | None = None, + timeout: float = 30.0, +) -> dict: + """获取指定会话的聊天记录""" + fields: dict = { + "session_id": session_id, + "page": page, + "page_size": page_size, + } + if conversation_id: + fields["conversation_id"] = conversation_id + return _send_action_request( + "get_session_history", + extra_fields=fields, + socket_path=socket_path, + timeout=timeout, + ) diff --git a/astrbot/cli/client/output.py b/astrbot/cli/client/output.py new file mode 100644 index 0000000000..178ca783c5 --- /dev/null +++ b/astrbot/cli/client/output.py @@ -0,0 +1,84 @@ +"""输出格式化模块 - 响应格式化与输出 + +从 __main__.py 提取的输出相关功能。 +""" + +import json +import re + +import click + + +def format_response(response: dict) -> str: + """格式化响应输出 + + 处理: + 1. 分段回复(每行一句) + 2. 图片占位符 + + Args: + response: 响应字典 + + Returns: + 格式化后的字符串 + """ + if response.get("status") != "success": + return "" + + text = response.get("response", "") + images = response.get("images", []) + image_count = len(images) + + lines = text.split("\n") + + if image_count > 0: + if image_count == 1: + lines.append("[图片]") + else: + lines.append(f"[{image_count}张图片]") + + return "\n".join(lines) + + +def fix_git_bash_path(message: str) -> str: + """修复 Git Bash 路径转换问题 + + Git Bash (MSYS2) 会把 /plugin ls 转换为 C:/Program Files/Git/plugin ls + 检测并还原原始命令 + + Args: + message: 被转换后的消息 + + Returns: + 修复后的消息 + """ + pattern = r"[A-Z]:/(Program Files/Git|msys[0-9]+/[^/]+)/([^/]+)" + match = re.match(pattern, message) + + if match: + command = match.group(2) + rest = message[match.end() :].lstrip() + if rest: + return f"/{command} {rest}" + return f"/{command}" + + return message + + +def output_response(response: dict, use_json: bool) -> None: + """统一输出响应 + + Args: + response: 响应字典 + use_json: 是否输出原始JSON + """ + if use_json: + click.echo(json.dumps(response, ensure_ascii=False, indent=2)) + else: + if response.get("status") == "success": + formatted = format_response(response) + click.echo(formatted) + else: + error = response.get("error", "Unknown error") + click.echo(f"Error: {error}", err=True) + raise SystemExit(1) diff --git a/astrbot/cli/commands/__init__.py b/astrbot/cli/commands/__init__.py index 1d3e0bca2f..f3f56c1da8 100644 --- a/astrbot/cli/commands/__init__.py +++ b/astrbot/cli/commands/__init__.py @@ -1,6 +1,8 @@ from .cmd_conf import conf from .cmd_init import init from .cmd_plug import plug +from .cmd_restart import restart from .cmd_run import run +from .cmd_stop import stop -__all__ = ["conf", "init", "plug", "run"] +__all__ = ["conf", "init", "plug", "restart", "run", "stop"] diff --git a/astrbot/cli/commands/cmd_restart.py b/astrbot/cli/commands/cmd_restart.py new file mode 100644 index 0000000000..d498e0f9c1 --- /dev/null +++ b/astrbot/cli/commands/cmd_restart.py @@ -0,0 +1,324 @@ +import asyncio +import os +import signal +import subprocess +import sys +import time +from pathlib import Path + +import click +from filelock import FileLock, Timeout + +from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root + + +async def run_astrbot(astrbot_root: Path): + """运行 AstrBot""" + from astrbot.core import LogBroker, LogManager, db_helper, logger + from astrbot.core.initial_loader import InitialLoader + + await check_dashboard(astrbot_root / "data") + + log_broker = LogBroker() + LogManager.set_queue_handler(logger, log_broker) + db = db_helper + + core_lifecycle = InitialLoader(db, log_broker) + + await core_lifecycle.start() + + +def launch_in_new_window( + astrbot_root: Path, + reload: bool, + port: str | None, +) -> None: + """在新窗口启动 AstrBot(仅 Windows)""" + python_exe = sys.executable + + # 构建命令,添加 --no-window 标志让新窗口在当前窗口运行 + cmd = [python_exe, "-m", "astrbot.cli", "run", "--no-window"] + + if reload: + cmd.append("-r") + + if port: + cmd.extend(["-p", port]) + + # 设置环境变量 + env = os.environ.copy() + env["ASTRBOT_CLI"] = "1" + env["ASTRBOT_ROOT"] = str(astrbot_root) + + if port: + env["DASHBOARD_PORT"] = port + + if reload: + env["ASTRBOT_RELOAD"] = "1" + + if sys.platform == "win32": + # Windows: 使用 CREATE_NEW_CONSOLE 在新窗口启动,环境变量通过 env 直接传递 + CREATE_NEW_CONSOLE = 0x00000010 + subprocess.Popen( + cmd, + env=env, + cwd=str(astrbot_root), + creationflags=CREATE_NEW_CONSOLE, + ) + elif sys.platform == "darwin": + # macOS: 使用 osascript 打开新的 Terminal 窗口 + cmd_str = " ".join(cmd) + script = f""" + tell application "Terminal" + do script "cd {astrbot_root} && {cmd_str}" + activate + end tell + """ + subprocess.Popen(["osascript", "-e", script], env=env) + else: + # Linux: 使用 gnome-terminal 或 xterm + cmd_str = " ".join(cmd) + for term_cmd in ["gnome-terminal", "xterm", "konsole", "xfce4-terminal"]: + try: + if term_cmd == "gnome-terminal": + subprocess.Popen( + [ + term_cmd, + "--", + "bash", + "-c", + f"cd {astrbot_root} && {cmd_str}; exec bash", + ], + env=env, + ) + else: + subprocess.Popen( + [ + term_cmd, + "-e", + "bash", + "-c", + f"cd {astrbot_root} && {cmd_str}; exec bash", + ], + env=env, + ) + break + except FileNotFoundError: + continue + else: + raise click.ClickException( + "无法找到终端模拟器,请手动安装 gnome-terminal 或 xterm" + ) + + +def find_and_kill_astrbot_processes(astrbot_root: Path) -> bool: + """查找并终止正在运行的 AstrBot 进程 + + Returns: + bool: 是否成功终止了进程 + """ + killed = False + current_pid = os.getpid() + + if sys.platform == "win32": + # Windows: 使用 wmic 获取进程命令行,精确匹配 AstrBot 进程 + import subprocess + + try: + # 使用 wmic 获取所有 python.exe 进程的命令行 + result = subprocess.run( + [ + "wmic", + "process", + "where", + "name='python.exe'", + "get", + "processid,commandline", + "/format:csv", + ], + capture_output=True, + text=True, + timeout=10, + ) + + # 解析输出并终止相关进程 + for line in result.stdout.split("\n"): + if not line.strip() or "CommandLine" in line: + continue + + parts = line.split(",") + if len(parts) >= 3: + _, cmdline, pid_str = ( + parts[0].strip('"'), + parts[1].strip('"'), + parts[2].strip('"'), + ) + + try: + pid = int(pid_str) + + # 跳过当前进程 + if pid == current_pid: + continue + + # 只终止包含 astrbot 的进程 + # 匹配: astrbot, astrbot.exe, astrbot run, -m astrbot 等 + cmdline_lower = cmdline.lower() + if "astrbot" in cmdline_lower or "astrbot.exe" in cmdline_lower: + subprocess.run( + ["taskkill", "/F", "/T", "/PID", str(pid)], + capture_output=True, + timeout=5, + ) + click.echo(f"已终止进程: {pid}") + killed = True + except (ValueError, subprocess.TimeoutExpired): + continue + except (subprocess.CalledProcessError, FileNotFoundError, Exception) as e: + click.echo(f"查找进程时出错: {e}") + + else: + # Unix/Linux/macOS: 使用 ps 和 kill + import subprocess + + try: + # 使用 ps 获取完整命令行,精确匹配 + result = subprocess.run( + ["ps", "aux"], + capture_output=True, + text=True, + timeout=10, + ) + + for line in result.stdout.split("\n"): + # 跳过标题行 + if line.startswith("USER"): + continue + + # 检查是否是 python 进程且包含 astrbot + if "python" in line.lower() and "astrbot" in line.lower(): + parts = line.split(None, 10) # 最多分割10次,保留完整命令行 + if len(parts) >= 2: + try: + pid = int(parts[1]) + + # 跳过当前进程 + if pid == current_pid: + continue + + os.kill(pid, signal.SIGTERM) + click.echo(f"已发送 SIGTERM 到进程: {pid}") + killed = True + except (ValueError, ProcessLookupError): + continue + except Exception as e: + click.echo(f"查找进程时出错: {e}") + + return killed + + +@click.option("--reload", "-r", is_flag=True, help="启用插件自动重载") +@click.option("--port", "-p", help="Dashboard 端口", required=False, type=str) +@click.option( + "--force", + "-f", + is_flag=True, + help="强制重启,即使无法清理锁文件也尝试启动", +) +@click.option( + "--wait-time", + type=float, + default=3.0, + help="等待进程退出的时间(秒)", +) +@click.option( + "--no-window", + is_flag=True, + help="在当前窗口重启(仅 Windows)", +) +@click.command() +def restart( + reload: bool, + port: str, + force: bool, + wait_time: float, + no_window: bool, +) -> None: + """重启 AstrBot(Windows 默认新窗口,Linux/macOS 当前窗口)""" + try: + os.environ["ASTRBOT_CLI"] = "1" + astrbot_root = get_astrbot_root() + + if not check_astrbot_root(astrbot_root): + raise click.ClickException( + f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", + ) + + os.environ["ASTRBOT_ROOT"] = str(astrbot_root) + sys.path.insert(0, str(astrbot_root)) + + if port: + os.environ["DASHBOARD_PORT"] = port + + if reload: + os.environ["ASTRBOT_RELOAD"] = "1" + + lock_file = astrbot_root / "astrbot.lock" + + # 尝试获取锁,如果成功说明没有实例在运行 + lock = FileLock(lock_file, timeout=1) + + try: + lock.acquire() + lock.release() + except Timeout: + # 锁文件存在,有实例在运行 + click.echo("检测到正在运行的实例,正在停止...") + + # 1. 先尝试通过查找并终止进程 + killed = find_and_kill_astrbot_processes(astrbot_root) + + if killed: + click.echo(f"等待 {wait_time} 秒以确保进程退出...") + time.sleep(wait_time) + + # 2. 尝试删除锁文件(可能需要重试) + max_retries = 5 + for i in range(max_retries): + try: + lock_file.unlink(missing_ok=True) + break + except PermissionError: + if i < max_retries - 1: + time.sleep(1) + else: + # 最后一次尝试:使用 FileLock 强制获取 + try: + force_lock = FileLock(lock_file, timeout=1) + force_lock.acquire(force=True) + force_lock.release() + lock_file.unlink(missing_ok=True) + except Exception: + pass + + # 重新启动 + # Windows: 默认在新窗口启动(除非指定 --no-window) + # Linux/macOS: 始终在当前窗口运行 + if sys.platform == "win32" and not no_window: + launch_in_new_window(astrbot_root, reload, port) + click.echo("[OK] AstrBot 已在新窗口中重启") + return + + # 在当前窗口运行(Linux/macOS 默认,Windows 指定 --no-window) + lock = FileLock(lock_file, timeout=5) + with lock: + asyncio.run(run_astrbot(astrbot_root)) + + except KeyboardInterrupt: + click.echo("\nAstrBot 已关闭...") + except Timeout: + raise click.ClickException( + "无法获取锁文件,请使用 --force 参数强制重启", + ) + except Exception as e: + raise click.ClickException(f"运行时出现错误: {e}") diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 23665dff3d..cb6dc82a47 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -1,5 +1,6 @@ import asyncio import os +import subprocess import sys import traceback from pathlib import Path @@ -26,20 +27,116 @@ async def run_astrbot(astrbot_root: Path) -> None: await core_lifecycle.start() +def launch_in_new_window( + astrbot_root: Path, + reload: bool, + port: str | None, +) -> None: + """在新窗口启动 AstrBot(仅 Windows)""" + python_exe = sys.executable + + # 构建命令,添加 --no-window 标志表示在当前窗口运行 + cmd = [python_exe, "-m", "astrbot.cli", "run", "--no-window"] + + if reload: + cmd.append("-r") + + if port: + cmd.extend(["-p", port]) + + # 设置环境变量 + env = os.environ.copy() + env["ASTRBOT_CLI"] = "1" + env["ASTRBOT_ROOT"] = str(astrbot_root) + + if port: + env["DASHBOARD_PORT"] = port + + if reload: + env["ASTRBOT_RELOAD"] = "1" + + if sys.platform == "win32": + # Windows: 使用 start 命令开新窗口 + cmd_str = " ".join(f'"{c}"' if " " in str(c) else str(c) for c in cmd) + ps_script = f'Start-Process powershell -ArgumentList "-NoExit", "-Command", "cd {astrbot_root}; {cmd_str}" -WindowStyle Normal' + subprocess.Popen( + ["powershell", "-Command", ps_script], + env=env, + shell=False, + ) + elif sys.platform == "darwin": + # macOS: 使用 osascript 打开新的 Terminal 窗口 + cmd_str = " ".join(cmd) + script = f""" + tell application "Terminal" + do script "cd {astrbot_root} && {cmd_str}" + activate + end tell + """ + subprocess.Popen(["osascript", "-e", script], env=env) + else: + # Linux: 使用 gnome-terminal 或 xterm + cmd_str = " ".join(cmd) + for term_cmd in ["gnome-terminal", "xterm", "konsole", "xfce4-terminal"]: + try: + if term_cmd == "gnome-terminal": + subprocess.Popen( + [ + term_cmd, + "--", + "bash", + "-c", + f"cd {astrbot_root} && {cmd_str}; exec bash", + ], + env=env, + ) + else: + subprocess.Popen( + [ + term_cmd, + "-e", + "bash", + "-c", + f"cd {astrbot_root} && {cmd_str}; exec bash", + ], + env=env, + ) + break + except FileNotFoundError: + continue + else: + raise click.ClickException( + "无法找到终端模拟器,请手动安装 gnome-terminal 或 xterm" + ) + + @click.option("--reload", "-r", is_flag=True, help="插件自动重载") @click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str) +@click.option( + "--new-window", + is_flag=True, + help="在新窗口启动(仅 Windows/macOS/Linux 桌面环境)", +) +@click.option("--no-window", is_flag=True, hidden=True, help="内部使用:防止递归开窗口") @click.command() -def run(reload: bool, port: str) -> None: - """运行 AstrBot""" - try: - os.environ["ASTRBOT_CLI"] = "1" - astrbot_root = get_astrbot_root() +def run(reload: bool, port: str, new_window: bool, no_window: bool) -> None: + """运行 AstrBot(默认当前窗口)""" + os.environ["ASTRBOT_CLI"] = "1" + astrbot_root = get_astrbot_root() - if not check_astrbot_root(astrbot_root): - raise click.ClickException( - f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", - ) + if not check_astrbot_root(astrbot_root): + raise click.ClickException( + f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init", + ) + # 仅在明确指定 --new-window 且非内部调用时才在新窗口启动 + if new_window and not no_window: + launch_in_new_window(astrbot_root, reload, port) + click.echo("[OK] AstrBot 已在新窗口中启动") + return + + # 默认在当前窗口运行 + try: os.environ["ASTRBOT_ROOT"] = str(astrbot_root) sys.path.insert(0, str(astrbot_root)) diff --git a/astrbot/cli/commands/cmd_stop.py b/astrbot/cli/commands/cmd_stop.py new file mode 100644 index 0000000000..414bf17f73 --- /dev/null +++ b/astrbot/cli/commands/cmd_stop.py @@ -0,0 +1,150 @@ +import os +import signal +import subprocess +import sys +import time + +import click + +from ..utils import check_astrbot_root, get_astrbot_root + + +def find_and_kill_astrbot_processes(astrbot_root: str) -> bool: + """查找并终止正在运行的 AstrBot 进程 + + Returns: + bool: 是否成功终止了进程 + """ + killed = False + current_pid = os.getpid() + + if sys.platform == "win32": + # Windows: 使用 wmic 获取进程命令行,精确匹配 AstrBot 进程 + try: + result = subprocess.run( + [ + "wmic", + "process", + "where", + "name='python.exe'", + "get", + "processid,commandline", + "/format:csv", + ], + capture_output=True, + text=True, + timeout=10, + ) + + for line in result.stdout.split("\n"): + if not line.strip() or "CommandLine" in line: + continue + + parts = line.split(",") + if len(parts) >= 3: + _, cmdline, pid_str = ( + parts[0].strip('"'), + parts[1].strip('"'), + parts[2].strip('"'), + ) + + try: + pid = int(pid_str) + + if pid == current_pid: + continue + + cmdline_lower = cmdline.lower() + if "astrbot" in cmdline_lower or "astrbot.exe" in cmdline_lower: + subprocess.run( + ["taskkill", "/F", "/T", "/PID", str(pid)], + capture_output=True, + timeout=5, + ) + click.echo(f"已终止进程: {pid}") + killed = True + except (ValueError, subprocess.TimeoutExpired): + continue + except (subprocess.CalledProcessError, FileNotFoundError, Exception) as e: + click.echo(f"查找进程时出错: {e}") + + else: + # Unix/Linux/macOS: 使用 ps 和 kill + try: + result = subprocess.run( + ["ps", "aux"], + capture_output=True, + text=True, + timeout=10, + ) + + for line in result.stdout.split("\n"): + if line.startswith("USER"): + continue + + if "python" in line.lower() and "astrbot" in line.lower(): + parts = line.split(None, 10) + if len(parts) >= 2: + try: + pid = int(parts[1]) + + if pid == current_pid: + continue + + os.kill(pid, signal.SIGTERM) + click.echo(f"已发送 SIGTERM 到进程: {pid}") + killed = True + except (ValueError, ProcessLookupError): + continue + except Exception as e: + click.echo(f"查找进程时出错: {e}") + + return killed + + +@click.option( + "--wait-time", + type=float, + default=3.0, + help="等待进程退出的时间(秒)", +) +@click.option( + "--force", + "-f", + is_flag=True, + help="强制终止,删除锁文件", +) +@click.command() +def stop(wait_time: float, force: bool) -> None: + """停止 AstrBot 进程""" + try: + os.environ["ASTRBOT_CLI"] = "1" + astrbot_root = get_astrbot_root() + + if not check_astrbot_root(astrbot_root): + raise click.ClickException( + f"{astrbot_root}不是有效的 AstrBot 根目录", + ) + + # 查找并终止进程 + killed = find_and_kill_astrbot_processes(astrbot_root) + + if killed: + click.echo(f"等待 {wait_time} 秒以确保进程退出...") + time.sleep(wait_time) + + # 删除锁文件 + if force: + lock_file = astrbot_root / "astrbot.lock" + try: + lock_file.unlink(missing_ok=True) + click.echo("已删除锁文件") + except Exception as e: + click.echo(f"删除锁文件失败: {e}") + + click.echo("[OK] AstrBot 已停止") + else: + click.echo("未找到正在运行的 AstrBot 进程") + + except Exception as e: + raise click.ClickException(f"停止时出现错误: {e}") diff --git a/astrbot/cli/utils/basic.py b/astrbot/cli/utils/basic.py index 5dbe290065..503a9713f8 100644 --- a/astrbot/cli/utils/basic.py +++ b/astrbot/cli/utils/basic.py @@ -15,8 +15,37 @@ def check_astrbot_root(path: str | Path) -> bool: def get_astrbot_root() -> Path: - """获取Astrbot根目录路径""" - return Path.cwd() + """获取 AstrBot 根目录路径 + + 查找顺序: + 1. 环境变量 ASTRBOT_ROOT + 2. 通过包安装路径定位(editable install / 源码目录) + 3. 从当前目录向上查找包含 .astrbot 标记的目录 + 4. 回退到当前工作目录 + """ + # 1. 环境变量 + import os + + env_root = os.environ.get("ASTRBOT_ROOT") + if env_root: + p = Path(env_root) + if check_astrbot_root(p): + return p + + # 2. 通过包安装路径定位(editable install 场景) + # __file__ 在 astrbot/cli/utils/basic.py,向上 4 级到达项目根目录 + source_root = Path(__file__).resolve().parent.parent.parent.parent + if check_astrbot_root(source_root): + return source_root + + # 3. 向上查找 .astrbot 标记 + current = Path.cwd() + for parent in [current, *current.parents]: + if (parent / ".astrbot").exists(): + return parent + + # 4. 回退到当前目录 + return current async def check_dashboard(astrbot_root: Path) -> None: diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 8475009d3f..9510f76cb0 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -15,7 +15,6 @@ def __init__( tool_description: str | None = None, **kwargs, ) -> None: - # Avoid passing duplicate `description` to the FunctionTool dataclass. # Some call sites (e.g. SubAgentOrchestrator) pass `description` via kwargs # to override what the main agent sees, while we also compute a default diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 6fcb3608c3..7400d41d71 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -12,11 +12,21 @@ from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 +# 模块级单例引用,供无法通过依赖注入获取实例的模块使用(如CLI适配器) +_global_instance: "ConversationManager | None" = None + + +def get_conversation_manager() -> "ConversationManager | None": + """获取全局 ConversationManager 实例""" + return _global_instance + class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase) -> None: + global _global_instance + _global_instance = self self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index ffb9c5c99c..07017d79a0 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -90,6 +90,10 @@ async def execute(self, event: AstrMessageEvent) -> None: if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent): await event.send(None) + # 通知事件管道已完成(鸭子类型,供需要收集完整响应的适配器使用) + if hasattr(event, "finalize"): + await event.finalize() + logger.debug("pipeline 执行完毕。") finally: active_event_registry.unregister(event) diff --git a/astrbot/core/pipeline/whitelist_check/stage.py b/astrbot/core/pipeline/whitelist_check/stage.py index ea9c55228e..8e66da921f 100644 --- a/astrbot/core/pipeline/whitelist_check/stage.py +++ b/astrbot/core/pipeline/whitelist_check/stage.py @@ -44,6 +44,10 @@ async def process( # WebChat 豁免 return + if event.get_platform_name() == "cli": + # CLI 平台豁免(用于测试) + return + # 检查是否在白名单 if self.wl_ignore_admin_on_group: if ( diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 0238779dad..3c2ef01ba9 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -180,6 +180,10 @@ async def load_platform(self, platform_config: dict) -> None: from .sources.line.line_adapter import ( LinePlatformAdapter, # noqa: F401 ) + case "cli": + from .sources.cli.cli_adapter import ( + CLIPlatformAdapter, # noqa: F401 + ) except (ImportError, ModuleNotFoundError) as e: logger.error( f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。", @@ -292,6 +296,16 @@ async def terminate(self) -> None: def get_insts(self): return self.platform_insts + @staticmethod + def _get_platform_id(inst) -> str: + """安全获取平台ID,兼容dict和dataclass类型的config""" + config = getattr(inst, "config", None) + if config is None: + return "unknown" + if isinstance(config, dict): + return config.get("id", "unknown") + return getattr(config, "platform_id", getattr(config, "id", "unknown")) + def get_all_stats(self) -> dict: """获取所有平台的统计信息 @@ -317,7 +331,7 @@ def get_all_stats(self) -> dict: logger.warning(f"获取平台统计信息失败: {e}") stats_list.append( { - "id": getattr(inst, "config", {}).get("id", "unknown"), + "id": self._get_platform_id(inst), "type": "unknown", "status": "unknown", "error_count": 0, diff --git a/astrbot/core/platform/sources/cli/__init__.py b/astrbot/core/platform/sources/cli/__init__.py new file mode 100644 index 0000000000..a8532db5e4 --- /dev/null +++ b/astrbot/core/platform/sources/cli/__init__.py @@ -0,0 +1,12 @@ +"""CLI Platform Adapter Module + +命令行模拟器平台适配器,用于快速测试AstrBot插件。 +""" + +from .cli_adapter import CLIPlatformAdapter +from .cli_event import CLIMessageEvent + +__all__ = [ + "CLIPlatformAdapter", + "CLIMessageEvent", +] diff --git a/astrbot/core/platform/sources/cli/cli_adapter.py b/astrbot/core/platform/sources/cli/cli_adapter.py new file mode 100644 index 0000000000..f5fb284f1a --- /dev/null +++ b/astrbot/core/platform/sources/cli/cli_adapter.py @@ -0,0 +1,377 @@ +"""CLI平台适配器 + +编排层:组合各模块实现CLI测试功能。 +""" + +import asyncio +import json +import os +import secrets +import time +from collections.abc import Awaitable +from typing import Any + +from astrbot import logger +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform import Platform, PlatformMetadata +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path + +from ...register import register_platform_adapter +from .cli_event import MessageConverter +from .file_handler import FileHandler +from .socket_handler import ( + SocketClientHandler, + SocketModeHandler, + write_connection_info, +) +from .socket_server import create_socket_server, detect_platform + +# ------------------------------------------------------------------ +# Token管理 +# ------------------------------------------------------------------ + + +class TokenManager: + """Token管理器""" + + TOKEN_FILE = ".cli_token" + + def __init__(self): + self._token: str | None = None + self._token_file = os.path.join(get_astrbot_data_path(), self.TOKEN_FILE) + + @property + def token(self) -> str | None: + if self._token is None: + self._token = self._ensure_token() + return self._token + + def _ensure_token(self) -> str | None: + try: + if os.path.exists(self._token_file): + with open(self._token_file, encoding="utf-8") as f: + token = f.read().strip() + if token: + logger.info("[CLI] Authentication token loaded from file") + return token + + token = secrets.token_urlsafe(32) + with open(self._token_file, "w", encoding="utf-8") as f: + f.write(token) + try: + os.chmod(self._token_file, 0o600) + except OSError: + pass + logger.info(f"[CLI] Generated new authentication token: {token}") + logger.info(f"[CLI] Token saved to: {self._token_file}") + return token + except Exception as e: + logger.error(f"[CLI] Failed to ensure token: {e}") + logger.warning("[CLI] Authentication disabled due to token error") + return None + + def validate(self, provided_token: str) -> bool: + if not self.token: + return True + if not provided_token: + logger.warning("[CLI] Request rejected: missing auth_token") + return False + if provided_token != self.token: + logger.warning( + f"[CLI] Request rejected: invalid auth_token (length={len(provided_token)})" + ) + return False + return True + + +# ------------------------------------------------------------------ +# 会话管理 +# ------------------------------------------------------------------ + + +class SessionManager: + """会话管理器""" + + CLEANUP_INTERVAL = 10 + + def __init__(self, ttl: int = 30, enabled: bool = False): + self.ttl = ttl + self.enabled = enabled + self._timestamps: dict[str, float] = {} + self._cleanup_task: asyncio.Task | None = None + self._running = False + + def register(self, session_id: str) -> None: + if not self.enabled: + return + if session_id not in self._timestamps: + self._timestamps[session_id] = time.time() + logger.debug( + f"[CLI] Created isolated session: {session_id}, TTL={self.ttl}s" + ) + + def touch(self, session_id: str) -> None: + if self.enabled and session_id in self._timestamps: + self._timestamps[session_id] = time.time() + + def is_expired(self, session_id: str) -> bool: + if not self.enabled: + return False + timestamp = self._timestamps.get(session_id) + if timestamp is None: + return True + return time.time() - timestamp > self.ttl + + def start_cleanup_task(self) -> None: + if not self.enabled: + return + self._running = True + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info(f"[CLI] Session cleanup task started, TTL={self.ttl}s") + + async def stop_cleanup_task(self) -> None: + self._running = False + if self._cleanup_task and not self._cleanup_task.done(): + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + async def _cleanup_loop(self) -> None: + while self._running: + try: + await asyncio.sleep(self.CLEANUP_INTERVAL) + if not self.enabled: + continue + current_time = time.time() + expired = [ + sid + for sid, ts in list(self._timestamps.items()) + if current_time - ts > self.ttl + ] + for session_id in expired: + logger.info(f"[CLI] Cleaning expired session: {session_id}") + self._timestamps.pop(session_id, None) + if expired: + logger.info(f"[CLI] Cleaned {len(expired)} expired sessions") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"[CLI] Session cleanup error: {e}") + logger.info("[CLI] Session cleanup task stopped") + + +# ------------------------------------------------------------------ +# 配置加载 +# ------------------------------------------------------------------ + + +def _load_config(platform_config: dict, platform_settings: dict | None = None) -> dict: + """加载配置,合并配置文件覆盖""" + config_filename = platform_config.get("config_file", "cli_config.json") + config_path = os.path.join(get_astrbot_data_path(), config_filename) + + if os.path.exists(config_path): + try: + with open(config_path, encoding="utf-8") as f: + file_config = json.load(f) + logger.info(f"[CLI] Loaded config from {config_path}") + if "platform_config" in file_config: + merged = platform_config.copy() + merged.update(file_config["platform_config"]) + platform_config = merged + except Exception as e: + logger.warning(f"[CLI] Failed to load config from {config_path}: {e}") + + return platform_config + + +# ------------------------------------------------------------------ +# CLI平台适配器 +# ------------------------------------------------------------------ + + +@register_platform_adapter( + "cli", + "CLI测试器,用于快速测试和调试插件,构建快速反馈循环", + default_config_tmpl={ + "type": "cli", + "enable": False, + "mode": "socket", + "socket_type": "auto", + "socket_path": None, + "tcp_host": "127.0.0.1", + "tcp_port": 0, + "whitelist": [], + "use_isolated_sessions": False, + "session_ttl": 30, + }, + support_streaming_message=False, +) +class CLIPlatformAdapter(Platform): + """CLI平台适配器""" + + def __init__( + self, + platform_config: dict, + platform_settings: dict, + event_queue: asyncio.Queue, + ) -> None: + super().__init__(platform_config, event_queue) + + # 加载配置 + cfg = _load_config(platform_config, platform_settings) + self.mode = cfg.get("mode", "socket") + self.socket_type = cfg.get("socket_type", "auto") + self.socket_path = cfg.get("socket_path") or os.path.join( + get_astrbot_temp_path(), "astrbot.sock" + ) + self.tcp_host = cfg.get("tcp_host", "127.0.0.1") + self.tcp_port = cfg.get("tcp_port", 0) + self.input_file = cfg.get("input_file") or os.path.join( + get_astrbot_temp_path(), "astrbot_cli", "input.txt" + ) + self.output_file = cfg.get("output_file") or os.path.join( + get_astrbot_temp_path(), "astrbot_cli", "output.txt" + ) + self.poll_interval = cfg.get("poll_interval", 1.0) + self.use_isolated_sessions = cfg.get("use_isolated_sessions", False) + self.session_ttl = cfg.get("session_ttl", 30) + self.whitelist = cfg.get("whitelist", []) + self.platform_id = cfg.get("id", "cli") + + # 初始化模块 + self.token_manager = TokenManager() + self.session_manager = SessionManager( + ttl=self.session_ttl, enabled=self.use_isolated_sessions + ) + self.message_converter = MessageConverter() + + # 平台元数据 + self.metadata = PlatformMetadata( + name="cli", + description="命令行模拟器", + id=self.platform_id, + support_streaming_message=False, + ) + + # 运行状态 + self._running = False + self._output_queue: asyncio.Queue = asyncio.Queue() + self._handler = None + + logger.info(f"[CLI] Adapter initialized, mode={self.mode}") + + def run(self) -> Awaitable[Any]: + return self._run_loop() + + async def _run_loop(self) -> None: + self._running = True + self.session_manager.start_cleanup_task() + + try: + if self.mode == "socket": + await self._run_socket_mode() + elif self.mode == "file": + await self._run_file_mode() + else: + await self._run_socket_mode() + finally: + self._running = False + await self.session_manager.stop_cleanup_task() + + async def _run_socket_mode(self) -> None: + platform_info = detect_platform() + server = create_socket_server( + platform_info, + { + "socket_type": self.socket_type, + "socket_path": self.socket_path, + "tcp_host": self.tcp_host, + "tcp_port": self.tcp_port, + }, + self.token_manager.token, + ) + + client_handler = SocketClientHandler( + token_manager=self.token_manager, + message_converter=self.message_converter, + session_manager=self.session_manager, + platform_meta=self.metadata, + output_queue=self._output_queue, + event_committer=self.commit_event, + use_isolated_sessions=self.use_isolated_sessions, + data_path=get_astrbot_data_path(), + ) + + self._handler = SocketModeHandler( + server=server, + client_handler=client_handler, + connection_info_writer=write_connection_info, + data_path=get_astrbot_data_path(), + ) + + await self._handler.run() + + async def _run_file_mode(self) -> None: + self._handler = FileHandler( + input_file=self.input_file, + output_file=self.output_file, + poll_interval=self.poll_interval, + message_converter=self.message_converter, + platform_meta=self.metadata, + output_queue=self._output_queue, + event_committer=self.commit_event, + ) + await self._handler.run() + + async def send_by_session( + self, + session: MessageSesion, + message_chain: MessageChain, + ) -> None: + await self._output_queue.put(message_chain) + await super().send_by_session(session, message_chain) + + def meta(self) -> PlatformMetadata: + return self.metadata + + def unified_webhook(self) -> bool: + return False + + def get_stats(self) -> dict: + meta = self.meta() + meta_info = { + "id": meta.id, + "name": meta.name, + "display_name": meta.adapter_display_name or meta.name, + "description": meta.description, + "support_streaming_message": meta.support_streaming_message, + "support_proactive_message": meta.support_proactive_message, + } + return { + "id": meta.id or self.platform_id, + "type": meta.name, + "display_name": meta.adapter_display_name or meta.name, + "status": self._status.value, + "started_at": self._started_at.isoformat() if self._started_at else None, + "error_count": len(self._errors), + "last_error": { + "message": self.last_error.message, + "timestamp": self.last_error.timestamp.isoformat(), + "traceback": self.last_error.traceback, + } + if self.last_error + else None, + "unified_webhook": False, + "meta": meta_info, + } + + async def terminate(self) -> None: + self._running = False + if self._handler: + self._handler.stop() + await self.session_manager.stop_cleanup_task() + logger.info("[CLI] Adapter terminated") diff --git a/astrbot/core/platform/sources/cli/cli_event.py b/astrbot/core/platform/sources/cli/cli_event.py new file mode 100644 index 0000000000..57c877ff52 --- /dev/null +++ b/astrbot/core/platform/sources/cli/cli_event.py @@ -0,0 +1,364 @@ +"""CLI消息事件模块 + +处理CLI平台的消息事件、消息转换和图片处理。 +""" + +import asyncio +import base64 +import os +import tempfile +import uuid +from collections.abc import AsyncGenerator +from typing import Any + +from astrbot import logger +from astrbot.core.message.components import Image, Plain +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform import AstrBotMessage, MessageMember, MessageType +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +# ------------------------------------------------------------------ +# 消息转换 +# ------------------------------------------------------------------ + + +class MessageConverter: + """将文本输入转换为AstrBotMessage对象""" + + def __init__( + self, + default_session_id: str = "cli_session", + user_id: str = "cli_user", + user_nickname: str = "CLI User", + ): + self.default_session_id = default_session_id + self.user_id = user_id + self.user_nickname = user_nickname + + def convert( + self, + text: str, + request_id: str | None = None, + use_isolated_session: bool = False, + ) -> AstrBotMessage: + """将文本转换为AstrBotMessage""" + message = AstrBotMessage() + message.self_id = "cli_bot" + message.message_str = text + message.message = [Plain(text)] + message.type = MessageType.FRIEND_MESSAGE + message.message_id = str(uuid.uuid4()) + + if use_isolated_session and request_id: + message.session_id = f"cli_session_{request_id}" + else: + message.session_id = self.default_session_id + + message.sender = MessageMember( + user_id=self.user_id, + nickname=self.user_nickname, + ) + message.raw_message = None + return message + + +# ------------------------------------------------------------------ +# 图片处理 +# ------------------------------------------------------------------ + + +def preprocess_chain(message_chain: MessageChain) -> None: + """预处理消息链:将本地文件图片转换为base64""" + for comp in message_chain.chain: + if isinstance(comp, Image) and comp.file and comp.file.startswith("file:///"): + file_path = comp.file[8:] + try: + if os.path.exists(file_path): + with open(file_path, "rb") as f: + data = f.read() + comp.file = f"base64://{base64.b64encode(data).decode('utf-8')}" + except Exception as e: + logger.error(f"[CLI] Failed to read image file {file_path}: {e}") + + +def extract_images(message_chain: MessageChain) -> list[dict]: + """从消息链提取图片信息,返回字典列表""" + images = [] + for comp in message_chain.chain: + if isinstance(comp, Image) and comp.file: + image_info = _process_image(comp.file) + images.append(image_info) + return images + + +def _process_image(file_ref: str) -> dict: + """处理单个图片引用,返回字典""" + if file_ref.startswith("http"): + return {"type": "url", "url": file_ref} + + if file_ref.startswith("file:///"): + return _process_local_file(file_ref[8:]) + + if file_ref.startswith("base64://"): + return _process_base64(file_ref[9:]) + + return {"type": "unknown"} + + +def _process_local_file(file_path: str) -> dict: + """处理本地文件""" + result: dict[str, Any] = {"type": "file", "path": file_path} + try: + if os.path.exists(file_path): + with open(file_path, "rb") as f: + data = f.read() + result["base64_data"] = base64.b64encode(data).decode("utf-8") + result["size"] = len(data) + else: + result["error"] = "Failed to read file" + except Exception as e: + result["error"] = str(e) + return result + + +def _process_base64(base64_data: str) -> dict: + """处理base64数据""" + try: + data = base64.b64decode(base64_data) + temp_dir = get_astrbot_temp_path() + os.makedirs(temp_dir, exist_ok=True) + temp_file = tempfile.NamedTemporaryFile( + delete=False, suffix=".png", dir=temp_dir + ) + temp_file.write(data) + temp_file.close() + return {"type": "file", "path": temp_file.name, "size": len(data)} + except Exception as e: + return {"type": "base64", "error": str(e)} + + +# ------------------------------------------------------------------ +# 向后兼容:ImageProcessor 和 ImageInfo +# ------------------------------------------------------------------ + + +class ImageInfo: + """图片信息(向后兼容)""" + + def __init__( + self, type: str, url=None, path=None, base64_data=None, size=None, error=None + ): + self.type = type + self.url = url + self.path = path + self.base64_data = base64_data + self.size = size + self.error = error + + def to_dict(self) -> dict: + result = {"type": self.type} + if self.url: + result["url"] = self.url + if self.path: + result["path"] = self.path + if self.base64_data: + result["base64_data"] = self.base64_data + if self.size: + result["size"] = self.size + if self.error: + result["error"] = self.error + return result + + +class ImageProcessor: + """图片处理器(向后兼容门面)""" + + @staticmethod + def local_file_to_base64(file_path: str) -> str | None: + try: + if os.path.exists(file_path): + with open(file_path, "rb") as f: + data = f.read() + return base64.b64encode(data).decode("utf-8") + except Exception as e: + logger.error(f"[CLI] Failed to read file {file_path}: {e}") + return None + + @staticmethod + def base64_to_temp_file(base64_data: str) -> str | None: + try: + data = base64.b64decode(base64_data) + temp_dir = get_astrbot_temp_path() + os.makedirs(temp_dir, exist_ok=True) + temp_file = tempfile.NamedTemporaryFile( + delete=False, suffix=".png", dir=temp_dir + ) + temp_file.write(data) + temp_file.close() + return temp_file.name + except Exception: + return None + + @staticmethod + def preprocess_chain(message_chain: MessageChain) -> None: + preprocess_chain(message_chain) + + @staticmethod + def extract_images(message_chain: MessageChain) -> list[ImageInfo]: + raw_images = extract_images(message_chain) + return [ + ImageInfo( + type=img.get("type", "unknown"), + url=img.get("url"), + path=img.get("path"), + base64_data=img.get("base64_data"), + size=img.get("size"), + error=img.get("error"), + ) + for img in raw_images + ] + + @staticmethod + def image_info_to_dict(image_info: ImageInfo) -> dict: + return image_info.to_dict() + + +# ------------------------------------------------------------------ +# 响应构建器(向后兼容) +# ------------------------------------------------------------------ + + +class ResponseBuilder: + """JSON响应构建器(向后兼容)""" + + @staticmethod + def build_success( + message_chain: MessageChain, + request_id: str, + extra: dict[str, Any] | None = None, + ) -> str: + import json + + response_text = message_chain.get_plain_text() + images = ImageProcessor.extract_images(message_chain) + result = { + "status": "success", + "response": response_text, + "images": [img.to_dict() for img in images], + "request_id": request_id, + } + if extra: + result.update(extra) + return json.dumps(result, ensure_ascii=False) + + @staticmethod + def build_error( + error_msg: str, + request_id: str | None = None, + error_code: str | None = None, + ) -> str: + import json + + result: dict[str, Any] = {"status": "error", "error": error_msg} + if request_id: + result["request_id"] = request_id + if error_code: + result["error_code"] = error_code + return json.dumps(result, ensure_ascii=False) + + +# ------------------------------------------------------------------ +# CLI消息事件 +# ------------------------------------------------------------------ + + +class CLIMessageEvent(AstrMessageEvent): + """CLI消息事件 + + Socket模式下收集管道中所有send()调用的消息,在管道完成(finalize)后统一返回。 + """ + + MAX_BUFFER_SIZE = 100 + + def __init__( + self, + message_str: str, + message_obj: AstrBotMessage, + platform_meta: PlatformMetadata, + session_id: str, + output_queue: asyncio.Queue, + response_future: asyncio.Future = None, + ): + super().__init__( + message_str=message_str, + message_obj=message_obj, + platform_meta=platform_meta, + session_id=session_id, + ) + self.output_queue = output_queue + self.response_future = response_future + self.send_buffer = None + + async def send(self, message_chain: MessageChain) -> dict[str, Any]: + await super().send(message_chain) + + if self.response_future is not None and not self.response_future.done(): + preprocess_chain(message_chain) + + if not self.send_buffer: + self.send_buffer = message_chain + logger.debug("[CLI] First send: buffer initialized") + else: + current_size = len(self.send_buffer.chain) + new_size = len(message_chain.chain) + if current_size + new_size > self.MAX_BUFFER_SIZE: + logger.warning( + f"[CLI] Buffer size limit reached ({current_size} + {new_size} > {self.MAX_BUFFER_SIZE}), truncating" + ) + available = self.MAX_BUFFER_SIZE - current_size + if available > 0: + self.send_buffer.chain.extend(message_chain.chain[:available]) + else: + self.send_buffer.chain.extend(message_chain.chain) + logger.debug( + f"[CLI] Appended to buffer, total: {len(self.send_buffer.chain)}" + ) + else: + await self.output_queue.put(message_chain) + + return {"success": True} + + async def send_streaming( + self, + generator: AsyncGenerator[MessageChain, None], + use_fallback: bool = False, + ) -> None: + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + + if not buffer: + return + + buffer.squash_plain() + await self.send(buffer) + await super().send_streaming(generator, use_fallback) + + async def reply(self, message_chain: MessageChain) -> dict[str, Any]: + return await self.send(message_chain) + + async def finalize(self) -> None: + if self.response_future and not self.response_future.done(): + if self.send_buffer: + self.response_future.set_result(self.send_buffer) + logger.debug( + f"[CLI] Pipeline done, response set with {len(self.send_buffer.chain)} components" + ) + else: + self.response_future.set_result(None) + logger.debug("[CLI] Pipeline done, no response to send") diff --git a/astrbot/core/platform/sources/cli/file_handler.py b/astrbot/core/platform/sources/cli/file_handler.py new file mode 100644 index 0000000000..87f50a3821 --- /dev/null +++ b/astrbot/core/platform/sources/cli/file_handler.py @@ -0,0 +1,125 @@ +"""文件轮询模式处理器""" + +import asyncio +import datetime +import os +from collections.abc import Callable +from typing import TYPE_CHECKING + +from astrbot import logger +from astrbot.core.message.message_event_result import MessageChain + +if TYPE_CHECKING: + from astrbot.core.platform.platform_metadata import PlatformMetadata + + from .cli_event import CLIMessageEvent + + +class FileHandler: + """文件轮询模式处理器""" + + def __init__( + self, + input_file: str, + output_file: str, + poll_interval: float, + message_converter, + platform_meta: "PlatformMetadata", + output_queue: asyncio.Queue, + event_committer: Callable[["CLIMessageEvent"], None], + ): + self.input_file = input_file + self.output_file = output_file + self.poll_interval = poll_interval + self.message_converter = message_converter + self.platform_meta = platform_meta + self.output_queue = output_queue + self.event_committer = event_committer + self._running = False + + async def run(self) -> None: + self._running = True + self._ensure_directories() + logger.info( + f"[CLI] File mode: input={self.input_file}, output={self.output_file}" + ) + + output_task = asyncio.create_task(self._output_loop()) + try: + await self._poll_loop() + finally: + self._running = False + output_task.cancel() + try: + await output_task + except asyncio.CancelledError: + pass + + def stop(self) -> None: + self._running = False + + def _ensure_directories(self) -> None: + for path in (self.input_file, self.output_file): + dir_path = os.path.dirname(path) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + if not os.path.exists(self.input_file): + with open(self.input_file, "w") as f: + f.write("") + + async def _poll_loop(self) -> None: + while self._running: + commands = self._read_commands() + for cmd in commands: + if cmd: + await self._handle_command(cmd) + await asyncio.sleep(self.poll_interval) + + def _read_commands(self) -> list[str]: + try: + if not os.path.exists(self.input_file): + return [] + with open(self.input_file, encoding="utf-8") as f: + content = f.read().strip() + if not content: + return [] + with open(self.input_file, "w", encoding="utf-8") as f: + f.write("") + return [line.strip() for line in content.split("\n") if line.strip()] + except Exception as e: + logger.error(f"[CLI] Failed to read input file: {e}") + return [] + + async def _handle_command(self, text: str) -> None: + from .cli_event import CLIMessageEvent + + message = self.message_converter.convert(text) + message_event = CLIMessageEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.platform_meta, + session_id=message.session_id, + output_queue=self.output_queue, + ) + self.event_committer(message_event) + + async def _output_loop(self) -> None: + while self._running: + try: + message_chain = await asyncio.wait_for( + self.output_queue.get(), timeout=0.5 + ) + self._write_response(message_chain) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + def _write_response(self, message_chain: MessageChain) -> None: + try: + text = message_chain.get_plain_text() + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + with open(self.output_file, "a", encoding="utf-8") as f: + f.write(f"[{timestamp}] Bot: {text}\n") + except Exception as e: + logger.error(f"[CLI] Failed to write output file: {e}") diff --git a/astrbot/core/platform/sources/cli/socket_handler.py b/astrbot/core/platform/sources/cli/socket_handler.py new file mode 100644 index 0000000000..3f7d5cfbf2 --- /dev/null +++ b/astrbot/core/platform/sources/cli/socket_handler.py @@ -0,0 +1,736 @@ +"""Socket处理器模块 + +处理Socket客户端连接和Socket模式的生命周期管理。 +""" + +import asyncio +import json +import os +import re +import tempfile +import traceback +import uuid +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from astrbot import logger +from astrbot.core.message.message_event_result import MessageChain + +if TYPE_CHECKING: + from astrbot.core.platform.platform_metadata import PlatformMetadata + + from .cli_event import CLIMessageEvent + + +# ------------------------------------------------------------------ +# 连接信息写入 +# ------------------------------------------------------------------ + + +def write_connection_info(connection_info: dict[str, Any], data_dir: str) -> None: + """写入连接信息到文件,供客户端读取""" + if not isinstance(connection_info, dict): + raise ValueError("connection_info must be a dict") + + conn_type = connection_info.get("type") + if conn_type not in ("unix", "tcp"): + raise ValueError(f"Invalid type: {conn_type}, must be 'unix' or 'tcp'") + if conn_type == "unix" and "path" not in connection_info: + raise ValueError("Unix socket requires 'path' field") + if conn_type == "tcp" and ( + "host" not in connection_info or "port" not in connection_info + ): + raise ValueError("TCP socket requires 'host' and 'port' fields") + + target_path = os.path.join(data_dir, ".cli_connection") + + try: + fd, temp_path = tempfile.mkstemp( + dir=data_dir, prefix=".cli_connection.", suffix=".tmp" + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(connection_info, f, indent=2) + try: + os.chmod(temp_path, 0o600) + except OSError: + pass + os.replace(temp_path, target_path) + except Exception: + if os.path.exists(temp_path): + os.remove(temp_path) + raise + except Exception as e: + logger.error(f"[CLI] Failed to write connection info: {e}") + raise + + +# ------------------------------------------------------------------ +# 响应构建(从message/response_builder.py内联) +# ------------------------------------------------------------------ + + +def _build_success_response( + message_chain: MessageChain, + request_id: str, + images: list[dict], + extra: dict[str, Any] | None = None, +) -> str: + """构建成功响应JSON""" + result = { + "status": "success", + "response": message_chain.get_plain_text(), + "images": images, + "request_id": request_id, + } + if extra: + result.update(extra) + return json.dumps(result, ensure_ascii=False) + + +def _build_error_response( + error_msg: str, + request_id: str | None = None, + error_code: str | None = None, +) -> str: + """构建错误响应JSON""" + result: dict[str, Any] = {"status": "error", "error": error_msg} + if request_id: + result["request_id"] = request_id + if error_code: + result["error_code"] = error_code + return json.dumps(result, ensure_ascii=False) + + +# ------------------------------------------------------------------ +# Socket客户端处理器 +# ------------------------------------------------------------------ + + +class SocketClientHandler: + """处理单个Socket客户端连接""" + + RECV_BUFFER_SIZE = 4096 + MAX_REQUEST_SIZE = 1024 * 1024 # 1MB + RESPONSE_TIMEOUT = 120.0 + + def __init__( + self, + token_manager, + message_converter, + session_manager, + platform_meta: "PlatformMetadata", + output_queue: asyncio.Queue, + event_committer: Callable[["CLIMessageEvent"], None], + use_isolated_sessions: bool = False, + data_path: str | None = None, + ): + self.token_manager = token_manager + self.message_converter = message_converter + self.session_manager = session_manager + self.platform_meta = platform_meta + self.output_queue = output_queue + self.event_committer = event_committer + self.use_isolated_sessions = use_isolated_sessions + self.data_path = data_path or os.path.join(os.getcwd(), "data") + + async def handle(self, client_socket) -> None: + """处理单个客户端连接""" + try: + loop = asyncio.get_running_loop() + + data = await self._recv_with_limit(loop, client_socket) + if not data: + return + + request = self._parse_request(data) + if request is None: + await self._send_response( + loop, client_socket, _build_error_response("Invalid JSON format") + ) + return + + request_id = request.get("request_id", str(uuid.uuid4())) + auth_token = request.get("auth_token", "") + action = request.get("action", "") + + if not self.token_manager.validate(auth_token): + error_msg = ( + "Unauthorized: missing token" + if not auth_token + else "Unauthorized: invalid token" + ) + await self._send_response( + loop, + client_socket, + _build_error_response(error_msg, request_id, "AUTH_FAILED"), + ) + return + + if action == "get_logs": + response = await self._get_logs(request, request_id) + elif action == "list_tools": + response = self._list_tools(request_id) + elif action == "call_tool": + response = await self._call_tool(request, request_id) + elif action == "list_sessions": + response = await self._list_sessions(request, request_id) + elif action == "list_session_conversations": + response = await self._list_session_conversations(request, request_id) + elif action == "get_session_history": + response = await self._get_session_history(request, request_id) + else: + message_text = request.get("message", "") + response = await self._process_message(message_text, request_id) + + await self._send_response(loop, client_socket, response) + + except Exception as e: + logger.error(f"[CLI] Socket handler error: {e}", exc_info=True) + finally: + try: + client_socket.close() + except Exception as e: + logger.warning(f"[CLI] Failed to close socket: {e}") + + async def _recv_with_limit(self, loop, client_socket) -> bytes: + """接收数据,带大小限制""" + chunks = [] + total_size = 0 + + while True: + chunk = await loop.sock_recv(client_socket, self.RECV_BUFFER_SIZE) + if not chunk: + break + + total_size += len(chunk) + if total_size > self.MAX_REQUEST_SIZE: + logger.warning(f"[CLI] Request too large: {total_size} bytes") + return b"" + + chunks.append(chunk) + if chunk.rstrip().endswith(b"}"): + break + + return b"".join(chunks) + + def _parse_request(self, data: bytes) -> dict | None: + try: + return json.loads(data.decode("utf-8")) + except json.JSONDecodeError: + return None + + async def _send_response(self, loop, client_socket, response: str) -> None: + await loop.sock_sendall(client_socket, response.encode("utf-8")) + + async def _process_message(self, message_text: str, request_id: str) -> str: + """处理消息并返回JSON响应""" + from .cli_event import CLIMessageEvent, extract_images + + response_future = asyncio.Future() + + message = self.message_converter.convert( + message_text, + request_id=request_id, + use_isolated_session=self.use_isolated_sessions, + ) + + self.session_manager.register(message.session_id) + + message_event = CLIMessageEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.platform_meta, + session_id=message.session_id, + output_queue=self.output_queue, + response_future=response_future, + ) + + self.event_committer(message_event) + + try: + message_chain = await asyncio.wait_for( + response_future, timeout=self.RESPONSE_TIMEOUT + ) + if message_chain is None: + return _build_success_response(MessageChain([]), request_id, []) + images = extract_images(message_chain) + return _build_success_response(message_chain, request_id, images) + except asyncio.TimeoutError: + return _build_error_response("Request timeout", request_id, "TIMEOUT") + + async def _get_logs(self, request: dict, request_id: str) -> str: + """获取日志""" + LEVEL_MAP = { + "DEBUG": "DEBUG", + "INFO": "INFO", + "WARNING": "WARN", + "WARN": "WARN", + "ERROR": "ERRO", + "CRITICAL": "CRIT", + } + + try: + lines = min(request.get("lines", 100), 1000) + level_filter = request.get("level", "").upper() + level_filter = LEVEL_MAP.get(level_filter, level_filter) + pattern = request.get("pattern", "") + use_regex = request.get("regex", False) + + log_path = os.path.join(self.data_path, "logs", "astrbot.log") + + if not os.path.exists(log_path): # noqa: ASYNC240 + return json.dumps( + { + "status": "success", + "response": "", + "message": "日志文件未找到。请在配置中启用 log_file_enable 来记录日志到文件。", + "request_id": request_id, + }, + ensure_ascii=False, + ) + + logs = [] + try: + with open(log_path, encoding="utf-8", errors="ignore") as f: + all_lines = f.readlines() + + for line in reversed(all_lines): + if not line.strip(): + continue + if level_filter and not re.search(rf"\[{level_filter}\]", line): + continue + if pattern: + if use_regex: + try: + if not re.search(pattern, line): + continue + except re.error: + if pattern not in line: + continue + else: + if pattern not in line: + continue + logs.append(line.rstrip()) + if len(logs) >= lines: + break + except OSError as e: + logger.warning(f"[CLI] Failed to read log file: {e}") + return _build_error_response( + f"Failed to read log file: {e}", request_id + ) + + logs.reverse() + log_text = "\n".join(logs) + return json.dumps( + { + "status": "success", + "response": log_text, + "message": f"Retrieved {len(logs)} log lines", + "request_id": request_id, + }, + ensure_ascii=False, + ) + + except Exception as e: + logger.exception("[CLI] Error getting logs") + return _build_error_response(f"Error getting logs: {e}", request_id) + + # ------------------------------------------------------------------ + # 函数工具管理(CLI专属) + # ------------------------------------------------------------------ + + def _list_tools(self, request_id: str) -> str: + """列出所有注册的函数工具""" + from astrbot.core.provider.func_tool_manager import get_func_tool_manager + + tool_mgr = get_func_tool_manager() + if tool_mgr is None: + return _build_error_response("FunctionToolManager 未初始化", request_id) + + tools = [] + for tool in tool_mgr.func_list: + # 判断来源 + origin = "unknown" + origin_name = "unknown" + try: + from astrbot.core.agent.mcp_client import MCPTool + from astrbot.core.star.star import star_map + + if isinstance(tool, MCPTool): + origin = "mcp" + origin_name = tool.mcp_server_name + elif tool.handler_module_path and star_map.get( + tool.handler_module_path + ): + origin = "plugin" + origin_name = star_map[tool.handler_module_path].name + else: + origin = "builtin" + origin_name = "builtin" + except Exception: + pass + + tools.append( + { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + "active": tool.active, + "origin": origin, + "origin_name": origin_name, + } + ) + + return json.dumps( + { + "status": "success", + "response": json.dumps(tools, ensure_ascii=False, indent=2), + "tools": tools, + "images": [], + "request_id": request_id, + }, + ensure_ascii=False, + ) + + async def _call_tool(self, request: dict, request_id: str) -> str: + """调用指定的函数工具""" + from astrbot.core.provider.func_tool_manager import get_func_tool_manager + + tool_mgr = get_func_tool_manager() + if tool_mgr is None: + return _build_error_response("FunctionToolManager 未初始化", request_id) + + tool_name = request.get("tool_name", "") + tool_args = request.get("tool_args", {}) + + if not tool_name: + return _build_error_response("缺少 tool_name 参数", request_id) + + tool = tool_mgr.get_func(tool_name) + if tool is None: + return _build_error_response(f"未找到工具: {tool_name}", request_id) + + if not tool.active: + return _build_error_response(f"工具 {tool_name} 当前已停用", request_id) + + if tool.handler is None: + return _build_error_response( + f"工具 {tool_name} 没有可调用的处理函数", request_id + ) + + try: + # 构造一个最小化的 event 用于工具调用 + from .cli_event import CLIMessageEvent + + response_future = asyncio.Future() + message = self.message_converter.convert( + f"/tool call {tool_name}", + request_id=request_id, + use_isolated_session=self.use_isolated_sessions, + ) + event = CLIMessageEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.platform_meta, + session_id=message.session_id, + output_queue=self.output_queue, + response_future=response_future, + ) + + result = await tool.handler(event, **tool_args) + result_text = str(result) if result is not None else "(无返回值)" + + return json.dumps( + { + "status": "success", + "response": result_text, + "images": [], + "request_id": request_id, + }, + ensure_ascii=False, + ) + except Exception as e: + logger.error(f"[CLI] Tool call error: {tool_name}: {e}") + return _build_error_response( + f"调用工具 {tool_name} 失败: {e}\n{traceback.format_exc()[-300:]}", + request_id, + ) + + # ------------------------------------------------------------------ + # 跨会话浏览(CLI专属) + # ------------------------------------------------------------------ + + async def _list_sessions(self, request: dict, request_id: str) -> str: + """列出所有会话""" + from astrbot.core.conversation_mgr import get_conversation_manager + + conv_mgr = get_conversation_manager() + if conv_mgr is None: + return _build_error_response("ConversationManager 未初始化", request_id) + + try: + page = request.get("page", 1) + page_size = request.get("page_size", 20) + platform = request.get("platform") or None + search_query = request.get("search_query") or None + + sessions, total = await conv_mgr.db.get_session_conversations( + page=page, + page_size=page_size, + search_query=search_query, + platform=platform, + ) + + total_pages = (total + page_size - 1) // page_size if total > 0 else 0 + + return json.dumps( + { + "status": "success", + "sessions": sessions, + "total": total, + "page": page, + "page_size": page_size, + "total_pages": total_pages, + "response": f"共 {total} 个会话,第 {page}/{total_pages} 页", + "images": [], + "request_id": request_id, + }, + ensure_ascii=False, + ) + except Exception as e: + logger.exception("[CLI] Error listing sessions") + return _build_error_response(f"列出会话失败: {e}", request_id) + + async def _list_session_conversations(self, request: dict, request_id: str) -> str: + """列出指定会话的所有对话""" + from astrbot.core.conversation_mgr import get_conversation_manager + + conv_mgr = get_conversation_manager() + if conv_mgr is None: + return _build_error_response("ConversationManager 未初始化", request_id) + + session_id = request.get("session_id", "") + if not session_id: + return _build_error_response("缺少 session_id 参数", request_id) + + try: + page = request.get("page", 1) + page_size = request.get("page_size", 20) + + conversations = await conv_mgr.get_conversations( + unified_msg_origin=session_id, + ) + + # 手动分页 + total = len(conversations) + total_pages = (total + page_size - 1) // page_size if total > 0 else 0 + start = (page - 1) * page_size + end = start + page_size + paged = conversations[start:end] + + convs_data = [] + curr_cid = await conv_mgr.get_curr_conversation_id(session_id) + for conv in paged: + convs_data.append( + { + "cid": conv.cid, + "title": conv.title or "(无标题)", + "persona_id": conv.persona_id, + "created_at": conv.created_at, + "updated_at": conv.updated_at, + "token_usage": conv.token_usage, + "is_current": conv.cid == curr_cid, + } + ) + + return json.dumps( + { + "status": "success", + "conversations": convs_data, + "total": total, + "page": page, + "page_size": page_size, + "total_pages": total_pages, + "current_cid": curr_cid, + "response": f"会话 {session_id} 共 {total} 个对话,第 {page}/{total_pages} 页", + "images": [], + "request_id": request_id, + }, + ensure_ascii=False, + ) + except Exception as e: + logger.exception("[CLI] Error listing session conversations") + return _build_error_response(f"列出会话对话失败: {e}", request_id) + + async def _get_session_history(self, request: dict, request_id: str) -> str: + """获取指定会话的聊天记录""" + from astrbot.core.conversation_mgr import get_conversation_manager + + conv_mgr = get_conversation_manager() + if conv_mgr is None: + return _build_error_response("ConversationManager 未初始化", request_id) + + session_id = request.get("session_id", "") + if not session_id: + return _build_error_response("缺少 session_id 参数", request_id) + + try: + conversation_id = request.get("conversation_id") or None + page = request.get("page", 1) + page_size = request.get("page_size", 10) + + # 如果未指定 conversation_id,获取当前对话 + if not conversation_id: + conversation_id = await conv_mgr.get_curr_conversation_id(session_id) + + if not conversation_id: + return json.dumps( + { + "status": "success", + "history": [], + "total_pages": 0, + "page": page, + "conversation_id": None, + "response": f"会话 {session_id} 没有活跃的对话", + "images": [], + "request_id": request_id, + }, + ensure_ascii=False, + ) + + conversation = await conv_mgr.get_conversation(session_id, conversation_id) + if not conversation: + return json.dumps( + { + "status": "success", + "history": [], + "total_pages": 0, + "page": page, + "conversation_id": conversation_id, + "session_id": session_id, + "response": "(无记录)", + "images": [], + "request_id": request_id, + }, + ensure_ascii=False, + ) + + raw_history = json.loads(conversation.history) + + # 构建简洁的消息对列表,每对是 {"role": ..., "text": ...} + messages = [] + for record in raw_history: + role = record.get("role", "") + if role not in ("user", "assistant"): + continue + text = _extract_content_text(record) + messages.append({"role": role, "text": text}) + + # 分页(按消息条数) + total = len(messages) + total_pages = (total + page_size - 1) // page_size if total > 0 else 0 + start = (page - 1) * page_size + end = start + page_size + paged = messages[start:end] + + return json.dumps( + { + "status": "success", + "history": paged, + "total": total, + "total_pages": total_pages, + "page": page, + "conversation_id": conversation_id, + "session_id": session_id, + "response": "", + "images": [], + "request_id": request_id, + }, + ensure_ascii=False, + ) + except Exception as e: + logger.exception("[CLI] Error getting session history") + return _build_error_response(f"获取聊天记录失败: {e}", request_id) + + +def _extract_content_text(record: dict) -> str: + """从 OpenAI 格式的消息记录中提取纯文本,图片用 [图片] 占位。""" + content = record.get("content") + + # content 是字符串(最常见情况) + if isinstance(content, str): + return content + + # content 是 list(多部分内容,可能含图片) + if isinstance(content, list): + parts = [] + for part in content: + if isinstance(part, dict): + part_type = part.get("type", "") + if part_type == "text": + parts.append(part.get("text", "")) + elif part_type in ("image_url", "image"): + parts.append("[图片]") + else: + parts.append(f"[{part_type}]") + elif isinstance(part, str): + parts.append(part) + return " ".join(parts) if parts else "" + + # content 为 None(tool_calls 等情况) + if content is None: + if "tool_calls" in record: + names = [] + for tc in record.get("tool_calls", []): + fn = tc.get("function", {}) + names.append(fn.get("name", "?")) + return f"[调用工具: {', '.join(names)}]" + return "" + + return str(content) + + +# ------------------------------------------------------------------ +# Socket模式处理器 +# ------------------------------------------------------------------ + + +class SocketModeHandler: + """管理Socket服务器的生命周期""" + + def __init__( + self, + server, + client_handler: SocketClientHandler, + connection_info_writer: Callable[[dict, str], None], + data_path: str, + ): + self.server = server + self.client_handler = client_handler + self.connection_info_writer = connection_info_writer + self.data_path = data_path + self._running = False + + async def run(self) -> None: + self._running = True + try: + await self.server.start() + logger.info(f"[CLI] Socket server started: {type(self.server).__name__}") + + connection_info = self.server.get_connection_info() + self.connection_info_writer(connection_info, self.data_path) + + while self._running: + try: + client_socket, _ = await self.server.accept_connection() + asyncio.create_task(self.client_handler.handle(client_socket)) + except Exception as e: + if self._running: + logger.error(f"[CLI] Socket accept error: {e}") + await asyncio.sleep(0.1) + finally: + await self.server.stop() + + def stop(self) -> None: + self._running = False diff --git a/astrbot/core/platform/sources/cli/socket_server.py b/astrbot/core/platform/sources/cli/socket_server.py new file mode 100644 index 0000000000..f532656d3b --- /dev/null +++ b/astrbot/core/platform/sources/cli/socket_server.py @@ -0,0 +1,249 @@ +"""Socket服务器模块 + +提供Unix Socket和TCP Socket服务器实现,以及平台检测和工厂函数。 +""" + +import asyncio +import os +import platform as platform_mod +import socket +import sys +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Literal + +from astrbot import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +# ------------------------------------------------------------------ +# 平台检测 +# ------------------------------------------------------------------ + + +@dataclass +class PlatformInfo: + """平台信息""" + + os_type: Literal["windows", "linux", "darwin"] + python_version: tuple[int, int, int] + supports_unix_socket: bool + + +def detect_platform() -> PlatformInfo: + """检测当前平台信息""" + system = platform_mod.system() + if system == "Windows": + os_type = "windows" + elif system == "Linux": + os_type = "linux" + elif system == "Darwin": + os_type = "darwin" + else: + os_type = "linux" + + vi = sys.version_info + python_version = ( + (vi.major, vi.minor, vi.micro) + if hasattr(vi, "major") + else (vi[0], vi[1], vi[2]) + ) + + supports_unix = os_type in ("linux", "darwin") + if os_type == "windows" and python_version >= (3, 9, 0): + try: + if hasattr(socket, "AF_UNIX"): + test_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + test_sock.close() + supports_unix = True + except (OSError, AttributeError): + pass + + info = PlatformInfo( + os_type=os_type, + python_version=python_version, + supports_unix_socket=supports_unix, + ) + logger.info( + f"[CLI] Platform: {info.os_type}, Python {info.python_version}, unix_socket={info.supports_unix_socket}" + ) + return info + + +# ------------------------------------------------------------------ +# Socket服务器抽象基类 +# ------------------------------------------------------------------ + + +class AbstractSocketServer(ABC): + """Socket服务器抽象基类""" + + @abstractmethod + async def start(self) -> None: + pass + + @abstractmethod + async def stop(self) -> None: + pass + + @abstractmethod + async def accept_connection(self) -> tuple[Any, Any]: + pass + + @abstractmethod + def get_connection_info(self) -> dict: + pass + + +# ------------------------------------------------------------------ +# TCP Socket服务器 +# ------------------------------------------------------------------ + + +class TCPSocketServer(AbstractSocketServer): + """TCP Socket服务器,用于Windows或显式指定TCP的场景""" + + def __init__( + self, host: str = "127.0.0.1", port: int = 0, auth_token: str | None = None + ): + self.host = host + self.port = port + self.auth_token = auth_token + self.server_socket: socket.socket | None = None + self.actual_port: int = port + self._is_running = False + + async def start(self) -> None: + if self._is_running: + raise RuntimeError("Server is already running") + + try: + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.server_socket.bind((self.host, self.port)) + self.actual_port = self.server_socket.getsockname()[1] + self.server_socket.listen(5) + self.server_socket.setblocking(False) + self._is_running = True + logger.info(f"[CLI] TCP server listening on {self.host}:{self.actual_port}") + except Exception: + if self.server_socket: + self.server_socket.close() + self.server_socket = None + raise + + async def stop(self) -> None: + if not self._is_running and self.server_socket is None: + return + if self.server_socket: + self.server_socket.close() + self.server_socket = None + self._is_running = False + logger.info("[CLI] TCP server stopped") + + async def accept_connection(self) -> tuple[Any, Any]: + if not self._is_running or self.server_socket is None: + raise RuntimeError("Server is not running") + loop = asyncio.get_running_loop() + return await loop.sock_accept(self.server_socket) + + def get_connection_info(self) -> dict: + return {"type": "tcp", "host": self.host, "port": self.actual_port} + + +# ------------------------------------------------------------------ +# Unix Socket服务器 +# ------------------------------------------------------------------ + + +class UnixSocketServer(AbstractSocketServer): + """Unix Domain Socket服务器""" + + def __init__(self, socket_path: str, auth_token: str | None = None) -> None: + if not socket_path: + raise ValueError("socket_path cannot be empty") + self.socket_path = socket_path + self.auth_token = auth_token + self._server_socket: socket.socket | None = None + self._running = False + + async def start(self) -> None: + if self._running: + raise RuntimeError("Server is already running") + + if os.path.exists(self.socket_path): + os.remove(self.socket_path) + + self._server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._server_socket.bind(self.socket_path) + os.chmod(self.socket_path, 0o600) + self._server_socket.listen(5) + self._server_socket.setblocking(False) + self._running = True + logger.info(f"[CLI] Unix socket server listening on {self.socket_path}") + + async def stop(self) -> None: + self._running = False + if self._server_socket is not None: + try: + self._server_socket.close() + except Exception as e: + logger.error(f"[CLI] Failed to close socket: {e}") + if os.path.exists(self.socket_path): + try: + os.remove(self.socket_path) + except Exception as e: + logger.error(f"[CLI] Failed to remove socket file: {e}") + self._server_socket = None + logger.info("[CLI] Unix socket server stopped") + + async def accept_connection(self) -> tuple[Any, Any]: + if not self._running or self._server_socket is None: + raise RuntimeError("Server is not started") + loop = asyncio.get_running_loop() + return await loop.sock_accept(self._server_socket) + + def get_connection_info(self) -> dict: + return {"type": "unix", "path": self.socket_path} + + +# ------------------------------------------------------------------ +# 工厂函数 +# ------------------------------------------------------------------ + + +def create_socket_server( + platform_info: PlatformInfo, config: dict, auth_token: str | None +) -> AbstractSocketServer: + """根据平台和配置创建Socket服务器""" + socket_type = config.get("socket_type", "auto") + + if socket_type == "tcp": + use_tcp = True + elif socket_type == "unix": + use_tcp = False + elif socket_type == "auto": + use_tcp = ( + platform_info.os_type == "windows" + and not platform_info.supports_unix_socket + ) + else: + logger.warning( + f"[CLI] Invalid socket_type '{socket_type}', falling back to auto" + ) + use_tcp = ( + platform_info.os_type == "windows" + and not platform_info.supports_unix_socket + ) + + if use_tcp: + host = config.get("tcp_host", "127.0.0.1") + port = config.get("tcp_port", 0) + server = TCPSocketServer(host=host, port=port, auth_token=auth_token) + else: + socket_path = config.get("socket_path") or os.path.join( + get_astrbot_temp_path(), "astrbot.sock" + ) + server = UnixSocketServer(socket_path=socket_path, auth_token=auth_token) + + logger.info(f"[CLI] Created {server.__class__.__name__}") + return server diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 106b42cc5b..9aa0814943 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -103,8 +103,19 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"{e!s}" +# 模块级单例引用,供无法通过依赖注入获取实例的模块使用(如CLI适配器) +_global_instance: FunctionToolManager | None = None + + +def get_func_tool_manager() -> FunctionToolManager | None: + """获取全局 FunctionToolManager 实例""" + return _global_instance + + class FunctionToolManager: def __init__(self) -> None: + global _global_instance + _global_instance = self self.func_list: list[FuncTool] = [] self.mcp_client_dict: dict[str, MCPClient] = {} """MCP 服务列表""" diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py index 874bc19db7..487bad62e9 100644 --- a/astrbot/dashboard/routes/platform.py +++ b/astrbot/dashboard/routes/platform.py @@ -81,7 +81,13 @@ def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None: 平台适配器实例,未找到则返回 None """ for platform in self.platform_manager.platform_insts: - if platform.config.get("webhook_uuid") == webhook_uuid: + config = platform.config + uuid_val = ( + config.get("webhook_uuid") + if isinstance(config, dict) + else getattr(config, "webhook_uuid", None) + ) + if uuid_val == webhook_uuid: if platform.unified_webhook(): return platform return None diff --git a/pyproject.toml b/pyproject.toml index 9e421c3038..da9e0dab21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ dev = [ [project.scripts] astrbot = "astrbot.cli.__main__:cli" +astr = "astrbot.cli.client.__main__:main" [tool.ruff] exclude = ["astrbot/core/utils/t2i/local_strategy.py", "astrbot/api/all.py", "tests"] @@ -110,6 +111,10 @@ reportMissingImports = false include = ["astrbot"] exclude = ["dashboard", "node_modules", "dist", "data", "tests"] +[tool.pytest.ini_options] +asyncio_mode = "strict" +addopts = "--ignore=tests/test_cli" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/tests/test_cli/__init__.py b/tests/test_cli/__init__.py new file mode 100644 index 0000000000..fc247f7499 --- /dev/null +++ b/tests/test_cli/__init__.py @@ -0,0 +1 @@ +"""CLI模块单元测试""" diff --git a/tests/test_cli/conftest.py b/tests/test_cli/conftest.py new file mode 100644 index 0000000000..6c803f372a --- /dev/null +++ b/tests/test_cli/conftest.py @@ -0,0 +1,37 @@ +"""CLI测试共享fixtures""" + +import asyncio +from unittest.mock import MagicMock + +import pytest + + +@pytest.fixture +def event_loop(): + """创建事件循环""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +def mock_platform_meta(): + """创建模拟的PlatformMetadata""" + meta = MagicMock() + meta.name = "cli" + return meta + + +@pytest.fixture +def mock_output_queue(): + """创建模拟的输出队列""" + return asyncio.Queue() + + +@pytest.fixture +def mock_message_chain(): + """创建模拟的MessageChain""" + chain = MagicMock() + chain.chain = [] + chain.get_plain_text.return_value = "Test response" + return chain diff --git a/tests/test_cli/test_client_commands.py b/tests/test_cli/test_client_commands.py new file mode 100644 index 0000000000..8790b58be7 --- /dev/null +++ b/tests/test_cli/test_client_commands.py @@ -0,0 +1,742 @@ +"""CLI Client 命令模块单元测试 + +使用 click.testing.CliRunner 测试 CLI 命令的参数解析和消息映射。 +""" + +import json +from unittest.mock import patch + +from click.testing import CliRunner + +from astrbot.cli.client.__main__ import main + + +def _mock_send(response_text="OK", status="success"): + """创建 mock send_message 返回指定响应""" + return {"status": status, "response": response_text, "images": []} + + +def _mock_send_error(error_text="Connection error"): + """创建 mock send_message 返回错误""" + return {"status": "error", "error": error_text} + + +class TestSendCommand: + """send 命令测试""" + + @patch("astrbot.cli.client.commands.send.send_message") + def test_basic_send(self, mock_send): + """基本消息发送""" + mock_send.return_value = _mock_send("你好!") + runner = CliRunner() + result = runner.invoke(main, ["send", "你好"]) + + assert result.exit_code == 0 + assert "你好!" in result.output + mock_send.assert_called_once_with("你好", None, 30.0) + + @patch("astrbot.cli.client.commands.send.send_message") + def test_send_with_json(self, mock_send): + """JSON 输出""" + mock_send.return_value = _mock_send("hello") + runner = CliRunner() + result = runner.invoke(main, ["send", "-j", "hello"]) + + assert result.exit_code == 0 + output = json.loads(result.output) + assert output["status"] == "success" + + @patch("astrbot.cli.client.commands.send.send_message") + def test_send_multi_word(self, mock_send): + """多个词拼接""" + mock_send.return_value = _mock_send("response") + runner = CliRunner() + result = runner.invoke(main, ["send", "hello", "world"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("hello world", None, 30.0) + + @patch("astrbot.cli.client.commands.send.send_message") + def test_implicit_send(self, mock_send): + """astr 你好 隐式路由到 send""" + mock_send.return_value = _mock_send("response") + runner = CliRunner() + result = runner.invoke(main, ["你好"]) + + assert result.exit_code == 0 + mock_send.assert_called_once() + + @patch("astrbot.cli.client.commands.send.send_message") + def test_implicit_json_flag(self, mock_send): + """astr -j "test" 隐式路由到 send -j""" + mock_send.return_value = _mock_send("response") + runner = CliRunner() + result = runner.invoke(main, ["-j", "test"]) + + assert result.exit_code == 0 + output = json.loads(result.output) + assert output["status"] == "success" + + @patch("astrbot.cli.client.commands.send.send_message") + def test_send_error(self, mock_send): + """发送错误时退出码为 1""" + mock_send.return_value = _mock_send_error("Connection refused") + runner = CliRunner() + result = runner.invoke(main, ["send", "hello"]) + + assert result.exit_code == 1 + + def test_send_no_message(self): + """无消息内容时报错""" + runner = CliRunner() + result = runner.invoke(main, ["send"]) + + assert result.exit_code == 1 + + @patch("astrbot.cli.client.commands.send.send_message") + def test_send_with_timeout(self, mock_send): + """自定义超时时间""" + mock_send.return_value = _mock_send("ok") + runner = CliRunner() + result = runner.invoke(main, ["send", "-t", "60", "hello"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("hello", None, 60.0) + + @patch("astrbot.cli.client.commands.send.send_message") + def test_pipe_input(self, mock_send): + """管道输入""" + mock_send.return_value = _mock_send("piped") + runner = CliRunner() + result = runner.invoke(main, ["send"], input="hello from pipe") + + assert result.exit_code == 0 + mock_send.assert_called_once_with("hello from pipe", None, 30.0) + + +class TestLogCommand: + """log 命令测试""" + + @patch("astrbot.cli.client.commands.log._read_log_from_file") + def test_log_default(self, mock_read): + """默认读取文件日志""" + runner = CliRunner() + result = runner.invoke(main, ["log"]) + + assert result.exit_code == 0 + mock_read.assert_called_once_with(100, "", "", False) + + @patch("astrbot.cli.client.commands.log._read_log_from_file") + def test_log_with_options(self, mock_read): + """带选项读取日志""" + runner = CliRunner() + result = runner.invoke( + main, ["log", "--lines", "50", "--level", "ERROR", "--pattern", "test"] + ) + + assert result.exit_code == 0 + mock_read.assert_called_once_with(50, "ERROR", "test", False) + + @patch("astrbot.cli.client.commands.log._read_log_from_file") + def test_log_regex(self, mock_read): + """正则匹配日志""" + runner = CliRunner() + result = runner.invoke(main, ["log", "--pattern", "ERR|WARN", "--regex"]) + + assert result.exit_code == 0 + mock_read.assert_called_once_with(100, "", "ERR|WARN", True) + + @patch("astrbot.cli.client.commands.log._read_log_from_file") + def test_log_compat_flag(self, mock_read): + """--log 兼容旧用法""" + runner = CliRunner() + result = runner.invoke(main, ["--log"]) + + assert result.exit_code == 0 + mock_read.assert_called_once() + + @patch("astrbot.cli.client.commands.log.get_logs") + def test_log_socket_mode(self, mock_get_logs): + """Socket 模式获取日志""" + mock_get_logs.return_value = { + "status": "success", + "response": "log line 1\nlog line 2", + } + runner = CliRunner() + result = runner.invoke(main, ["log", "--socket"]) + + assert result.exit_code == 0 + assert "log line 1" in result.output + + +class TestConvCommand: + """conv 命令组测试""" + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_ls(self, mock_send): + """列出对话""" + mock_send.return_value = _mock_send("对话列表...") + runner = CliRunner() + result = runner.invoke(main, ["conv", "ls"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/ls") + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_ls_page(self, mock_send): + """带页码列出对话""" + mock_send.return_value = _mock_send("第2页") + runner = CliRunner() + result = runner.invoke(main, ["conv", "ls", "2"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/ls 2") + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_new(self, mock_send): + """创建新对话""" + mock_send.return_value = _mock_send("已创建") + runner = CliRunner() + result = runner.invoke(main, ["conv", "new"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/new") + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_switch(self, mock_send): + """切换对话""" + mock_send.return_value = _mock_send("已切换") + runner = CliRunner() + result = runner.invoke(main, ["conv", "switch", "3"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/switch 3") + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_del(self, mock_send): + """删除对话""" + mock_send.return_value = _mock_send("已删除") + runner = CliRunner() + result = runner.invoke(main, ["conv", "del"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/del") + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_rename(self, mock_send): + """重命名对话""" + mock_send.return_value = _mock_send("已重命名") + runner = CliRunner() + result = runner.invoke(main, ["conv", "rename", "新名称"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/rename 新名称") + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_reset(self, mock_send): + """重置 LLM 会话""" + mock_send.return_value = _mock_send("已重置") + runner = CliRunner() + result = runner.invoke(main, ["conv", "reset"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/reset") + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_history(self, mock_send): + """查看对话记录""" + mock_send.return_value = _mock_send("记录...") + runner = CliRunner() + result = runner.invoke(main, ["conv", "history"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/history") + + @patch("astrbot.cli.client.commands.conv.send_message") + def test_conv_history_page(self, mock_send): + """带页码查看记录""" + mock_send.return_value = _mock_send("第2页") + runner = CliRunner() + result = runner.invoke(main, ["conv", "history", "2"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/history 2") + + +class TestPluginCommand: + """plugin 命令组测试""" + + @patch("astrbot.cli.client.commands.plugin.send_message") + def test_plugin_ls(self, mock_send): + """列出插件""" + mock_send.return_value = _mock_send("插件列表") + runner = CliRunner() + result = runner.invoke(main, ["plugin", "ls"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/plugin ls") + + @patch("astrbot.cli.client.commands.plugin.send_message") + def test_plugin_on(self, mock_send): + """启用插件""" + mock_send.return_value = _mock_send("已启用") + runner = CliRunner() + result = runner.invoke(main, ["plugin", "on", "myplugin"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/plugin on myplugin") + + @patch("astrbot.cli.client.commands.plugin.send_message") + def test_plugin_off(self, mock_send): + """禁用插件""" + mock_send.return_value = _mock_send("已禁用") + runner = CliRunner() + result = runner.invoke(main, ["plugin", "off", "myplugin"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/plugin off myplugin") + + @patch("astrbot.cli.client.commands.plugin.send_message") + def test_plugin_help(self, mock_send): + """获取插件帮助""" + mock_send.return_value = _mock_send("帮助信息") + runner = CliRunner() + result = runner.invoke(main, ["plugin", "help", "myplugin"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/plugin help myplugin") + + @patch("astrbot.cli.client.commands.plugin.send_message") + def test_plugin_help_no_name(self, mock_send): + """获取通用插件帮助""" + mock_send.return_value = _mock_send("通用帮助") + runner = CliRunner() + result = runner.invoke(main, ["plugin", "help"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/plugin help") + + +class TestProviderModelKey: + """provider/model/key 命令测试""" + + @patch("astrbot.cli.client.commands.provider.send_message") + def test_provider_list(self, mock_send): + """查看 Provider 列表""" + mock_send.return_value = _mock_send("provider list") + runner = CliRunner() + result = runner.invoke(main, ["provider"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/provider") + + @patch("astrbot.cli.client.commands.provider.send_message") + def test_provider_switch(self, mock_send): + """切换 Provider""" + mock_send.return_value = _mock_send("switched") + runner = CliRunner() + result = runner.invoke(main, ["provider", "2"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/provider 2") + + @patch("astrbot.cli.client.commands.provider.send_message") + def test_model_list(self, mock_send): + """查看模型列表""" + mock_send.return_value = _mock_send("model list") + runner = CliRunner() + result = runner.invoke(main, ["model"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/model") + + @patch("astrbot.cli.client.commands.provider.send_message") + def test_model_switch(self, mock_send): + """切换模型""" + mock_send.return_value = _mock_send("switched") + runner = CliRunner() + result = runner.invoke(main, ["model", "gpt-4"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/model gpt-4") + + @patch("astrbot.cli.client.commands.provider.send_message") + def test_key_list(self, mock_send): + """查看 Key 列表""" + mock_send.return_value = _mock_send("key list") + runner = CliRunner() + result = runner.invoke(main, ["key"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/key") + + @patch("astrbot.cli.client.commands.provider.send_message") + def test_key_switch(self, mock_send): + """切换 Key""" + mock_send.return_value = _mock_send("switched") + runner = CliRunner() + result = runner.invoke(main, ["key", "1"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/key 1") + + +class TestDebugCommands: + """调试命令测试""" + + @patch("astrbot.cli.client.commands.debug.send_message") + def test_ping(self, mock_send): + """ping 测试""" + mock_send.return_value = _mock_send("help text") + runner = CliRunner() + result = runner.invoke(main, ["ping"]) + + assert result.exit_code == 0 + assert "pong" in result.output + + @patch("astrbot.cli.client.commands.debug.send_message") + def test_ping_count(self, mock_send): + """多次 ping""" + mock_send.return_value = _mock_send("help text") + runner = CliRunner() + result = runner.invoke(main, ["ping", "-c", "3"]) + + assert result.exit_code == 0 + assert result.output.count("pong") == 3 + + @patch("astrbot.cli.client.commands.debug.send_message") + @patch("astrbot.cli.client.commands.debug.load_auth_token", return_value="tok123") + @patch( + "astrbot.cli.client.commands.debug.load_connection_info", + return_value={"type": "tcp", "host": "127.0.0.1", "port": 12345}, + ) + def test_status(self, mock_conn, mock_token, mock_send): + """status 命令""" + mock_send.return_value = _mock_send("help text") + runner = CliRunner() + result = runner.invoke(main, ["status"]) + + assert result.exit_code == 0 + assert "TCP" in result.output + assert "127.0.0.1" in result.output + assert "在线" in result.output + + @patch("astrbot.cli.client.commands.debug.send_message") + def test_test_echo(self, mock_send): + """test echo 命令""" + mock_send.return_value = _mock_send("echo response") + runner = CliRunner() + result = runner.invoke(main, ["test", "echo", "hello"]) + + assert result.exit_code == 0 + assert "hello" in result.output + assert "echo response" in result.output + + @patch("astrbot.cli.client.commands.debug.send_message") + def test_test_plugin(self, mock_send): + """test plugin 命令""" + mock_send.return_value = _mock_send("plugin response") + runner = CliRunner() + result = runner.invoke(main, ["test", "plugin", "hello", "world"]) + + assert result.exit_code == 0 + mock_send.assert_called_once_with("/hello world") + + +class TestAliasCommands: + """快捷别名命令测试""" + + @patch("astrbot.cli.client.connection.send_message") + def test_help_alias(self, mock_send): + """help 别名""" + mock_send.return_value = _mock_send("help text") + runner = CliRunner() + result = runner.invoke(main, ["help"]) + + assert result.exit_code == 0 + mock_send.assert_called_with("/help") + + @patch("astrbot.cli.client.connection.send_message") + def test_sid_alias(self, mock_send): + """sid 别名""" + mock_send.return_value = _mock_send("session_123") + runner = CliRunner() + result = runner.invoke(main, ["sid"]) + + assert result.exit_code == 0 + mock_send.assert_called_with("/sid") + + @patch("astrbot.cli.client.connection.send_message") + def test_t2i_alias(self, mock_send): + """t2i 别名""" + mock_send.return_value = _mock_send("toggled") + runner = CliRunner() + result = runner.invoke(main, ["t2i"]) + + assert result.exit_code == 0 + mock_send.assert_called_with("/t2i") + + @patch("astrbot.cli.client.connection.send_message") + def test_tts_alias(self, mock_send): + """tts 别名""" + mock_send.return_value = _mock_send("toggled") + runner = CliRunner() + result = runner.invoke(main, ["tts"]) + + assert result.exit_code == 0 + mock_send.assert_called_with("/tts") + + +class TestBatchCommand: + """batch 命令测试""" + + @patch("astrbot.cli.client.connection.send_message") + def test_batch(self, mock_send, tmp_path): + """批量执行""" + mock_send.return_value = _mock_send("ok") + + batch_file = tmp_path / "commands.txt" + batch_file.write_text("hello\n# comment\n/help\n\n/plugin ls\n") + + runner = CliRunner() + result = runner.invoke(main, ["batch", str(batch_file)]) + + assert result.exit_code == 0 + assert mock_send.call_count == 3 + mock_send.assert_any_call("hello") + mock_send.assert_any_call("/help") + mock_send.assert_any_call("/plugin ls") + + +class TestBackwardCompatibility: + """向后兼容性测试""" + + @patch("astrbot.cli.client.commands.send.send_message") + def test_astr_hello(self, mock_send): + """astr 你好 → astr send 你好""" + mock_send.return_value = _mock_send("hi") + runner = CliRunner() + result = runner.invoke(main, ["你好"]) + + assert result.exit_code == 0 + mock_send.assert_called_once() + + @patch("astrbot.cli.client.commands.log._read_log_from_file") + def test_astr_log_flag(self, mock_read): + """astr --log → astr log""" + runner = CliRunner() + result = runner.invoke(main, ["--log"]) + + assert result.exit_code == 0 + mock_read.assert_called_once() + + def test_help_output(self): + """帮助输出包含新命令""" + runner = CliRunner() + result = runner.invoke(main, ["--help"]) + + assert result.exit_code == 0 + assert "send" in result.output + assert "log" in result.output + assert "conv" in result.output + assert "plugin" in result.output + assert "provider" in result.output + assert "ping" in result.output + + +class TestSessionCommand: + """session 命令组测试""" + + @patch("astrbot.cli.client.commands.session.list_sessions") + def test_session_ls(self, mock_list): + """列出所有会话""" + mock_list.return_value = { + "status": "success", + "sessions": [ + { + "session_id": "cli:FriendMessage:cli_session", + "conversation_id": "conv-123", + "title": "测试对话", + "persona_id": None, + "persona_name": None, + } + ], + "total": 1, + "page": 1, + "page_size": 20, + "total_pages": 1, + "response": "共 1 个会话", + } + runner = CliRunner() + result = runner.invoke(main, ["session", "ls"]) + + assert result.exit_code == 0 + assert "cli_session" in result.output + mock_list.assert_called_once() + + @patch("astrbot.cli.client.commands.session.list_sessions") + def test_session_ls_with_platform(self, mock_list): + """按平台过滤会话""" + mock_list.return_value = { + "status": "success", + "sessions": [], + "total": 0, + "page": 1, + "page_size": 20, + "total_pages": 0, + "response": "共 0 个会话", + } + runner = CliRunner() + result = runner.invoke(main, ["session", "ls", "-P", "qq"]) + + assert result.exit_code == 0 + mock_list.assert_called_once_with( + page=1, page_size=20, platform="qq", search_query=None + ) + + @patch("astrbot.cli.client.commands.session.list_sessions") + def test_session_ls_json(self, mock_list): + """JSON 输出""" + mock_list.return_value = { + "status": "success", + "sessions": [], + "total": 0, + "page": 1, + "page_size": 20, + "total_pages": 0, + "response": "共 0 个会话", + } + runner = CliRunner() + result = runner.invoke(main, ["session", "ls", "-j"]) + + assert result.exit_code == 0 + output = json.loads(result.output) + assert output["status"] == "success" + + @patch("astrbot.cli.client.commands.session.list_sessions") + def test_session_ls_error(self, mock_list): + """错误响应""" + mock_list.return_value = {"status": "error", "error": "未初始化"} + runner = CliRunner() + result = runner.invoke(main, ["session", "ls"]) + + assert result.exit_code == 1 + + @patch("astrbot.cli.client.commands.session.list_session_conversations") + def test_session_convs(self, mock_convs): + """查看指定会话的对话列表""" + mock_convs.return_value = { + "status": "success", + "conversations": [ + { + "cid": "conv-abc", + "title": "测试对话", + "persona_id": None, + "created_at": 1700000000, + "updated_at": 1700000000, + "token_usage": 100, + "is_current": True, + } + ], + "total": 1, + "page": 1, + "page_size": 20, + "total_pages": 1, + "current_cid": "conv-abc", + "response": "共 1 个对话", + } + runner = CliRunner() + result = runner.invoke( + main, ["session", "convs", "cli:FriendMessage:cli_session"] + ) + + assert result.exit_code == 0 + assert "conv-abc" in result.output + assert "测试对话" in result.output + + @patch("astrbot.cli.client.commands.session.list_session_conversations") + def test_session_convs_json(self, mock_convs): + """对话列表 JSON 输出""" + mock_convs.return_value = { + "status": "success", + "conversations": [], + "total": 0, + "page": 1, + "page_size": 20, + "total_pages": 0, + "current_cid": None, + "response": "共 0 个对话", + } + runner = CliRunner() + result = runner.invoke(main, ["session", "convs", "test_session", "-j"]) + + assert result.exit_code == 0 + output = json.loads(result.output) + assert output["status"] == "success" + + @patch("astrbot.cli.client.commands.session.get_session_history") + def test_session_history(self, mock_history): + """查看指定会话的聊天记录""" + mock_history.return_value = { + "status": "success", + "history": [ + {"role": "user", "text": "你好"}, + {"role": "assistant", "text": "你好!"}, + ], + "total_pages": 1, + "page": 1, + "conversation_id": "conv-abc", + "session_id": "cli:FriendMessage:cli_session", + "response": "", + } + runner = CliRunner() + result = runner.invoke( + main, ["session", "history", "cli:FriendMessage:cli_session"] + ) + + assert result.exit_code == 0 + assert "You: 你好" in result.output + assert "AI: 你好!" in result.output + + @patch("astrbot.cli.client.commands.session.get_session_history") + def test_session_history_with_cid(self, mock_history): + """指定对话 ID 查看聊天记录""" + mock_history.return_value = { + "status": "success", + "history": [], + "total_pages": 0, + "page": 1, + "conversation_id": "conv-xyz", + "session_id": "test_session", + "response": "(无记录)", + } + runner = CliRunner() + result = runner.invoke( + main, ["session", "history", "test_session", "-c", "conv-xyz"] + ) + + assert result.exit_code == 0 + mock_history.assert_called_once_with( + session_id="test_session", + conversation_id="conv-xyz", + page=1, + page_size=10, + ) + + @patch("astrbot.cli.client.commands.session.list_sessions") + def test_session_ls_empty(self, mock_list): + """空会话列表""" + mock_list.return_value = { + "status": "success", + "sessions": [], + "total": 0, + "page": 1, + "page_size": 20, + "total_pages": 0, + "response": "共 0 个会话", + } + runner = CliRunner() + result = runner.invoke(main, ["session", "ls"]) + + assert result.exit_code == 0 + assert "没有找到会话" in result.output diff --git a/tests/test_cli/test_client_connection.py b/tests/test_cli/test_client_connection.py new file mode 100644 index 0000000000..c21acca9e7 --- /dev/null +++ b/tests/test_cli/test_client_connection.py @@ -0,0 +1,267 @@ +"""CLI Client 连接模块单元测试""" + +import json +import os +from unittest.mock import MagicMock, patch + +import pytest + +from astrbot.cli.client.connection import ( + _receive_json_response, + connect_to_server, + get_data_path, + get_logs, + get_temp_path, + load_auth_token, + load_connection_info, + send_message, +) + + +class TestGetDataPath: + """get_data_path 路径解析测试""" + + def test_env_var_priority(self, tmp_path): + """环境变量 ASTRBOT_ROOT 优先""" + with patch.dict(os.environ, {"ASTRBOT_ROOT": str(tmp_path)}): + result = get_data_path() + assert result == os.path.join(str(tmp_path), "data") + + def test_fallback_to_source_root(self): + """回退到源码安装目录""" + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("ASTRBOT_ROOT", None) + result = get_data_path() + assert result.endswith("data") + + def test_no_env_returns_data_suffix(self): + """返回路径以 data 结尾""" + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("ASTRBOT_ROOT", None) + result = get_data_path() + assert os.path.basename(result) == "data" + + +class TestGetTempPath: + """get_temp_path 测试""" + + def test_env_var_priority(self, tmp_path): + """环境变量优先""" + with patch.dict(os.environ, {"ASTRBOT_ROOT": str(tmp_path)}): + result = get_temp_path() + assert result == os.path.join(str(tmp_path), "data", "temp") + + +class TestLoadAuthToken: + """load_auth_token 测试""" + + def test_token_found(self, tmp_path): + """正确读取 token""" + token_file = tmp_path / "data" / ".cli_token" + token_file.parent.mkdir(parents=True) + token_file.write_text("test_token_123") + + with patch( + "astrbot.cli.client.connection.get_data_path", + return_value=str(tmp_path / "data"), + ): + assert load_auth_token() == "test_token_123" + + def test_token_not_found(self, tmp_path): + """token 文件不存在返回空字符串""" + with patch( + "astrbot.cli.client.connection.get_data_path", + return_value=str(tmp_path / "nonexistent"), + ): + assert load_auth_token() == "" + + def test_token_strip_whitespace(self, tmp_path): + """去除 token 两端空白""" + token_file = tmp_path / "data" / ".cli_token" + token_file.parent.mkdir(parents=True) + token_file.write_text(" token_with_spaces \n") + + with patch( + "astrbot.cli.client.connection.get_data_path", + return_value=str(tmp_path / "data"), + ): + assert load_auth_token() == "token_with_spaces" + + +class TestLoadConnectionInfo: + """load_connection_info 测试""" + + def test_unix_connection(self, tmp_path): + """读取 Unix Socket 配置""" + conn_file = tmp_path / ".cli_connection" + conn_file.write_text(json.dumps({"type": "unix", "path": "/tmp/test.sock"})) + + result = load_connection_info(str(tmp_path)) + assert result == {"type": "unix", "path": "/tmp/test.sock"} + + def test_tcp_connection(self, tmp_path): + """读取 TCP 配置""" + conn_file = tmp_path / ".cli_connection" + conn_file.write_text( + json.dumps({"type": "tcp", "host": "127.0.0.1", "port": 12345}) + ) + + result = load_connection_info(str(tmp_path)) + assert result == {"type": "tcp", "host": "127.0.0.1", "port": 12345} + + def test_file_not_found(self, tmp_path): + """文件不存在返回 None""" + result = load_connection_info(str(tmp_path)) + assert result is None + + def test_invalid_json(self, tmp_path): + """无效 JSON 返回 None""" + conn_file = tmp_path / ".cli_connection" + conn_file.write_text("not json") + + result = load_connection_info(str(tmp_path)) + assert result is None + + +class TestConnectToServer: + """connect_to_server 测试""" + + def test_invalid_socket_type(self): + """无效连接类型抛出 ValueError""" + with pytest.raises(ValueError, match="Invalid socket type"): + connect_to_server({"type": "invalid"}) + + def test_unix_missing_path(self): + """Unix Socket 缺少路径抛出 ValueError""" + with pytest.raises(ValueError, match="path is missing"): + connect_to_server({"type": "unix"}) + + def test_tcp_missing_host(self): + """TCP 缺少 host 抛出 ValueError""" + with pytest.raises(ValueError, match="host or port is missing"): + connect_to_server({"type": "tcp", "host": "", "port": 1234}) + + def test_tcp_missing_port(self): + """TCP 缺少 port 抛出 ValueError""" + with pytest.raises(ValueError, match="host or port is missing"): + connect_to_server({"type": "tcp", "host": "127.0.0.1"}) + + +class TestReceiveJsonResponse: + """_receive_json_response 测试""" + + def test_single_chunk(self): + """单次接收完整 JSON""" + mock_socket = MagicMock() + data = json.dumps({"status": "success", "response": "hello"}).encode("utf-8") + mock_socket.recv.side_effect = [data, b""] + + result = _receive_json_response(mock_socket) + assert result["status"] == "success" + assert result["response"] == "hello" + + def test_multi_chunk(self): + """分多次接收 JSON""" + mock_socket = MagicMock() + data = json.dumps({"status": "success", "response": "hello"}).encode("utf-8") + # 分两块发送 + mid = len(data) // 2 + mock_socket.recv.side_effect = [data[:mid], data[mid:], b""] + + result = _receive_json_response(mock_socket) + assert result["status"] == "success" + + +class TestSendMessage: + """send_message 测试""" + + @patch("astrbot.cli.client.connection._get_connected_socket") + @patch("astrbot.cli.client.connection.load_auth_token", return_value="") + def test_success(self, mock_token, mock_socket): + """成功发送消息""" + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + response_data = json.dumps({"status": "success", "response": "hi"}).encode( + "utf-8" + ) + mock_sock.recv.side_effect = [response_data, b""] + + result = send_message("hello") + assert result["status"] == "success" + assert result["response"] == "hi" + mock_sock.close.assert_called_once() + + @patch("astrbot.cli.client.connection._get_connected_socket") + @patch("astrbot.cli.client.connection.load_auth_token", return_value="token123") + def test_with_auth_token(self, mock_token, mock_socket): + """带 token 发送""" + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + response_data = json.dumps({"status": "success"}).encode("utf-8") + mock_sock.recv.side_effect = [response_data, b""] + + send_message("hello") + + # 验证 sendall 包含 auth_token + sent_data = mock_sock.sendall.call_args[0][0] + sent_json = json.loads(sent_data) + assert sent_json["auth_token"] == "token123" + + @patch( + "astrbot.cli.client.connection._get_connected_socket", + side_effect=ConnectionError("refused"), + ) + @patch("astrbot.cli.client.connection.load_auth_token", return_value="") + def test_connection_error(self, mock_token, mock_socket): + """连接失败返回错误""" + result = send_message("hello") + assert result["status"] == "error" + assert "refused" in result["error"] + + @patch("astrbot.cli.client.connection._get_connected_socket") + @patch("astrbot.cli.client.connection.load_auth_token", return_value="") + def test_timeout(self, mock_token, mock_socket): + """超时返回错误""" + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + mock_sock.sendall.side_effect = TimeoutError() + + result = send_message("hello") + assert result["status"] == "error" + assert "timeout" in result["error"].lower() + mock_sock.close.assert_called_once() + + +class TestGetLogs: + """get_logs 测试""" + + @patch("astrbot.cli.client.connection._get_connected_socket") + @patch("astrbot.cli.client.connection.load_auth_token", return_value="") + def test_success(self, mock_token, mock_socket): + """成功获取日志""" + mock_sock = MagicMock() + mock_socket.return_value = mock_sock + response_data = json.dumps( + {"status": "success", "response": "log lines..."} + ).encode("utf-8") + mock_sock.recv.side_effect = [response_data, b""] + + result = get_logs(lines=50, level="ERROR") + assert result["status"] == "success" + + # 验证请求参数 + sent_data = mock_sock.sendall.call_args[0][0] + sent_json = json.loads(sent_data) + assert sent_json["action"] == "get_logs" + assert sent_json["lines"] == 50 + assert sent_json["level"] == "ERROR" + + @patch( + "astrbot.cli.client.connection._get_connected_socket", + side_effect=ConnectionError("no server"), + ) + @patch("astrbot.cli.client.connection.load_auth_token", return_value="") + def test_connection_error(self, mock_token, mock_socket): + """连接失败返回错误""" + result = get_logs() + assert result["status"] == "error" diff --git a/tests/test_cli/test_client_e2e.py b/tests/test_cli/test_client_e2e.py new file mode 100644 index 0000000000..2a2db8d989 --- /dev/null +++ b/tests/test_cli/test_client_e2e.py @@ -0,0 +1,414 @@ +"""CLI Client 端到端测试 + +不使用 mock,直接通过真实 socket 连接到运行中的 AstrBot 服务端。 +测试前提:AstrBot 已启动并开启 CLI 平台适配器(socket 模式)。 + +设计原则: + - 零重复:每个命令在整个文件中只发送一次 + - 行为测试与命令测试分离:行为测试(稳定性/并发)不重复验证命令内容 + - 纯函数用 fake 数据,不调服务端 + +运行方式(已从默认 pytest 排除,需手动指定): + pytest tests/test_cli/test_client_e2e.py -v --override-ini="addopts=" +""" + +import os +import time + +import pytest + +from astrbot.cli.client.connection import ( + call_tool, + get_data_path, + get_logs, + get_session_history, + list_session_conversations, + list_sessions, + list_tools, + load_auth_token, + load_connection_info, + send_message, +) +from astrbot.cli.client.output import format_response + + +def _server_reachable() -> bool: + """检查服务端是否可达(整个文件唯一的探测调用)""" + try: + resp = send_message("/help", timeout=10.0) + return resp.get("status") == "success" + except Exception: + return False + + +pytestmark = [ + pytest.mark.skipif( + not _server_reachable(), + reason="AstrBot 服务端未运行,跳过端到端测试", + ), + pytest.mark.e2e, +] + + +# ============================================================ +# 第一层:连接基础设施(纯本地检查,0 次服务端调用) +# ============================================================ + + +class TestConnectionInfra: + """验证本地配置文件:数据目录、连接信息、Token""" + + def test_data_path_exists(self): + data_dir = get_data_path() + assert os.path.isdir(data_dir), f"数据目录不存在: {data_dir}" + + def test_connection_info_valid(self): + data_dir = get_data_path() + info = load_connection_info(data_dir) + assert info is not None, ".cli_connection 不存在" + assert info["type"] in ("unix", "tcp"), f"未知连接类型: {info['type']}" + if info["type"] == "tcp": + assert "host" in info and isinstance(info["port"], int) + elif info["type"] == "unix": + assert "path" in info + + def test_auth_token_configured(self): + token = load_auth_token() + assert token and len(token) > 8, "Token 未配置或过短" + + +# ============================================================ +# 第二层:命令管道(每个命令只调用一次) +# +# 服务端调用:/help, /sid ×2, /model, /key, /plugin ls, +# /plugin help builtin_commands = 8 次 +# ============================================================ + + +class TestCommandPipeline: + """每个内置命令只测一次,一次性验证结构+内容+认证""" + + def test_help(self): + """/help — 响应结构、内容、延迟、认证(最全面的单命令测试)""" + start = time.time() + resp = send_message("/help") + elapsed = time.time() - start + + assert resp["status"] == "success" + assert "response" in resp and "images" in resp and "request_id" in resp + assert isinstance(resp["images"], list) + assert len(resp["request_id"]) > 0 + assert "/help" in resp["response"] + assert elapsed < 10.0, f"延迟过大: {elapsed:.2f}s" + assert resp.get("error_code") != "AUTH_FAILED" + + def test_sid_and_consistency(self): + """/sid — 内容正确 + 两次调用返回相同结果(会话一致性)""" + resp1 = send_message("/sid") + resp2 = send_message("/sid") + assert resp1["status"] == "success" and resp2["status"] == "success" + text = resp1["response"] + assert "cli_session" in text or "cli_user" in text or "UMO" in text + assert resp1["response"] == resp2["response"], "会话信息不一致" + + def test_model(self): + """/model — 模型列表""" + resp = send_message("/model") + assert resp["status"] == "success" + assert resp["response"] or resp["images"] + + def test_key(self): + """/key — Key 信息""" + resp = send_message("/key") + assert resp["status"] == "success" + text = resp["response"] + assert "Key" in text or "key" in text.lower() or "当前" in text + + def test_plugin_ls(self): + """/plugin ls — 插件列表""" + resp = send_message("/plugin ls") + assert resp["status"] == "success" + assert "插件" in resp["response"] or "plugin" in resp["response"].lower() + + def test_plugin_help(self): + """/plugin help — 指定插件帮助""" + resp = send_message("/plugin help builtin_commands") + assert resp["status"] == "success" + text = resp["response"] + assert "指令" in text or "帮助" in text or "help" in text.lower() + + +# ============================================================ +# 第三层:会话管理生命周期 +# +# 服务端调用:/new, /rename, /ls, /reset, /history, /del, +# /new, /switch 1, /del = 9 次 +# ============================================================ + + +class TestSessionManagement: + """会话操作:创建、重命名、列表、重置、历史、删除、切换""" + + def test_full_lifecycle(self): + """new → rename → ls → reset → history → del""" + resp = send_message("/new") + assert resp["status"] == "success" + + resp = send_message("/rename e2e_lifecycle_test") + assert resp["status"] == "success" + + resp = send_message("/ls") + assert resp["status"] == "success" + assert "e2e_lifecycle_test" in resp["response"] + + resp = send_message("/reset") + assert resp["status"] == "success" + + resp = send_message("/history") + assert resp["status"] == "success" + + resp = send_message("/del") + assert resp["status"] == "success" + + def test_switch(self): + """new → switch → del""" + send_message("/new") + resp = send_message("/switch 1") + assert resp["status"] == "success" + assert "切换" in resp["response"] + send_message("/del") + + +# ============================================================ +# 第四层:日志子系统 +# +# 服务端调用:get_logs ×3 = 3 次 +# ============================================================ + + +class TestLogSubsystem: + """日志获取、级别过滤、模式过滤""" + + def test_get_logs(self): + resp = get_logs(lines=10) + assert resp["status"] == "success" + assert "response" in resp + + def test_level_filter(self): + resp = get_logs(lines=50, level="INFO") + assert resp["status"] == "success" + text = resp.get("response", "") + for line in text.strip().split("\n"): + if line.strip(): + assert "[INFO]" in line, f"非 INFO 日志: {line}" + + def test_pattern_filter(self): + resp = get_logs(lines=50, pattern="CLI") + assert resp["status"] == "success" + text = resp.get("response", "") + for line in text.strip().split("\n"): + if line.strip(): + assert "CLI" in line or "cli" in line + + +# ============================================================ +# 第4.5层:函数工具管理(通过 socket action 协议,2 次服务端调用) +# ============================================================ + + +class TestFunctionTools: + """测试 list_tools 和 call_tool socket action""" + + def test_list_tools(self): + """列出所有注册的函数工具""" + resp = list_tools() + assert resp["status"] == "success" + # tools 可能在 tools 字段或 response 字段 + tools = resp.get("tools", []) + if not tools: + import json + + raw = resp.get("response", "") + if raw: + tools = json.loads(raw) + # 应该是列表类型 + assert isinstance(tools, list) + # 每个工具应有 name 字段 + for t in tools: + assert "name" in t + + def test_call_nonexistent_tool(self): + """调用不存在的工具应返回错误""" + resp = call_tool("__nonexistent_tool_xyz__") + assert resp["status"] == "error" + assert "未找到" in resp.get("error", "") + + +# ============================================================ +# 第5层:跨会话浏览(通过 socket action 协议) +# +# 服务端调用:list_sessions ×2, list_session_conversations ×1, +# get_session_history ×2 = 5 次 +# ============================================================ + + +class TestCrossSessionBrowse: + """跨会话浏览功能端到端测试""" + + def test_list_sessions(self): + """列出所有会话 — 响应结构完整""" + resp = list_sessions() + assert resp["status"] == "success" + assert "sessions" in resp + assert isinstance(resp["sessions"], list) + assert "total" in resp + assert isinstance(resp["total"], int) + assert "total_pages" in resp + # 至少应有 CLI 管理员会话 + if resp["total"] > 0: + s = resp["sessions"][0] + assert "session_id" in s + assert "conversation_id" in s + + def test_list_sessions_pagination(self): + """会话列表分页参数正确回传""" + resp = list_sessions(page=1, page_size=5) + assert resp["status"] == "success" + assert resp["page"] == 1 + assert resp["page_size"] == 5 + assert "total_pages" in resp + + def test_list_sessions_platform_filter(self): + """按平台过滤 — 不崩溃,结果与总数一致""" + resp = list_sessions(platform="cli") + assert resp["status"] == "success" + for s in resp["sessions"]: + assert s["session_id"].startswith("cli:") + + def test_list_sessions_search(self): + """搜索过滤 — 不崩溃""" + resp = list_sessions(search_query="cli_session") + assert resp["status"] == "success" + + def test_list_session_conversations(self): + """列出 CLI 管理员会话的对话列表""" + resp = list_session_conversations("cli:FriendMessage:cli_session") + assert resp["status"] == "success" + assert "conversations" in resp + assert isinstance(resp["conversations"], list) + assert "current_cid" in resp + # 每个对话应有 cid/title/is_current + for c in resp["conversations"]: + assert "cid" in c + assert "title" in c + assert "is_current" in c + + def test_list_session_conversations_empty(self): + """不存在的会话 — 返回空对话列表,不报错""" + resp = list_session_conversations("nonexistent:FriendMessage:no_one") + assert resp["status"] == "success" + assert resp["conversations"] == [] + + def test_get_session_history_admin(self): + """获取管理员 CLI 会话的聊天记录 — 新格式验证""" + resp = get_session_history("cli:FriendMessage:cli_session") + assert resp["status"] == "success" + assert "history" in resp + assert isinstance(resp["history"], list) + assert "total_pages" in resp + assert "total" in resp + # 验证消息格式(每条是 dict 含 role/text) + for msg in resp["history"]: + assert isinstance(msg, dict) + assert "role" in msg + assert msg["role"] in ("user", "assistant") + assert "text" in msg + + def test_get_session_history_pagination(self): + """聊天记录分页""" + resp = get_session_history("cli:FriendMessage:cli_session", page=1, page_size=2) + assert resp["status"] == "success" + assert resp["page"] == 1 + assert len(resp["history"]) <= 2 + + def test_get_session_history_nonexistent(self): + """获取不存在会话的聊天记录(应返回空记录)""" + resp = get_session_history("nonexistent:FriendMessage:no_session") + assert resp["status"] == "success" + assert resp["history"] == [] + + +# ============================================================ + + +class TestClientOutput: + """format_response 纯函数测试,使用 fake 数据""" + + def test_format_success(self): + fake_ok = { + "status": "success", + "response": "测试内容", + "images": [], + "request_id": "fake-id", + } + assert len(format_response(fake_ok)) > 0 + + def test_format_with_images(self): + fake_img = { + "status": "success", + "response": "", + "images": ["base64data"], + "request_id": "fake-id", + } + formatted = format_response(fake_img) + assert "图片" in formatted or len(formatted) > 0 + + def test_format_error(self): + fake_err = {"status": "error", "error": "test"} + assert format_response(fake_err) == "" + + +# ============================================================ +# 第六层:边界条件(每个 case 唯一,3 次服务端调用) +# ============================================================ + + +class TestEdgeCases: + """缺少参数、无效参数、不存在的命令""" + + def test_empty_command_args(self): + resp = send_message("/switch") + assert resp["status"] == "success" + + def test_invalid_switch_index(self): + resp = send_message("/switch 99999") + assert resp["status"] == "success" + + def test_unknown_slash_command(self): + resp = send_message("/nonexistent_cmd_xyz_123") + assert resp["status"] == "success" + + +# ============================================================ +# 第七层:健壮性(测试行为而非命令内容) +# +# 用 /ls 做轻量探测(不与上方命令测试重复验证内容) +# 服务端调用:/ls ×9 = 9 次 +# ============================================================ + + +class TestRobustness: + """并发稳定性和响应隔离(只验证机制,不验证命令内容)""" + + def test_rapid_fire_no_mixing(self): + """4次快速请求,request_id 全部唯一""" + responses = [send_message("/ls") for _ in range(4)] + for r in responses: + assert r["status"] == "success" + ids = [r["request_id"] for r in responses] + assert len(set(ids)) == len(ids), "request_id 不唯一" + + def test_stability(self): + """5次请求至少4次成功""" + success = sum(1 for _ in range(5) if send_message("/ls")["status"] == "success") + assert success >= 4, f"稳定性不足: {success}/5" diff --git a/tests/test_cli/test_e2e.py b/tests/test_cli/test_e2e.py new file mode 100644 index 0000000000..32a74c095c --- /dev/null +++ b/tests/test_cli/test_e2e.py @@ -0,0 +1,293 @@ +"""CLI端到端测试 - 验证完整消息处理流程""" + +import asyncio +import json +from unittest.mock import MagicMock, patch + +import pytest + + +class TestCLIEndToEnd: + """CLI端到端测试类""" + + @pytest.fixture + def mock_context(self): + """创建模拟的上下文""" + ctx = MagicMock() + ctx.register_platform = MagicMock() + return ctx + + @pytest.fixture + def mock_config(self): + """创建模拟的配置""" + return { + "id": "cli_test", + "enable": True, + "mode": "socket", + "socket_type": "tcp", + "tcp_port": 0, + "session_ttl": 30, + "use_isolated_sessions": False, + } + + @pytest.mark.asyncio + async def test_message_converter_to_event_flow(self): + """测试消息转换到事件的完整流程""" + from astrbot.core.platform.platform_metadata import PlatformMetadata + from astrbot.core.platform.sources.cli.cli_event import ( + CLIMessageEvent, + MessageConverter, + ) + + # 1. 创建消息转换器 + converter = MessageConverter() + + # 2. 转换输入消息 + message = converter.convert("Hello, AstrBot!") + + # 3. 验证消息结构 + assert message.message_str == "Hello, AstrBot!" + assert message.session_id == "cli_session" + assert message.sender.user_id == "cli_user" + + # 4. 创建事件 + platform_meta = PlatformMetadata( + name="cli", description="CLI Platform", id="cli_test" + ) + output_queue = asyncio.Queue() + + event = CLIMessageEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=platform_meta, + session_id=message.session_id, + output_queue=output_queue, + ) + + # 5. 验证事件属性 + assert event.message_str == "Hello, AstrBot!" + assert event.session_id == "cli_session" + + @pytest.mark.asyncio + async def test_response_builder_with_message_chain(self): + """测试响应构建器处理消息链""" + from astrbot.core.message.components import Image, Plain + from astrbot.core.message.message_event_result import MessageChain + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + # 1. 创建消息链 + chain = MessageChain() + chain.chain = [ + Plain("Hello!"), + Image(file="https://example.com/image.png"), + ] + chain.get_plain_text = MagicMock(return_value="Hello!") + + # 2. 构建响应 + response = ResponseBuilder.build_success(chain, "req123") + result = json.loads(response) + + # 3. 验证响应结构 + assert result["status"] == "success" + assert result["response"] == "Hello!" + assert result["request_id"] == "req123" + assert len(result["images"]) == 1 + assert result["images"][0]["type"] == "url" + assert result["images"][0]["url"] == "https://example.com/image.png" + + @pytest.mark.asyncio + async def test_session_lifecycle(self): + """测试会话生命周期""" + from astrbot.core.platform.sources.cli.cli_adapter import ( + SessionManager, + ) + + # 1. 创建会话管理器(启用) + manager = SessionManager(ttl=30, enabled=True) + + # 2. 注册会话 + manager.register("session_1") + manager.register("session_2") + + # 3. 验证会话存在(通过检查是否过期) + assert manager.is_expired("session_1") is False + assert manager.is_expired("session_2") is False + + # 4. 验证未注册的会话被视为过期 + assert manager.is_expired("nonexistent") is True + + @pytest.mark.asyncio + async def test_token_validation_flow(self): + """测试Token验证流程""" + import tempfile + + from astrbot.core.platform.sources.cli.cli_adapter import TokenManager + + # 使用临时目录避免影响真实token文件 + with tempfile.TemporaryDirectory() as tmpdir: + with patch( + "astrbot.core.platform.sources.cli.cli_adapter.get_astrbot_data_path" + ) as mock_path: + mock_path.return_value = tmpdir + + # 1. 创建Token管理器(会自动生成token) + manager = TokenManager() + token = manager.token + + # 2. 验证token已生成 + assert token is not None + assert len(token) > 0 + + # 3. 验证正确Token + assert manager.validate(token) is True + + # 4. 验证错误Token + assert manager.validate("wrong_token") is False + + # 5. 验证空Token + assert manager.validate("") is False + + @pytest.mark.asyncio + async def test_cli_event_send_to_queue(self): + """测试CLI事件发送到队列""" + from astrbot.core.message.components import Plain + from astrbot.core.message.message_event_result import MessageChain + from astrbot.core.platform.platform_metadata import PlatformMetadata + from astrbot.core.platform.sources.cli.cli_event import ( + CLIMessageEvent, + MessageConverter, + ) + + # 1. 使用MessageConverter创建真实的消息对象 + converter = MessageConverter() + message_obj = converter.convert("Test") + + platform_meta = PlatformMetadata( + name="cli", description="CLI Platform", id="cli_test" + ) + output_queue = asyncio.Queue() + + event = CLIMessageEvent( + message_str="Test", + message_obj=message_obj, + platform_meta=platform_meta, + session_id="test_session", + output_queue=output_queue, + ) + + # 2. 创建响应消息链 + response_chain = MessageChain() + response_chain.chain = [Plain("Response")] + + # 3. 发送响应(无response_future时直接放入队列) + result = await event.send(response_chain) + + # 4. 验证结果 + assert result["success"] is True + + # 5. 验证队列中有消息 + queued_message = await output_queue.get() + assert queued_message == response_chain + + @pytest.mark.asyncio + async def test_image_processor_pipeline(self): + """测试图片处理管道""" + import base64 + import os + import tempfile + + from astrbot.core.message.components import Image, Plain + from astrbot.core.platform.sources.cli.cli_event import ( + ImageProcessor, + preprocess_chain, + ) + + # ImageExtractor和ChainPreprocessor已合并到ImageProcessor和preprocess_chain + + # 1. 创建临时图片文件 + with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f: + f.write(b"fake image data") + temp_path = f.name + + try: + # 2. 测试本地文件转base64 + base64_data = ImageProcessor.local_file_to_base64(temp_path) + assert base64_data == base64.b64encode(b"fake image data").decode("utf-8") + + # 3. 创建混合消息链 + chain = MagicMock() + chain.chain = [ + Plain("Hello"), + Image(file="https://example.com/remote.png"), + Image(file=f"file:///{temp_path}"), + ] + + # 4. 提取图片信息 + images = ImageProcessor.extract_images(chain) + assert len(images) == 2 + assert images[0].type == "url" + assert images[0].url == "https://example.com/remote.png" + + # 5. 预处理消息链(本地文件转base64) + local_image = Image(file=f"file:///{temp_path}") + preprocess_mock = MagicMock() + preprocess_mock.chain = [local_image] + + preprocess_chain(preprocess_mock) + + # 验证本地文件已转换为base64 + assert local_image.file.startswith("base64://") + + finally: + os.unlink(temp_path) + + @pytest.mark.asyncio + async def test_error_response_building(self): + """测试错误响应构建""" + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + # 1. 基本错误 + response = ResponseBuilder.build_error("Something went wrong") + result = json.loads(response) + assert result["status"] == "error" + assert result["error"] == "Something went wrong" + + # 2. 带request_id的错误 + response = ResponseBuilder.build_error("Auth failed", request_id="req123") + result = json.loads(response) + assert result["request_id"] == "req123" + + # 3. 带错误代码的错误 + response = ResponseBuilder.build_error( + "Unauthorized", request_id="req123", error_code="AUTH_FAILED" + ) + result = json.loads(response) + assert result["error_code"] == "AUTH_FAILED" + + @pytest.mark.asyncio + async def test_isolated_session_creation(self): + """测试隔离会话创建""" + from astrbot.core.platform.sources.cli.cli_event import MessageConverter + + converter = MessageConverter() + + # 1. 不启用隔离 + msg1 = converter.convert("Test", request_id="req1", use_isolated_session=False) + assert msg1.session_id == "cli_session" + + # 2. 启用隔离 + msg2 = converter.convert("Test", request_id="req2", use_isolated_session=True) + assert msg2.session_id == "cli_session_req2" + + # 3. 启用隔离但无request_id + msg3 = converter.convert("Test", request_id=None, use_isolated_session=True) + assert msg3.session_id == "cli_session" + + # 4. 不同request_id产生不同session + msg4 = converter.convert("Test", request_id="req3", use_isolated_session=True) + assert msg4.session_id == "cli_session_req3" + assert msg2.session_id != msg4.session_id diff --git a/tests/test_cli/test_image_processor.py b/tests/test_cli/test_image_processor.py new file mode 100644 index 0000000000..e8785fbc60 --- /dev/null +++ b/tests/test_cli/test_image_processor.py @@ -0,0 +1,217 @@ +"""ImageProcessor 单元测试""" + +import base64 +import os +import tempfile +from unittest.mock import MagicMock, patch + + +class TestImageProcessorBase64: + """base64 编解码测试""" + + def test_encode(self): + """测试 base64 编码""" + data = b"Hello, World!" + encoded = base64.b64encode(data).decode("utf-8") + assert encoded == base64.b64encode(data).decode("utf-8") + + def test_decode(self): + """测试 base64 解码""" + original = b"Hello, World!" + encoded = base64.b64encode(original).decode("utf-8") + decoded = base64.b64decode(encoded) + assert decoded == original + + +class TestImageProcessorFileIO: + """文件读写测试""" + + def test_read_existing_file(self): + """测试读取存在的文件""" + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"test content") + temp_path = f.name + + try: + with open(temp_path, "rb") as f: + data = f.read() + assert data == b"test content" + finally: + os.unlink(temp_path) + + def test_read_nonexistent_file(self): + """测试读取不存在的文件""" + from astrbot.core.platform.sources.cli.cli_event import ImageProcessor + + result = ImageProcessor.local_file_to_base64("/nonexistent/path/file.png") + assert result is None + + def test_base64_to_temp_file(self): + """测试 base64 写入临时文件""" + from astrbot.core.platform.sources.cli.cli_event import ImageProcessor + + with patch( + "astrbot.core.platform.sources.cli.cli_event.get_astrbot_temp_path" + ) as mock_temp: + mock_temp.return_value = tempfile.gettempdir() + + data = b"test image data" + base64_data = base64.b64encode(data).decode("utf-8") + temp_path = ImageProcessor.base64_to_temp_file(base64_data) + + assert temp_path is not None + assert os.path.exists(temp_path) + + with open(temp_path, "rb") as f: + assert f.read() == data + + os.unlink(temp_path) + + +class TestImageInfo: + """ImageInfo 测试类""" + + def test_to_dict_url(self): + """测试 URL 类型转字典""" + from astrbot.core.platform.sources.cli.cli_event import ImageInfo + + info = ImageInfo(type="url", url="https://example.com/image.png") + result = info.to_dict() + + assert result["type"] == "url" + assert result["url"] == "https://example.com/image.png" + + def test_to_dict_file(self): + """测试文件类型转字典""" + from astrbot.core.platform.sources.cli.cli_event import ImageInfo + + info = ImageInfo(type="file", path="/path/to/image.png", size=1024) + result = info.to_dict() + + assert result["type"] == "file" + assert result["path"] == "/path/to/image.png" + assert result["size"] == 1024 + + def test_to_dict_with_error(self): + """测试带错误信息转字典""" + from astrbot.core.platform.sources.cli.cli_event import ImageInfo + + info = ImageInfo(type="file", error="Failed to read") + result = info.to_dict() + + assert result["error"] == "Failed to read" + + +class TestImageExtractor: + """ImageExtractor 测试类""" + + def test_extract_url_image(self): + """测试提取 URL 图片""" + from astrbot.core.message.components import Image + from astrbot.core.platform.sources.cli.cli_event import ImageProcessor + + chain = MagicMock() + chain.chain = [Image(file="https://example.com/image.png")] + + images = ImageProcessor.extract_images(chain) + + assert len(images) == 1 + assert images[0].type == "url" + assert images[0].url == "https://example.com/image.png" + + def test_extract_empty_chain(self): + """测试提取空消息链""" + from astrbot.core.platform.sources.cli.cli_event import ImageProcessor + + chain = MagicMock() + chain.chain = [] + + images = ImageProcessor.extract_images(chain) + assert len(images) == 0 + + def test_extract_mixed_components(self): + """测试提取混合组件""" + from astrbot.core.message.components import Image, Plain + from astrbot.core.platform.sources.cli.cli_event import ImageProcessor + + chain = MagicMock() + chain.chain = [ + Plain("Hello"), + Image(file="https://example.com/1.png"), + Plain("World"), + Image(file="https://example.com/2.png"), + ] + + images = ImageProcessor.extract_images(chain) + + assert len(images) == 2 + assert images[0].url == "https://example.com/1.png" + assert images[1].url == "https://example.com/2.png" + + +class TestChainPreprocessor: + """ChainPreprocessor 测试类""" + + def test_preprocess_local_file(self): + """测试预处理本地文件图片""" + from astrbot.core.message.components import Image + from astrbot.core.platform.sources.cli.cli_event import preprocess_chain + + # 创建临时图片文件 + with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f: + f.write(b"fake image data") + temp_path = f.name + + try: + chain = MagicMock() + image = Image(file=f"file:///{temp_path}") + chain.chain = [image] + + preprocess_chain(chain) + + # 验证已转换为 base64 + assert image.file.startswith("base64://") + base64_data = image.file[9:] + decoded = base64.b64decode(base64_data) + assert decoded == b"fake image data" + finally: + os.unlink(temp_path) + + def test_preprocess_url_unchanged(self): + """测试 URL 图片不变""" + from astrbot.core.message.components import Image + from astrbot.core.platform.sources.cli.cli_event import preprocess_chain + + chain = MagicMock() + image = Image(file="https://example.com/image.png") + chain.chain = [image] + + preprocess_chain(chain) + + # URL 应保持不变 + assert image.file == "https://example.com/image.png" + + +class TestImageProcessor: + """ImageProcessor 门面测试类""" + + def test_local_file_to_base64(self): + """测试本地文件转 base64""" + from astrbot.core.platform.sources.cli.cli_event import ImageProcessor + + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(b"test data") + temp_path = f.name + + try: + result = ImageProcessor.local_file_to_base64(temp_path) + assert result == base64.b64encode(b"test data").decode("utf-8") + finally: + os.unlink(temp_path) + + def test_local_file_to_base64_nonexistent(self): + """测试不存在的文件""" + from astrbot.core.platform.sources.cli.cli_event import ImageProcessor + + result = ImageProcessor.local_file_to_base64("/nonexistent/file.png") + assert result is None diff --git a/tests/test_cli/test_message_converter.py b/tests/test_cli/test_message_converter.py new file mode 100644 index 0000000000..98ee0ca6d8 --- /dev/null +++ b/tests/test_cli/test_message_converter.py @@ -0,0 +1,91 @@ +"""MessageConverter 单元测试""" + + +import pytest + + +class TestMessageConverter: + """MessageConverter 测试类""" + + @pytest.fixture + def converter(self): + """创建 MessageConverter 实例""" + from astrbot.core.platform.sources.cli.cli_event import MessageConverter + + return MessageConverter() + + def test_convert_basic_text(self, converter): + """测试基本文本转换""" + message = converter.convert("Hello, World!") + + assert message.message_str == "Hello, World!" + assert message.self_id == "cli_bot" + assert message.session_id == "cli_session" + assert message.sender.user_id == "cli_user" + assert message.sender.nickname == "CLI User" + + def test_convert_with_request_id_no_isolation(self, converter): + """测试带 request_id 但不启用隔离""" + message = converter.convert( + "Test", request_id="req123", use_isolated_session=False + ) + + # 不启用隔离时,使用默认 session_id + assert message.session_id == "cli_session" + + def test_convert_with_isolated_session(self, converter): + """测试启用会话隔离""" + message = converter.convert( + "Test", request_id="req123", use_isolated_session=True + ) + + # 启用隔离时,session_id 包含 request_id + assert message.session_id == "cli_session_req123" + + def test_convert_isolated_without_request_id(self, converter): + """测试启用隔离但无 request_id""" + message = converter.convert("Test", request_id=None, use_isolated_session=True) + + # 无 request_id 时,使用默认 session_id + assert message.session_id == "cli_session" + + def test_convert_message_has_id(self, converter): + """测试消息有唯一 ID""" + message1 = converter.convert("Test1") + message2 = converter.convert("Test2") + + assert message1.message_id is not None + assert message2.message_id is not None + assert message1.message_id != message2.message_id + + def test_convert_message_has_plain_component(self, converter): + """测试消息包含 Plain 组件""" + from astrbot.core.message.components import Plain + + message = converter.convert("Hello") + + assert len(message.message) == 1 + assert isinstance(message.message[0], Plain) + assert message.message[0].text == "Hello" + + def test_custom_default_session_id(self): + """测试自定义默认 session_id""" + from astrbot.core.platform.sources.cli.cli_event import MessageConverter + + converter = MessageConverter(default_session_id="custom_session") + message = converter.convert("Test") + + assert message.session_id == "custom_session" + + def test_custom_user_info(self): + """测试自定义用户信息""" + from astrbot.core.platform.sources.cli.cli_event import MessageConverter + + converter = MessageConverter( + user_id="custom_user", + user_nickname="Custom User", + ) + message = converter.convert("Test") + + assert message.sender.user_id == "custom_user" + assert message.sender.nickname == "Custom User" diff --git a/tests/test_cli/test_response_builder.py b/tests/test_cli/test_response_builder.py new file mode 100644 index 0000000000..90c3f4b3be --- /dev/null +++ b/tests/test_cli/test_response_builder.py @@ -0,0 +1,115 @@ +"""ResponseBuilder 单元测试""" + +import json +from unittest.mock import MagicMock + +import pytest + + +class TestResponseBuilder: + """ResponseBuilder 测试类""" + + @pytest.fixture + def mock_message_chain(self): + """创建模拟的 MessageChain""" + chain = MagicMock() + chain.get_plain_text.return_value = "Hello, World!" + chain.chain = [] + return chain + + def test_build_success_basic(self, mock_message_chain): + """测试构建基本成功响应""" + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + response = ResponseBuilder.build_success(mock_message_chain, "req123") + result = json.loads(response) + + assert result["status"] == "success" + assert result["response"] == "Hello, World!" + assert result["request_id"] == "req123" + assert result["images"] == [] + + def test_build_success_with_extra(self, mock_message_chain): + """测试构建带额外字段的成功响应""" + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + extra = {"custom_field": "custom_value"} + response = ResponseBuilder.build_success(mock_message_chain, "req123", extra) + result = json.loads(response) + + assert result["custom_field"] == "custom_value" + + def test_build_error_basic(self): + """测试构建基本错误响应""" + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + response = ResponseBuilder.build_error("Something went wrong") + result = json.loads(response) + + assert result["status"] == "error" + assert result["error"] == "Something went wrong" + assert "request_id" not in result + + def test_build_error_with_request_id(self): + """测试构建带 request_id 的错误响应""" + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + response = ResponseBuilder.build_error("Error", request_id="req123") + result = json.loads(response) + + assert result["request_id"] == "req123" + + def test_build_error_with_error_code(self): + """测试构建带错误代码的错误响应""" + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + response = ResponseBuilder.build_error( + "Unauthorized", request_id="req123", error_code="AUTH_FAILED" + ) + result = json.loads(response) + + assert result["error_code"] == "AUTH_FAILED" + + def test_build_success_with_url_image(self): + """测试构建带 URL 图片的成功响应""" + from astrbot.core.message.components import Image + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + chain = MagicMock() + chain.get_plain_text.return_value = "Image response" + + # 创建 URL 图片组件 + image = Image(file="https://example.com/image.png") + chain.chain = [image] + + response = ResponseBuilder.build_success(chain, "req123") + result = json.loads(response) + + assert len(result["images"]) == 1 + assert result["images"][0]["type"] == "url" + assert result["images"][0]["url"] == "https://example.com/image.png" + + def test_build_success_chinese_text(self, mock_message_chain): + """测试构建中文文本响应""" + from astrbot.core.platform.sources.cli.cli_event import ( + ResponseBuilder, + ) + + mock_message_chain.get_plain_text.return_value = "你好,世界!" + + response = ResponseBuilder.build_success(mock_message_chain, "req123") + result = json.loads(response) + + assert result["response"] == "你好,世界!" diff --git a/tests/test_cli/test_session_actions.py b/tests/test_cli/test_session_actions.py new file mode 100644 index 0000000000..86e9fe5bb1 --- /dev/null +++ b/tests/test_cli/test_session_actions.py @@ -0,0 +1,338 @@ +"""socket_handler 跨会话浏览 action 单元测试 + +测试 SocketClientHandler 中的 _list_sessions、_list_session_conversations、 +_get_session_history 三个方法。使用 mock 替代真实的 ConversationManager。 +""" + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +def _make_handler(): + """创建最小化的 SocketClientHandler 用于测试""" + from astrbot.core.platform.sources.cli.socket_handler import ( + SocketClientHandler, + ) + + handler = SocketClientHandler( + token_manager=MagicMock(), + message_converter=MagicMock(), + session_manager=MagicMock(), + platform_meta=MagicMock(), + output_queue=MagicMock(), + event_committer=MagicMock(), + ) + return handler + + +class TestListSessions: + """_list_sessions 方法测试""" + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_list_sessions_success(self, mock_get_mgr): + handler = _make_handler() + mock_mgr = MagicMock() + mock_mgr.db.get_session_conversations = AsyncMock( + return_value=( + [ + { + "session_id": "cli:FriendMessage:cli_session", + "conversation_id": "conv-1", + "title": "测试", + "persona_id": None, + "persona_name": None, + } + ], + 1, + ) + ) + mock_get_mgr.return_value = mock_mgr + + result = await handler._list_sessions({"page": 1, "page_size": 20}, "req-1") + data = json.loads(result) + + assert data["status"] == "success" + assert len(data["sessions"]) == 1 + assert data["total"] == 1 + assert data["sessions"][0]["session_id"] == "cli:FriendMessage:cli_session" + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_list_sessions_with_platform(self, mock_get_mgr): + handler = _make_handler() + mock_mgr = MagicMock() + mock_mgr.db.get_session_conversations = AsyncMock(return_value=([], 0)) + mock_get_mgr.return_value = mock_mgr + + result = await handler._list_sessions( + {"page": 1, "page_size": 10, "platform": "qq"}, "req-2" + ) + data = json.loads(result) + + assert data["status"] == "success" + assert data["sessions"] == [] + assert data["total"] == 0 + mock_mgr.db.get_session_conversations.assert_called_once_with( + page=1, page_size=10, search_query=None, platform="qq" + ) + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_list_sessions_not_initialized(self, mock_get_mgr): + handler = _make_handler() + mock_get_mgr.return_value = None + + result = await handler._list_sessions({}, "req-3") + data = json.loads(result) + + assert data["status"] == "error" + assert "未初始化" in data["error"] + + +class TestListSessionConversations: + """_list_session_conversations 方法测试""" + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_list_convs_success(self, mock_get_mgr): + handler = _make_handler() + mock_mgr = MagicMock() + + # Mock conversation objects + mock_conv = MagicMock() + mock_conv.cid = "conv-abc" + mock_conv.title = "测试对话" + mock_conv.persona_id = None + mock_conv.created_at = 1700000000 + mock_conv.updated_at = 1700000000 + mock_conv.token_usage = 150 + + mock_mgr.get_conversations = AsyncMock(return_value=[mock_conv]) + mock_mgr.get_curr_conversation_id = AsyncMock(return_value="conv-abc") + mock_get_mgr.return_value = mock_mgr + + result = await handler._list_session_conversations( + {"session_id": "cli:FriendMessage:cli_session", "page": 1, "page_size": 20}, + "req-4", + ) + data = json.loads(result) + + assert data["status"] == "success" + assert len(data["conversations"]) == 1 + assert data["conversations"][0]["cid"] == "conv-abc" + assert data["conversations"][0]["is_current"] is True + assert data["current_cid"] == "conv-abc" + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_list_convs_missing_session_id(self, mock_get_mgr): + handler = _make_handler() + mock_get_mgr.return_value = MagicMock() + + result = await handler._list_session_conversations({"page": 1}, "req-5") + data = json.loads(result) + + assert data["status"] == "error" + assert "session_id" in data["error"] + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_list_convs_pagination(self, mock_get_mgr): + handler = _make_handler() + mock_mgr = MagicMock() + + # Create 3 mock conversations + convs = [] + for i in range(3): + c = MagicMock() + c.cid = f"conv-{i}" + c.title = f"对话 {i}" + c.persona_id = None + c.created_at = 1700000000 + c.updated_at = 1700000000 + c.token_usage = 0 + convs.append(c) + + mock_mgr.get_conversations = AsyncMock(return_value=convs) + mock_mgr.get_curr_conversation_id = AsyncMock(return_value="conv-0") + mock_get_mgr.return_value = mock_mgr + + # Page 1 with size 2 + result = await handler._list_session_conversations( + {"session_id": "test", "page": 1, "page_size": 2}, "req-6" + ) + data = json.loads(result) + + assert data["total"] == 3 + assert data["total_pages"] == 2 + assert len(data["conversations"]) == 2 + + +class TestGetSessionHistory: + """_get_session_history 方法测试""" + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_get_history_success(self, mock_get_mgr): + handler = _make_handler() + mock_mgr = MagicMock() + mock_mgr.get_curr_conversation_id = AsyncMock(return_value="conv-abc") + + # Mock conversation with raw history + mock_conv = MagicMock() + mock_conv.history = json.dumps( + [ + {"role": "user", "content": "你好"}, + {"role": "assistant", "content": "你好!"}, + ] + ) + mock_mgr.get_conversation = AsyncMock(return_value=mock_conv) + mock_get_mgr.return_value = mock_mgr + + result = await handler._get_session_history( + {"session_id": "cli:FriendMessage:cli_session", "page": 1}, + "req-7", + ) + data = json.loads(result) + + assert data["status"] == "success" + assert len(data["history"]) == 2 + assert data["history"][0]["role"] == "user" + assert data["history"][0]["text"] == "你好" + assert data["history"][1]["role"] == "assistant" + assert data["history"][1]["text"] == "你好!" + assert data["conversation_id"] == "conv-abc" + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_get_history_with_image(self, mock_get_mgr): + """图片内容应被替换为 [图片]""" + handler = _make_handler() + mock_mgr = MagicMock() + + mock_conv = MagicMock() + mock_conv.history = json.dumps( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "看这张图"}, + {"type": "image_url", "image_url": {"url": "http://..."}}, + ], + }, + {"role": "assistant", "content": "这是一只猫"}, + ] + ) + mock_mgr.get_conversation = AsyncMock(return_value=mock_conv) + mock_get_mgr.return_value = mock_mgr + + result = await handler._get_session_history( + {"session_id": "test", "conversation_id": "conv-img", "page": 1}, + "req-img", + ) + data = json.loads(result) + + assert data["status"] == "success" + assert "[图片]" in data["history"][0]["text"] + assert "看这张图" in data["history"][0]["text"] + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_get_history_with_tool_calls(self, mock_get_mgr): + """tool_calls 应被替换为 [调用工具: name]""" + handler = _make_handler() + mock_mgr = MagicMock() + + mock_conv = MagicMock() + mock_conv.history = json.dumps( + [ + {"role": "user", "content": "查天气"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + {"function": {"name": "get_weather", "arguments": "{}"}} + ], + }, + ] + ) + mock_mgr.get_conversation = AsyncMock(return_value=mock_conv) + mock_get_mgr.return_value = mock_mgr + + result = await handler._get_session_history( + {"session_id": "test", "conversation_id": "conv-tc", "page": 1}, + "req-tc", + ) + data = json.loads(result) + + assert data["status"] == "success" + assert "[调用工具: get_weather]" in data["history"][1]["text"] + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_get_history_with_conv_id(self, mock_get_mgr): + handler = _make_handler() + mock_mgr = MagicMock() + + mock_conv = MagicMock() + mock_conv.history = json.dumps([{"role": "user", "content": "test"}]) + mock_mgr.get_conversation = AsyncMock(return_value=mock_conv) + mock_get_mgr.return_value = mock_mgr + + result = await handler._get_session_history( + { + "session_id": "test_session", + "conversation_id": "conv-xyz", + "page": 1, + }, + "req-8", + ) + data = json.loads(result) + + assert data["status"] == "success" + assert data["conversation_id"] == "conv-xyz" + mock_mgr.get_conversation.assert_called_once_with("test_session", "conv-xyz") + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_get_history_no_active_conversation(self, mock_get_mgr): + handler = _make_handler() + mock_mgr = MagicMock() + mock_mgr.get_curr_conversation_id = AsyncMock(return_value=None) + mock_get_mgr.return_value = mock_mgr + + result = await handler._get_session_history( + {"session_id": "no_conv_session", "page": 1}, "req-9" + ) + data = json.loads(result) + + assert data["status"] == "success" + assert data["history"] == [] + assert "没有活跃的对话" in data["response"] + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_get_history_missing_session_id(self, mock_get_mgr): + handler = _make_handler() + mock_get_mgr.return_value = MagicMock() + + result = await handler._get_session_history({"page": 1}, "req-10") + data = json.loads(result) + + assert data["status"] == "error" + assert "session_id" in data["error"] + + @pytest.mark.asyncio + @patch("astrbot.core.conversation_mgr.get_conversation_manager") + async def test_get_history_not_initialized(self, mock_get_mgr): + handler = _make_handler() + mock_get_mgr.return_value = None + + result = await handler._get_session_history({"session_id": "test"}, "req-11") + data = json.loads(result) + + assert data["status"] == "error" + assert "未初始化" in data["error"] diff --git a/tests/test_cli/test_token_manager.py b/tests/test_cli/test_token_manager.py new file mode 100644 index 0000000000..0464adbbaa --- /dev/null +++ b/tests/test_cli/test_token_manager.py @@ -0,0 +1,106 @@ +"""TokenManager 单元测试""" + +import os +import tempfile +from unittest.mock import patch + +import pytest + + +class TestTokenManager: + """TokenManager 测试类""" + + @pytest.fixture + def temp_data_path(self): + """创建临时数据目录""" + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + @pytest.fixture + def token_manager(self, temp_data_path): + """创建 TokenManager 实例""" + with patch( + "astrbot.core.platform.sources.cli.cli_adapter.get_astrbot_data_path", + return_value=temp_data_path, + ): + from astrbot.core.platform.sources.cli.cli_adapter import TokenManager + + return TokenManager() + + def test_generate_new_token(self, token_manager, temp_data_path): + """测试首次生成 Token""" + token = token_manager.token + assert token is not None + assert len(token) > 0 + + # 验证 Token 文件已创建 + token_file = os.path.join(temp_data_path, ".cli_token") + assert os.path.exists(token_file) + + def test_load_existing_token(self, temp_data_path): + """测试加载已存在的 Token""" + # 预先写入 Token + token_file = os.path.join(temp_data_path, ".cli_token") + expected_token = "test_token_12345" + with open(token_file, "w", encoding="utf-8") as f: + f.write(expected_token) + + with patch( + "astrbot.core.platform.sources.cli.cli_adapter.get_astrbot_data_path", + return_value=temp_data_path, + ): + from astrbot.core.platform.sources.cli.cli_adapter import TokenManager + + manager = TokenManager() + assert manager.token == expected_token + + def test_validate_correct_token(self, token_manager): + """测试验证正确的 Token""" + token = token_manager.token + assert token_manager.validate(token) is True + + def test_validate_wrong_token(self, token_manager): + """测试验证错误的 Token""" + _ = token_manager.token # 确保 Token 已生成 + assert token_manager.validate("wrong_token") is False + + def test_validate_empty_token(self, token_manager): + """测试验证空 Token""" + _ = token_manager.token # 确保 Token 已生成 + assert token_manager.validate("") is False + + def test_validate_without_server_token(self, temp_data_path): + """测试服务器无 Token 时跳过验证""" + with patch( + "astrbot.core.platform.sources.cli.cli_adapter.get_astrbot_data_path", + return_value=temp_data_path, + ): + from astrbot.core.platform.sources.cli.cli_adapter import TokenManager + + manager = TokenManager() + # 模拟 _ensure_token 返回 None(Token 生成失败场景) + with patch.object(manager, "_ensure_token", return_value=None): + manager._token = None # 重置缓存 + + # 无 Token 时应跳过验证 + assert manager.validate("any_token") is True + + def test_regenerate_empty_token_file(self, temp_data_path): + """测试空 Token 文件时重新生成""" + # 创建空 Token 文件 + token_file = os.path.join(temp_data_path, ".cli_token") + with open(token_file, "w", encoding="utf-8") as f: + f.write("") + + with patch( + "astrbot.core.platform.sources.cli.cli_adapter.get_astrbot_data_path", + return_value=temp_data_path, + ): + from astrbot.core.platform.sources.cli.cli_adapter import TokenManager + + manager = TokenManager() + token = manager.token + + # 应该生成新 Token + assert token is not None + assert len(token) > 0