diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e199608a2..7350ab7ed 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -577,27 +577,30 @@ def cleanup(self) -> None: self.tool_registry.cleanup() def add_hook( - self, callback: HookCallback[TEvent], event_type: type[TEvent] | None = None, **kwargs: dict[str, Any] + self, callback: HookCallback[TEvent], event_type: type[TEvent] | list[type[TEvent]] | None = None ) -> None: """Register a callback function for a specific event type. - This method supports two call patterns: + This method supports multiple call patterns: 1. ``add_hook(callback)`` - Event type inferred from callback's type hint 2. ``add_hook(callback, event_type)`` - Event type specified explicitly + 3. ``add_hook(callback, [TypeA, TypeB])`` - Register for multiple event types + + When the callback's type hint is a union type (``A | B`` or ``Union[A, B]``), + the callback is automatically registered for each event type in the union. Callbacks can be either synchronous or asynchronous functions. Args: callback: The callback function to invoke when events of this type occur. - event_type: The class type of events this callback should handle. - If not provided, the event type will be inferred from the callback's - first parameter type hint. - **kwargs: Additional arguments (ignored). - + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. If a list is provided, + the callback is registered for each type in the list. Raises: ValueError: If event_type is not provided and cannot be inferred from - the callback's type hints. + the callback's type hints, or if the event_type list is empty. Example: ```python @@ -611,6 +614,16 @@ def log_model_call(event: BeforeModelCallEvent) -> None: # With explicit event type agent.add_hook(log_model_call, BeforeModelCallEvent) + + # With union type hint (registers for all types) + def log_event(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + print(f"Event: {type(event).__name__}") + agent.add_hook(log_event) + + # With list of event types + def multi_handler(event) -> None: + print(f"Event: {type(event).__name__}") + agent.add_hook(multi_handler, [BeforeModelCallEvent, AfterModelCallEvent]) ``` Docs: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 2f465a751..886ea5644 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,6 +9,7 @@ import inspect import logging +import types from collections.abc import Awaitable, Generator from dataclasses import dataclass from typing import ( @@ -17,6 +18,10 @@ Generic, Protocol, TypeVar, + Union, + cast, + get_args, + get_origin, get_type_hints, runtime_checkable, ) @@ -167,22 +172,27 @@ def __init__(self) -> None: def add_callback( self, - event_type: type[TEvent] | None, + event_type: type[TEvent] | list[type[TEvent]] | None, callback: HookCallback[TEvent], ) -> None: """Register a callback function for a specific event type. If ``event_type`` is None, then this will check the callback handler type hint - for the lifecycle event type. + for the lifecycle event type. Union types (``A | B`` or ``Union[A, B]``) in + type hints will register the callback for each event type in the union. + + If ``event_type`` is a list, the callback will be registered for each event + type in the list (duplicates are ignored). Args: - event_type: The class type of events this callback should handle. + event_type: The lifecycle event type(s) this callback should handle. + Can be a single type, a list of types, or None to infer from type hints. callback: The callback function to invoke when events of this type occur. Raises: ValueError: If event_type is not provided and cannot be inferred from the callback's type hints, or if AgentInitializedEvent is registered - with an async callback. + with an async callback, or if the event_type list is empty. Example: ```python @@ -194,35 +204,77 @@ def my_handler(event: StartRequestEvent): # With event type inferred from type hint registry.add_callback(None, my_handler) + + # With union type hint (registers for both types) + def union_handler(event: BeforeModelCallEvent | AfterModelCallEvent): + print(f"Event: {type(event).__name__}") + registry.add_callback(None, union_handler) + + # With list of event types + def multi_handler(event): + print(f"Event: {type(event).__name__}") + registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], multi_handler) ``` """ - resolved_event_type: type[TEvent] - - # Support both add_callback(None, callback) and add_callback(event_type, callback) - if event_type is None: - # callback provided but event_type is None - infer it - resolved_event_type = self._infer_event_type(callback) + resolved_event_types: list[type[TEvent]] + + # Handle list of event types + if isinstance(event_type, list): + if not event_type: + raise ValueError("event_type list cannot be empty") + resolved_event_types = self._validate_event_type_list(event_type) + elif event_type is None: + # Infer event type(s) from callback type hints + resolved_event_types = self._infer_event_types(callback) else: - resolved_event_type = event_type + # Single event type provided explicitly + resolved_event_types = [event_type] - # Related issue: https://github.com/strands-agents/sdk-python/issues/330 - if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): - raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") + # Deduplicate event types while preserving order + unique_event_types: set[type[TEvent]] = set(resolved_event_types) - callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) - callbacks.append(callback) + # Register callback for each event type + for resolved_event_type in unique_event_types: + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 + if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") - def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: - """Infer the event type from a callback's type hints. + callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) + callbacks.append(callback) + + def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[type[TEvent]]: + """Validate that all types in a list are valid BaseHookEvent subclasses. + + Args: + event_types: List of event types to validate. + + Returns: + The validated list of event types. + + Raises: + ValueError: If any type is not a valid BaseHookEvent subclass. + """ + validated: list[type[TEvent]] = [] + for et in event_types: + if not (isinstance(et, type) and issubclass(et, BaseHookEvent)): + raise ValueError(f"Invalid event type: {et} | must be a subclass of BaseHookEvent") + validated.append(et) + return validated + + def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]: + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). Args: callback: The callback function to inspect. Returns: - The event type inferred from the callback's first parameter type hint. + A list of event types inferred from the callback's first parameter type hint. Raises: - ValueError: If the event type cannot be inferred from the callback's type hints. + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. """ try: hints = get_type_hints(callback) @@ -250,9 +302,21 @@ def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: "cannot infer event type, please provide event_type explicitly" ) + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast(type[TEvent], arg)) + return event_types + # Handle single type if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return type_hint # type: ignore[return-value] + return [cast(type[TEvent], type_hint)] raise ValueError( f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 5331bfa43..79829b92b 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -1,8 +1,16 @@ import unittest.mock +from typing import Union import pytest -from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry +from strands.hooks import ( + AfterModelCallEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + HookRegistry, +) from strands.interrupt import Interrupt, _InterruptState @@ -155,3 +163,148 @@ def callback(event: BeforeInvocationEvent) -> None: assert BeforeInvocationEvent in registry._registered_callbacks assert callback in registry._registered_callbacks[BeforeInvocationEvent] + +# ========== Tests for union type support ========== + + +def test_hook_registry_add_callback_infers_union_types_pipe_syntax(registry): + """Test that add_callback registers callback for each type in A | B union.""" + + def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + pass + + registry.add_callback(None, union_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert union_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_infers_union_types_union_syntax(registry): + """Test that add_callback registers callback for each type in Union[A, B].""" + + def union_callback(event: Union[BeforeModelCallEvent, AfterModelCallEvent]) -> None: # noqa: UP007 + pass + + registry.add_callback(None, union_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert union_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_union_with_none_raises_error(registry): + """Test that add_callback raises error when union contains None.""" + + def callback_with_none(event: BeforeModelCallEvent | None) -> None: + pass + + with pytest.raises(ValueError, match="None is not a valid event type"): + registry.add_callback(None, callback_with_none) + + +def test_hook_registry_add_callback_union_with_invalid_type_raises_error(registry): + """Test that add_callback raises error when union contains non-BaseHookEvent type.""" + + def callback_with_invalid_type(event: BeforeModelCallEvent | str) -> None: + pass + + with pytest.raises(ValueError, match="Invalid type in union"): + registry.add_callback(None, callback_with_invalid_type) + + +def test_hook_registry_add_callback_union_multiple_types(registry): + """Test that add_callback handles union with more than two types.""" + + def multi_union_callback(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent) -> None: + pass + + registry.add_callback(None, multi_union_callback) + + # Callback should be registered for all three event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert BeforeInvocationEvent in registry._registered_callbacks + assert multi_union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert multi_union_callback in registry._registered_callbacks[AfterModelCallEvent] + assert multi_union_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +# ========== Tests for list of types support ========== + + +def test_hook_registry_add_callback_with_list_of_types(registry): + """Test that add_callback registers callback for each type in a list.""" + + def my_callback(event) -> None: + pass + + registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], my_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert my_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert my_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_with_list_deduplicates(registry): + """Test that add_callback deduplicates event types in a list.""" + + def my_callback(event) -> None: + pass + + # Same type appears multiple times + registry.add_callback([BeforeModelCallEvent, BeforeModelCallEvent, AfterModelCallEvent], my_callback) + + # Callback should be registered only once per event type + assert len(registry._registered_callbacks[BeforeModelCallEvent]) == 1 + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + +def test_hook_registry_add_callback_with_list_validates_types(registry): + """Test that add_callback validates all types in a list are BaseHookEvent subclasses.""" + + def my_callback(event) -> None: + pass + + with pytest.raises(ValueError, match="Invalid event type"): + registry.add_callback([BeforeModelCallEvent, str], my_callback) + + +def test_hook_registry_add_callback_with_empty_list_raises_error(registry): + """Test that add_callback raises error when given an empty list.""" + + def my_callback(event) -> None: + pass + + with pytest.raises(ValueError, match="event_type list cannot be empty"): + registry.add_callback([], my_callback) + + +@pytest.mark.asyncio +async def test_hook_registry_union_callback_invoked_for_each_type(registry, agent): + """Test that a union-registered callback is invoked correctly for each event type.""" + call_count = {"before": 0, "after": 0} + + def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + if isinstance(event, BeforeModelCallEvent): + call_count["before"] += 1 + elif isinstance(event, AfterModelCallEvent): + call_count["after"] += 1 + + registry.add_callback(None, union_callback) + + # Invoke BeforeModelCallEvent + before_event = BeforeModelCallEvent(agent=agent) + await registry.invoke_callbacks_async(before_event) + assert call_count["before"] == 1 + + # Invoke AfterModelCallEvent + after_event = AfterModelCallEvent(agent=agent) + await registry.invoke_callbacks_async(after_event) + assert call_count["after"] == 1