diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 2f19434c9d..8afcf39519 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -86,6 +86,9 @@ async def to_dict(self) -> dict: # 默认情况下,回退到旧的同步 toDict() return self.toDict() + def empty(self) -> bool: + return True + class Plain(BaseMessageComponent): type: ComponentType = ComponentType.Plain @@ -100,6 +103,9 @@ def toDict(self) -> dict: async def to_dict(self) -> dict: return {"type": "text", "data": {"text": self.text}} + def empty(self) -> bool: + return not bool(self.text and self.text.strip()) + class Face(BaseMessageComponent): type: ComponentType = ComponentType.Face @@ -108,6 +114,9 @@ class Face(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return self.id is None + class Record(BaseMessageComponent): type: ComponentType = ComponentType.Record @@ -125,6 +134,9 @@ def __init__(self, file: str | None, **_) -> None: # Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}") super().__init__(file=file, **_) + def empty(self) -> bool: + return not bool(self.file) + @staticmethod def fromFileSystem(path, **_): return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) @@ -224,6 +236,9 @@ class Video(BaseMessageComponent): def __init__(self, file: str, **_) -> None: super().__init__(file=file, **_) + def empty(self) -> bool: + return not bool(self.file) + @staticmethod def fromFileSystem(path, **_): return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) @@ -307,6 +322,9 @@ class At(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not (bool(self.qq) or bool(self.name)) + def toDict(self): return { "type": "at", @@ -327,6 +345,9 @@ class RPS(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return False + class Dice(BaseMessageComponent): # TODO type: ComponentType = ComponentType.Dice @@ -334,6 +355,9 @@ class Dice(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return False + class Shake(BaseMessageComponent): # TODO type: ComponentType = ComponentType.Shake @@ -341,6 +365,9 @@ class Shake(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return False + class Share(BaseMessageComponent): type: ComponentType = ComponentType.Share @@ -352,6 +379,9 @@ class Share(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not (bool(self.url) or bool(self.title)) + class Contact(BaseMessageComponent): # TODO type: ComponentType = ComponentType.Contact @@ -361,6 +391,9 @@ class Contact(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not bool(self._type and self.id) + class Location(BaseMessageComponent): # TODO type: ComponentType = ComponentType.Location @@ -372,6 +405,9 @@ class Location(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not bool(self.lat is not None and self.lon is not None) + class Music(BaseMessageComponent): type: ComponentType = ComponentType.Music @@ -389,6 +425,12 @@ def __init__(self, **_) -> None: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") super().__init__(**_) + def empty(self) -> bool: + return not ( + (self.id and self._type and self._type != "custom") + or (self._type == "custom" and self.url and self.audio and self.title) + ) + class Image(BaseMessageComponent): type: ComponentType = ComponentType.Image @@ -401,6 +443,9 @@ class Image(BaseMessageComponent): def __init__(self, file: str | None, **_) -> None: super().__init__(file=file, **_) + def empty(self) -> bool: + return not bool(self.file) + @staticmethod def fromURL(url: str, **_): if url.startswith("http://") or url.startswith("https://"): @@ -525,6 +570,9 @@ class Reply(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not (bool(self.id) and self.sender_id is not None) + class Poke(BaseMessageComponent): type: ComponentType = ComponentType.Poke @@ -551,6 +599,9 @@ def target_id(self) -> str | None: return text return None + def empty(self) -> bool: + return self.target_id() is None + def toDict(self): target_id = self.target_id() data = {"type": str(self._type or "126")} @@ -566,6 +617,9 @@ class Forward(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not bool(self.id) + class Node(BaseMessageComponent): """群合并转发消息""" @@ -584,6 +638,9 @@ def __init__(self, content: list[BaseMessageComponent], **_) -> None: content = [content] super().__init__(content=content, **_) + def empty(self) -> bool: + return not bool(self.content) + async def to_dict(self): data_content = [] for comp in self.content: @@ -628,6 +685,9 @@ class Nodes(BaseMessageComponent): def __init__(self, nodes: list[Node], **_) -> None: super().__init__(nodes=nodes, **_) + def empty(self) -> bool: + return not bool(self.nodes) + def toDict(self): """Deprecated. Use to_dict instead""" ret = { @@ -656,11 +716,17 @@ def __init__(self, data: str | dict, **_) -> None: data = json.loads(data) super().__init__(data=data, **_) + def empty(self) -> bool: + return not bool(self.data) + class Unknown(BaseMessageComponent): type: ComponentType = ComponentType.Unknown text: str + def empty(self) -> bool: + return not bool(self.text and self.text.strip()) + class File(BaseMessageComponent): """文件消息段""" @@ -674,6 +740,9 @@ def __init__(self, name: str, file: str = "", url: str = "") -> None: """文件消息段。""" super().__init__(name=name, file_=file, url=url) + def empty(self) -> bool: + return not bool(self.file_ or self.url) + @property def file(self) -> str: """获取文件路径,如果文件不存在但有URL,则同步下载文件 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 604f1ded0e..9b77974f8a 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -17,38 +17,6 @@ @register_stage class RespondStage(Stage): - # 组件类型到其非空判断函数的映射 - _component_validators = { - Comp.Plain: lambda comp: bool( - comp.text and comp.text.strip(), - ), # 纯文本消息需要strip - Comp.Face: lambda comp: comp.id is not None, # QQ表情 - Comp.Record: lambda comp: bool(comp.file), # 语音 - Comp.Video: lambda comp: bool(comp.file), # 视频 - Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @ - Comp.Image: lambda comp: bool(comp.file), # 图片 - Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 - Comp.Poke: lambda comp: comp.target_id() is not None, # 戳一戳 - Comp.Node: lambda comp: bool(comp.content), # 转发节点 - Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 - Comp.File: lambda comp: bool(comp.file_ or comp.url), - Comp.Json: lambda comp: bool(comp.data), # Json 卡片 - Comp.Share: lambda comp: bool(comp.url) or bool(comp.title), - Comp.Music: lambda comp: ( - (comp.id and comp._type and comp._type != "custom") - or (comp._type == "custom" and comp.url and comp.audio and comp.title) - ), # 音乐分享 - Comp.Forward: lambda comp: bool(comp.id), # 合并转发 - Comp.Location: lambda comp: bool( - comp.lat is not None and comp.lon is not None - ), # 位置 - Comp.Contact: lambda comp: bool(comp._type and comp.id), # 推荐好友 or 群 - Comp.Shake: lambda _: True, # 窗口抖动(戳一戳) - Comp.Dice: lambda _: True, # 掷骰子魔法表情 - Comp.RPS: lambda _: True, # 猜拳魔法表情 - Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), - } - async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config @@ -117,12 +85,8 @@ async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bo return True for comp in chain: - comp_type = type(comp) - - # 检查组件类型是否在字典中 - if comp_type in self._component_validators: - if self._component_validators[comp_type](comp): - return False + if not comp.empty(): + return False # 如果所有组件都为空 return True diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index 433509f5e1..2a6bce7c1e 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -8,6 +8,14 @@ class DiscordEmbed(BaseMessageComponent): """Discord Embed消息组件""" type: str = "discord_embed" + title: str | None = None + description: str | None = None + color: int | None = None + url: str | None = None + thumbnail: str | None = None + image: str | None = None + footer: str | None = None + fields: list[dict] | None = None def __init__( self, @@ -20,14 +28,33 @@ def __init__( footer: str | None = None, fields: list[dict] | None = None, ) -> None: - self.title = title - self.description = description - self.color = color - self.url = url - self.thumbnail = thumbnail - self.image = image - self.footer = footer - self.fields = fields or [] + super().__init__( + title=title, + description=description, + color=color, + url=url, + thumbnail=thumbnail, + image=image, + footer=footer, + fields=fields or [], + ) + + def empty(self) -> bool: + return not ( + any( + bool(value) + for value in ( + self.title, + self.description, + self.url, + self.thumbnail, + self.image, + self.footer, + self.fields, + ) + ) + or self.color is not None + ) def to_discord_embed(self) -> discord.Embed: """转换为Discord Embed对象""" @@ -37,7 +64,7 @@ def to_discord_embed(self) -> discord.Embed: embed.title = self.title if self.description: embed.description = self.description - if self.color: + if self.color is not None: embed.color = self.color if self.url: embed.url = self.url @@ -48,7 +75,7 @@ def to_discord_embed(self) -> discord.Embed: if self.footer: embed.set_footer(text=self.footer) - for field in self.fields: + for field in self.fields or []: embed.add_field( name=field.get("name", ""), value=field.get("value", ""), @@ -62,6 +89,12 @@ class DiscordButton(BaseMessageComponent): """Discord按钮组件""" type: str = "discord_button" + label: str + custom_id: str | None = None + style: str = "primary" + emoji: str | None = None + url: str | None = None + disabled: bool = False def __init__( self, @@ -72,42 +105,55 @@ def __init__( url: str | None = None, disabled: bool = False, ) -> None: - self.label = label - self.custom_id = custom_id - self.style = style - self.emoji = emoji - self.url = url - self.disabled = disabled + super().__init__( + label=label, + custom_id=custom_id, + style=style, + emoji=emoji, + url=url, + disabled=disabled, + ) + + def empty(self) -> bool: + return not bool(self.label or self.url or self.custom_id or self.emoji) class DiscordReference(BaseMessageComponent): """Discord引用组件""" type: str = "discord_reference" + message_id: str + channel_id: str def __init__(self, message_id: str, channel_id: str) -> None: - self.message_id = message_id - self.channel_id = channel_id + super().__init__(message_id=message_id, channel_id=channel_id) + + def empty(self) -> bool: + return not bool(self.message_id and self.channel_id) class DiscordView(BaseMessageComponent): """Discord视图组件,包含按钮和选择菜单""" type: str = "discord_view" + components: list[BaseMessageComponent] | None = None + timeout: float | None = None def __init__( self, components: list[BaseMessageComponent] | None = None, timeout: float | None = None, ) -> None: - self.components = components or [] - self.timeout = timeout + super().__init__(components=components or [], timeout=timeout) + + def empty(self) -> bool: + return not bool(self.components) def to_discord_view(self) -> discord.ui.View: """转换为Discord View对象""" view = discord.ui.View(timeout=self.timeout) - for component in self.components: + for component in self.components or []: if isinstance(component, DiscordButton): button_style = getattr( discord.ButtonStyle, diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 02d4dae868..99cae2f578 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -12,7 +12,6 @@ from astrbot import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( - BaseMessageComponent, File, Image, Plain, @@ -24,14 +23,6 @@ from .components import DiscordEmbed, DiscordView -# 自定义Discord视图组件(兼容旧版本) -class DiscordViewComponent(BaseMessageComponent): - type: str = "discord_view" - - def __init__(self, view: discord.ui.View) -> None: - self.view = view - - class DiscordPlatformEvent(AstrMessageEvent): def __init__( self, @@ -254,12 +245,10 @@ async def _parse_to_discord( elif isinstance(i, DiscordView): # Discord视图组件(按钮、选择菜单等) view = i.to_discord_view() - elif isinstance(i, DiscordViewComponent): - # 如果消息链中包含Discord视图组件(兼容旧版本) - if isinstance(i.view, discord.ui.View): - view = i.view else: - logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}") + logger.debug( + f"[Discord] 忽略了不支持的消息组件: {getattr(i, 'type', None)}" + ) content = "".join(content_parts) if len(content) > 2000: diff --git a/tests/unit/test_aiocqhttp_poke.py b/tests/unit/test_aiocqhttp_poke.py index dff886b91e..d3e6cb4146 100644 --- a/tests/unit/test_aiocqhttp_poke.py +++ b/tests/unit/test_aiocqhttp_poke.py @@ -22,6 +22,7 @@ def test_poke_to_dict_matches_onebot_v11_segment_format(): async def test_respond_stage_treats_poke_with_target_as_non_empty(): stage = RespondStage() chain = [Comp.Poke(type="126", id=2003)] + assert chain[0].empty() is False assert await stage._is_empty_message_chain(chain) is False