Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
465 changes: 465 additions & 0 deletions astrbot/core/platform/sources/webchat/message_parts_helper.py

Large diffs are not rendered by default.

168 changes: 99 additions & 69 deletions astrbot/core/platform/sources/webchat/webchat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import time
import uuid
from collections.abc import Callable, Coroutine
from pathlib import Path
from typing import Any

from astrbot import logger
from astrbot.core import db_helper
from astrbot.core.db.po import PlatformMessageHistory
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import (
AstrBotMessage,
Expand All @@ -21,10 +21,23 @@
from astrbot.core.utils.astrbot_path import get_astrbot_data_path

from ...register import register_platform_adapter
from .message_parts_helper import (
message_chain_to_storage_message_parts,
parse_webchat_message_parts,
)
from .webchat_event import WebChatMessageEvent
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr


def _extract_conversation_id(session_id: str) -> str:
"""Extract raw webchat conversation id from event/session id."""
if session_id.startswith("webchat!"):
parts = session_id.split("!", 2)
if len(parts) == 3:
return parts[2]
return session_id


class QueueListener:
def __init__(
self,
Expand Down Expand Up @@ -57,13 +70,15 @@ def __init__(

self.settings = platform_settings
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
self.attachments_dir = Path(get_astrbot_data_path()) / "attachments"
os.makedirs(self.imgs_dir, exist_ok=True)
self.attachments_dir.mkdir(parents=True, exist_ok=True)

self.metadata = PlatformMetadata(
name="webchat",
description="webchat",
id="webchat",
support_proactive_message=False,
support_proactive_message=True,
)
self._shutdown_event = asyncio.Event()
self._webchat_queue_mgr = webchat_queue_mgr
Expand All @@ -73,10 +88,67 @@ async def send_by_session(
session: MessageSesion,
message_chain: MessageChain,
) -> None:
message_id = f"active_{str(uuid.uuid4())}"
await WebChatMessageEvent._send(message_id, message_chain, session.session_id)
conversation_id = _extract_conversation_id(session.session_id)
active_request_ids = self._webchat_queue_mgr.list_back_request_ids(
conversation_id
)
subscription_request_ids = [
req_id for req_id in active_request_ids if req_id.startswith("ws_sub_")
]
target_request_ids = subscription_request_ids or active_request_ids

if target_request_ids:
for request_id in target_request_ids:
await WebChatMessageEvent._send(
request_id,
message_chain,
session.session_id,
)
else:
message_id = f"active_{uuid.uuid4()!s}"
await WebChatMessageEvent._send(
message_id,
message_chain,
session.session_id,
)
Comment on lines +108 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

When sending proactive messages and no active requests are found for a conversation, the adapter creates a new back queue with a random message_id and attempts to send the message. Since no client is listening for this random message_id, the messages will accumulate in the queue. Once the queue reaches its maximum size (default 512), subsequent calls to WebChatMessageEvent._send (which uses await queue.put()) will block indefinitely, potentially hanging the bot's execution for that task.


should_persist = (
bool(subscription_request_ids)
or not active_request_ids
or all(req_id.startswith("active_") for req_id in active_request_ids)
)
Comment on lines +115 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to determine should_persist is quite complex and relies on the naming conventions of request IDs. To improve code clarity and make it easier for future developers to understand, please add a comment explaining the conditions under which a proactive message is persisted versus when persistence is delegated to another handler (e.g., for a standard request-response cycle).

if should_persist:
try:
await self._save_proactive_message(conversation_id, message_chain)
except Exception as e:
logger.error(
f"[WebChatAdapter] Failed to save proactive message: {e}",
exc_info=True,
)

await super().send_by_session(session, message_chain)

async def _save_proactive_message(
self,
conversation_id: str,
message_chain: MessageChain,
) -> None:
message_parts = await message_chain_to_storage_message_parts(
message_chain,
insert_attachment=db_helper.insert_attachment,
attachments_dir=self.attachments_dir,
)
if not message_parts:
return

await db_helper.insert_platform_message_history(
platform_id="webchat",
user_id=conversation_id,
content={"type": "bot", "message": message_parts},
sender_id="bot",
sender_name="bot",
)

async def _get_message_history(
self, message_id: int
) -> PlatformMessageHistory | None:
Expand All @@ -98,72 +170,30 @@ async def _parse_message_parts(
Returns:
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
"""
components = []
text_parts = []

for part in message_parts:
part_type = part.get("type")
if part_type == "plain":
text = part.get("text", "")
components.append(Plain(text=text))
text_parts.append(text)
elif part_type == "reply":
message_id = part.get("message_id")
reply_chain = []
reply_message_str = part.get("selected_text", "")
sender_id = None
sender_name = None

if reply_message_str:
reply_chain = [Plain(text=reply_message_str)]

# recursively get the content of the referenced message, if selected_text is empty
if not reply_message_str and depth < max_depth and message_id:
history = await self._get_message_history(message_id)
if history and history.content:
reply_parts = history.content.get("message", [])
if isinstance(reply_parts, list):
(
reply_chain,
reply_text_parts,
) = await self._parse_message_parts(
reply_parts,
depth=depth + 1,
max_depth=max_depth,
)
reply_message_str = "".join(reply_text_parts)
sender_id = history.sender_id
sender_name = history.sender_name

components.append(
Reply(
id=message_id,
chain=reply_chain,
message_str=reply_message_str,
sender_id=sender_id,
sender_nickname=sender_name,
)
)
elif part_type == "image":
path = part.get("path")
if path:
components.append(Image.fromFileSystem(path))
elif part_type == "record":
path = part.get("path")
if path:
components.append(Record.fromFileSystem(path))
elif part_type == "file":
path = part.get("path")
if path:
filename = part.get("filename") or (
os.path.basename(path) if path else "file"
)
components.append(File(name=filename, file=path))
elif part_type == "video":
path = part.get("path")
if path:
components.append(Video.fromFileSystem(path))

async def get_reply_parts(
message_id: Any,
) -> tuple[list[dict], str | None, str | None] | None:
history = await self._get_message_history(message_id)
if not history or not history.content:
return None

reply_parts = history.content.get("message", [])
if not isinstance(reply_parts, list):
return None

return reply_parts, history.sender_id, history.sender_name

components, text_parts, _ = await parse_webchat_message_parts(
message_parts,
strict=False,
include_empty_plain=True,
verify_media_path_exists=False,
reply_history_getter=get_reply_parts,
current_depth=depth,
max_reply_depth=max_depth,
cast_reply_id_to_str=False,
)
return components, text_parts

async def convert_message(self, data: tuple) -> AstrBotMessage:
Expand Down
13 changes: 11 additions & 2 deletions astrbot/core/platform/sources/webchat/webchat_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
attachments_dir = os.path.join(get_astrbot_data_path(), "attachments")


def _extract_conversation_id(session_id: str) -> str:
"""Extract raw webchat conversation id from event/session id."""
if session_id.startswith("webchat!"):
parts = session_id.split("!", 2)
if len(parts) == 3:
return parts[2]
return session_id
Comment on lines +17 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function _extract_conversation_id is also defined in astrbot/core/platform/sources/webchat/webchat_adapter.py. To adhere to the Don't Repeat Yourself (DRY) principle and improve maintainability, this duplicated logic should be consolidated into a single, shared location. Consider moving it to a utility file, such as the newly created message_parts_helper.py, and importing it where needed.



class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
Expand All @@ -27,7 +36,7 @@ async def _send(
streaming: bool = False,
) -> str | None:
request_id = str(message_id)
conversation_id = session_id.split("!")[-1]
conversation_id = _extract_conversation_id(session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
Expand Down Expand Up @@ -130,7 +139,7 @@ async def send_streaming(self, generator, use_fallback: bool = False) -> None:
reasoning_content = ""
message_id = self.message_obj.message_id
request_id = str(message_id)
conversation_id = self.session_id.split("!")[-1]
conversation_id = _extract_conversation_id(self.session_id)
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
request_id,
conversation_id,
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/platform/sources/webchat/webchat_queue_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def remove_queue(self, conversation_id: str):
if task is not None:
task.cancel()

def list_back_request_ids(self, conversation_id: str) -> list[str]:
"""List active back-queue request IDs for a conversation."""
return list(self._conversation_back_requests.get(conversation_id, set()))

def has_queue(self, conversation_id: str) -> bool:
"""Check if a queue exists for the given conversation ID"""
return conversation_id in self.queues
Expand Down
Loading