Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdks/python/ag_ui/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,16 @@ class StepFinishedEvent(BaseEvent):
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageChunkEvent,
ThinkingTextMessageStartEvent,
ThinkingTextMessageContentEvent,
ThinkingTextMessageEndEvent,
ToolCallStartEvent,
ToolCallArgsEvent,
ToolCallEndEvent,
ToolCallChunkEvent,
ToolCallResultEvent,
ThinkingStartEvent,
ThinkingEndEvent,
StateSnapshotEvent,
StateDeltaEvent,
MessagesSnapshotEvent,
Expand Down
27 changes: 27 additions & 0 deletions sdks/python/tests/test_events.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest
import json
import typing
from datetime import datetime
from pydantic import ValidationError, TypeAdapter

from ag_ui.core import events as events_module
from ag_ui.core.types import Message, UserMessage, AssistantMessage, FunctionCall, ToolCall
from ag_ui.core.events import (
EventType,
Expand Down Expand Up @@ -596,6 +598,31 @@ def test_event_with_unicode_and_special_chars(self):
# Verify Unicode and special characters are preserved
self.assertEqual(deserialized.delta, text)

def test_all_event_subclasses_in_event_union(self):
"""Ensure all BaseEvent subclasses are included in the Event union type"""

# Get all classes defined in the events module that are subclasses of BaseEvent
event_subclasses = set()
for name in dir(events_module):
obj = getattr(events_module, name)
if (
isinstance(obj, type)
and issubclass(obj, BaseEvent)
and obj is not BaseEvent
):
event_subclasses.add(obj)

# Get all types in the Event union
union_types = set(typing.get_args(typing.get_args(Event)[0]))

# Check that all event subclasses are in the union
missing_from_union = event_subclasses - union_types
self.assertEqual(
missing_from_union,
set(),
f"The following event types are missing from the Event union: {missing_from_union}"
)


if __name__ == "__main__":
unittest.main()