diff --git a/astrbot/core/platform/sources/webchat/message_parts_helper.py b/astrbot/core/platform/sources/webchat/message_parts_helper.py
new file mode 100644
index 000000000..43072ec1c
--- /dev/null
+++ b/astrbot/core/platform/sources/webchat/message_parts_helper.py
@@ -0,0 +1,465 @@
+import json
+import mimetypes
+import shutil
+import uuid
+from collections.abc import Awaitable, Callable, Sequence
+from pathlib import Path
+from typing import Any
+
+from astrbot.core.db.po import Attachment
+from astrbot.core.message.components import (
+ File,
+ Image,
+ Json,
+ Plain,
+ Record,
+ Reply,
+ Video,
+)
+from astrbot.core.message.message_event_result import MessageChain
+
+AttachmentGetter = Callable[[str], Awaitable[Attachment | None]]
+AttachmentInserter = Callable[[str, str, str], Awaitable[Attachment | None]]
+ReplyHistoryGetter = Callable[
+ [Any],
+ Awaitable[tuple[list[dict], str | None, str | None] | None],
+]
+
+MEDIA_PART_TYPES = {"image", "record", "file", "video"}
+
+
+def strip_message_parts_path_fields(message_parts: list[dict]) -> list[dict]:
+ return [{k: v for k, v in part.items() if k != "path"} for part in message_parts]
+
+
+def webchat_message_parts_have_content(message_parts: list[dict]) -> bool:
+ return any(
+ part.get("type") in ("plain", "image", "record", "file", "video")
+ and (part.get("text") or part.get("attachment_id") or part.get("filename"))
+ for part in message_parts
+ )
+
+
+async def parse_webchat_message_parts(
+ message_parts: list,
+ *,
+ strict: bool = False,
+ include_empty_plain: bool = False,
+ verify_media_path_exists: bool = True,
+ reply_history_getter: ReplyHistoryGetter | None = None,
+ current_depth: int = 0,
+ max_reply_depth: int = 0,
+ cast_reply_id_to_str: bool = True,
+) -> tuple[list, list[str], bool]:
+ """Parse webchat message parts into components/text parts.
+
+ Returns:
+ tuple[list, list[str], bool]:
+ (components, plain_text_parts, has_non_reply_content)
+ """
+ components = []
+ text_parts: list[str] = []
+ has_content = False
+
+ for part in message_parts:
+ if not isinstance(part, dict):
+ if strict:
+ raise ValueError("message part must be an object")
+ continue
+
+ part_type = str(part.get("type", "")).strip()
+ if part_type == "plain":
+ text = str(part.get("text", ""))
+ if text or include_empty_plain:
+ components.append(Plain(text=text))
+ text_parts.append(text)
+ if text:
+ has_content = True
+ continue
+
+ if part_type == "reply":
+ message_id = part.get("message_id")
+ if message_id is None:
+ if strict:
+ raise ValueError("reply part missing message_id")
+ continue
+
+ reply_chain = []
+ reply_message_str = str(part.get("selected_text", ""))
+ sender_id = None
+ sender_name = None
+
+ if reply_message_str:
+ reply_chain = [Plain(text=reply_message_str)]
+ elif (
+ reply_history_getter
+ and current_depth < max_reply_depth
+ and message_id is not None
+ ):
+ reply_info = await reply_history_getter(message_id)
+ if reply_info:
+ reply_parts, sender_id, sender_name = reply_info
+ (
+ reply_chain,
+ reply_text_parts,
+ _,
+ ) = await parse_webchat_message_parts(
+ reply_parts,
+ strict=strict,
+ include_empty_plain=include_empty_plain,
+ verify_media_path_exists=verify_media_path_exists,
+ reply_history_getter=reply_history_getter,
+ current_depth=current_depth + 1,
+ max_reply_depth=max_reply_depth,
+ cast_reply_id_to_str=cast_reply_id_to_str,
+ )
+ reply_message_str = "".join(reply_text_parts)
+
+ reply_id = str(message_id) if cast_reply_id_to_str else message_id
+ components.append(
+ Reply(
+ id=reply_id,
+ message_str=reply_message_str,
+ chain=reply_chain,
+ sender_id=sender_id,
+ sender_nickname=sender_name,
+ )
+ )
+ continue
+
+ if part_type not in MEDIA_PART_TYPES:
+ if strict:
+ raise ValueError(f"unsupported message part type: {part_type}")
+ continue
+
+ path = part.get("path")
+ if not path:
+ if strict:
+ raise ValueError(f"{part_type} part missing path")
+ continue
+
+ file_path = Path(str(path))
+ if verify_media_path_exists and not file_path.exists():
+ if strict:
+ raise ValueError(f"file not found: {file_path!s}")
+ continue
+
+ file_path_str = (
+ str(file_path.resolve()) if verify_media_path_exists else str(file_path)
+ )
+ has_content = True
+ if part_type == "image":
+ components.append(Image.fromFileSystem(file_path_str))
+ elif part_type == "record":
+ components.append(Record.fromFileSystem(file_path_str))
+ elif part_type == "video":
+ components.append(Video.fromFileSystem(file_path_str))
+ else:
+ filename = str(part.get("filename", "")).strip() or file_path.name
+ components.append(File(name=filename, file=file_path_str))
+
+ return components, text_parts, has_content
+
+
+async def build_webchat_message_parts(
+ message_payload: str | list,
+ *,
+ get_attachment_by_id: AttachmentGetter,
+ strict: bool = False,
+) -> list[dict]:
+ if isinstance(message_payload, str):
+ text = message_payload.strip()
+ return [{"type": "plain", "text": text}] if text else []
+
+ if not isinstance(message_payload, list):
+ if strict:
+ raise ValueError("message must be a string or list")
+ return []
+
+ message_parts: list[dict] = []
+ for part in message_payload:
+ if not isinstance(part, dict):
+ if strict:
+ raise ValueError("message part must be an object")
+ continue
+
+ part_type = str(part.get("type", "")).strip()
+ if part_type == "plain":
+ text = str(part.get("text", ""))
+ if text:
+ message_parts.append({"type": "plain", "text": text})
+ continue
+
+ if part_type == "reply":
+ message_id = part.get("message_id")
+ if message_id is None:
+ if strict:
+ raise ValueError("reply part missing message_id")
+ continue
+ message_parts.append(
+ {
+ "type": "reply",
+ "message_id": message_id,
+ "selected_text": str(part.get("selected_text", "")),
+ }
+ )
+ continue
+
+ if part_type not in MEDIA_PART_TYPES:
+ if strict:
+ raise ValueError(f"unsupported message part type: {part_type}")
+ continue
+
+ attachment_id = part.get("attachment_id")
+ if not attachment_id:
+ if strict:
+ raise ValueError(f"{part_type} part missing attachment_id")
+ continue
+
+ attachment = await get_attachment_by_id(str(attachment_id))
+ if not attachment:
+ if strict:
+ raise ValueError(f"attachment not found: {attachment_id}")
+ continue
+
+ attachment_path = Path(attachment.path)
+ message_parts.append(
+ {
+ "type": attachment.type,
+ "attachment_id": attachment.attachment_id,
+ "filename": attachment_path.name,
+ "path": str(attachment_path),
+ }
+ )
+
+ return message_parts
+
+
+def webchat_message_parts_to_message_chain(
+ message_parts: list[dict],
+ *,
+ strict: bool = False,
+) -> MessageChain:
+ components = []
+ has_content = False
+
+ for part in message_parts:
+ if not isinstance(part, dict):
+ if strict:
+ raise ValueError("message part must be an object")
+ continue
+
+ part_type = str(part.get("type", "")).strip()
+ if part_type == "plain":
+ text = str(part.get("text", ""))
+ if text:
+ components.append(Plain(text=text))
+ has_content = True
+ continue
+
+ if part_type == "reply":
+ message_id = part.get("message_id")
+ if message_id is None:
+ if strict:
+ raise ValueError("reply part missing message_id")
+ continue
+ components.append(
+ Reply(
+ id=str(message_id),
+ message_str=str(part.get("selected_text", "")),
+ chain=[],
+ )
+ )
+ continue
+
+ if part_type not in MEDIA_PART_TYPES:
+ if strict:
+ raise ValueError(f"unsupported message part type: {part_type}")
+ continue
+
+ path = part.get("path")
+ if not path:
+ if strict:
+ raise ValueError(f"{part_type} part missing path")
+ continue
+
+ file_path = Path(str(path))
+ if not file_path.exists():
+ if strict:
+ raise ValueError(f"file not found: {file_path!s}")
+ continue
+
+ file_path_str = str(file_path.resolve())
+ has_content = True
+ if part_type == "image":
+ components.append(Image.fromFileSystem(file_path_str))
+ elif part_type == "record":
+ components.append(Record.fromFileSystem(file_path_str))
+ elif part_type == "video":
+ components.append(Video.fromFileSystem(file_path_str))
+ else:
+ filename = str(part.get("filename", "")).strip() or file_path.name
+ components.append(File(name=filename, file=file_path_str))
+
+ if strict and (not components or not has_content):
+ raise ValueError("Message content is empty (reply only is not allowed)")
+
+ return MessageChain(chain=components)
+
+
+async def build_message_chain_from_payload(
+ message_payload: str | list,
+ *,
+ get_attachment_by_id: AttachmentGetter,
+ strict: bool = True,
+) -> MessageChain:
+ message_parts = await build_webchat_message_parts(
+ message_payload,
+ get_attachment_by_id=get_attachment_by_id,
+ strict=strict,
+ )
+ components, _, has_content = await parse_webchat_message_parts(
+ message_parts,
+ strict=strict,
+ )
+ if strict and (not components or not has_content):
+ raise ValueError("Message content is empty (reply only is not allowed)")
+ return MessageChain(chain=components)
+
+
+async def create_attachment_part_from_existing_file(
+ filename: str,
+ *,
+ attach_type: str,
+ insert_attachment: AttachmentInserter,
+ attachments_dir: str | Path,
+ fallback_dirs: Sequence[str | Path] = (),
+) -> dict | None:
+ basename = Path(filename).name
+ candidate_paths = [Path(attachments_dir) / basename]
+ candidate_paths.extend(Path(p) / basename for p in fallback_dirs)
+
+ file_path = next((path for path in candidate_paths if path.exists()), None)
+ if not file_path:
+ return None
+
+ mime_type, _ = mimetypes.guess_type(str(file_path))
+ attachment = await insert_attachment(
+ str(file_path),
+ attach_type,
+ mime_type or "application/octet-stream",
+ )
+ if not attachment:
+ return None
+
+ return {
+ "type": attach_type,
+ "attachment_id": attachment.attachment_id,
+ "filename": file_path.name,
+ }
+
+
+async def message_chain_to_storage_message_parts(
+ message_chain: MessageChain,
+ *,
+ insert_attachment: AttachmentInserter,
+ attachments_dir: str | Path,
+) -> list[dict]:
+ target_dir = Path(attachments_dir)
+ target_dir.mkdir(parents=True, exist_ok=True)
+
+ parts: list[dict] = []
+ for comp in message_chain.chain:
+ if isinstance(comp, Plain):
+ if comp.text:
+ parts.append({"type": "plain", "text": comp.text})
+ continue
+
+ if isinstance(comp, Json):
+ parts.append(
+ {"type": "plain", "text": json.dumps(comp.data, ensure_ascii=False)}
+ )
+ continue
+
+ if isinstance(comp, Image):
+ file_path = await comp.convert_to_file_path()
+ attachment_part = await _copy_file_to_attachment_part(
+ file_path=file_path,
+ attach_type="image",
+ insert_attachment=insert_attachment,
+ attachments_dir=target_dir,
+ )
+ if attachment_part:
+ parts.append(attachment_part)
+ continue
+
+ if isinstance(comp, Record):
+ file_path = await comp.convert_to_file_path()
+ attachment_part = await _copy_file_to_attachment_part(
+ file_path=file_path,
+ attach_type="record",
+ insert_attachment=insert_attachment,
+ attachments_dir=target_dir,
+ )
+ if attachment_part:
+ parts.append(attachment_part)
+ continue
+
+ if isinstance(comp, Video):
+ file_path = await comp.convert_to_file_path()
+ attachment_part = await _copy_file_to_attachment_part(
+ file_path=file_path,
+ attach_type="video",
+ insert_attachment=insert_attachment,
+ attachments_dir=target_dir,
+ )
+ if attachment_part:
+ parts.append(attachment_part)
+ continue
+
+ if isinstance(comp, File):
+ file_path = await comp.get_file()
+ attachment_part = await _copy_file_to_attachment_part(
+ file_path=file_path,
+ attach_type="file",
+ insert_attachment=insert_attachment,
+ attachments_dir=target_dir,
+ display_name=comp.name,
+ )
+ if attachment_part:
+ parts.append(attachment_part)
+ continue
+
+ return parts
+
+
+async def _copy_file_to_attachment_part(
+ *,
+ file_path: str,
+ attach_type: str,
+ insert_attachment: AttachmentInserter,
+ attachments_dir: Path,
+ display_name: str | None = None,
+) -> dict | None:
+ src_path = Path(file_path)
+ if not src_path.exists() or not src_path.is_file():
+ return None
+
+ suffix = src_path.suffix
+ target_path = attachments_dir / f"{uuid.uuid4().hex}{suffix}"
+ shutil.copy2(src_path, target_path)
+
+ mime_type, _ = mimetypes.guess_type(target_path.name)
+ attachment = await insert_attachment(
+ str(target_path),
+ attach_type,
+ mime_type or "application/octet-stream",
+ )
+ if not attachment:
+ return None
+
+ return {
+ "type": attach_type,
+ "attachment_id": attachment.attachment_id,
+ "filename": display_name or src_path.name,
+ }
diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py
index 047417aaa..54718fefb 100644
--- a/astrbot/core/platform/sources/webchat/webchat_adapter.py
+++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py
@@ -3,12 +3,12 @@
import time
import uuid
from collections.abc import Callable, Coroutine
+from pathlib import Path
from typing import Any
from astrbot import logger
from astrbot.core import db_helper
from astrbot.core.db.po import PlatformMessageHistory
-from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import (
AstrBotMessage,
@@ -21,10 +21,23 @@
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ...register import register_platform_adapter
+from .message_parts_helper import (
+ message_chain_to_storage_message_parts,
+ parse_webchat_message_parts,
+)
from .webchat_event import WebChatMessageEvent
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
+def _extract_conversation_id(session_id: str) -> str:
+ """Extract raw webchat conversation id from event/session id."""
+ if session_id.startswith("webchat!"):
+ parts = session_id.split("!", 2)
+ if len(parts) == 3:
+ return parts[2]
+ return session_id
+
+
class QueueListener:
def __init__(
self,
@@ -57,13 +70,15 @@ def __init__(
self.settings = platform_settings
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
+ self.attachments_dir = Path(get_astrbot_data_path()) / "attachments"
os.makedirs(self.imgs_dir, exist_ok=True)
+ self.attachments_dir.mkdir(parents=True, exist_ok=True)
self.metadata = PlatformMetadata(
name="webchat",
description="webchat",
id="webchat",
- support_proactive_message=False,
+ support_proactive_message=True,
)
self._shutdown_event = asyncio.Event()
self._webchat_queue_mgr = webchat_queue_mgr
@@ -73,10 +88,67 @@ async def send_by_session(
session: MessageSesion,
message_chain: MessageChain,
) -> None:
- message_id = f"active_{str(uuid.uuid4())}"
- await WebChatMessageEvent._send(message_id, message_chain, session.session_id)
+ conversation_id = _extract_conversation_id(session.session_id)
+ active_request_ids = self._webchat_queue_mgr.list_back_request_ids(
+ conversation_id
+ )
+ subscription_request_ids = [
+ req_id for req_id in active_request_ids if req_id.startswith("ws_sub_")
+ ]
+ target_request_ids = subscription_request_ids or active_request_ids
+
+ if target_request_ids:
+ for request_id in target_request_ids:
+ await WebChatMessageEvent._send(
+ request_id,
+ message_chain,
+ session.session_id,
+ )
+ else:
+ message_id = f"active_{uuid.uuid4()!s}"
+ await WebChatMessageEvent._send(
+ message_id,
+ message_chain,
+ session.session_id,
+ )
+
+ should_persist = (
+ bool(subscription_request_ids)
+ or not active_request_ids
+ or all(req_id.startswith("active_") for req_id in active_request_ids)
+ )
+ if should_persist:
+ try:
+ await self._save_proactive_message(conversation_id, message_chain)
+ except Exception as e:
+ logger.error(
+ f"[WebChatAdapter] Failed to save proactive message: {e}",
+ exc_info=True,
+ )
+
await super().send_by_session(session, message_chain)
+ async def _save_proactive_message(
+ self,
+ conversation_id: str,
+ message_chain: MessageChain,
+ ) -> None:
+ message_parts = await message_chain_to_storage_message_parts(
+ message_chain,
+ insert_attachment=db_helper.insert_attachment,
+ attachments_dir=self.attachments_dir,
+ )
+ if not message_parts:
+ return
+
+ await db_helper.insert_platform_message_history(
+ platform_id="webchat",
+ user_id=conversation_id,
+ content={"type": "bot", "message": message_parts},
+ sender_id="bot",
+ sender_name="bot",
+ )
+
async def _get_message_history(
self, message_id: int
) -> PlatformMessageHistory | None:
@@ -98,72 +170,30 @@ async def _parse_message_parts(
Returns:
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
"""
- components = []
- text_parts = []
-
- for part in message_parts:
- part_type = part.get("type")
- if part_type == "plain":
- text = part.get("text", "")
- components.append(Plain(text=text))
- text_parts.append(text)
- elif part_type == "reply":
- message_id = part.get("message_id")
- reply_chain = []
- reply_message_str = part.get("selected_text", "")
- sender_id = None
- sender_name = None
-
- if reply_message_str:
- reply_chain = [Plain(text=reply_message_str)]
-
- # recursively get the content of the referenced message, if selected_text is empty
- if not reply_message_str and depth < max_depth and message_id:
- history = await self._get_message_history(message_id)
- if history and history.content:
- reply_parts = history.content.get("message", [])
- if isinstance(reply_parts, list):
- (
- reply_chain,
- reply_text_parts,
- ) = await self._parse_message_parts(
- reply_parts,
- depth=depth + 1,
- max_depth=max_depth,
- )
- reply_message_str = "".join(reply_text_parts)
- sender_id = history.sender_id
- sender_name = history.sender_name
-
- components.append(
- Reply(
- id=message_id,
- chain=reply_chain,
- message_str=reply_message_str,
- sender_id=sender_id,
- sender_nickname=sender_name,
- )
- )
- elif part_type == "image":
- path = part.get("path")
- if path:
- components.append(Image.fromFileSystem(path))
- elif part_type == "record":
- path = part.get("path")
- if path:
- components.append(Record.fromFileSystem(path))
- elif part_type == "file":
- path = part.get("path")
- if path:
- filename = part.get("filename") or (
- os.path.basename(path) if path else "file"
- )
- components.append(File(name=filename, file=path))
- elif part_type == "video":
- path = part.get("path")
- if path:
- components.append(Video.fromFileSystem(path))
+ async def get_reply_parts(
+ message_id: Any,
+ ) -> tuple[list[dict], str | None, str | None] | None:
+ history = await self._get_message_history(message_id)
+ if not history or not history.content:
+ return None
+
+ reply_parts = history.content.get("message", [])
+ if not isinstance(reply_parts, list):
+ return None
+
+ return reply_parts, history.sender_id, history.sender_name
+
+ components, text_parts, _ = await parse_webchat_message_parts(
+ message_parts,
+ strict=False,
+ include_empty_plain=True,
+ verify_media_path_exists=False,
+ reply_history_getter=get_reply_parts,
+ current_depth=depth,
+ max_reply_depth=max_depth,
+ cast_reply_id_to_str=False,
+ )
return components, text_parts
async def convert_message(self, data: tuple) -> AstrBotMessage:
diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py
index a680f7617..b7da864aa 100644
--- a/astrbot/core/platform/sources/webchat/webchat_event.py
+++ b/astrbot/core/platform/sources/webchat/webchat_event.py
@@ -14,6 +14,15 @@
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
+def _extract_conversation_id(session_id: str) -> str:
+ """Extract raw webchat conversation id from event/session id."""
+ if session_id.startswith("webchat!"):
+ parts = session_id.split("!", 2)
+ if len(parts) == 3:
+ return parts[2]
+ return session_id
+
+
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
@@ -27,7 +36,7 @@ async def _send(
streaming: bool = False,
) -> str | None:
request_id = str(message_id)
- conversation_id = session_id.split("!")[-1]
+ conversation_id = _extract_conversation_id(session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
@@ -130,7 +139,7 @@ async def send_streaming(self, generator, use_fallback: bool = False) -> None:
reasoning_content = ""
message_id = self.message_obj.message_id
request_id = str(message_id)
- conversation_id = self.session_id.split("!")[-1]
+ conversation_id = _extract_conversation_id(self.session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
index fd35e837c..f3ade1589 100644
--- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
+++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
@@ -75,6 +75,10 @@ def remove_queue(self, conversation_id: str):
if task is not None:
task.cancel()
+ def list_back_request_ids(self, conversation_id: str) -> list[str]:
+ """List active back-queue request IDs for a conversation."""
+ return list(self._conversation_back_requests.get(conversation_id, set()))
+
def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index 1235dd381..0602cc074 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -1,6 +1,5 @@
import asyncio
import json
-import mimetypes
import os
import re
import uuid
@@ -14,6 +13,12 @@
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_type import MessageType
+from astrbot.core.platform.sources.webchat.message_parts_helper import (
+ build_webchat_message_parts,
+ create_attachment_part_from_existing_file,
+ strip_message_parts_path_fields,
+ webchat_message_parts_have_content,
+)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.active_event_registry import active_event_registry
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -166,83 +171,24 @@ async def post_file(self):
)
async def _build_user_message_parts(self, message: str | list) -> list[dict]:
- """构建用户消息的部分列表
-
- Args:
- message: 文本消息 (str) 或消息段列表 (list)
- """
- parts = []
-
- if isinstance(message, list):
- for part in message:
- part_type = part.get("type")
- if part_type == "plain":
- parts.append({"type": "plain", "text": part.get("text", "")})
- elif part_type == "reply":
- parts.append(
- {
- "type": "reply",
- "message_id": part.get("message_id"),
- "selected_text": part.get("selected_text", ""),
- }
- )
- elif attachment_id := part.get("attachment_id"):
- attachment = await self.db.get_attachment_by_id(attachment_id)
- if attachment:
- parts.append(
- {
- "type": attachment.type,
- "attachment_id": attachment.attachment_id,
- "filename": os.path.basename(attachment.path),
- "path": attachment.path, # will be deleted
- }
- )
- return parts
-
- if message:
- parts.append({"type": "plain", "text": message})
-
- return parts
+ """构建用户消息的部分列表。"""
+ return await build_webchat_message_parts(
+ message,
+ get_attachment_by_id=self.db.get_attachment_by_id,
+ strict=False,
+ )
async def _create_attachment_from_file(
self, filename: str, attach_type: str
) -> dict | None:
- """从本地文件创建 attachment 并返回消息部分
-
- 用于处理 bot 回复中的媒体文件
-
- Args:
- filename: 存储的文件名
- attach_type: 附件类型 (image, record, file, video)
- """
- basename = os.path.basename(filename)
- candidate_paths = [
- os.path.join(self.attachments_dir, basename),
- os.path.join(self.legacy_img_dir, basename),
- ]
- file_path = next((p for p in candidate_paths if os.path.exists(p)), None)
- if not file_path:
- return None
-
- # guess mime type
- mime_type, _ = mimetypes.guess_type(filename)
- if not mime_type:
- mime_type = "application/octet-stream"
-
- # insert attachment
- attachment = await self.db.insert_attachment(
- path=file_path,
- type=attach_type,
- mime_type=mime_type,
+ """从本地文件创建 attachment 并返回消息部分。"""
+ return await create_attachment_part_from_existing_file(
+ filename,
+ attach_type=attach_type,
+ insert_attachment=self.db.insert_attachment,
+ attachments_dir=self.attachments_dir,
+ fallback_dirs=[self.legacy_img_dir],
)
- if not attachment:
- return None
-
- return {
- "type": attach_type,
- "attachment_id": attachment.attachment_id,
- "filename": os.path.basename(file_path),
- }
def _extract_web_search_refs(
self, accumulated_text: str, accumulated_parts: list
@@ -356,21 +302,6 @@ async def chat(self, post_data: dict | None = None):
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
- # 检查消息是否为空
- if isinstance(message, list):
- has_content = any(
- part.get("type") in ("plain", "image", "record", "file", "video")
- for part in message
- )
- if not has_content:
- return (
- Response()
- .error("Message content is empty (reply only is not allowed)")
- .__dict__
- )
- elif not message:
- return Response().error("Message are both empty").__dict__
-
if not session_id:
return Response().error("session_id is empty").__dict__
@@ -378,6 +309,12 @@ async def chat(self, post_data: dict | None = None):
# 构建用户消息段(包含 path 用于传递给 adapter)
message_parts = await self._build_user_message_parts(message)
+ if not webchat_message_parts_have_content(message_parts):
+ return (
+ Response()
+ .error("Message content is empty (reply only is not allowed)")
+ .__dict__
+ )
message_id = str(uuid.uuid4())
back_queue = webchat_queue_mgr.get_or_create_back_queue(
@@ -583,10 +520,7 @@ async def stream():
),
)
- message_parts_for_storage = []
- for part in message_parts:
- part_copy = {k: v for k, v in part.items() if k != "path"}
- message_parts_for_storage.append(part_copy)
+ message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.platform_history_mgr.insert(
platform_id="webchat",
diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py
index 8c922ab69..25438565e 100644
--- a/astrbot/dashboard/routes/live_chat.py
+++ b/astrbot/dashboard/routes/live_chat.py
@@ -1,6 +1,7 @@
import asyncio
import json
import os
+import re
import time
import uuid
import wave
@@ -10,9 +11,16 @@
from quart import websocket
from astrbot import logger
+from astrbot.core import sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
+from astrbot.core.platform.sources.webchat.message_parts_helper import (
+ build_webchat_message_parts,
+ create_attachment_part_from_existing_file,
+ strip_message_parts_path_fields,
+ webchat_message_parts_have_content,
+)
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
-from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
+from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from .route import Route, RouteContext
@@ -30,6 +38,9 @@ def __init__(self, session_id: str, username: str) -> None:
self.audio_frames: list[bytes] = []
self.current_stamp: str | None = None
self.temp_audio_path: str | None = None
+ self.chat_subscriptions: dict[str, str] = {}
+ self.chat_subscription_tasks: dict[str, asyncio.Task] = {}
+ self.ws_send_lock = asyncio.Lock()
def start_speaking(self, stamp: str) -> None:
"""开始说话"""
@@ -106,13 +117,26 @@ def __init__(
self.core_lifecycle = core_lifecycle
self.db = db
self.plugin_manager = core_lifecycle.plugin_manager
+ self.platform_history_mgr = core_lifecycle.platform_message_history_manager
self.sessions: dict[str, LiveChatSession] = {}
+ self.attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")
+ self.legacy_img_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
+ os.makedirs(self.attachments_dir, exist_ok=True)
# 注册 WebSocket 路由
self.app.websocket("/api/live_chat/ws")(self.live_chat_ws)
+ self.app.websocket("/api/unified_chat/ws")(self.unified_chat_ws)
async def live_chat_ws(self) -> None:
- """Live Chat WebSocket 处理器"""
+ """Legacy Live Chat WebSocket 处理器(默认 ct=live)"""
+ await self._unified_ws_loop(force_ct="live")
+
+ async def unified_chat_ws(self) -> None:
+ """Unified Chat WebSocket 处理器(支持 ct=live/chat)"""
+ await self._unified_ws_loop(force_ct=None)
+
+ async def _unified_ws_loop(self, force_ct: str | None = None) -> None:
+ """统一 WebSocket 循环"""
# WebSocket 不能通过 header 传递 token,需要从 query 参数获取
# 注意:WebSocket 上下文使用 websocket.args 而不是 request.args
token = websocket.args.get("token")
@@ -140,7 +164,11 @@ async def live_chat_ws(self) -> None:
try:
while True:
message = await websocket.receive_json()
- await self._handle_message(live_session, message)
+ ct = force_ct or message.get("ct", "live")
+ if ct == "chat":
+ await self._handle_chat_message(live_session, message)
+ else:
+ await self._handle_message(live_session, message)
except Exception as e:
logger.error(f"[Live Chat] WebSocket 错误: {e}", exc_info=True)
@@ -148,10 +176,488 @@ async def live_chat_ws(self) -> None:
finally:
# 清理会话
if session_id in self.sessions:
+ await self._cleanup_chat_subscriptions(live_session)
live_session.cleanup()
del self.sessions[session_id]
logger.info(f"[Live Chat] WebSocket 连接关闭: {username}")
+ async def _create_attachment_from_file(
+ self, filename: str, attach_type: str
+ ) -> dict | None:
+ """从本地文件创建 attachment 并返回消息部分。"""
+ return await create_attachment_part_from_existing_file(
+ filename,
+ attach_type=attach_type,
+ insert_attachment=self.db.insert_attachment,
+ attachments_dir=self.attachments_dir,
+ fallback_dirs=[self.legacy_img_dir],
+ )
+
+ def _extract_web_search_refs(
+ self, accumulated_text: str, accumulated_parts: list
+ ) -> dict:
+ """从消息中提取 web_search 引用。"""
+ supported = ["web_search_tavily", "web_search_bocha"]
+ web_search_results = {}
+ tool_call_parts = [
+ p
+ for p in accumulated_parts
+ if p.get("type") == "tool_call" and p.get("tool_calls")
+ ]
+
+ for part in tool_call_parts:
+ for tool_call in part["tool_calls"]:
+ if tool_call.get("name") not in supported or not tool_call.get(
+ "result"
+ ):
+ continue
+ try:
+ result_data = json.loads(tool_call["result"])
+ for item in result_data.get("results", []):
+ if idx := item.get("index"):
+ web_search_results[idx] = {
+ "url": item.get("url"),
+ "title": item.get("title"),
+ "snippet": item.get("snippet"),
+ }
+ except (json.JSONDecodeError, KeyError):
+ pass
+
+ if not web_search_results:
+ return {}
+
+ ref_indices = {
+ m.strip() for m in re.findall(r"[(.*?)]", accumulated_text)
+ }
+
+ used_refs = []
+ for ref_index in ref_indices:
+ if ref_index not in web_search_results:
+ continue
+ payload = {"index": ref_index, **web_search_results[ref_index]}
+ if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
+ payload["favicon"] = favicon
+ used_refs.append(payload)
+
+ return {"used": used_refs} if used_refs else {}
+
+ async def _save_bot_message(
+ self,
+ webchat_conv_id: str,
+ text: str,
+ media_parts: list,
+ reasoning: str,
+ agent_stats: dict,
+ refs: dict,
+ ):
+ """保存 bot 消息到历史记录。"""
+ bot_message_parts = []
+ bot_message_parts.extend(media_parts)
+ if text:
+ bot_message_parts.append({"type": "plain", "text": text})
+
+ new_his = {"type": "bot", "message": bot_message_parts}
+ if reasoning:
+ new_his["reasoning"] = reasoning
+ if agent_stats:
+ new_his["agent_stats"] = agent_stats
+ if refs:
+ new_his["refs"] = refs
+
+ return await self.platform_history_mgr.insert(
+ platform_id="webchat",
+ user_id=webchat_conv_id,
+ content=new_his,
+ sender_id="bot",
+ sender_name="bot",
+ )
+
+ async def _send_chat_payload(self, session: LiveChatSession, payload: dict) -> None:
+ async with session.ws_send_lock:
+ await websocket.send_json(payload)
+
+ async def _forward_chat_subscription(
+ self,
+ session: LiveChatSession,
+ chat_session_id: str,
+ request_id: str,
+ ) -> None:
+ back_queue = webchat_queue_mgr.get_or_create_back_queue(
+ request_id, chat_session_id
+ )
+ try:
+ while True:
+ result = await back_queue.get()
+ if not result:
+ continue
+ await self._send_chat_payload(session, {"ct": "chat", **result})
+ except asyncio.CancelledError:
+ pass
+ except Exception as e:
+ logger.error(
+ f"[Live Chat] chat subscription forward failed ({chat_session_id}): {e}",
+ exc_info=True,
+ )
+ finally:
+ webchat_queue_mgr.remove_back_queue(request_id)
+ if session.chat_subscriptions.get(chat_session_id) == request_id:
+ session.chat_subscriptions.pop(chat_session_id, None)
+ session.chat_subscription_tasks.pop(chat_session_id, None)
+
+ async def _ensure_chat_subscription(
+ self,
+ session: LiveChatSession,
+ chat_session_id: str,
+ ) -> str:
+ existing_request_id = session.chat_subscriptions.get(chat_session_id)
+ existing_task = session.chat_subscription_tasks.get(chat_session_id)
+ if existing_request_id and existing_task and not existing_task.done():
+ return existing_request_id
+
+ request_id = f"ws_sub_{uuid.uuid4().hex}"
+ session.chat_subscriptions[chat_session_id] = request_id
+ task = asyncio.create_task(
+ self._forward_chat_subscription(session, chat_session_id, request_id),
+ name=f"chat_ws_sub_{chat_session_id}",
+ )
+ session.chat_subscription_tasks[chat_session_id] = task
+ return request_id
+
+ async def _cleanup_chat_subscriptions(self, session: LiveChatSession) -> None:
+ tasks = list(session.chat_subscription_tasks.values())
+ for task in tasks:
+ task.cancel()
+ if tasks:
+ await asyncio.gather(*tasks, return_exceptions=True)
+
+ for request_id in list(session.chat_subscriptions.values()):
+ webchat_queue_mgr.remove_back_queue(request_id)
+ session.chat_subscriptions.clear()
+ session.chat_subscription_tasks.clear()
+
+ async def _handle_chat_message(
+ self, session: LiveChatSession, message: dict
+ ) -> None:
+ """处理 Chat Mode 消息(ct=chat)"""
+ msg_type = message.get("t")
+
+ if msg_type == "bind":
+ chat_session_id = message.get("session_id")
+ if not isinstance(chat_session_id, str) or not chat_session_id:
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "session_id is required",
+ "code": "INVALID_MESSAGE_FORMAT",
+ },
+ )
+ return
+
+ request_id = await self._ensure_chat_subscription(session, chat_session_id)
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "type": "session_bound",
+ "session_id": chat_session_id,
+ "message_id": request_id,
+ },
+ )
+ return
+
+ if msg_type == "interrupt":
+ session.should_interrupt = True
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "INTERRUPTED",
+ "code": "INTERRUPTED",
+ },
+ )
+ return
+
+ if msg_type != "send":
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": f"Unsupported message type: {msg_type}",
+ "code": "INVALID_MESSAGE_FORMAT",
+ },
+ )
+ return
+
+ if session.is_processing:
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "Session is busy",
+ "code": "PROCESSING_ERROR",
+ },
+ )
+ return
+
+ payload = message.get("message")
+ session_id = message.get("session_id") or session.session_id
+ message_id = message.get("message_id") or str(uuid.uuid4())
+ selected_provider = message.get("selected_provider")
+ selected_model = message.get("selected_model")
+ selected_stt_provider = message.get("selected_stt_provider")
+ selected_tts_provider = message.get("selected_tts_provider")
+ persona_prompt = message.get("persona_prompt")
+ show_reasoning = message.get("show_reasoning")
+ enable_streaming = message.get("enable_streaming", True)
+
+ if not isinstance(payload, list):
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "message must be list",
+ "code": "INVALID_MESSAGE_FORMAT",
+ },
+ )
+ return
+
+ message_parts = await self._build_chat_message_parts(payload)
+ has_content = webchat_message_parts_have_content(message_parts)
+ if not has_content:
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": "Message content is empty",
+ "code": "INVALID_MESSAGE_FORMAT",
+ },
+ )
+ return
+
+ await self._ensure_chat_subscription(session, session_id)
+
+ session.is_processing = True
+ session.should_interrupt = False
+ back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
+
+ try:
+ chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
+ await chat_queue.put(
+ (
+ session.username,
+ session_id,
+ {
+ "message": message_parts,
+ "selected_provider": selected_provider,
+ "selected_model": selected_model,
+ "selected_stt_provider": selected_stt_provider,
+ "selected_tts_provider": selected_tts_provider,
+ "persona_prompt": persona_prompt,
+ "show_reasoning": show_reasoning,
+ "enable_streaming": enable_streaming,
+ "message_id": message_id,
+ },
+ ),
+ )
+
+ message_parts_for_storage = strip_message_parts_path_fields(message_parts)
+ await self.platform_history_mgr.insert(
+ platform_id="webchat",
+ user_id=session_id,
+ content={"type": "user", "message": message_parts_for_storage},
+ sender_id=session.username,
+ sender_name=session.username,
+ )
+
+ accumulated_parts = []
+ accumulated_text = ""
+ accumulated_reasoning = ""
+ tool_calls = {}
+ agent_stats = {}
+ refs = {}
+
+ while True:
+ if session.should_interrupt:
+ session.should_interrupt = False
+ break
+
+ try:
+ result = await asyncio.wait_for(back_queue.get(), timeout=1)
+ except asyncio.TimeoutError:
+ continue
+
+ if not result:
+ continue
+ if result.get("message_id") and result.get("message_id") != message_id:
+ continue
+
+ result_text = result.get("data", "")
+ msg_type = result.get("type")
+ streaming = result.get("streaming", False)
+ chain_type = result.get("chain_type")
+ if chain_type == "agent_stats":
+ try:
+ parsed_agent_stats = json.loads(result_text)
+ agent_stats = parsed_agent_stats
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "type": "agent_stats",
+ "data": parsed_agent_stats,
+ },
+ )
+ except Exception:
+ pass
+ continue
+
+ outgoing = {"ct": "chat", **result}
+ await self._send_chat_payload(session, outgoing)
+
+ if msg_type == "plain":
+ if chain_type == "tool_call":
+ try:
+ tool_call = json.loads(result_text)
+ tool_calls[tool_call.get("id")] = tool_call
+ if accumulated_text:
+ accumulated_parts.append(
+ {"type": "plain", "text": accumulated_text}
+ )
+ accumulated_text = ""
+ except Exception:
+ pass
+ elif chain_type == "tool_call_result":
+ try:
+ tcr = json.loads(result_text)
+ tc_id = tcr.get("id")
+ if tc_id in tool_calls:
+ tool_calls[tc_id]["result"] = tcr.get("result")
+ tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
+ accumulated_parts.append(
+ {
+ "type": "tool_call",
+ "tool_calls": [tool_calls[tc_id]],
+ }
+ )
+ tool_calls.pop(tc_id, None)
+ except Exception:
+ pass
+ elif chain_type == "reasoning":
+ accumulated_reasoning += result_text
+ elif streaming:
+ accumulated_text += result_text
+ else:
+ accumulated_text = result_text
+ elif msg_type == "image":
+ filename = str(result_text).replace("[IMAGE]", "")
+ part = await self._create_attachment_from_file(filename, "image")
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "record":
+ filename = str(result_text).replace("[RECORD]", "")
+ part = await self._create_attachment_from_file(filename, "record")
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "file":
+ filename = str(result_text).replace("[FILE]", "").split("|", 1)[0]
+ part = await self._create_attachment_from_file(filename, "file")
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "video":
+ filename = str(result_text).replace("[VIDEO]", "").split("|", 1)[0]
+ part = await self._create_attachment_from_file(filename, "video")
+ if part:
+ accumulated_parts.append(part)
+
+ should_save = False
+ if msg_type == "end":
+ should_save = bool(
+ accumulated_parts
+ or accumulated_text
+ or accumulated_reasoning
+ or refs
+ or agent_stats
+ )
+ elif (streaming and msg_type == "complete") or not streaming:
+ if chain_type not in (
+ "tool_call",
+ "tool_call_result",
+ "agent_stats",
+ ):
+ should_save = True
+
+ if should_save:
+ try:
+ refs = self._extract_web_search_refs(
+ accumulated_text,
+ accumulated_parts,
+ )
+ except Exception as e:
+ logger.exception(
+ f"[Live Chat] Failed to extract web search refs: {e}",
+ exc_info=True,
+ )
+
+ saved_record = await self._save_bot_message(
+ session_id,
+ accumulated_text,
+ accumulated_parts,
+ accumulated_reasoning,
+ agent_stats,
+ refs,
+ )
+ if saved_record:
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "type": "message_saved",
+ "data": {
+ "id": saved_record.id,
+ "created_at": saved_record.created_at.astimezone().isoformat(),
+ },
+ },
+ )
+
+ accumulated_parts = []
+ accumulated_text = ""
+ accumulated_reasoning = ""
+ agent_stats = {}
+ refs = {}
+
+ if msg_type == "end":
+ break
+
+ except Exception as e:
+ logger.error(f"[Live Chat] 处理 chat 消息失败: {e}", exc_info=True)
+ await self._send_chat_payload(
+ session,
+ {
+ "ct": "chat",
+ "t": "error",
+ "data": f"处理失败: {str(e)}",
+ "code": "PROCESSING_ERROR",
+ },
+ )
+ finally:
+ session.is_processing = False
+ webchat_queue_mgr.remove_back_queue(message_id)
+
+ async def _build_chat_message_parts(self, message: list[dict]) -> list[dict]:
+ """构建 chat websocket 用户消息段(复用 webchat 逻辑)"""
+ return await build_webchat_message_parts(
+ message,
+ get_attachment_by_id=self.db.get_attachment_by_id,
+ strict=False,
+ )
+
async def _handle_message(self, session: LiveChatSession, message: dict) -> None:
"""处理 WebSocket 消息"""
msg_type = message.get("t") # 使用 t 代替 type
diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py
index c25870ebb..653e22cbf 100644
--- a/astrbot/dashboard/routes/open_api.py
+++ b/astrbot/dashboard/routes/open_api.py
@@ -1,15 +1,22 @@
-from pathlib import Path
+import asyncio
+import hashlib
+import json
from uuid import uuid4
-from quart import g, request
+from quart import g, request, websocket
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
-from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
-from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.message_session import MessageSesion
-
+from astrbot.core.platform.sources.webchat.message_parts_helper import (
+ build_message_chain_from_payload,
+ strip_message_parts_path_fields,
+ webchat_message_parts_have_content,
+)
+from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
+
+from .api_key import ALL_OPEN_API_SCOPES
from .chat import ChatRoute
from .route import Response, Route, RouteContext
@@ -37,6 +44,7 @@ def __init__(
"/v1/im/bots": ("GET", self.get_bots),
}
self.register_routes()
+ self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
@staticmethod
def _resolve_open_username(
@@ -181,6 +189,348 @@ async def chat_send(self):
finally:
g.username = original_username
+ @staticmethod
+ def _extract_ws_api_key() -> str | None:
+ if key := websocket.args.get("api_key"):
+ return key.strip()
+ if key := websocket.args.get("key"):
+ return key.strip()
+ if key := websocket.headers.get("X-API-Key"):
+ return key.strip()
+
+ auth_header = websocket.headers.get("Authorization", "").strip()
+ if auth_header.startswith("Bearer "):
+ return auth_header.removeprefix("Bearer ").strip()
+ if auth_header.startswith("ApiKey "):
+ return auth_header.removeprefix("ApiKey ").strip()
+ return None
+
+ async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]:
+ raw_key = self._extract_ws_api_key()
+ if not raw_key:
+ return False, "Missing API key"
+
+ key_hash = hashlib.pbkdf2_hmac(
+ "sha256",
+ raw_key.encode("utf-8"),
+ b"astrbot_api_key",
+ 100_000,
+ ).hex()
+ api_key = await self.db.get_active_api_key_by_hash(key_hash)
+ if not api_key:
+ return False, "Invalid API key"
+
+ if isinstance(api_key.scopes, list):
+ scopes = api_key.scopes
+ else:
+ scopes = list(ALL_OPEN_API_SCOPES)
+
+ if "*" not in scopes and "chat" not in scopes:
+ return False, "Insufficient API key scope"
+
+ await self.db.touch_api_key(api_key.key_id)
+ return True, None
+
+ async def _send_chat_ws_error(self, message: str, code: str) -> None:
+ await websocket.send_json(
+ {
+ "type": "error",
+ "code": code,
+ "data": message,
+ }
+ )
+
+ async def _update_session_config_route(
+ self,
+ *,
+ username: str,
+ session_id: str,
+ config_id: str | None,
+ ) -> str | None:
+ if not config_id:
+ return None
+
+ umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
+ try:
+ if config_id == "default":
+ await self.core_lifecycle.umop_config_router.delete_route(umo)
+ else:
+ await self.core_lifecycle.umop_config_router.update_route(
+ umo, config_id
+ )
+ except Exception as e:
+ logger.error(
+ "Failed to update chat config route for %s with %s: %s",
+ umo,
+ config_id,
+ e,
+ exc_info=True,
+ )
+ return f"Failed to update chat config route: {e}"
+ return None
+
+ async def _handle_chat_ws_send(self, post_data: dict) -> None:
+ effective_username, username_err = self._resolve_open_username(
+ post_data.get("username")
+ )
+ if username_err or not effective_username:
+ await self._send_chat_ws_error(
+ username_err or "Invalid username", "BAD_USER"
+ )
+ return
+
+ message = post_data.get("message")
+ if message is None:
+ await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE")
+ return
+
+ raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
+ session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
+ if not session_id:
+ session_id = str(uuid4())
+
+ ensure_session_err = await self._ensure_chat_session(
+ effective_username,
+ session_id,
+ )
+ if ensure_session_err:
+ await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR")
+ return
+
+ config_id, resolve_err = self._resolve_chat_config_id(post_data)
+ if resolve_err:
+ await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR")
+ return
+
+ config_err = await self._update_session_config_route(
+ username=effective_username,
+ session_id=session_id,
+ config_id=config_id,
+ )
+ if config_err:
+ await self._send_chat_ws_error(config_err, "CONFIG_ERROR")
+ return
+
+ message_parts = await self.chat_route._build_user_message_parts(message)
+ if not webchat_message_parts_have_content(message_parts):
+ await self._send_chat_ws_error(
+ "Message content is empty (reply only is not allowed)",
+ "INVALID_MESSAGE",
+ )
+ return
+
+ message_id = str(post_data.get("message_id") or uuid4())
+ selected_provider = post_data.get("selected_provider")
+ selected_model = post_data.get("selected_model")
+ enable_streaming = post_data.get("enable_streaming", True)
+
+ back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
+ try:
+ chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
+ await chat_queue.put(
+ (
+ effective_username,
+ session_id,
+ {
+ "message": message_parts,
+ "selected_provider": selected_provider,
+ "selected_model": selected_model,
+ "enable_streaming": enable_streaming,
+ "message_id": message_id,
+ },
+ )
+ )
+
+ message_parts_for_storage = strip_message_parts_path_fields(message_parts)
+ await self.chat_route.platform_history_mgr.insert(
+ platform_id="webchat",
+ user_id=session_id,
+ content={"type": "user", "message": message_parts_for_storage},
+ sender_id=effective_username,
+ sender_name=effective_username,
+ )
+
+ await websocket.send_json(
+ {
+ "type": "session_id",
+ "data": None,
+ "session_id": session_id,
+ "message_id": message_id,
+ }
+ )
+
+ accumulated_parts = []
+ accumulated_text = ""
+ accumulated_reasoning = ""
+ tool_calls = {}
+ agent_stats = {}
+ refs = {}
+ while True:
+ try:
+ result = await asyncio.wait_for(back_queue.get(), timeout=1)
+ except asyncio.TimeoutError:
+ continue
+
+ if not result:
+ continue
+
+ if "message_id" in result and result["message_id"] != message_id:
+ logger.warning("openapi ws stream message_id mismatch")
+ continue
+
+ result_text = result.get("data", "")
+ msg_type = result.get("type")
+ streaming = result.get("streaming", False)
+ chain_type = result.get("chain_type")
+
+ if chain_type == "agent_stats":
+ try:
+ stats_info = {
+ "type": "agent_stats",
+ "data": json.loads(result_text),
+ }
+ await websocket.send_json(stats_info)
+ agent_stats = stats_info["data"]
+ except Exception:
+ pass
+ continue
+
+ await websocket.send_json(result)
+
+ if msg_type == "plain":
+ if chain_type == "tool_call":
+ tool_call = json.loads(result_text)
+ tool_calls[tool_call.get("id")] = tool_call
+ if accumulated_text:
+ accumulated_parts.append(
+ {"type": "plain", "text": accumulated_text}
+ )
+ accumulated_text = ""
+ elif chain_type == "tool_call_result":
+ tcr = json.loads(result_text)
+ tc_id = tcr.get("id")
+ if tc_id in tool_calls:
+ tool_calls[tc_id]["result"] = tcr.get("result")
+ tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
+ accumulated_parts.append(
+ {"type": "tool_call", "tool_calls": [tool_calls[tc_id]]}
+ )
+ tool_calls.pop(tc_id, None)
+ elif chain_type == "reasoning":
+ accumulated_reasoning += result_text
+ elif streaming:
+ accumulated_text += result_text
+ else:
+ accumulated_text = result_text
+ elif msg_type == "image":
+ filename = str(result_text).replace("[IMAGE]", "")
+ part = await self.chat_route._create_attachment_from_file(
+ filename, "image"
+ )
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "record":
+ filename = str(result_text).replace("[RECORD]", "")
+ part = await self.chat_route._create_attachment_from_file(
+ filename, "record"
+ )
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "file":
+ filename = str(result_text).replace("[FILE]", "")
+ part = await self.chat_route._create_attachment_from_file(
+ filename, "file"
+ )
+ if part:
+ accumulated_parts.append(part)
+ elif msg_type == "video":
+ filename = str(result_text).replace("[VIDEO]", "")
+ part = await self.chat_route._create_attachment_from_file(
+ filename, "video"
+ )
+ if part:
+ accumulated_parts.append(part)
+
+ if msg_type == "end":
+ break
+ if (streaming and msg_type == "complete") or not streaming:
+ if chain_type in ("tool_call", "tool_call_result"):
+ continue
+ try:
+ refs = self.chat_route._extract_web_search_refs(
+ accumulated_text,
+ accumulated_parts,
+ )
+ except Exception as e:
+ logger.exception(
+ f"Open API WS failed to extract web search refs: {e}",
+ exc_info=True,
+ )
+
+ saved_record = await self.chat_route._save_bot_message(
+ session_id,
+ accumulated_text,
+ accumulated_parts,
+ accumulated_reasoning,
+ agent_stats,
+ refs,
+ )
+ if saved_record:
+ await websocket.send_json(
+ {
+ "type": "message_saved",
+ "data": {
+ "id": saved_record.id,
+ "created_at": saved_record.created_at.astimezone().isoformat(),
+ },
+ "session_id": session_id,
+ }
+ )
+ accumulated_parts = []
+ accumulated_text = ""
+ accumulated_reasoning = ""
+ agent_stats = {}
+ refs = {}
+ except Exception as e:
+ logger.exception(f"Open API WS chat failed: {e}", exc_info=True)
+ await self._send_chat_ws_error(
+ f"Failed to process message: {e}", "PROCESSING_ERROR"
+ )
+ finally:
+ webchat_queue_mgr.remove_back_queue(message_id)
+
+ async def chat_ws(self) -> None:
+ authed, auth_err = await self._authenticate_chat_ws_api_key()
+ if not authed:
+ await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
+ await websocket.close(1008, auth_err or "Unauthorized")
+ return
+
+ try:
+ while True:
+ message = await websocket.receive_json()
+ if not isinstance(message, dict):
+ await self._send_chat_ws_error(
+ "message must be an object",
+ "INVALID_MESSAGE",
+ )
+ continue
+
+ msg_type = message.get("t", "send")
+ if msg_type == "ping":
+ await websocket.send_json({"type": "pong"})
+ continue
+ if msg_type != "send":
+ await self._send_chat_ws_error(
+ f"Unsupported message type: {msg_type}",
+ "INVALID_MESSAGE",
+ )
+ continue
+
+ await self._handle_chat_ws_send(message)
+ except Exception as e:
+ logger.debug("Open API WS connection closed: %s", e)
+
async def upload_file(self):
return await self.chat_route.post_file()
@@ -254,83 +604,12 @@ async def get_chat_configs(self):
async def _build_message_chain_from_payload(
self,
message_payload: str | list,
- ) -> MessageChain:
- if isinstance(message_payload, str):
- text = message_payload.strip()
- if not text:
- raise ValueError("Message is empty")
- return MessageChain(chain=[Plain(text=text)])
-
- if not isinstance(message_payload, list):
- raise ValueError("message must be a string or list")
-
- components = []
- has_content = False
-
- for part in message_payload:
- if not isinstance(part, dict):
- raise ValueError("message part must be an object")
-
- part_type = str(part.get("type", "")).strip()
- if part_type == "plain":
- text = str(part.get("text", ""))
- if text:
- has_content = True
- components.append(Plain(text=text))
- continue
-
- if part_type == "reply":
- message_id = part.get("message_id")
- if message_id is None:
- raise ValueError("reply part missing message_id")
- components.append(
- Reply(
- id=str(message_id),
- message_str=str(part.get("selected_text", "")),
- chain=[],
- )
- )
- continue
-
- if part_type not in {"image", "record", "file", "video"}:
- raise ValueError(f"unsupported message part type: {part_type}")
-
- has_content = True
- file_path: Path | None = None
- resolved_type = part_type
- filename = str(part.get("filename", "")).strip()
-
- attachment_id = part.get("attachment_id")
- if attachment_id:
- attachment = await self.db.get_attachment_by_id(str(attachment_id))
- if not attachment:
- raise ValueError(f"attachment not found: {attachment_id}")
- file_path = Path(attachment.path)
- resolved_type = attachment.type
- if not filename:
- filename = file_path.name
- else:
- raise ValueError(f"{part_type} part missing attachment_id")
-
- if not file_path.exists():
- raise ValueError(f"file not found: {file_path!s}")
-
- file_path_str = str(file_path.resolve())
- if resolved_type == "image":
- components.append(Image.fromFileSystem(file_path_str))
- elif resolved_type == "record":
- components.append(Record.fromFileSystem(file_path_str))
- elif resolved_type == "video":
- components.append(Video.fromFileSystem(file_path_str))
- else:
- components.append(
- File(name=filename or file_path.name, file=file_path_str)
- )
-
- if not components or not has_content:
- raise ValueError("Message content is empty (reply only is not allowed)")
-
- return MessageChain(chain=components)
+ ):
+ return await build_message_chain_from_payload(
+ message_payload,
+ get_attachment_by_id=self.db.get_attachment_by_id,
+ strict=True,
+ )
async def send_message(self):
post_data = await request.json or {}
diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py
index a9631fc09..a9650cd06 100644
--- a/astrbot/dashboard/server.py
+++ b/astrbot/dashboard/server.py
@@ -204,6 +204,10 @@ async def auth_middleware(self):
@staticmethod
def _extract_raw_api_key() -> str | None:
+ if key := request.args.get("api_key"):
+ return key.strip()
+ if key := request.args.get("key"):
+ return key.strip()
if key := request.headers.get("X-API-Key"):
return key.strip()
auth_header = request.headers.get("Authorization", "").strip()
@@ -217,6 +221,7 @@ def _extract_raw_api_key() -> str | None:
def _get_required_open_api_scope(path: str) -> str | None:
scope_map = {
"/api/v1/chat": "chat",
+ "/api/v1/chat/ws": "chat",
"/api/v1/chat/sessions": "chat",
"/api/v1/configs": "config",
"/api/v1/file": "file",
diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue
index 803c5d826..054a18662 100644
--- a/dashboard/src/components/chat/Chat.vue
+++ b/dashboard/src/components/chat/Chat.vue
@@ -10,6 +10,7 @@
:selectedSessions="selectedSessions"
:currSessionId="currSessionId"
:selectedProjectId="selectedProjectId"
+ :transportMode="transportMode"
:isDark="isDark"
:chatboxMode="chatboxMode"
:isMobile="isMobile"
@@ -26,6 +27,7 @@
@createProject="showCreateProjectDialog"
@editProject="showEditProjectDialog"
@deleteProject="handleDeleteProject"
+ @updateTransportMode="setTransportMode"
/>
@@ -301,11 +303,14 @@ const {
isStreaming,
isConvRunning,
enableStreaming,
+ transportMode,
currentSessionProject,
getSessionMessages: getSessionMsg,
sendMessage: sendMsg,
stopMessage: stopMsg,
- toggleStreaming
+ toggleStreaming,
+ setTransportMode,
+ cleanupTransport
} = useMessages(currSessionId, getMediaFile, updateSessionTitle, getSessions);
// 组件引用
@@ -695,6 +700,7 @@ onMounted(() => {
onBeforeUnmount(() => {
window.removeEventListener('resize', checkMobile);
cleanupMediaCache();
+ cleanupTransport();
});
diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue
index a728930d9..97f2179e7 100644
--- a/dashboard/src/components/chat/ConversationSidebar.vue
+++ b/dashboard/src/components/chat/ConversationSidebar.vue
@@ -117,6 +117,27 @@
{{ isDark ? tm('modes.lightMode') : tm('modes.darkMode') }}
+
+
+