-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat: implement websockets transport mode selection for chat #5410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,12 +3,12 @@ | |
| import time | ||
| import uuid | ||
| from collections.abc import Callable, Coroutine | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| from astrbot import logger | ||
| from astrbot.core import db_helper | ||
| from astrbot.core.db.po import PlatformMessageHistory | ||
| from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video | ||
| from astrbot.core.message.message_event_result import MessageChain | ||
| from astrbot.core.platform import ( | ||
| AstrBotMessage, | ||
|
|
@@ -21,10 +21,23 @@ | |
| from astrbot.core.utils.astrbot_path import get_astrbot_data_path | ||
|
|
||
| from ...register import register_platform_adapter | ||
| from .message_parts_helper import ( | ||
| message_chain_to_storage_message_parts, | ||
| parse_webchat_message_parts, | ||
| ) | ||
| from .webchat_event import WebChatMessageEvent | ||
| from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr | ||
|
|
||
|
|
||
| def _extract_conversation_id(session_id: str) -> str: | ||
| """Extract raw webchat conversation id from event/session id.""" | ||
| if session_id.startswith("webchat!"): | ||
| parts = session_id.split("!", 2) | ||
| if len(parts) == 3: | ||
| return parts[2] | ||
| return session_id | ||
|
|
||
|
|
||
| class QueueListener: | ||
| def __init__( | ||
| self, | ||
|
|
@@ -57,13 +70,15 @@ def __init__( | |
|
|
||
| self.settings = platform_settings | ||
| self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") | ||
| self.attachments_dir = Path(get_astrbot_data_path()) / "attachments" | ||
| os.makedirs(self.imgs_dir, exist_ok=True) | ||
| self.attachments_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| self.metadata = PlatformMetadata( | ||
| name="webchat", | ||
| description="webchat", | ||
| id="webchat", | ||
| support_proactive_message=False, | ||
| support_proactive_message=True, | ||
| ) | ||
| self._shutdown_event = asyncio.Event() | ||
| self._webchat_queue_mgr = webchat_queue_mgr | ||
|
|
@@ -73,10 +88,67 @@ async def send_by_session( | |
| session: MessageSesion, | ||
| message_chain: MessageChain, | ||
| ) -> None: | ||
| message_id = f"active_{str(uuid.uuid4())}" | ||
| await WebChatMessageEvent._send(message_id, message_chain, session.session_id) | ||
| conversation_id = _extract_conversation_id(session.session_id) | ||
| active_request_ids = self._webchat_queue_mgr.list_back_request_ids( | ||
| conversation_id | ||
| ) | ||
| subscription_request_ids = [ | ||
| req_id for req_id in active_request_ids if req_id.startswith("ws_sub_") | ||
| ] | ||
| target_request_ids = subscription_request_ids or active_request_ids | ||
|
|
||
| if target_request_ids: | ||
| for request_id in target_request_ids: | ||
| await WebChatMessageEvent._send( | ||
| request_id, | ||
| message_chain, | ||
| session.session_id, | ||
| ) | ||
| else: | ||
| message_id = f"active_{uuid.uuid4()!s}" | ||
| await WebChatMessageEvent._send( | ||
| message_id, | ||
| message_chain, | ||
| session.session_id, | ||
| ) | ||
|
|
||
| should_persist = ( | ||
| bool(subscription_request_ids) | ||
| or not active_request_ids | ||
| or all(req_id.startswith("active_") for req_id in active_request_ids) | ||
| ) | ||
|
Comment on lines
+115
to
+119
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to determine |
||
| if should_persist: | ||
| try: | ||
| await self._save_proactive_message(conversation_id, message_chain) | ||
| except Exception as e: | ||
| logger.error( | ||
| f"[WebChatAdapter] Failed to save proactive message: {e}", | ||
| exc_info=True, | ||
| ) | ||
|
|
||
| await super().send_by_session(session, message_chain) | ||
|
|
||
| async def _save_proactive_message( | ||
| self, | ||
| conversation_id: str, | ||
| message_chain: MessageChain, | ||
| ) -> None: | ||
| message_parts = await message_chain_to_storage_message_parts( | ||
| message_chain, | ||
| insert_attachment=db_helper.insert_attachment, | ||
| attachments_dir=self.attachments_dir, | ||
| ) | ||
| if not message_parts: | ||
| return | ||
|
|
||
| await db_helper.insert_platform_message_history( | ||
| platform_id="webchat", | ||
| user_id=conversation_id, | ||
| content={"type": "bot", "message": message_parts}, | ||
| sender_id="bot", | ||
| sender_name="bot", | ||
| ) | ||
|
|
||
| async def _get_message_history( | ||
| self, message_id: int | ||
| ) -> PlatformMessageHistory | None: | ||
|
|
@@ -98,72 +170,30 @@ async def _parse_message_parts( | |
| Returns: | ||
| tuple[list, list[str]]: (消息组件列表, 纯文本列表) | ||
| """ | ||
| components = [] | ||
| text_parts = [] | ||
|
|
||
| for part in message_parts: | ||
| part_type = part.get("type") | ||
| if part_type == "plain": | ||
| text = part.get("text", "") | ||
| components.append(Plain(text=text)) | ||
| text_parts.append(text) | ||
| elif part_type == "reply": | ||
| message_id = part.get("message_id") | ||
| reply_chain = [] | ||
| reply_message_str = part.get("selected_text", "") | ||
| sender_id = None | ||
| sender_name = None | ||
|
|
||
| if reply_message_str: | ||
| reply_chain = [Plain(text=reply_message_str)] | ||
|
|
||
| # recursively get the content of the referenced message, if selected_text is empty | ||
| if not reply_message_str and depth < max_depth and message_id: | ||
| history = await self._get_message_history(message_id) | ||
| if history and history.content: | ||
| reply_parts = history.content.get("message", []) | ||
| if isinstance(reply_parts, list): | ||
| ( | ||
| reply_chain, | ||
| reply_text_parts, | ||
| ) = await self._parse_message_parts( | ||
| reply_parts, | ||
| depth=depth + 1, | ||
| max_depth=max_depth, | ||
| ) | ||
| reply_message_str = "".join(reply_text_parts) | ||
| sender_id = history.sender_id | ||
| sender_name = history.sender_name | ||
|
|
||
| components.append( | ||
| Reply( | ||
| id=message_id, | ||
| chain=reply_chain, | ||
| message_str=reply_message_str, | ||
| sender_id=sender_id, | ||
| sender_nickname=sender_name, | ||
| ) | ||
| ) | ||
| elif part_type == "image": | ||
| path = part.get("path") | ||
| if path: | ||
| components.append(Image.fromFileSystem(path)) | ||
| elif part_type == "record": | ||
| path = part.get("path") | ||
| if path: | ||
| components.append(Record.fromFileSystem(path)) | ||
| elif part_type == "file": | ||
| path = part.get("path") | ||
| if path: | ||
| filename = part.get("filename") or ( | ||
| os.path.basename(path) if path else "file" | ||
| ) | ||
| components.append(File(name=filename, file=path)) | ||
| elif part_type == "video": | ||
| path = part.get("path") | ||
| if path: | ||
| components.append(Video.fromFileSystem(path)) | ||
|
|
||
| async def get_reply_parts( | ||
| message_id: Any, | ||
| ) -> tuple[list[dict], str | None, str | None] | None: | ||
| history = await self._get_message_history(message_id) | ||
| if not history or not history.content: | ||
| return None | ||
|
|
||
| reply_parts = history.content.get("message", []) | ||
| if not isinstance(reply_parts, list): | ||
| return None | ||
|
|
||
| return reply_parts, history.sender_id, history.sender_name | ||
|
|
||
| components, text_parts, _ = await parse_webchat_message_parts( | ||
| message_parts, | ||
| strict=False, | ||
| include_empty_plain=True, | ||
| verify_media_path_exists=False, | ||
| reply_history_getter=get_reply_parts, | ||
| current_depth=depth, | ||
| max_reply_depth=max_depth, | ||
| cast_reply_id_to_str=False, | ||
| ) | ||
| return components, text_parts | ||
|
|
||
| async def convert_message(self, data: tuple) -> AstrBotMessage: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,15 @@ | |
| attachments_dir = os.path.join(get_astrbot_data_path(), "attachments") | ||
|
|
||
|
|
||
| def _extract_conversation_id(session_id: str) -> str: | ||
| """Extract raw webchat conversation id from event/session id.""" | ||
| if session_id.startswith("webchat!"): | ||
| parts = session_id.split("!", 2) | ||
| if len(parts) == 3: | ||
| return parts[2] | ||
| return session_id | ||
|
Comment on lines
+17
to
+23
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function |
||
|
|
||
|
|
||
| class WebChatMessageEvent(AstrMessageEvent): | ||
| def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: | ||
| super().__init__(message_str, message_obj, platform_meta, session_id) | ||
|
|
@@ -27,7 +36,7 @@ async def _send( | |
| streaming: bool = False, | ||
| ) -> str | None: | ||
| request_id = str(message_id) | ||
| conversation_id = session_id.split("!")[-1] | ||
| conversation_id = _extract_conversation_id(session_id) | ||
| web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( | ||
| request_id, | ||
| conversation_id, | ||
|
|
@@ -130,7 +139,7 @@ async def send_streaming(self, generator, use_fallback: bool = False) -> None: | |
| reasoning_content = "" | ||
| message_id = self.message_obj.message_id | ||
| request_id = str(message_id) | ||
| conversation_id = self.session_id.split("!")[-1] | ||
| conversation_id = _extract_conversation_id(self.session_id) | ||
| web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue( | ||
| request_id, | ||
| conversation_id, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When sending proactive messages and no active requests are found for a conversation, the adapter creates a new back queue with a random
message_idand attempts to send the message. Since no client is listening for this randommessage_id, the messages will accumulate in the queue. Once the queue reaches its maximum size (default 512), subsequent calls toWebChatMessageEvent._send(which usesawait queue.put()) will block indefinitely, potentially hanging the bot's execution for that task.