From 778ca2bd122e6f84ad33e95ad6613f1779c012c6 Mon Sep 17 00:00:00 2001 From: SGSxingchen <853304398@qq.com> Date: Sun, 3 May 2026 08:30:09 +0000 Subject: [PATCH 1/7] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20Discord=20?= =?UTF-8?q?=E4=B8=93=E7=94=A8=E6=B6=88=E6=81=AF=E7=BB=84=E4=BB=B6=E6=9E=84?= =?UTF-8?q?=E9=80=A0=E4=B8=8E=E5=8F=91=E9=80=81=E8=AF=86=E5=88=AB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/pipeline/respond/stage.py | 27 ++++ .../platform/sources/discord/components.py | 70 +++++++--- .../sources/discord/discord_platform_event.py | 79 ++++++++++-- tests/unit/test_discord_message_components.py | 122 ++++++++++++++++++ 4 files changed, 266 insertions(+), 32 deletions(-) create mode 100644 tests/unit/test_discord_message_components.py diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 604f1ded0e..2aacceb9db 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -48,6 +48,27 @@ class RespondStage(Stage): Comp.RPS: lambda _: True, # 猜拳魔法表情 Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), } + _platform_component_validators = { + "discord_embed": lambda comp: any( + bool(getattr(comp, attr, None)) + for attr in ( + "title", + "description", + "color", + "url", + "thumbnail", + "image", + "footer", + "fields", + ) + ) + or callable(getattr(comp, "to_discord_embed", None)), + "discord_view": lambda comp: bool( + getattr(comp, "components", None) + or getattr(comp, "view", None) + or callable(getattr(comp, "to_discord_view", None)), + ), + } async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx @@ -118,6 +139,12 @@ async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bo for comp in chain: comp_type = type(comp) + component_type = getattr(comp, "type", None) + component_type = getattr(component_type, "value", component_type) + + if component_type in self._platform_component_validators: + if self._platform_component_validators[component_type](comp): + return False # 检查组件类型是否在字典中 if comp_type in self._component_validators: diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index 433509f5e1..a7a2871097 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -8,6 +8,14 @@ class DiscordEmbed(BaseMessageComponent): """Discord Embed消息组件""" type: str = "discord_embed" + title: str | None = None + description: str | None = None + color: int | None = None + url: str | None = None + thumbnail: str | None = None + image: str | None = None + footer: str | None = None + fields: list[dict] | None = None def __init__( self, @@ -20,14 +28,16 @@ def __init__( footer: str | None = None, fields: list[dict] | None = None, ) -> None: - self.title = title - self.description = description - self.color = color - self.url = url - self.thumbnail = thumbnail - self.image = image - self.footer = footer - self.fields = fields or [] + super().__init__( + title=title, + description=description, + color=color, + url=url, + thumbnail=thumbnail, + image=image, + footer=footer, + fields=fields or [], + ) def to_discord_embed(self) -> discord.Embed: """转换为Discord Embed对象""" @@ -48,7 +58,7 @@ def to_discord_embed(self) -> discord.Embed: if self.footer: embed.set_footer(text=self.footer) - for field in self.fields: + for field in self.fields or []: embed.add_field( name=field.get("name", ""), value=field.get("value", ""), @@ -62,6 +72,12 @@ class DiscordButton(BaseMessageComponent): """Discord按钮组件""" type: str = "discord_button" + label: str + custom_id: str | None = None + style: str = "primary" + emoji: str | None = None + url: str | None = None + disabled: bool = False def __init__( self, @@ -72,43 +88,55 @@ def __init__( url: str | None = None, disabled: bool = False, ) -> None: - self.label = label - self.custom_id = custom_id - self.style = style - self.emoji = emoji - self.url = url - self.disabled = disabled + super().__init__( + label=label, + custom_id=custom_id, + style=style, + emoji=emoji, + url=url, + disabled=disabled, + ) class DiscordReference(BaseMessageComponent): """Discord引用组件""" type: str = "discord_reference" + message_id: str + channel_id: str def __init__(self, message_id: str, channel_id: str) -> None: - self.message_id = message_id - self.channel_id = channel_id + super().__init__(message_id=message_id, channel_id=channel_id) class DiscordView(BaseMessageComponent): """Discord视图组件,包含按钮和选择菜单""" type: str = "discord_view" + components: list[BaseMessageComponent] | None = None + timeout: float | None = None def __init__( self, components: list[BaseMessageComponent] | None = None, timeout: float | None = None, ) -> None: - self.components = components or [] - self.timeout = timeout + super().__init__(components=components or [], timeout=timeout) def to_discord_view(self) -> discord.ui.View: """转换为Discord View对象""" view = discord.ui.View(timeout=self.timeout) - for component in self.components: - if isinstance(component, DiscordButton): + for component in self.components or []: + if ( + isinstance(component, DiscordButton) + or getattr( + component, + "type", + None, + ) + == "discord_button" + ): button_style = getattr( discord.ButtonStyle, component.style, diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 02d4dae868..9ba455d2a1 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -4,7 +4,7 @@ from collections.abc import AsyncGenerator from io import BytesIO from pathlib import Path -from typing import cast +from typing import Any, cast import discord from discord.types.interactions import ComponentInteractionData @@ -27,9 +27,62 @@ # 自定义Discord视图组件(兼容旧版本) class DiscordViewComponent(BaseMessageComponent): type: str = "discord_view" + view: Any def __init__(self, view: discord.ui.View) -> None: - self.view = view + super().__init__(view=view) + + +def _component_type(component: BaseMessageComponent) -> str | None: + component_type = getattr(component, "type", None) + return getattr(component_type, "value", component_type) + + +def _is_component_type(component: BaseMessageComponent, component_type: str) -> bool: + return _component_type(component) == component_type + + +def _to_discord_embed(component: BaseMessageComponent) -> discord.Embed: + converter = getattr(component, "to_discord_embed", None) + if callable(converter): + return converter() + + embed = discord.Embed() + if title := getattr(component, "title", None): + embed.title = title + if description := getattr(component, "description", None): + embed.description = description + if color := getattr(component, "color", None): + embed.color = color + if url := getattr(component, "url", None): + embed.url = url + if thumbnail := getattr(component, "thumbnail", None): + embed.set_thumbnail(url=thumbnail) + if image := getattr(component, "image", None): + embed.set_image(url=image) + if footer := getattr(component, "footer", None): + embed.set_footer(text=footer) + + for field in getattr(component, "fields", None) or []: + embed.add_field( + name=field.get("name", ""), + value=field.get("value", ""), + inline=field.get("inline", False), + ) + + return embed + + +def _to_discord_view(component: BaseMessageComponent) -> discord.ui.View | None: + existing_view = getattr(component, "view", None) + if isinstance(existing_view, discord.ui.View): + return existing_view + + converter = getattr(component, "to_discord_view", None) + if callable(converter): + return converter() + + return None class DiscordPlatformEvent(AstrMessageEvent): @@ -248,18 +301,22 @@ async def _parse_to_discord( logger.warning(f"[Discord] 获取文件失败: {i.name}") except Exception as e: logger.warning(f"[Discord] 处理文件失败: {i.name}, 错误: {e}") - elif isinstance(i, DiscordEmbed): + elif isinstance(i, DiscordEmbed) or _is_component_type(i, "discord_embed"): # Discord Embed消息 - embeds.append(i.to_discord_embed()) - elif isinstance(i, DiscordView): + embeds.append(_to_discord_embed(i)) + elif ( + isinstance(i, DiscordView) + or isinstance(i, DiscordViewComponent) + or _is_component_type(i, "discord_view") + ): # Discord视图组件(按钮、选择菜单等) - view = i.to_discord_view() - elif isinstance(i, DiscordViewComponent): - # 如果消息链中包含Discord视图组件(兼容旧版本) - if isinstance(i.view, discord.ui.View): - view = i.view + parsed_view = _to_discord_view(i) + if parsed_view: + view = parsed_view else: - logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}") + logger.debug( + f"[Discord] 忽略了不支持的消息组件: {getattr(i, 'type', None)}" + ) content = "".join(content_parts) if len(content) > 2000: diff --git a/tests/unit/test_discord_message_components.py b/tests/unit/test_discord_message_components.py new file mode 100644 index 0000000000..c027ce3666 --- /dev/null +++ b/tests/unit/test_discord_message_components.py @@ -0,0 +1,122 @@ +from types import SimpleNamespace + +import pytest + +from astrbot.api.message_components import BaseMessageComponent, File, Image, Plain +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.pipeline.respond.stage import RespondStage +from astrbot.core.platform.sources.discord.components import ( + DiscordButton, + DiscordEmbed, + DiscordReference, + DiscordView, +) +from astrbot.core.platform.sources.discord.discord_platform_event import ( + DiscordPlatformEvent, + DiscordViewComponent, +) + + +def test_discord_components_construct() -> None: + embed = DiscordEmbed(title="test") + button = DiscordButton(label="Click") + view = DiscordView(components=[button]) + reference = DiscordReference(message_id="1", channel_id="2") + view_component = DiscordViewComponent(object()) + + assert embed.title == "test" + assert button.label == "Click" + assert view.components == [button] + assert reference.message_id == "1" + assert view_component.view is not None + + +@pytest.mark.asyncio +async def test_parse_to_discord_handles_discord_embed() -> None: + chain = MessageChain(chain=[DiscordEmbed(title="test")]) + + ( + content, + files, + view, + embeds, + reference, + ) = await DiscordPlatformEvent._parse_to_discord( + object(), + chain, + ) + + assert content == "" + assert files == [] + assert view is None + assert reference is None + assert len(embeds) == 1 + assert embeds[0].title == "test" + + +@pytest.mark.asyncio +async def test_parse_to_discord_handles_duck_typed_discord_embed() -> None: + class CompatibleDiscordEmbed(BaseMessageComponent): + type: str = "discord_embed" + title: str | None = None + + def __init__(self, title: str) -> None: + super().__init__(title=title) + + chain = SimpleNamespace(chain=[CompatibleDiscordEmbed("duck")]) + + _, _, _, embeds, _ = await DiscordPlatformEvent._parse_to_discord(object(), chain) + + assert len(embeds) == 1 + assert embeds[0].title == "duck" + + +@pytest.mark.asyncio +async def test_respond_stage_keeps_non_empty_discord_components() -> None: + stage = RespondStage() + + assert await stage._is_empty_message_chain([DiscordEmbed(title="test")]) is False + assert ( + await stage._is_empty_message_chain( + [DiscordView(components=[DiscordButton(label="Click")])], + ) + is False + ) + + +@pytest.mark.asyncio +async def test_plain_image_file_regression(tmp_path) -> None: + stage = RespondStage() + file_path = tmp_path / "example.txt" + file_path.write_text("hello") + + assert await stage._is_empty_message_chain([Plain("hello")]) is False + assert await stage._is_empty_message_chain([Image.fromBase64("aGVsbG8=")]) is False + assert ( + await stage._is_empty_message_chain( + [File(name="example.txt", file=str(file_path))], + ) + is False + ) + + ( + content, + files, + view, + embeds, + reference, + ) = await DiscordPlatformEvent._parse_to_discord( + object(), + MessageChain( + chain=[ + Plain("hello"), + Image.fromBase64("aGVsbG8="), + ], + ), + ) + + assert content == "hello" + assert len(files) == 1 + assert view is None + assert embeds == [] + assert reference is None From 4b3367843d0fe3ca6c9ebddd8c83f1a568ee333d Mon Sep 17 00:00:00 2001 From: SGSxingchen <853304398@qq.com> Date: Sun, 3 May 2026 08:59:18 +0000 Subject: [PATCH 2/7] fix: address Discord component review comments --- astrbot/core/pipeline/respond/stage.py | 24 ++++++++++--------- .../platform/sources/discord/components.py | 12 +++------- .../sources/discord/discord_platform_event.py | 4 +++- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 2aacceb9db..e488d694c4 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -49,18 +49,20 @@ class RespondStage(Stage): Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), } _platform_component_validators = { - "discord_embed": lambda comp: any( - bool(getattr(comp, attr, None)) - for attr in ( - "title", - "description", - "color", - "url", - "thumbnail", - "image", - "footer", - "fields", + "discord_embed": lambda comp: ( + any( + bool(getattr(comp, attr, None)) + for attr in ( + "title", + "description", + "url", + "thumbnail", + "image", + "footer", + "fields", + ) ) + or getattr(comp, "color", None) is not None ) or callable(getattr(comp, "to_discord_embed", None)), "discord_view": lambda comp: bool( diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index a7a2871097..73bf894be2 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -128,15 +128,9 @@ def to_discord_view(self) -> discord.ui.View: view = discord.ui.View(timeout=self.timeout) for component in self.components or []: - if ( - isinstance(component, DiscordButton) - or getattr( - component, - "type", - None, - ) - == "discord_button" - ): + raw_type = getattr(component, "type", None) + comp_type = getattr(raw_type, "value", raw_type) + if isinstance(component, DiscordButton) or comp_type == "discord_button": button_style = getattr( discord.ButtonStyle, component.style, diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 9ba455d2a1..bcd5914d02 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -80,7 +80,9 @@ def _to_discord_view(component: BaseMessageComponent) -> discord.ui.View | None: converter = getattr(component, "to_discord_view", None) if callable(converter): - return converter() + result = converter() + if isinstance(result, discord.ui.View): + return result return None From c47e2db6a7b5a5db9e538f94b573d23fffeaf12c Mon Sep 17 00:00:00 2001 From: SGSxingchen <853304398@qq.com> Date: Sun, 3 May 2026 09:22:00 +0000 Subject: [PATCH 3/7] fix: harden Discord component review follow-ups --- astrbot/core/pipeline/respond/stage.py | 61 ++++++++++--------- .../platform/sources/discord/components.py | 26 +++++--- 2 files changed, 49 insertions(+), 38 deletions(-) diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index e488d694c4..64f9e6ffdf 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -48,29 +48,6 @@ class RespondStage(Stage): Comp.RPS: lambda _: True, # 猜拳魔法表情 Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), } - _platform_component_validators = { - "discord_embed": lambda comp: ( - any( - bool(getattr(comp, attr, None)) - for attr in ( - "title", - "description", - "url", - "thumbnail", - "image", - "footer", - "fields", - ) - ) - or getattr(comp, "color", None) is not None - ) - or callable(getattr(comp, "to_discord_embed", None)), - "discord_view": lambda comp: bool( - getattr(comp, "components", None) - or getattr(comp, "view", None) - or callable(getattr(comp, "to_discord_view", None)), - ), - } async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx @@ -129,6 +106,37 @@ async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: # random return random.uniform(self.interval[0], self.interval[1]) + def _platform_component_type(self, comp: BaseMessageComponent) -> str | None: + t = getattr(comp, "type", None) + return getattr(t, "value", t) + + def _is_platform_component_non_empty(self, comp: BaseMessageComponent) -> bool: + ctype = self._platform_component_type(comp) + if ctype == "discord_embed": + return ( + any( + bool(getattr(comp, attr, None)) + for attr in ( + "title", + "description", + "url", + "thumbnail", + "image", + "footer", + "fields", + ) + ) + or getattr(comp, "color", None) is not None + or callable(getattr(comp, "to_discord_embed", None)) + ) + if ctype == "discord_view": + return bool( + getattr(comp, "components", None) + or getattr(comp, "view", None) + or callable(getattr(comp, "to_discord_view", None)), + ) + return False + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: """检查消息链是否为空 @@ -141,12 +149,9 @@ async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bo for comp in chain: comp_type = type(comp) - component_type = getattr(comp, "type", None) - component_type = getattr(component_type, "value", component_type) - if component_type in self._platform_component_validators: - if self._platform_component_validators[component_type](comp): - return False + if self._is_platform_component_non_empty(comp): + return False # 检查组件类型是否在字典中 if comp_type in self._component_validators: diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index 73bf894be2..d91555ad0b 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -131,29 +131,35 @@ def to_discord_view(self) -> discord.ui.View: raw_type = getattr(component, "type", None) comp_type = getattr(raw_type, "value", raw_type) if isinstance(component, DiscordButton) or comp_type == "discord_button": + style_name = getattr(component, "style", "primary") + label = getattr(component, "label", "") + url = getattr(component, "url", None) + custom_id = getattr(component, "custom_id", None) + emoji = getattr(component, "emoji", None) + disabled = getattr(component, "disabled", False) button_style = getattr( discord.ButtonStyle, - component.style, + style_name, discord.ButtonStyle.primary, ) - if component.url: + if url: # URL按钮 button = discord.ui.Button( - label=component.label, + label=label, style=discord.ButtonStyle.link, - url=component.url, - emoji=component.emoji, - disabled=component.disabled, + url=url, + emoji=emoji, + disabled=disabled, ) else: # 普通按钮 button = discord.ui.Button( - label=component.label, + label=label, style=button_style, - custom_id=component.custom_id, - emoji=component.emoji, - disabled=component.disabled, + custom_id=custom_id, + emoji=emoji, + disabled=disabled, ) view.add_item(button) From e6d318a5a2d21e2bbc26f8ed62d98c6c6fe10fd8 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 3 May 2026 20:50:40 +0800 Subject: [PATCH 4/7] =?UTF-8?q?fix:=20=E6=B7=BB=E5=8A=A0=E7=BB=84=E4=BB=B6?= =?UTF-8?q?=E7=9A=84=20empty()=20=E6=96=B9=E6=B3=95=E4=BB=A5=E5=88=A4?= =?UTF-8?q?=E6=96=AD=E5=85=B6=E6=98=AF=E5=90=A6=E4=B8=BA=E7=A9=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/message/components.py | 69 ++++++++++++++++++ astrbot/core/pipeline/respond/stage.py | 72 +------------------ .../platform/sources/discord/components.py | 26 +++++++ .../sources/discord/discord_platform_event.py | 3 + tests/unit/test_aiocqhttp_poke.py | 1 + tests/unit/test_discord_message_components.py | 2 + 6 files changed, 102 insertions(+), 71 deletions(-) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 2f19434c9d..8afcf39519 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -86,6 +86,9 @@ async def to_dict(self) -> dict: # 默认情况下,回退到旧的同步 toDict() return self.toDict() + def empty(self) -> bool: + return True + class Plain(BaseMessageComponent): type: ComponentType = ComponentType.Plain @@ -100,6 +103,9 @@ def toDict(self) -> dict: async def to_dict(self) -> dict: return {"type": "text", "data": {"text": self.text}} + def empty(self) -> bool: + return not bool(self.text and self.text.strip()) + class Face(BaseMessageComponent): type: ComponentType = ComponentType.Face @@ -108,6 +114,9 @@ class Face(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return self.id is None + class Record(BaseMessageComponent): type: ComponentType = ComponentType.Record @@ -125,6 +134,9 @@ def __init__(self, file: str | None, **_) -> None: # Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}") super().__init__(file=file, **_) + def empty(self) -> bool: + return not bool(self.file) + @staticmethod def fromFileSystem(path, **_): return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) @@ -224,6 +236,9 @@ class Video(BaseMessageComponent): def __init__(self, file: str, **_) -> None: super().__init__(file=file, **_) + def empty(self) -> bool: + return not bool(self.file) + @staticmethod def fromFileSystem(path, **_): return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) @@ -307,6 +322,9 @@ class At(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not (bool(self.qq) or bool(self.name)) + def toDict(self): return { "type": "at", @@ -327,6 +345,9 @@ class RPS(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return False + class Dice(BaseMessageComponent): # TODO type: ComponentType = ComponentType.Dice @@ -334,6 +355,9 @@ class Dice(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return False + class Shake(BaseMessageComponent): # TODO type: ComponentType = ComponentType.Shake @@ -341,6 +365,9 @@ class Shake(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return False + class Share(BaseMessageComponent): type: ComponentType = ComponentType.Share @@ -352,6 +379,9 @@ class Share(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not (bool(self.url) or bool(self.title)) + class Contact(BaseMessageComponent): # TODO type: ComponentType = ComponentType.Contact @@ -361,6 +391,9 @@ class Contact(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not bool(self._type and self.id) + class Location(BaseMessageComponent): # TODO type: ComponentType = ComponentType.Location @@ -372,6 +405,9 @@ class Location(BaseMessageComponent): # TODO def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not bool(self.lat is not None and self.lon is not None) + class Music(BaseMessageComponent): type: ComponentType = ComponentType.Music @@ -389,6 +425,12 @@ def __init__(self, **_) -> None: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") super().__init__(**_) + def empty(self) -> bool: + return not ( + (self.id and self._type and self._type != "custom") + or (self._type == "custom" and self.url and self.audio and self.title) + ) + class Image(BaseMessageComponent): type: ComponentType = ComponentType.Image @@ -401,6 +443,9 @@ class Image(BaseMessageComponent): def __init__(self, file: str | None, **_) -> None: super().__init__(file=file, **_) + def empty(self) -> bool: + return not bool(self.file) + @staticmethod def fromURL(url: str, **_): if url.startswith("http://") or url.startswith("https://"): @@ -525,6 +570,9 @@ class Reply(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not (bool(self.id) and self.sender_id is not None) + class Poke(BaseMessageComponent): type: ComponentType = ComponentType.Poke @@ -551,6 +599,9 @@ def target_id(self) -> str | None: return text return None + def empty(self) -> bool: + return self.target_id() is None + def toDict(self): target_id = self.target_id() data = {"type": str(self._type or "126")} @@ -566,6 +617,9 @@ class Forward(BaseMessageComponent): def __init__(self, **_) -> None: super().__init__(**_) + def empty(self) -> bool: + return not bool(self.id) + class Node(BaseMessageComponent): """群合并转发消息""" @@ -584,6 +638,9 @@ def __init__(self, content: list[BaseMessageComponent], **_) -> None: content = [content] super().__init__(content=content, **_) + def empty(self) -> bool: + return not bool(self.content) + async def to_dict(self): data_content = [] for comp in self.content: @@ -628,6 +685,9 @@ class Nodes(BaseMessageComponent): def __init__(self, nodes: list[Node], **_) -> None: super().__init__(nodes=nodes, **_) + def empty(self) -> bool: + return not bool(self.nodes) + def toDict(self): """Deprecated. Use to_dict instead""" ret = { @@ -656,11 +716,17 @@ def __init__(self, data: str | dict, **_) -> None: data = json.loads(data) super().__init__(data=data, **_) + def empty(self) -> bool: + return not bool(self.data) + class Unknown(BaseMessageComponent): type: ComponentType = ComponentType.Unknown text: str + def empty(self) -> bool: + return not bool(self.text and self.text.strip()) + class File(BaseMessageComponent): """文件消息段""" @@ -674,6 +740,9 @@ def __init__(self, name: str, file: str = "", url: str = "") -> None: """文件消息段。""" super().__init__(name=name, file_=file, url=url) + def empty(self) -> bool: + return not bool(self.file_ or self.url) + @property def file(self) -> str: """获取文件路径,如果文件不存在但有URL,则同步下载文件 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 64f9e6ffdf..9b77974f8a 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -17,38 +17,6 @@ @register_stage class RespondStage(Stage): - # 组件类型到其非空判断函数的映射 - _component_validators = { - Comp.Plain: lambda comp: bool( - comp.text and comp.text.strip(), - ), # 纯文本消息需要strip - Comp.Face: lambda comp: comp.id is not None, # QQ表情 - Comp.Record: lambda comp: bool(comp.file), # 语音 - Comp.Video: lambda comp: bool(comp.file), # 视频 - Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @ - Comp.Image: lambda comp: bool(comp.file), # 图片 - Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 - Comp.Poke: lambda comp: comp.target_id() is not None, # 戳一戳 - Comp.Node: lambda comp: bool(comp.content), # 转发节点 - Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 - Comp.File: lambda comp: bool(comp.file_ or comp.url), - Comp.Json: lambda comp: bool(comp.data), # Json 卡片 - Comp.Share: lambda comp: bool(comp.url) or bool(comp.title), - Comp.Music: lambda comp: ( - (comp.id and comp._type and comp._type != "custom") - or (comp._type == "custom" and comp.url and comp.audio and comp.title) - ), # 音乐分享 - Comp.Forward: lambda comp: bool(comp.id), # 合并转发 - Comp.Location: lambda comp: bool( - comp.lat is not None and comp.lon is not None - ), # 位置 - Comp.Contact: lambda comp: bool(comp._type and comp.id), # 推荐好友 or 群 - Comp.Shake: lambda _: True, # 窗口抖动(戳一戳) - Comp.Dice: lambda _: True, # 掷骰子魔法表情 - Comp.RPS: lambda _: True, # 猜拳魔法表情 - Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), - } - async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config @@ -106,37 +74,6 @@ async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: # random return random.uniform(self.interval[0], self.interval[1]) - def _platform_component_type(self, comp: BaseMessageComponent) -> str | None: - t = getattr(comp, "type", None) - return getattr(t, "value", t) - - def _is_platform_component_non_empty(self, comp: BaseMessageComponent) -> bool: - ctype = self._platform_component_type(comp) - if ctype == "discord_embed": - return ( - any( - bool(getattr(comp, attr, None)) - for attr in ( - "title", - "description", - "url", - "thumbnail", - "image", - "footer", - "fields", - ) - ) - or getattr(comp, "color", None) is not None - or callable(getattr(comp, "to_discord_embed", None)) - ) - if ctype == "discord_view": - return bool( - getattr(comp, "components", None) - or getattr(comp, "view", None) - or callable(getattr(comp, "to_discord_view", None)), - ) - return False - async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: """检查消息链是否为空 @@ -148,16 +85,9 @@ async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bo return True for comp in chain: - comp_type = type(comp) - - if self._is_platform_component_non_empty(comp): + if not comp.empty(): return False - # 检查组件类型是否在字典中 - if comp_type in self._component_validators: - if self._component_validators[comp_type](comp): - return False - # 如果所有组件都为空 return True diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index d91555ad0b..a864a29dfb 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -39,6 +39,23 @@ def __init__( fields=fields or [], ) + def empty(self) -> bool: + return not ( + any( + bool(value) + for value in ( + self.title, + self.description, + self.url, + self.thumbnail, + self.image, + self.footer, + self.fields, + ) + ) + or self.color is not None + ) + def to_discord_embed(self) -> discord.Embed: """转换为Discord Embed对象""" embed = discord.Embed() @@ -97,6 +114,9 @@ def __init__( disabled=disabled, ) + def empty(self) -> bool: + return not bool(self.label or self.url or self.custom_id or self.emoji) + class DiscordReference(BaseMessageComponent): """Discord引用组件""" @@ -108,6 +128,9 @@ class DiscordReference(BaseMessageComponent): def __init__(self, message_id: str, channel_id: str) -> None: super().__init__(message_id=message_id, channel_id=channel_id) + def empty(self) -> bool: + return not bool(self.message_id and self.channel_id) + class DiscordView(BaseMessageComponent): """Discord视图组件,包含按钮和选择菜单""" @@ -123,6 +146,9 @@ def __init__( ) -> None: super().__init__(components=components or [], timeout=timeout) + def empty(self) -> bool: + return not bool(self.components) + def to_discord_view(self) -> discord.ui.View: """转换为Discord View对象""" view = discord.ui.View(timeout=self.timeout) diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index bcd5914d02..fce41f8200 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -32,6 +32,9 @@ class DiscordViewComponent(BaseMessageComponent): def __init__(self, view: discord.ui.View) -> None: super().__init__(view=view) + def empty(self) -> bool: + return self.view is None + def _component_type(component: BaseMessageComponent) -> str | None: component_type = getattr(component, "type", None) diff --git a/tests/unit/test_aiocqhttp_poke.py b/tests/unit/test_aiocqhttp_poke.py index dff886b91e..d3e6cb4146 100644 --- a/tests/unit/test_aiocqhttp_poke.py +++ b/tests/unit/test_aiocqhttp_poke.py @@ -22,6 +22,7 @@ def test_poke_to_dict_matches_onebot_v11_segment_format(): async def test_respond_stage_treats_poke_with_target_as_non_empty(): stage = RespondStage() chain = [Comp.Poke(type="126", id=2003)] + assert chain[0].empty() is False assert await stage._is_empty_message_chain(chain) is False diff --git a/tests/unit/test_discord_message_components.py b/tests/unit/test_discord_message_components.py index c027ce3666..a17601b750 100644 --- a/tests/unit/test_discord_message_components.py +++ b/tests/unit/test_discord_message_components.py @@ -75,6 +75,8 @@ def __init__(self, title: str) -> None: async def test_respond_stage_keeps_non_empty_discord_components() -> None: stage = RespondStage() + assert DiscordEmbed().empty() is True + assert DiscordView().empty() is True assert await stage._is_empty_message_chain([DiscordEmbed(title="test")]) is False assert ( await stage._is_empty_message_chain( From 8ffccd66daa294a8e07387093aeed0fd9f9ba12b Mon Sep 17 00:00:00 2001 From: Weilong Liao <37870767+Soulter@users.noreply.github.com> Date: Sun, 3 May 2026 20:51:21 +0800 Subject: [PATCH 5/7] Delete tests/unit/test_discord_message_components.py --- tests/unit/test_discord_message_components.py | 124 ------------------ 1 file changed, 124 deletions(-) delete mode 100644 tests/unit/test_discord_message_components.py diff --git a/tests/unit/test_discord_message_components.py b/tests/unit/test_discord_message_components.py deleted file mode 100644 index a17601b750..0000000000 --- a/tests/unit/test_discord_message_components.py +++ /dev/null @@ -1,124 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from astrbot.api.message_components import BaseMessageComponent, File, Image, Plain -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.pipeline.respond.stage import RespondStage -from astrbot.core.platform.sources.discord.components import ( - DiscordButton, - DiscordEmbed, - DiscordReference, - DiscordView, -) -from astrbot.core.platform.sources.discord.discord_platform_event import ( - DiscordPlatformEvent, - DiscordViewComponent, -) - - -def test_discord_components_construct() -> None: - embed = DiscordEmbed(title="test") - button = DiscordButton(label="Click") - view = DiscordView(components=[button]) - reference = DiscordReference(message_id="1", channel_id="2") - view_component = DiscordViewComponent(object()) - - assert embed.title == "test" - assert button.label == "Click" - assert view.components == [button] - assert reference.message_id == "1" - assert view_component.view is not None - - -@pytest.mark.asyncio -async def test_parse_to_discord_handles_discord_embed() -> None: - chain = MessageChain(chain=[DiscordEmbed(title="test")]) - - ( - content, - files, - view, - embeds, - reference, - ) = await DiscordPlatformEvent._parse_to_discord( - object(), - chain, - ) - - assert content == "" - assert files == [] - assert view is None - assert reference is None - assert len(embeds) == 1 - assert embeds[0].title == "test" - - -@pytest.mark.asyncio -async def test_parse_to_discord_handles_duck_typed_discord_embed() -> None: - class CompatibleDiscordEmbed(BaseMessageComponent): - type: str = "discord_embed" - title: str | None = None - - def __init__(self, title: str) -> None: - super().__init__(title=title) - - chain = SimpleNamespace(chain=[CompatibleDiscordEmbed("duck")]) - - _, _, _, embeds, _ = await DiscordPlatformEvent._parse_to_discord(object(), chain) - - assert len(embeds) == 1 - assert embeds[0].title == "duck" - - -@pytest.mark.asyncio -async def test_respond_stage_keeps_non_empty_discord_components() -> None: - stage = RespondStage() - - assert DiscordEmbed().empty() is True - assert DiscordView().empty() is True - assert await stage._is_empty_message_chain([DiscordEmbed(title="test")]) is False - assert ( - await stage._is_empty_message_chain( - [DiscordView(components=[DiscordButton(label="Click")])], - ) - is False - ) - - -@pytest.mark.asyncio -async def test_plain_image_file_regression(tmp_path) -> None: - stage = RespondStage() - file_path = tmp_path / "example.txt" - file_path.write_text("hello") - - assert await stage._is_empty_message_chain([Plain("hello")]) is False - assert await stage._is_empty_message_chain([Image.fromBase64("aGVsbG8=")]) is False - assert ( - await stage._is_empty_message_chain( - [File(name="example.txt", file=str(file_path))], - ) - is False - ) - - ( - content, - files, - view, - embeds, - reference, - ) = await DiscordPlatformEvent._parse_to_discord( - object(), - MessageChain( - chain=[ - Plain("hello"), - Image.fromBase64("aGVsbG8="), - ], - ), - ) - - assert content == "hello" - assert len(files) == 1 - assert view is None - assert embeds == [] - assert reference is None From 2fdec05ebf21565e11d10bafeccbb25e47e76a22 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 3 May 2026 21:00:25 +0800 Subject: [PATCH 6/7] =?UTF-8?q?fix:=20=E7=A7=BB=E9=99=A4=E5=86=97=E4=BD=99?= =?UTF-8?q?=E7=9A=84=20Discord=20=E7=BB=84=E4=BB=B6=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sources/discord/discord_platform_event.py | 83 ++----------------- 1 file changed, 5 insertions(+), 78 deletions(-) diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index fce41f8200..99cae2f578 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -4,7 +4,7 @@ from collections.abc import AsyncGenerator from io import BytesIO from pathlib import Path -from typing import Any, cast +from typing import cast import discord from discord.types.interactions import ComponentInteractionData @@ -12,7 +12,6 @@ from astrbot import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( - BaseMessageComponent, File, Image, Plain, @@ -24,72 +23,6 @@ from .components import DiscordEmbed, DiscordView -# 自定义Discord视图组件(兼容旧版本) -class DiscordViewComponent(BaseMessageComponent): - type: str = "discord_view" - view: Any - - def __init__(self, view: discord.ui.View) -> None: - super().__init__(view=view) - - def empty(self) -> bool: - return self.view is None - - -def _component_type(component: BaseMessageComponent) -> str | None: - component_type = getattr(component, "type", None) - return getattr(component_type, "value", component_type) - - -def _is_component_type(component: BaseMessageComponent, component_type: str) -> bool: - return _component_type(component) == component_type - - -def _to_discord_embed(component: BaseMessageComponent) -> discord.Embed: - converter = getattr(component, "to_discord_embed", None) - if callable(converter): - return converter() - - embed = discord.Embed() - if title := getattr(component, "title", None): - embed.title = title - if description := getattr(component, "description", None): - embed.description = description - if color := getattr(component, "color", None): - embed.color = color - if url := getattr(component, "url", None): - embed.url = url - if thumbnail := getattr(component, "thumbnail", None): - embed.set_thumbnail(url=thumbnail) - if image := getattr(component, "image", None): - embed.set_image(url=image) - if footer := getattr(component, "footer", None): - embed.set_footer(text=footer) - - for field in getattr(component, "fields", None) or []: - embed.add_field( - name=field.get("name", ""), - value=field.get("value", ""), - inline=field.get("inline", False), - ) - - return embed - - -def _to_discord_view(component: BaseMessageComponent) -> discord.ui.View | None: - existing_view = getattr(component, "view", None) - if isinstance(existing_view, discord.ui.View): - return existing_view - - converter = getattr(component, "to_discord_view", None) - if callable(converter): - result = converter() - if isinstance(result, discord.ui.View): - return result - - return None - - class DiscordPlatformEvent(AstrMessageEvent): def __init__( self, @@ -306,18 +239,12 @@ async def _parse_to_discord( logger.warning(f"[Discord] 获取文件失败: {i.name}") except Exception as e: logger.warning(f"[Discord] 处理文件失败: {i.name}, 错误: {e}") - elif isinstance(i, DiscordEmbed) or _is_component_type(i, "discord_embed"): + elif isinstance(i, DiscordEmbed): # Discord Embed消息 - embeds.append(_to_discord_embed(i)) - elif ( - isinstance(i, DiscordView) - or isinstance(i, DiscordViewComponent) - or _is_component_type(i, "discord_view") - ): + embeds.append(i.to_discord_embed()) + elif isinstance(i, DiscordView): # Discord视图组件(按钮、选择菜单等) - parsed_view = _to_discord_view(i) - if parsed_view: - view = parsed_view + view = i.to_discord_view() else: logger.debug( f"[Discord] 忽略了不支持的消息组件: {getattr(i, 'type', None)}" From f6993c981958eac775ccb20e738d49e9ea8e2c87 Mon Sep 17 00:00:00 2001 From: SGSxingchen <853304398@qq.com> Date: Sun, 3 May 2026 15:55:23 +0000 Subject: [PATCH 7/7] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20Discord=20=E7=BB=84?= =?UTF-8?q?=E4=BB=B6=E8=AF=86=E5=88=AB=E5=92=8C=20Embed=20=E9=A2=9C?= =?UTF-8?q?=E8=89=B2=E5=88=A4=E7=A9=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../platform/sources/discord/components.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index a864a29dfb..2a6bce7c1e 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -64,7 +64,7 @@ def to_discord_embed(self) -> discord.Embed: embed.title = self.title if self.description: embed.description = self.description - if self.color: + if self.color is not None: embed.color = self.color if self.url: embed.url = self.url @@ -154,38 +154,30 @@ def to_discord_view(self) -> discord.ui.View: view = discord.ui.View(timeout=self.timeout) for component in self.components or []: - raw_type = getattr(component, "type", None) - comp_type = getattr(raw_type, "value", raw_type) - if isinstance(component, DiscordButton) or comp_type == "discord_button": - style_name = getattr(component, "style", "primary") - label = getattr(component, "label", "") - url = getattr(component, "url", None) - custom_id = getattr(component, "custom_id", None) - emoji = getattr(component, "emoji", None) - disabled = getattr(component, "disabled", False) + if isinstance(component, DiscordButton): button_style = getattr( discord.ButtonStyle, - style_name, + component.style, discord.ButtonStyle.primary, ) - if url: + if component.url: # URL按钮 button = discord.ui.Button( - label=label, + label=component.label, style=discord.ButtonStyle.link, - url=url, - emoji=emoji, - disabled=disabled, + url=component.url, + emoji=component.emoji, + disabled=component.disabled, ) else: # 普通按钮 button = discord.ui.Button( - label=label, + label=component.label, style=button_style, - custom_id=custom_id, - emoji=emoji, - disabled=disabled, + custom_id=component.custom_id, + emoji=component.emoji, + disabled=component.disabled, ) view.add_item(button)