diff --git a/astrbot/core/platform/sources/webchat/message_parts_helper.py b/astrbot/core/platform/sources/webchat/message_parts_helper.py new file mode 100644 index 000000000..43072ec1c --- /dev/null +++ b/astrbot/core/platform/sources/webchat/message_parts_helper.py @@ -0,0 +1,465 @@ +import json +import mimetypes +import shutil +import uuid +from collections.abc import Awaitable, Callable, Sequence +from pathlib import Path +from typing import Any + +from astrbot.core.db.po import Attachment +from astrbot.core.message.components import ( + File, + Image, + Json, + Plain, + Record, + Reply, + Video, +) +from astrbot.core.message.message_event_result import MessageChain + +AttachmentGetter = Callable[[str], Awaitable[Attachment | None]] +AttachmentInserter = Callable[[str, str, str], Awaitable[Attachment | None]] +ReplyHistoryGetter = Callable[ + [Any], + Awaitable[tuple[list[dict], str | None, str | None] | None], +] + +MEDIA_PART_TYPES = {"image", "record", "file", "video"} + + +def strip_message_parts_path_fields(message_parts: list[dict]) -> list[dict]: + return [{k: v for k, v in part.items() if k != "path"} for part in message_parts] + + +def webchat_message_parts_have_content(message_parts: list[dict]) -> bool: + return any( + part.get("type") in ("plain", "image", "record", "file", "video") + and (part.get("text") or part.get("attachment_id") or part.get("filename")) + for part in message_parts + ) + + +async def parse_webchat_message_parts( + message_parts: list, + *, + strict: bool = False, + include_empty_plain: bool = False, + verify_media_path_exists: bool = True, + reply_history_getter: ReplyHistoryGetter | None = None, + current_depth: int = 0, + max_reply_depth: int = 0, + cast_reply_id_to_str: bool = True, +) -> tuple[list, list[str], bool]: + """Parse webchat message parts into components/text parts. + + Returns: + tuple[list, list[str], bool]: + (components, plain_text_parts, has_non_reply_content) + """ + components = [] + text_parts: list[str] = [] + has_content = False + + for part in message_parts: + if not isinstance(part, dict): + if strict: + raise ValueError("message part must be an object") + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "plain": + text = str(part.get("text", "")) + if text or include_empty_plain: + components.append(Plain(text=text)) + text_parts.append(text) + if text: + has_content = True + continue + + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + if strict: + raise ValueError("reply part missing message_id") + continue + + reply_chain = [] + reply_message_str = str(part.get("selected_text", "")) + sender_id = None + sender_name = None + + if reply_message_str: + reply_chain = [Plain(text=reply_message_str)] + elif ( + reply_history_getter + and current_depth < max_reply_depth + and message_id is not None + ): + reply_info = await reply_history_getter(message_id) + if reply_info: + reply_parts, sender_id, sender_name = reply_info + ( + reply_chain, + reply_text_parts, + _, + ) = await parse_webchat_message_parts( + reply_parts, + strict=strict, + include_empty_plain=include_empty_plain, + verify_media_path_exists=verify_media_path_exists, + reply_history_getter=reply_history_getter, + current_depth=current_depth + 1, + max_reply_depth=max_reply_depth, + cast_reply_id_to_str=cast_reply_id_to_str, + ) + reply_message_str = "".join(reply_text_parts) + + reply_id = str(message_id) if cast_reply_id_to_str else message_id + components.append( + Reply( + id=reply_id, + message_str=reply_message_str, + chain=reply_chain, + sender_id=sender_id, + sender_nickname=sender_name, + ) + ) + continue + + if part_type not in MEDIA_PART_TYPES: + if strict: + raise ValueError(f"unsupported message part type: {part_type}") + continue + + path = part.get("path") + if not path: + if strict: + raise ValueError(f"{part_type} part missing path") + continue + + file_path = Path(str(path)) + if verify_media_path_exists and not file_path.exists(): + if strict: + raise ValueError(f"file not found: {file_path!s}") + continue + + file_path_str = ( + str(file_path.resolve()) if verify_media_path_exists else str(file_path) + ) + has_content = True + if part_type == "image": + components.append(Image.fromFileSystem(file_path_str)) + elif part_type == "record": + components.append(Record.fromFileSystem(file_path_str)) + elif part_type == "video": + components.append(Video.fromFileSystem(file_path_str)) + else: + filename = str(part.get("filename", "")).strip() or file_path.name + components.append(File(name=filename, file=file_path_str)) + + return components, text_parts, has_content + + +async def build_webchat_message_parts( + message_payload: str | list, + *, + get_attachment_by_id: AttachmentGetter, + strict: bool = False, +) -> list[dict]: + if isinstance(message_payload, str): + text = message_payload.strip() + return [{"type": "plain", "text": text}] if text else [] + + if not isinstance(message_payload, list): + if strict: + raise ValueError("message must be a string or list") + return [] + + message_parts: list[dict] = [] + for part in message_payload: + if not isinstance(part, dict): + if strict: + raise ValueError("message part must be an object") + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "plain": + text = str(part.get("text", "")) + if text: + message_parts.append({"type": "plain", "text": text}) + continue + + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + if strict: + raise ValueError("reply part missing message_id") + continue + message_parts.append( + { + "type": "reply", + "message_id": message_id, + "selected_text": str(part.get("selected_text", "")), + } + ) + continue + + if part_type not in MEDIA_PART_TYPES: + if strict: + raise ValueError(f"unsupported message part type: {part_type}") + continue + + attachment_id = part.get("attachment_id") + if not attachment_id: + if strict: + raise ValueError(f"{part_type} part missing attachment_id") + continue + + attachment = await get_attachment_by_id(str(attachment_id)) + if not attachment: + if strict: + raise ValueError(f"attachment not found: {attachment_id}") + continue + + attachment_path = Path(attachment.path) + message_parts.append( + { + "type": attachment.type, + "attachment_id": attachment.attachment_id, + "filename": attachment_path.name, + "path": str(attachment_path), + } + ) + + return message_parts + + +def webchat_message_parts_to_message_chain( + message_parts: list[dict], + *, + strict: bool = False, +) -> MessageChain: + components = [] + has_content = False + + for part in message_parts: + if not isinstance(part, dict): + if strict: + raise ValueError("message part must be an object") + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "plain": + text = str(part.get("text", "")) + if text: + components.append(Plain(text=text)) + has_content = True + continue + + if part_type == "reply": + message_id = part.get("message_id") + if message_id is None: + if strict: + raise ValueError("reply part missing message_id") + continue + components.append( + Reply( + id=str(message_id), + message_str=str(part.get("selected_text", "")), + chain=[], + ) + ) + continue + + if part_type not in MEDIA_PART_TYPES: + if strict: + raise ValueError(f"unsupported message part type: {part_type}") + continue + + path = part.get("path") + if not path: + if strict: + raise ValueError(f"{part_type} part missing path") + continue + + file_path = Path(str(path)) + if not file_path.exists(): + if strict: + raise ValueError(f"file not found: {file_path!s}") + continue + + file_path_str = str(file_path.resolve()) + has_content = True + if part_type == "image": + components.append(Image.fromFileSystem(file_path_str)) + elif part_type == "record": + components.append(Record.fromFileSystem(file_path_str)) + elif part_type == "video": + components.append(Video.fromFileSystem(file_path_str)) + else: + filename = str(part.get("filename", "")).strip() or file_path.name + components.append(File(name=filename, file=file_path_str)) + + if strict and (not components or not has_content): + raise ValueError("Message content is empty (reply only is not allowed)") + + return MessageChain(chain=components) + + +async def build_message_chain_from_payload( + message_payload: str | list, + *, + get_attachment_by_id: AttachmentGetter, + strict: bool = True, +) -> MessageChain: + message_parts = await build_webchat_message_parts( + message_payload, + get_attachment_by_id=get_attachment_by_id, + strict=strict, + ) + components, _, has_content = await parse_webchat_message_parts( + message_parts, + strict=strict, + ) + if strict and (not components or not has_content): + raise ValueError("Message content is empty (reply only is not allowed)") + return MessageChain(chain=components) + + +async def create_attachment_part_from_existing_file( + filename: str, + *, + attach_type: str, + insert_attachment: AttachmentInserter, + attachments_dir: str | Path, + fallback_dirs: Sequence[str | Path] = (), +) -> dict | None: + basename = Path(filename).name + candidate_paths = [Path(attachments_dir) / basename] + candidate_paths.extend(Path(p) / basename for p in fallback_dirs) + + file_path = next((path for path in candidate_paths if path.exists()), None) + if not file_path: + return None + + mime_type, _ = mimetypes.guess_type(str(file_path)) + attachment = await insert_attachment( + str(file_path), + attach_type, + mime_type or "application/octet-stream", + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": file_path.name, + } + + +async def message_chain_to_storage_message_parts( + message_chain: MessageChain, + *, + insert_attachment: AttachmentInserter, + attachments_dir: str | Path, +) -> list[dict]: + target_dir = Path(attachments_dir) + target_dir.mkdir(parents=True, exist_ok=True) + + parts: list[dict] = [] + for comp in message_chain.chain: + if isinstance(comp, Plain): + if comp.text: + parts.append({"type": "plain", "text": comp.text}) + continue + + if isinstance(comp, Json): + parts.append( + {"type": "plain", "text": json.dumps(comp.data, ensure_ascii=False)} + ) + continue + + if isinstance(comp, Image): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="image", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, Record): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="record", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, Video): + file_path = await comp.convert_to_file_path() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="video", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + ) + if attachment_part: + parts.append(attachment_part) + continue + + if isinstance(comp, File): + file_path = await comp.get_file() + attachment_part = await _copy_file_to_attachment_part( + file_path=file_path, + attach_type="file", + insert_attachment=insert_attachment, + attachments_dir=target_dir, + display_name=comp.name, + ) + if attachment_part: + parts.append(attachment_part) + continue + + return parts + + +async def _copy_file_to_attachment_part( + *, + file_path: str, + attach_type: str, + insert_attachment: AttachmentInserter, + attachments_dir: Path, + display_name: str | None = None, +) -> dict | None: + src_path = Path(file_path) + if not src_path.exists() or not src_path.is_file(): + return None + + suffix = src_path.suffix + target_path = attachments_dir / f"{uuid.uuid4().hex}{suffix}" + shutil.copy2(src_path, target_path) + + mime_type, _ = mimetypes.guess_type(target_path.name) + attachment = await insert_attachment( + str(target_path), + attach_type, + mime_type or "application/octet-stream", + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": display_name or src_path.name, + } diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 047417aaa..54718fefb 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -3,12 +3,12 @@ import time import uuid from collections.abc import Callable, Coroutine +from pathlib import Path from typing import Any from astrbot import logger from astrbot.core import db_helper from astrbot.core.db.po import PlatformMessageHistory -from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform import ( AstrBotMessage, @@ -21,10 +21,23 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path from ...register import register_platform_adapter +from .message_parts_helper import ( + message_chain_to_storage_message_parts, + parse_webchat_message_parts, +) from .webchat_event import WebChatMessageEvent from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr +def _extract_conversation_id(session_id: str) -> str: + """Extract raw webchat conversation id from event/session id.""" + if session_id.startswith("webchat!"): + parts = session_id.split("!", 2) + if len(parts) == 3: + return parts[2] + return session_id + + class QueueListener: def __init__( self, @@ -57,13 +70,15 @@ def __init__( self.settings = platform_settings self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + self.attachments_dir = Path(get_astrbot_data_path()) / "attachments" os.makedirs(self.imgs_dir, exist_ok=True) + self.attachments_dir.mkdir(parents=True, exist_ok=True) self.metadata = PlatformMetadata( name="webchat", description="webchat", id="webchat", - support_proactive_message=False, + support_proactive_message=True, ) self._shutdown_event = asyncio.Event() self._webchat_queue_mgr = webchat_queue_mgr @@ -73,10 +88,67 @@ async def send_by_session( session: MessageSesion, message_chain: MessageChain, ) -> None: - message_id = f"active_{str(uuid.uuid4())}" - await WebChatMessageEvent._send(message_id, message_chain, session.session_id) + conversation_id = _extract_conversation_id(session.session_id) + active_request_ids = self._webchat_queue_mgr.list_back_request_ids( + conversation_id + ) + subscription_request_ids = [ + req_id for req_id in active_request_ids if req_id.startswith("ws_sub_") + ] + target_request_ids = subscription_request_ids or active_request_ids + + if target_request_ids: + for request_id in target_request_ids: + await WebChatMessageEvent._send( + request_id, + message_chain, + session.session_id, + ) + else: + message_id = f"active_{uuid.uuid4()!s}" + await WebChatMessageEvent._send( + message_id, + message_chain, + session.session_id, + ) + + should_persist = ( + bool(subscription_request_ids) + or not active_request_ids + or all(req_id.startswith("active_") for req_id in active_request_ids) + ) + if should_persist: + try: + await self._save_proactive_message(conversation_id, message_chain) + except Exception as e: + logger.error( + f"[WebChatAdapter] Failed to save proactive message: {e}", + exc_info=True, + ) + await super().send_by_session(session, message_chain) + async def _save_proactive_message( + self, + conversation_id: str, + message_chain: MessageChain, + ) -> None: + message_parts = await message_chain_to_storage_message_parts( + message_chain, + insert_attachment=db_helper.insert_attachment, + attachments_dir=self.attachments_dir, + ) + if not message_parts: + return + + await db_helper.insert_platform_message_history( + platform_id="webchat", + user_id=conversation_id, + content={"type": "bot", "message": message_parts}, + sender_id="bot", + sender_name="bot", + ) + async def _get_message_history( self, message_id: int ) -> PlatformMessageHistory | None: @@ -98,72 +170,30 @@ async def _parse_message_parts( Returns: tuple[list, list[str]]: (消息组件列表, 纯文本列表) """ - components = [] - text_parts = [] - - for part in message_parts: - part_type = part.get("type") - if part_type == "plain": - text = part.get("text", "") - components.append(Plain(text=text)) - text_parts.append(text) - elif part_type == "reply": - message_id = part.get("message_id") - reply_chain = [] - reply_message_str = part.get("selected_text", "") - sender_id = None - sender_name = None - - if reply_message_str: - reply_chain = [Plain(text=reply_message_str)] - - # recursively get the content of the referenced message, if selected_text is empty - if not reply_message_str and depth < max_depth and message_id: - history = await self._get_message_history(message_id) - if history and history.content: - reply_parts = history.content.get("message", []) - if isinstance(reply_parts, list): - ( - reply_chain, - reply_text_parts, - ) = await self._parse_message_parts( - reply_parts, - depth=depth + 1, - max_depth=max_depth, - ) - reply_message_str = "".join(reply_text_parts) - sender_id = history.sender_id - sender_name = history.sender_name - - components.append( - Reply( - id=message_id, - chain=reply_chain, - message_str=reply_message_str, - sender_id=sender_id, - sender_nickname=sender_name, - ) - ) - elif part_type == "image": - path = part.get("path") - if path: - components.append(Image.fromFileSystem(path)) - elif part_type == "record": - path = part.get("path") - if path: - components.append(Record.fromFileSystem(path)) - elif part_type == "file": - path = part.get("path") - if path: - filename = part.get("filename") or ( - os.path.basename(path) if path else "file" - ) - components.append(File(name=filename, file=path)) - elif part_type == "video": - path = part.get("path") - if path: - components.append(Video.fromFileSystem(path)) + async def get_reply_parts( + message_id: Any, + ) -> tuple[list[dict], str | None, str | None] | None: + history = await self._get_message_history(message_id) + if not history or not history.content: + return None + + reply_parts = history.content.get("message", []) + if not isinstance(reply_parts, list): + return None + + return reply_parts, history.sender_id, history.sender_name + + components, text_parts, _ = await parse_webchat_message_parts( + message_parts, + strict=False, + include_empty_plain=True, + verify_media_path_exists=False, + reply_history_getter=get_reply_parts, + current_depth=depth, + max_reply_depth=max_depth, + cast_reply_id_to_str=False, + ) return components, text_parts async def convert_message(self, data: tuple) -> AstrBotMessage: diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index a680f7617..b7da864aa 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -14,6 +14,15 @@ attachments_dir = os.path.join(get_astrbot_data_path(), "attachments") +def _extract_conversation_id(session_id: str) -> str: + """Extract raw webchat conversation id from event/session id.""" + if session_id.startswith("webchat!"): + parts = session_id.split("!", 2) + if len(parts) == 3: + return parts[2] + return session_id + + class WebChatMessageEvent(AstrMessageEvent): def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) @@ -27,7 +36,7 @@ async def _send( streaming: bool = False, ) -> str | None: request_id = str(message_id) - conversation_id = session_id.split("!")[-1] + conversation_id = _extract_conversation_id(session_id) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( request_id, conversation_id, @@ -130,7 +139,7 @@ async def send_streaming(self, generator, use_fallback: bool = False) -> None: reasoning_content = "" message_id = self.message_obj.message_id request_id = str(message_id) - conversation_id = self.session_id.split("!")[-1] + conversation_id = _extract_conversation_id(self.session_id) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( request_id, conversation_id, diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index fd35e837c..f3ade1589 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -75,6 +75,10 @@ def remove_queue(self, conversation_id: str): if task is not None: task.cancel() + def list_back_request_ids(self, conversation_id: str) -> list[str]: + """List active back-queue request IDs for a conversation.""" + return list(self._conversation_back_requests.get(conversation_id, set())) + def has_queue(self, conversation_id: str) -> bool: """Check if a queue exists for the given conversation ID""" return conversation_id in self.queues diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 1235dd381..0602cc074 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -1,6 +1,5 @@ import asyncio import json -import mimetypes import os import re import uuid @@ -14,6 +13,12 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_webchat_message_parts, + create_attachment_part_from_existing_file, + strip_message_parts_path_fields, + webchat_message_parts_have_content, +) from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from astrbot.core.utils.active_event_registry import active_event_registry from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -166,83 +171,24 @@ async def post_file(self): ) async def _build_user_message_parts(self, message: str | list) -> list[dict]: - """构建用户消息的部分列表 - - Args: - message: 文本消息 (str) 或消息段列表 (list) - """ - parts = [] - - if isinstance(message, list): - for part in message: - part_type = part.get("type") - if part_type == "plain": - parts.append({"type": "plain", "text": part.get("text", "")}) - elif part_type == "reply": - parts.append( - { - "type": "reply", - "message_id": part.get("message_id"), - "selected_text": part.get("selected_text", ""), - } - ) - elif attachment_id := part.get("attachment_id"): - attachment = await self.db.get_attachment_by_id(attachment_id) - if attachment: - parts.append( - { - "type": attachment.type, - "attachment_id": attachment.attachment_id, - "filename": os.path.basename(attachment.path), - "path": attachment.path, # will be deleted - } - ) - return parts - - if message: - parts.append({"type": "plain", "text": message}) - - return parts + """构建用户消息的部分列表。""" + return await build_webchat_message_parts( + message, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=False, + ) async def _create_attachment_from_file( self, filename: str, attach_type: str ) -> dict | None: - """从本地文件创建 attachment 并返回消息部分 - - 用于处理 bot 回复中的媒体文件 - - Args: - filename: 存储的文件名 - attach_type: 附件类型 (image, record, file, video) - """ - basename = os.path.basename(filename) - candidate_paths = [ - os.path.join(self.attachments_dir, basename), - os.path.join(self.legacy_img_dir, basename), - ] - file_path = next((p for p in candidate_paths if os.path.exists(p)), None) - if not file_path: - return None - - # guess mime type - mime_type, _ = mimetypes.guess_type(filename) - if not mime_type: - mime_type = "application/octet-stream" - - # insert attachment - attachment = await self.db.insert_attachment( - path=file_path, - type=attach_type, - mime_type=mime_type, + """从本地文件创建 attachment 并返回消息部分。""" + return await create_attachment_part_from_existing_file( + filename, + attach_type=attach_type, + insert_attachment=self.db.insert_attachment, + attachments_dir=self.attachments_dir, + fallback_dirs=[self.legacy_img_dir], ) - if not attachment: - return None - - return { - "type": attach_type, - "attachment_id": attachment.attachment_id, - "filename": os.path.basename(file_path), - } def _extract_web_search_refs( self, accumulated_text: str, accumulated_parts: list @@ -356,21 +302,6 @@ async def chat(self, post_data: dict | None = None): selected_model = post_data.get("selected_model") enable_streaming = post_data.get("enable_streaming", True) - # 检查消息是否为空 - if isinstance(message, list): - has_content = any( - part.get("type") in ("plain", "image", "record", "file", "video") - for part in message - ) - if not has_content: - return ( - Response() - .error("Message content is empty (reply only is not allowed)") - .__dict__ - ) - elif not message: - return Response().error("Message are both empty").__dict__ - if not session_id: return Response().error("session_id is empty").__dict__ @@ -378,6 +309,12 @@ async def chat(self, post_data: dict | None = None): # 构建用户消息段(包含 path 用于传递给 adapter) message_parts = await self._build_user_message_parts(message) + if not webchat_message_parts_have_content(message_parts): + return ( + Response() + .error("Message content is empty (reply only is not allowed)") + .__dict__ + ) message_id = str(uuid.uuid4()) back_queue = webchat_queue_mgr.get_or_create_back_queue( @@ -583,10 +520,7 @@ async def stream(): ), ) - message_parts_for_storage = [] - for part in message_parts: - part_copy = {k: v for k, v in part.items() if k != "path"} - message_parts_for_storage.append(part_copy) + message_parts_for_storage = strip_message_parts_path_fields(message_parts) await self.platform_history_mgr.insert( platform_id="webchat", diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 8c922ab69..25438565e 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -1,6 +1,7 @@ import asyncio import json import os +import re import time import uuid import wave @@ -10,9 +11,16 @@ from quart import websocket from astrbot import logger +from astrbot.core import sp from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_webchat_message_parts, + create_attachment_part_from_existing_file, + strip_message_parts_path_fields, + webchat_message_parts_have_content, +) from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path from .route import Route, RouteContext @@ -30,6 +38,9 @@ def __init__(self, session_id: str, username: str) -> None: self.audio_frames: list[bytes] = [] self.current_stamp: str | None = None self.temp_audio_path: str | None = None + self.chat_subscriptions: dict[str, str] = {} + self.chat_subscription_tasks: dict[str, asyncio.Task] = {} + self.ws_send_lock = asyncio.Lock() def start_speaking(self, stamp: str) -> None: """开始说话""" @@ -106,13 +117,26 @@ def __init__( self.core_lifecycle = core_lifecycle self.db = db self.plugin_manager = core_lifecycle.plugin_manager + self.platform_history_mgr = core_lifecycle.platform_message_history_manager self.sessions: dict[str, LiveChatSession] = {} + self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments") + self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + os.makedirs(self.attachments_dir, exist_ok=True) # 注册 WebSocket 路由 self.app.websocket("/api/live_chat/ws")(self.live_chat_ws) + self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws) async def live_chat_ws(self) -> None: - """Live Chat WebSocket 处理器""" + """Legacy Live Chat WebSocket 处理器(默认 ct=live)""" + await self._unified_ws_loop(force_ct="live") + + async def unified_chat_ws(self) -> None: + """Unified Chat WebSocket 处理器(支持 ct=live/chat)""" + await self._unified_ws_loop(force_ct=None) + + async def _unified_ws_loop(self, force_ct: str | None = None) -> None: + """统一 WebSocket 循环""" # WebSocket 不能通过 header 传递 token,需要从 query 参数获取 # 注意:WebSocket 上下文使用 websocket.args 而不是 request.args token = websocket.args.get("token") @@ -140,7 +164,11 @@ async def live_chat_ws(self) -> None: try: while True: message = await websocket.receive_json() - await self._handle_message(live_session, message) + ct = force_ct or message.get("ct", "live") + if ct == "chat": + await self._handle_chat_message(live_session, message) + else: + await self._handle_message(live_session, message) except Exception as e: logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True) @@ -148,10 +176,488 @@ async def live_chat_ws(self) -> None: finally: # 清理会话 if session_id in self.sessions: + await self._cleanup_chat_subscriptions(live_session) live_session.cleanup() del self.sessions[session_id] logger.info(f"[Live Chat] WebSocket 连接关闭: {username}") + async def _create_attachment_from_file( + self, filename: str, attach_type: str + ) -> dict | None: + """从本地文件创建 attachment 并返回消息部分。""" + return await create_attachment_part_from_existing_file( + filename, + attach_type=attach_type, + insert_attachment=self.db.insert_attachment, + attachments_dir=self.attachments_dir, + fallback_dirs=[self.legacy_img_dir], + ) + + def _extract_web_search_refs( + self, accumulated_text: str, accumulated_parts: list + ) -> dict: + """从消息中提取 web_search 引用。""" + supported = ["web_search_tavily", "web_search_bocha"] + web_search_results = {} + tool_call_parts = [ + p + for p in accumulated_parts + if p.get("type") == "tool_call" and p.get("tool_calls") + ] + + for part in tool_call_parts: + for tool_call in part["tool_calls"]: + if tool_call.get("name") not in supported or not tool_call.get( + "result" + ): + continue + try: + result_data = json.loads(tool_call["result"]) + for item in result_data.get("results", []): + if idx := item.get("index"): + web_search_results[idx] = { + "url": item.get("url"), + "title": item.get("title"), + "snippet": item.get("snippet"), + } + except (json.JSONDecodeError, KeyError): + pass + + if not web_search_results: + return {} + + ref_indices = { + m.strip() for m in re.findall(r"(.*?)", accumulated_text) + } + + used_refs = [] + for ref_index in ref_indices: + if ref_index not in web_search_results: + continue + payload = {"index": ref_index, **web_search_results[ref_index]} + if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]): + payload["favicon"] = favicon + used_refs.append(payload) + + return {"used": used_refs} if used_refs else {} + + async def _save_bot_message( + self, + webchat_conv_id: str, + text: str, + media_parts: list, + reasoning: str, + agent_stats: dict, + refs: dict, + ): + """保存 bot 消息到历史记录。""" + bot_message_parts = [] + bot_message_parts.extend(media_parts) + if text: + bot_message_parts.append({"type": "plain", "text": text}) + + new_his = {"type": "bot", "message": bot_message_parts} + if reasoning: + new_his["reasoning"] = reasoning + if agent_stats: + new_his["agent_stats"] = agent_stats + if refs: + new_his["refs"] = refs + + return await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id="bot", + sender_name="bot", + ) + + async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None: + async with session.ws_send_lock: + await websocket.send_json(payload) + + async def _forward_chat_subscription( + self, + session: LiveChatSession, + chat_session_id: str, + request_id: str, + ) -> None: + back_queue = webchat_queue_mgr.get_or_create_back_queue( + request_id, chat_session_id + ) + try: + while True: + result = await back_queue.get() + if not result: + continue + await self._send_chat_payload(session, {"ct": "chat", **result}) + except asyncio.CancelledError: + pass + except Exception as e: + logger.error( + f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}", + exc_info=True, + ) + finally: + webchat_queue_mgr.remove_back_queue(request_id) + if session.chat_subscriptions.get(chat_session_id) == request_id: + session.chat_subscriptions.pop(chat_session_id, None) + session.chat_subscription_tasks.pop(chat_session_id, None) + + async def _ensure_chat_subscription( + self, + session: LiveChatSession, + chat_session_id: str, + ) -> str: + existing_request_id = session.chat_subscriptions.get(chat_session_id) + existing_task = session.chat_subscription_tasks.get(chat_session_id) + if existing_request_id and existing_task and not existing_task.done(): + return existing_request_id + + request_id = f"ws_sub_{uuid.uuid4().hex}" + session.chat_subscriptions[chat_session_id] = request_id + task = asyncio.create_task( + self._forward_chat_subscription(session, chat_session_id, request_id), + name=f"chat_ws_sub_{chat_session_id}", + ) + session.chat_subscription_tasks[chat_session_id] = task + return request_id + + async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None: + tasks = list(session.chat_subscription_tasks.values()) + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + for request_id in list(session.chat_subscriptions.values()): + webchat_queue_mgr.remove_back_queue(request_id) + session.chat_subscriptions.clear() + session.chat_subscription_tasks.clear() + + async def _handle_chat_message( + self, session: LiveChatSession, message: dict + ) -> None: + """处理 Chat Mode 消息(ct=chat)""" + msg_type = message.get("t") + + if msg_type == "bind": + chat_session_id = message.get("session_id") + if not isinstance(chat_session_id, str) or not chat_session_id: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "session_id is required", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + request_id = await self._ensure_chat_subscription(session, chat_session_id) + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "session_bound", + "session_id": chat_session_id, + "message_id": request_id, + }, + ) + return + + if msg_type == "interrupt": + session.should_interrupt = True + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "INTERRUPTED", + "code": "INTERRUPTED", + }, + ) + return + + if msg_type != "send": + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": f"Unsupported message type: {msg_type}", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + if session.is_processing: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "Session is busy", + "code": "PROCESSING_ERROR", + }, + ) + return + + payload = message.get("message") + session_id = message.get("session_id") or session.session_id + message_id = message.get("message_id") or str(uuid.uuid4()) + selected_provider = message.get("selected_provider") + selected_model = message.get("selected_model") + selected_stt_provider = message.get("selected_stt_provider") + selected_tts_provider = message.get("selected_tts_provider") + persona_prompt = message.get("persona_prompt") + show_reasoning = message.get("show_reasoning") + enable_streaming = message.get("enable_streaming", True) + + if not isinstance(payload, list): + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "message must be list", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + message_parts = await self._build_chat_message_parts(payload) + has_content = webchat_message_parts_have_content(message_parts) + if not has_content: + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": "Message content is empty", + "code": "INVALID_MESSAGE_FORMAT", + }, + ) + return + + await self._ensure_chat_subscription(session, session_id) + + session.is_processing = True + session.should_interrupt = False + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id) + + try: + chat_queue = webchat_queue_mgr.get_or_create_queue(session_id) + await chat_queue.put( + ( + session.username, + session_id, + { + "message": message_parts, + "selected_provider": selected_provider, + "selected_model": selected_model, + "selected_stt_provider": selected_stt_provider, + "selected_tts_provider": selected_tts_provider, + "persona_prompt": persona_prompt, + "show_reasoning": show_reasoning, + "enable_streaming": enable_streaming, + "message_id": message_id, + }, + ), + ) + + message_parts_for_storage = strip_message_parts_path_fields(message_parts) + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=session_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=session.username, + sender_name=session.username, + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + tool_calls = {} + agent_stats = {} + refs = {} + + while True: + if session.should_interrupt: + session.should_interrupt = False + break + + try: + result = await asyncio.wait_for(back_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue + + if not result: + continue + if result.get("message_id") and result.get("message_id") != message_id: + continue + + result_text = result.get("data", "") + msg_type = result.get("type") + streaming = result.get("streaming", False) + chain_type = result.get("chain_type") + if chain_type == "agent_stats": + try: + parsed_agent_stats = json.loads(result_text) + agent_stats = parsed_agent_stats + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "agent_stats", + "data": parsed_agent_stats, + }, + ) + except Exception: + pass + continue + + outgoing = {"ct": "chat", **result} + await self._send_chat_payload(session, outgoing) + + if msg_type == "plain": + if chain_type == "tool_call": + try: + tool_call = json.loads(result_text) + tool_calls[tool_call.get("id")] = tool_call + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + except Exception: + pass + elif chain_type == "tool_call_result": + try: + tcr = json.loads(result_text) + tc_id = tcr.get("id") + if tc_id in tool_calls: + tool_calls[tc_id]["result"] = tcr.get("result") + tool_calls[tc_id]["finished_ts"] = tcr.get("ts") + accumulated_parts.append( + { + "type": "tool_call", + "tool_calls": [tool_calls[tc_id]], + } + ) + tool_calls.pop(tc_id, None) + except Exception: + pass + elif chain_type == "reasoning": + accumulated_reasoning += result_text + elif streaming: + accumulated_text += result_text + else: + accumulated_text = result_text + elif msg_type == "image": + filename = str(result_text).replace("[IMAGE]", "") + part = await self._create_attachment_from_file(filename, "image") + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = str(result_text).replace("[RECORD]", "") + part = await self._create_attachment_from_file(filename, "record") + if part: + accumulated_parts.append(part) + elif msg_type == "file": + filename = str(result_text).replace("[FILE]", "").split("|", 1)[0] + part = await self._create_attachment_from_file(filename, "file") + if part: + accumulated_parts.append(part) + elif msg_type == "video": + filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0] + part = await self._create_attachment_from_file(filename, "video") + if part: + accumulated_parts.append(part) + + should_save = False + if msg_type == "end": + should_save = bool( + accumulated_parts + or accumulated_text + or accumulated_reasoning + or refs + or agent_stats + ) + elif (streaming and msg_type == "complete") or not streaming: + if chain_type not in ( + "tool_call", + "tool_call_result", + "agent_stats", + ): + should_save = True + + if should_save: + try: + refs = self._extract_web_search_refs( + accumulated_text, + accumulated_parts, + ) + except Exception as e: + logger.exception( + f"[Live Chat] Failed to extract web search refs: {e}", + exc_info=True, + ) + + saved_record = await self._save_bot_message( + session_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, + agent_stats, + refs, + ) + if saved_record: + await self._send_chat_payload( + session, + { + "ct": "chat", + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + }, + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + agent_stats = {} + refs = {} + + if msg_type == "end": + break + + except Exception as e: + logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True) + await self._send_chat_payload( + session, + { + "ct": "chat", + "t": "error", + "data": f"处理失败: {str(e)}", + "code": "PROCESSING_ERROR", + }, + ) + finally: + session.is_processing = False + webchat_queue_mgr.remove_back_queue(message_id) + + async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]: + """构建 chat websocket 用户消息段(复用 webchat 逻辑)""" + return await build_webchat_message_parts( + message, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=False, + ) + async def _handle_message(self, session: LiveChatSession, message: dict) -> None: """处理 WebSocket 消息""" msg_type = message.get("t") # 使用 t 代替 type diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index c25870ebb..653e22cbf 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -1,15 +1,22 @@ -from pathlib import Path +import asyncio +import hashlib +import json from uuid import uuid4 -from quart import g, request +from quart import g, request, websocket from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video -from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform.message_session import MessageSesion - +from astrbot.core.platform.sources.webchat.message_parts_helper import ( + build_message_chain_from_payload, + strip_message_parts_path_fields, + webchat_message_parts_have_content, +) +from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr + +from .api_key import ALL_OPEN_API_SCOPES from .chat import ChatRoute from .route import Response, Route, RouteContext @@ -37,6 +44,7 @@ def __init__( "/v1/im/bots": ("GET", self.get_bots), } self.register_routes() + self.app.websocket("/api/v1/chat/ws")(self.chat_ws) @staticmethod def _resolve_open_username( @@ -181,6 +189,348 @@ async def chat_send(self): finally: g.username = original_username + @staticmethod + def _extract_ws_api_key() -> str | None: + if key := websocket.args.get("api_key"): + return key.strip() + if key := websocket.args.get("key"): + return key.strip() + if key := websocket.headers.get("X-API-Key"): + return key.strip() + + auth_header = websocket.headers.get("Authorization", "").strip() + if auth_header.startswith("Bearer "): + return auth_header.removeprefix("Bearer ").strip() + if auth_header.startswith("ApiKey "): + return auth_header.removeprefix("ApiKey ").strip() + return None + + async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]: + raw_key = self._extract_ws_api_key() + if not raw_key: + return False, "Missing API key" + + key_hash = hashlib.pbkdf2_hmac( + "sha256", + raw_key.encode("utf-8"), + b"astrbot_api_key", + 100_000, + ).hex() + api_key = await self.db.get_active_api_key_by_hash(key_hash) + if not api_key: + return False, "Invalid API key" + + if isinstance(api_key.scopes, list): + scopes = api_key.scopes + else: + scopes = list(ALL_OPEN_API_SCOPES) + + if "*" not in scopes and "chat" not in scopes: + return False, "Insufficient API key scope" + + await self.db.touch_api_key(api_key.key_id) + return True, None + + async def _send_chat_ws_error(self, message: str, code: str) -> None: + await websocket.send_json( + { + "type": "error", + "code": code, + "data": message, + } + ) + + async def _update_session_config_route( + self, + *, + username: str, + session_id: str, + config_id: str | None, + ) -> str | None: + if not config_id: + return None + + umo = f"webchat:FriendMessage:webchat!{username}!{session_id}" + try: + if config_id == "default": + await self.core_lifecycle.umop_config_router.delete_route(umo) + else: + await self.core_lifecycle.umop_config_router.update_route( + umo, config_id + ) + except Exception as e: + logger.error( + "Failed to update chat config route for %s with %s: %s", + umo, + config_id, + e, + exc_info=True, + ) + return f"Failed to update chat config route: {e}" + return None + + async def _handle_chat_ws_send(self, post_data: dict) -> None: + effective_username, username_err = self._resolve_open_username( + post_data.get("username") + ) + if username_err or not effective_username: + await self._send_chat_ws_error( + username_err or "Invalid username", "BAD_USER" + ) + return + + message = post_data.get("message") + if message is None: + await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE") + return + + raw_session_id = post_data.get("session_id", post_data.get("conversation_id")) + session_id = str(raw_session_id).strip() if raw_session_id is not None else "" + if not session_id: + session_id = str(uuid4()) + + ensure_session_err = await self._ensure_chat_session( + effective_username, + session_id, + ) + if ensure_session_err: + await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR") + return + + config_id, resolve_err = self._resolve_chat_config_id(post_data) + if resolve_err: + await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR") + return + + config_err = await self._update_session_config_route( + username=effective_username, + session_id=session_id, + config_id=config_id, + ) + if config_err: + await self._send_chat_ws_error(config_err, "CONFIG_ERROR") + return + + message_parts = await self.chat_route._build_user_message_parts(message) + if not webchat_message_parts_have_content(message_parts): + await self._send_chat_ws_error( + "Message content is empty (reply only is not allowed)", + "INVALID_MESSAGE", + ) + return + + message_id = str(post_data.get("message_id") or uuid4()) + selected_provider = post_data.get("selected_provider") + selected_model = post_data.get("selected_model") + enable_streaming = post_data.get("enable_streaming", True) + + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id) + try: + chat_queue = webchat_queue_mgr.get_or_create_queue(session_id) + await chat_queue.put( + ( + effective_username, + session_id, + { + "message": message_parts, + "selected_provider": selected_provider, + "selected_model": selected_model, + "enable_streaming": enable_streaming, + "message_id": message_id, + }, + ) + ) + + message_parts_for_storage = strip_message_parts_path_fields(message_parts) + await self.chat_route.platform_history_mgr.insert( + platform_id="webchat", + user_id=session_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=effective_username, + sender_name=effective_username, + ) + + await websocket.send_json( + { + "type": "session_id", + "data": None, + "session_id": session_id, + "message_id": message_id, + } + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + tool_calls = {} + agent_stats = {} + refs = {} + while True: + try: + result = await asyncio.wait_for(back_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue + + if not result: + continue + + if "message_id" in result and result["message_id"] != message_id: + logger.warning("openapi ws stream message_id mismatch") + continue + + result_text = result.get("data", "") + msg_type = result.get("type") + streaming = result.get("streaming", False) + chain_type = result.get("chain_type") + + if chain_type == "agent_stats": + try: + stats_info = { + "type": "agent_stats", + "data": json.loads(result_text), + } + await websocket.send_json(stats_info) + agent_stats = stats_info["data"] + except Exception: + pass + continue + + await websocket.send_json(result) + + if msg_type == "plain": + if chain_type == "tool_call": + tool_call = json.loads(result_text) + tool_calls[tool_call.get("id")] = tool_call + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + elif chain_type == "tool_call_result": + tcr = json.loads(result_text) + tc_id = tcr.get("id") + if tc_id in tool_calls: + tool_calls[tc_id]["result"] = tcr.get("result") + tool_calls[tc_id]["finished_ts"] = tcr.get("ts") + accumulated_parts.append( + {"type": "tool_call", "tool_calls": [tool_calls[tc_id]]} + ) + tool_calls.pop(tc_id, None) + elif chain_type == "reasoning": + accumulated_reasoning += result_text + elif streaming: + accumulated_text += result_text + else: + accumulated_text = result_text + elif msg_type == "image": + filename = str(result_text).replace("[IMAGE]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "image" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = str(result_text).replace("[RECORD]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "record" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "file": + filename = str(result_text).replace("[FILE]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "file" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "video": + filename = str(result_text).replace("[VIDEO]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "video" + ) + if part: + accumulated_parts.append(part) + + if msg_type == "end": + break + if (streaming and msg_type == "complete") or not streaming: + if chain_type in ("tool_call", "tool_call_result"): + continue + try: + refs = self.chat_route._extract_web_search_refs( + accumulated_text, + accumulated_parts, + ) + except Exception as e: + logger.exception( + f"Open API WS failed to extract web search refs: {e}", + exc_info=True, + ) + + saved_record = await self.chat_route._save_bot_message( + session_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, + agent_stats, + refs, + ) + if saved_record: + await websocket.send_json( + { + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + "session_id": session_id, + } + ) + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + agent_stats = {} + refs = {} + except Exception as e: + logger.exception(f"Open API WS chat failed: {e}", exc_info=True) + await self._send_chat_ws_error( + f"Failed to process message: {e}", "PROCESSING_ERROR" + ) + finally: + webchat_queue_mgr.remove_back_queue(message_id) + + async def chat_ws(self) -> None: + authed, auth_err = await self._authenticate_chat_ws_api_key() + if not authed: + await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED") + await websocket.close(1008, auth_err or "Unauthorized") + return + + try: + while True: + message = await websocket.receive_json() + if not isinstance(message, dict): + await self._send_chat_ws_error( + "message must be an object", + "INVALID_MESSAGE", + ) + continue + + msg_type = message.get("t", "send") + if msg_type == "ping": + await websocket.send_json({"type": "pong"}) + continue + if msg_type != "send": + await self._send_chat_ws_error( + f"Unsupported message type: {msg_type}", + "INVALID_MESSAGE", + ) + continue + + await self._handle_chat_ws_send(message) + except Exception as e: + logger.debug("Open API WS connection closed: %s", e) + async def upload_file(self): return await self.chat_route.post_file() @@ -254,83 +604,12 @@ async def get_chat_configs(self): async def _build_message_chain_from_payload( self, message_payload: str | list, - ) -> MessageChain: - if isinstance(message_payload, str): - text = message_payload.strip() - if not text: - raise ValueError("Message is empty") - return MessageChain(chain=[Plain(text=text)]) - - if not isinstance(message_payload, list): - raise ValueError("message must be a string or list") - - components = [] - has_content = False - - for part in message_payload: - if not isinstance(part, dict): - raise ValueError("message part must be an object") - - part_type = str(part.get("type", "")).strip() - if part_type == "plain": - text = str(part.get("text", "")) - if text: - has_content = True - components.append(Plain(text=text)) - continue - - if part_type == "reply": - message_id = part.get("message_id") - if message_id is None: - raise ValueError("reply part missing message_id") - components.append( - Reply( - id=str(message_id), - message_str=str(part.get("selected_text", "")), - chain=[], - ) - ) - continue - - if part_type not in {"image", "record", "file", "video"}: - raise ValueError(f"unsupported message part type: {part_type}") - - has_content = True - file_path: Path | None = None - resolved_type = part_type - filename = str(part.get("filename", "")).strip() - - attachment_id = part.get("attachment_id") - if attachment_id: - attachment = await self.db.get_attachment_by_id(str(attachment_id)) - if not attachment: - raise ValueError(f"attachment not found: {attachment_id}") - file_path = Path(attachment.path) - resolved_type = attachment.type - if not filename: - filename = file_path.name - else: - raise ValueError(f"{part_type} part missing attachment_id") - - if not file_path.exists(): - raise ValueError(f"file not found: {file_path!s}") - - file_path_str = str(file_path.resolve()) - if resolved_type == "image": - components.append(Image.fromFileSystem(file_path_str)) - elif resolved_type == "record": - components.append(Record.fromFileSystem(file_path_str)) - elif resolved_type == "video": - components.append(Video.fromFileSystem(file_path_str)) - else: - components.append( - File(name=filename or file_path.name, file=file_path_str) - ) - - if not components or not has_content: - raise ValueError("Message content is empty (reply only is not allowed)") - - return MessageChain(chain=components) + ): + return await build_message_chain_from_payload( + message_payload, + get_attachment_by_id=self.db.get_attachment_by_id, + strict=True, + ) async def send_message(self): post_data = await request.json or {} diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index a9631fc09..a9650cd06 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -204,6 +204,10 @@ async def auth_middleware(self): @staticmethod def _extract_raw_api_key() -> str | None: + if key := request.args.get("api_key"): + return key.strip() + if key := request.args.get("key"): + return key.strip() if key := request.headers.get("X-API-Key"): return key.strip() auth_header = request.headers.get("Authorization", "").strip() @@ -217,6 +221,7 @@ def _extract_raw_api_key() -> str | None: def _get_required_open_api_scope(path: str) -> str | None: scope_map = { "/api/v1/chat": "chat", + "/api/v1/chat/ws": "chat", "/api/v1/chat/sessions": "chat", "/api/v1/configs": "config", "/api/v1/file": "file", diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 803c5d826..054a18662 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -10,6 +10,7 @@ :selectedSessions="selectedSessions" :currSessionId="currSessionId" :selectedProjectId="selectedProjectId" + :transportMode="transportMode" :isDark="isDark" :chatboxMode="chatboxMode" :isMobile="isMobile" @@ -26,6 +27,7 @@ @createProject="showCreateProjectDialog" @editProject="showEditProjectDialog" @deleteProject="handleDeleteProject" + @updateTransportMode="setTransportMode" /> @@ -301,11 +303,14 @@ const { isStreaming, isConvRunning, enableStreaming, + transportMode, currentSessionProject, getSessionMessages: getSessionMsg, sendMessage: sendMsg, stopMessage: stopMsg, - toggleStreaming + toggleStreaming, + setTransportMode, + cleanupTransport } = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions); // 组件引用 @@ -695,6 +700,7 @@ onMounted(() => { onBeforeUnmount(() => { window.removeEventListener('resize', checkMobile); cleanupMediaCache(); + cleanupTransport(); }); diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue index a728930d9..97f2179e7 100644 --- a/dashboard/src/components/chat/ConversationSidebar.vue +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -117,6 +117,27 @@ {{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }} + + + + {{ tm('transport.title') }} + + +