From 52f9f9dc5ebf7d3b83de96c0be08cd8f402de440 Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Wed, 18 Feb 2026 20:40:04 +0000 Subject: [PATCH 1/6] feat(hooks): support union types and list of types for add_hook - Add union type support (A | B and Union[A, B]) for callback type hints - Allow passing list of event types to add_hook/add_callback - Remove unused **kwargs from Agent.add_hook() - Deduplicate event types when registering callbacks - Validate all types in unions and lists are BaseHookEvent subclasses - Error on None or invalid types in unions Resolves #1714 --- src/strands/agent/agent.py | 29 +++-- src/strands/hooks/registry.py | 110 +++++++++++++++---- tests/strands/hooks/test_registry.py | 156 ++++++++++++++++++++++++++- 3 files changed, 265 insertions(+), 30 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index e199608a2..9a0ad8373 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -577,27 +577,30 @@ def cleanup(self) -> None: self.tool_registry.cleanup() def add_hook( - self, callback: HookCallback[TEvent], event_type: type[TEvent] | None = None, **kwargs: dict[str, Any] + self, callback: HookCallback[TEvent], event_type: type[TEvent] | list[type[TEvent]] | None = None ) -> None: """Register a callback function for a specific event type. - This method supports two call patterns: + This method supports multiple call patterns: 1. ``add_hook(callback)`` - Event type inferred from callback's type hint 2. ``add_hook(callback, event_type)`` - Event type specified explicitly + 3. ``add_hook(callback, [TypeA, TypeB])`` - Register for multiple event types + + When the callback's type hint is a union type (``A | B`` or ``Union[A, B]``), + the callback is automatically registered for each event type in the union. Callbacks can be either synchronous or asynchronous functions. Args: callback: The callback function to invoke when events of this type occur. - event_type: The class type of events this callback should handle. - If not provided, the event type will be inferred from the callback's - first parameter type hint. - **kwargs: Additional arguments (ignored). - + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from + the callback's first parameter type hint. If a list is provided, + the callback is registered for each type in the list. Raises: ValueError: If event_type is not provided and cannot be inferred from - the callback's type hints. + the callback's type hints, or if the event_type list is empty. Example: ```python @@ -611,6 +614,16 @@ def log_model_call(event: BeforeModelCallEvent) -> None: # With explicit event type agent.add_hook(log_model_call, BeforeModelCallEvent) + + # With union type hint (registers for both types) + def log_event(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + print(f"Event: {type(event).__name__}") + agent.add_hook(log_event) + + # With list of event types + def multi_handler(event) -> None: + print(f"Event: {type(event).__name__}") + agent.add_hook(multi_handler, [BeforeModelCallEvent, AfterModelCallEvent]) ``` Docs: https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/ diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 2f465a751..25663f725 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,6 +9,7 @@ import inspect import logging +import types from collections.abc import Awaitable, Generator from dataclasses import dataclass from typing import ( @@ -17,6 +18,9 @@ Generic, Protocol, TypeVar, + Union, + get_args, + get_origin, get_type_hints, runtime_checkable, ) @@ -167,22 +171,27 @@ def __init__(self) -> None: def add_callback( self, - event_type: type[TEvent] | None, + event_type: type[TEvent] | list[type[TEvent]] | None, callback: HookCallback[TEvent], ) -> None: """Register a callback function for a specific event type. If ``event_type`` is None, then this will check the callback handler type hint - for the lifecycle event type. + for the lifecycle event type. Union types (``A | B`` or ``Union[A, B]``) in + type hints will register the callback for each event type in the union. + + If ``event_type`` is a list, the callback will be registered for each event + type in the list (duplicates are ignored). Args: - event_type: The class type of events this callback should handle. + event_type: The class type(s) of events this callback should handle. + Can be a single type, a list of types, or None to infer from type hints. callback: The callback function to invoke when events of this type occur. Raises: ValueError: If event_type is not provided and cannot be inferred from the callback's type hints, or if AgentInitializedEvent is registered - with an async callback. + with an async callback, or if the event_type list is empty. Example: ```python @@ -194,35 +203,82 @@ def my_handler(event: StartRequestEvent): # With event type inferred from type hint registry.add_callback(None, my_handler) + + # With union type hint (registers for both types) + def union_handler(event: BeforeModelCallEvent | AfterModelCallEvent): + print(f"Event: {type(event).__name__}") + registry.add_callback(None, union_handler) + + # With list of event types + def multi_handler(event): + print(f"Event: {type(event).__name__}") + registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], multi_handler) ``` """ - resolved_event_type: type[TEvent] - - # Support both add_callback(None, callback) and add_callback(event_type, callback) - if event_type is None: - # callback provided but event_type is None - infer it - resolved_event_type = self._infer_event_type(callback) + resolved_event_types: list[type[TEvent]] + + # Handle list of event types + if isinstance(event_type, list): + if not event_type: + raise ValueError("event_type list cannot be empty") + resolved_event_types = self._validate_event_type_list(event_type) + elif event_type is None: + # Infer event type(s) from callback type hints + resolved_event_types = self._infer_event_types(callback) else: - resolved_event_type = event_type + # Single event type provided explicitly + resolved_event_types = [event_type] - # Related issue: https://github.com/strands-agents/sdk-python/issues/330 - if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): - raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") + # Deduplicate event types + seen: set[type[TEvent]] = set() + unique_event_types: list[type[TEvent]] = [] + for et in resolved_event_types: + if et not in seen: + seen.add(et) + unique_event_types.append(et) - callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) - callbacks.append(callback) + # Register callback for each event type + for resolved_event_type in unique_event_types: + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 + if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") - def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: - """Infer the event type from a callback's type hints. + callbacks = self._registered_callbacks.setdefault(resolved_event_type, []) + callbacks.append(callback) + + def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[type[TEvent]]: + """Validate that all types in a list are valid BaseHookEvent subclasses. + + Args: + event_types: List of event types to validate. + + Returns: + The validated list of event types. + + Raises: + ValueError: If any type is not a valid BaseHookEvent subclass. + """ + validated: list[type[TEvent]] = [] + for et in event_types: + if not (isinstance(et, type) and issubclass(et, BaseHookEvent)): + raise ValueError(f"Invalid event type: {et} | must be a subclass of BaseHookEvent") + validated.append(et) + return validated + + def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]: + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). Args: callback: The callback function to inspect. Returns: - The event type inferred from the callback's first parameter type hint. + A list of event types inferred from the callback's first parameter type hint. Raises: - ValueError: If the event type cannot be inferred from the callback's type hints. + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. """ try: hints = get_type_hints(callback) @@ -250,9 +306,21 @@ def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]: "cannot infer event type, please provide event_type explicitly" ) + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(arg) # type: ignore[arg-type] + return event_types + # Handle single type if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return type_hint # type: ignore[return-value] + return [type_hint] # type: ignore[list-item] raise ValueError( f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 5331bfa43..5b0f3c574 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -1,8 +1,16 @@ import unittest.mock +from typing import Union import pytest -from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry +from strands.hooks import ( + AfterModelCallEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + HookRegistry, +) from strands.interrupt import Interrupt, _InterruptState @@ -155,3 +163,149 @@ def callback(event: BeforeInvocationEvent) -> None: assert BeforeInvocationEvent in registry._registered_callbacks assert callback in registry._registered_callbacks[BeforeInvocationEvent] + + +# ========== Tests for union type support ========== + + +def test_hook_registry_add_callback_infers_union_types_pipe_syntax(registry): + """Test that add_callback registers callback for each type in A | B union.""" + + def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + pass + + registry.add_callback(None, union_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert union_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_infers_union_types_union_syntax(registry): + """Test that add_callback registers callback for each type in Union[A, B].""" + + def union_callback(event: Union[BeforeModelCallEvent, AfterModelCallEvent]) -> None: # noqa: UP007 + pass + + registry.add_callback(None, union_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert union_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_union_with_none_raises_error(registry): + """Test that add_callback raises error when union contains None.""" + + def callback_with_none(event: BeforeModelCallEvent | None) -> None: + pass + + with pytest.raises(ValueError, match="None is not a valid event type"): + registry.add_callback(None, callback_with_none) + + +def test_hook_registry_add_callback_union_with_invalid_type_raises_error(registry): + """Test that add_callback raises error when union contains non-BaseHookEvent type.""" + + def callback_with_invalid_type(event: BeforeModelCallEvent | str) -> None: + pass + + with pytest.raises(ValueError, match="Invalid type in union"): + registry.add_callback(None, callback_with_invalid_type) + + +def test_hook_registry_add_callback_union_multiple_types(registry): + """Test that add_callback handles union with more than two types.""" + + def multi_union_callback(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent) -> None: + pass + + registry.add_callback(None, multi_union_callback) + + # Callback should be registered for all three event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert BeforeInvocationEvent in registry._registered_callbacks + assert multi_union_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert multi_union_callback in registry._registered_callbacks[AfterModelCallEvent] + assert multi_union_callback in registry._registered_callbacks[BeforeInvocationEvent] + + +# ========== Tests for list of types support ========== + + +def test_hook_registry_add_callback_with_list_of_types(registry): + """Test that add_callback registers callback for each type in a list.""" + + def my_callback(event) -> None: + pass + + registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], my_callback) + + # Callback should be registered for both event types + assert BeforeModelCallEvent in registry._registered_callbacks + assert AfterModelCallEvent in registry._registered_callbacks + assert my_callback in registry._registered_callbacks[BeforeModelCallEvent] + assert my_callback in registry._registered_callbacks[AfterModelCallEvent] + + +def test_hook_registry_add_callback_with_list_deduplicates(registry): + """Test that add_callback deduplicates event types in a list.""" + + def my_callback(event) -> None: + pass + + # Same type appears multiple times + registry.add_callback([BeforeModelCallEvent, BeforeModelCallEvent, AfterModelCallEvent], my_callback) + + # Callback should be registered only once per event type + assert len(registry._registered_callbacks[BeforeModelCallEvent]) == 1 + assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1 + + +def test_hook_registry_add_callback_with_list_validates_types(registry): + """Test that add_callback validates all types in a list are BaseHookEvent subclasses.""" + + def my_callback(event) -> None: + pass + + with pytest.raises(ValueError, match="Invalid event type"): + registry.add_callback([BeforeModelCallEvent, str], my_callback) + + +def test_hook_registry_add_callback_with_empty_list_raises_error(registry): + """Test that add_callback raises error when given an empty list.""" + + def my_callback(event) -> None: + pass + + with pytest.raises(ValueError, match="event_type list cannot be empty"): + registry.add_callback([], my_callback) + + +@pytest.mark.asyncio +async def test_hook_registry_union_callback_invoked_for_each_type(registry, agent): + """Test that a union-registered callback is invoked correctly for each event type.""" + call_count = {"before": 0, "after": 0} + + def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: + if isinstance(event, BeforeModelCallEvent): + call_count["before"] += 1 + elif isinstance(event, AfterModelCallEvent): + call_count["after"] += 1 + + registry.add_callback(None, union_callback) + + # Invoke BeforeModelCallEvent + before_event = BeforeModelCallEvent(agent=agent) + await registry.invoke_callbacks_async(before_event) + assert call_count["before"] == 1 + + # Invoke AfterModelCallEvent + after_event = AfterModelCallEvent(agent=agent) + await registry.invoke_callbacks_async(after_event) + assert call_count["after"] == 1 From 2ad60725cc035458ccde5f9492f95634e08f5d3c Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Wed, 18 Feb 2026 21:07:18 +0000 Subject: [PATCH 2/6] refactor: address review feedback - Use set() for deduplication instead of manual loop - Replace type: ignore with cast() for explicit type narrowing - Restore missing test case --- .artifact/write_operations.jsonl | 3 +++ src/strands/hooks/registry.py | 14 +++++--------- tests/strands/hooks/test_registry.py | 13 +++++++++++++ 3 files changed, 21 insertions(+), 9 deletions(-) create mode 100644 .artifact/write_operations.jsonl diff --git a/.artifact/write_operations.jsonl b/.artifact/write_operations.jsonl new file mode 100644 index 000000000..9a87e609a --- /dev/null +++ b/.artifact/write_operations.jsonl @@ -0,0 +1,3 @@ +{"timestamp": "2026-02-18T21:07:11.256137Z", "function": "reply_to_review_comment", "args": [], "kwargs": {"pr_number": 1719, "comment_id": 2824420583, "reply_text": "Updated to use `cast()` instead - cleaner and more explicit.", "repo": null}} +{"timestamp": "2026-02-18T21:07:11.257190Z", "function": "reply_to_review_comment", "args": [], "kwargs": {"pr_number": 1719, "comment_id": 2824414612, "reply_text": "Good suggestion! Simplified to use `set()` directly for deduplication.", "repo": null}} +{"timestamp": "2026-02-18T21:07:11.259500Z", "function": "reply_to_review_comment", "args": [], "kwargs": {"pr_number": 1719, "comment_id": 2824423120, "reply_text": "Updated to use `cast()` instead - cleaner and more explicit.", "repo": null}} diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 25663f725..bad618488 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -19,6 +19,7 @@ Protocol, TypeVar, Union, + cast, get_args, get_origin, get_type_hints, @@ -229,13 +230,8 @@ def multi_handler(event): # Single event type provided explicitly resolved_event_types = [event_type] - # Deduplicate event types - seen: set[type[TEvent]] = set() - unique_event_types: list[type[TEvent]] = [] - for et in resolved_event_types: - if et not in seen: - seen.add(et) - unique_event_types.append(et) + # Deduplicate event types while preserving order + unique_event_types: set[type[TEvent]] = set(resolved_event_types) # Register callback for each event type for resolved_event_type in unique_event_types: @@ -315,12 +311,12 @@ def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent raise ValueError("None is not a valid event type in union") if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") - event_types.append(arg) # type: ignore[arg-type] + event_types.append(cast(type[TEvent], arg)) return event_types # Handle single type if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return [type_hint] # type: ignore[list-item] + return [cast(type[TEvent], type_hint)] raise ValueError( f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 5b0f3c574..c53e6f3bf 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -165,6 +165,19 @@ def callback(event: BeforeInvocationEvent) -> None: assert callback in registry._registered_callbacks[BeforeInvocationEvent] +def test_hook_registry_add_callback_raises_error_on_type_hints_failure(registry): + """Test that add_callback raises error when get_type_hints fails.""" + + class BadCallback: + def __call__(self, event: "NonExistentType") -> None: # noqa: F821 + pass + + callback = BadCallback() + + with pytest.raises(ValueError, match="failed to get type hints for callback"): + registry.add_callback(None, callback) + + # ========== Tests for union type support ========== From 95a0293ea3abf3453f92153517b402a54f001f53 Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Wed, 18 Feb 2026 21:07:53 +0000 Subject: [PATCH 3/6] Additional changes from write operations --- .artifact/write_operations.jsonl | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .artifact/write_operations.jsonl diff --git a/.artifact/write_operations.jsonl b/.artifact/write_operations.jsonl deleted file mode 100644 index 9a87e609a..000000000 --- a/.artifact/write_operations.jsonl +++ /dev/null @@ -1,3 +0,0 @@ -{"timestamp": "2026-02-18T21:07:11.256137Z", "function": "reply_to_review_comment", "args": [], "kwargs": {"pr_number": 1719, "comment_id": 2824420583, "reply_text": "Updated to use `cast()` instead - cleaner and more explicit.", "repo": null}} -{"timestamp": "2026-02-18T21:07:11.257190Z", "function": "reply_to_review_comment", "args": [], "kwargs": {"pr_number": 1719, "comment_id": 2824414612, "reply_text": "Good suggestion! Simplified to use `set()` directly for deduplication.", "repo": null}} -{"timestamp": "2026-02-18T21:07:11.259500Z", "function": "reply_to_review_comment", "args": [], "kwargs": {"pr_number": 1719, "comment_id": 2824423120, "reply_text": "Updated to use `cast()` instead - cleaner and more explicit.", "repo": null}} From dd408471cf1300663997351bc42d819f916f847a Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 18 Feb 2026 16:16:21 -0500 Subject: [PATCH 4/6] Apply suggestions from code review --- src/strands/hooks/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index bad618488..886ea5644 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -185,7 +185,7 @@ def add_callback( type in the list (duplicates are ignored). Args: - event_type: The class type(s) of events this callback should handle. + event_type: The lifecycle event type(s) this callback should handle. Can be a single type, a list of types, or None to infer from type hints. callback: The callback function to invoke when events of this type occur. From feb1cbce2a140c4923e305524c6fdab784fdacce Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 18 Feb 2026 16:17:16 -0500 Subject: [PATCH 5/6] Apply suggestion from @Unshure --- tests/strands/hooks/test_registry.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index c53e6f3bf..79829b92b 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -164,20 +164,6 @@ def callback(event: BeforeInvocationEvent) -> None: assert BeforeInvocationEvent in registry._registered_callbacks assert callback in registry._registered_callbacks[BeforeInvocationEvent] - -def test_hook_registry_add_callback_raises_error_on_type_hints_failure(registry): - """Test that add_callback raises error when get_type_hints fails.""" - - class BadCallback: - def __call__(self, event: "NonExistentType") -> None: # noqa: F821 - pass - - callback = BadCallback() - - with pytest.raises(ValueError, match="failed to get type hints for callback"): - registry.add_callback(None, callback) - - # ========== Tests for union type support ========== From 6bf6cd6edd183c62555f25528a5339f239f76986 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 18 Feb 2026 16:44:38 -0500 Subject: [PATCH 6/6] Update src/strands/agent/agent.py --- src/strands/agent/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9a0ad8373..7350ab7ed 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -615,7 +615,7 @@ def log_model_call(event: BeforeModelCallEvent) -> None: # With explicit event type agent.add_hook(log_model_call, BeforeModelCallEvent) - # With union type hint (registers for both types) + # With union type hint (registers for all types) def log_event(event: BeforeModelCallEvent | AfterModelCallEvent) -> None: print(f"Event: {type(event).__name__}") agent.add_hook(log_event)