Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,28 +576,31 @@
"""
self.tool_registry.cleanup()

def add_hook(

Check warning on line 579 in src/strands/agent/agent.py

View workflow job for this annotation

GitHub Actions / check-api

Agent.add_hook(kwargs)

Parameter was removed
self, callback: HookCallback[TEvent], event_type: type[TEvent] | None = None, **kwargs: dict[str, Any]
self, callback: HookCallback[TEvent], event_type: type[TEvent] | list[type[TEvent]] | None = None
) -> None:
"""Register a callback function for a specific event type.

This method supports two call patterns:
This method supports multiple call patterns:
1. ``add_hook(callback)`` - Event type inferred from callback's type hint
2. ``add_hook(callback, event_type)`` - Event type specified explicitly
3. ``add_hook(callback, [TypeA, TypeB])`` - Register for multiple event types

When the callback's type hint is a union type (``A | B`` or ``Union[A, B]``),
the callback is automatically registered for each event type in the union.

Callbacks can be either synchronous or asynchronous functions.

Args:
callback: The callback function to invoke when events of this type occur.
event_type: The class type of events this callback should handle.
If not provided, the event type will be inferred from the callback's
first parameter type hint.
**kwargs: Additional arguments (ignored).

event_type: The class type(s) of events this callback should handle.
Can be a single type, a list of types, or None to infer from
the callback's first parameter type hint. If a list is provided,
the callback is registered for each type in the list.

Raises:
ValueError: If event_type is not provided and cannot be inferred from
the callback's type hints.
the callback's type hints, or if the event_type list is empty.

Example:
```python
Expand All @@ -611,6 +614,16 @@

# With explicit event type
agent.add_hook(log_model_call, BeforeModelCallEvent)

# With union type hint (registers for all types)
def log_event(event: BeforeModelCallEvent | AfterModelCallEvent) -> None:
print(f"Event: {type(event).__name__}")
agent.add_hook(log_event)

# With list of event types
def multi_handler(event) -> None:
print(f"Event: {type(event).__name__}")
agent.add_hook(multi_handler, [BeforeModelCallEvent, AfterModelCallEvent])
```
Docs:
https://strandsagents.com/latest/documentation/docs/user-guide/concepts/agents/hooks/
Expand Down
106 changes: 85 additions & 21 deletions src/strands/hooks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import inspect
import logging
import types
from collections.abc import Awaitable, Generator
from dataclasses import dataclass
from typing import (
Expand All @@ -17,6 +18,10 @@
Generic,
Protocol,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
runtime_checkable,
)
Expand Down Expand Up @@ -167,22 +172,27 @@ def __init__(self) -> None:

def add_callback(
self,
event_type: type[TEvent] | None,
event_type: type[TEvent] | list[type[TEvent]] | None,
callback: HookCallback[TEvent],
) -> None:
"""Register a callback function for a specific event type.

If ``event_type`` is None, then this will check the callback handler type hint
for the lifecycle event type.
for the lifecycle event type. Union types (``A | B`` or ``Union[A, B]``) in
type hints will register the callback for each event type in the union.

If ``event_type`` is a list, the callback will be registered for each event
type in the list (duplicates are ignored).

Args:
event_type: The class type of events this callback should handle.
event_type: The lifecycle event type(s) this callback should handle.
Can be a single type, a list of types, or None to infer from type hints.
callback: The callback function to invoke when events of this type occur.

Raises:
ValueError: If event_type is not provided and cannot be inferred from
the callback's type hints, or if AgentInitializedEvent is registered
with an async callback.
with an async callback, or if the event_type list is empty.

Example:
```python
Expand All @@ -194,35 +204,77 @@ def my_handler(event: StartRequestEvent):

# With event type inferred from type hint
registry.add_callback(None, my_handler)

# With union type hint (registers for both types)
def union_handler(event: BeforeModelCallEvent | AfterModelCallEvent):
print(f"Event: {type(event).__name__}")
registry.add_callback(None, union_handler)

# With list of event types
def multi_handler(event):
print(f"Event: {type(event).__name__}")
registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], multi_handler)
```
"""
resolved_event_type: type[TEvent]

# Support both add_callback(None, callback) and add_callback(event_type, callback)
if event_type is None:
# callback provided but event_type is None - infer it
resolved_event_type = self._infer_event_type(callback)
resolved_event_types: list[type[TEvent]]

# Handle list of event types
if isinstance(event_type, list):
if not event_type:
raise ValueError("event_type list cannot be empty")
resolved_event_types = self._validate_event_type_list(event_type)
elif event_type is None:
# Infer event type(s) from callback type hints
resolved_event_types = self._infer_event_types(callback)
else:
resolved_event_type = event_type
# Single event type provided explicitly
resolved_event_types = [event_type]

# Related issue: https://github.com/strands-agents/sdk-python/issues/330
if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")
# Deduplicate event types while preserving order
unique_event_types: set[type[TEvent]] = set(resolved_event_types)

callbacks = self._registered_callbacks.setdefault(resolved_event_type, [])
callbacks.append(callback)
# Register callback for each event type
for resolved_event_type in unique_event_types:
# Related issue: https://github.com/strands-agents/sdk-python/issues/330
if resolved_event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback):
raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback")

def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]:
"""Infer the event type from a callback's type hints.
callbacks = self._registered_callbacks.setdefault(resolved_event_type, [])
callbacks.append(callback)

def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[type[TEvent]]:
"""Validate that all types in a list are valid BaseHookEvent subclasses.

Args:
event_types: List of event types to validate.

Returns:
The validated list of event types.

Raises:
ValueError: If any type is not a valid BaseHookEvent subclass.
"""
validated: list[type[TEvent]] = []
for et in event_types:
if not (isinstance(et, type) and issubclass(et, BaseHookEvent)):
raise ValueError(f"Invalid event type: {et} | must be a subclass of BaseHookEvent")
validated.append(et)
return validated

def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]:
"""Infer the event type(s) from a callback's type hints.

Supports both single types and union types (A | B or Union[A, B]).

Args:
callback: The callback function to inspect.

Returns:
The event type inferred from the callback's first parameter type hint.
A list of event types inferred from the callback's first parameter type hint.

Raises:
ValueError: If the event type cannot be inferred from the callback's type hints.
ValueError: If the event type cannot be inferred from the callback's type hints,
or if a union contains None or non-BaseHookEvent types.
"""
try:
hints = get_type_hints(callback)
Expand Down Expand Up @@ -250,9 +302,21 @@ def _infer_event_type(self, callback: HookCallback[TEvent]) -> type[TEvent]:
"cannot infer event type, please provide event_type explicitly"
)

# Check if it's a Union type (Union[A, B] or A | B)
origin = get_origin(type_hint)
if origin is Union or origin is types.UnionType:
event_types: list[type[TEvent]] = []
for arg in get_args(type_hint):
if arg is type(None):
raise ValueError("None is not a valid event type in union")
if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)):
raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent")
event_types.append(cast(type[TEvent], arg))
return event_types

# Handle single type
if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent):
return type_hint # type: ignore[return-value]
return [cast(type[TEvent], type_hint)]

raise ValueError(
f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent"
Expand Down
155 changes: 154 additions & 1 deletion tests/strands/hooks/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import unittest.mock
from typing import Union

import pytest

from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry
from strands.hooks import (
AfterModelCallEvent,
AgentInitializedEvent,
BeforeInvocationEvent,
BeforeModelCallEvent,
BeforeToolCallEvent,
HookRegistry,
)
from strands.interrupt import Interrupt, _InterruptState


Expand Down Expand Up @@ -155,3 +163,148 @@ def callback(event: BeforeInvocationEvent) -> None:

assert BeforeInvocationEvent in registry._registered_callbacks
assert callback in registry._registered_callbacks[BeforeInvocationEvent]

# ========== Tests for union type support ==========


def test_hook_registry_add_callback_infers_union_types_pipe_syntax(registry):
"""Test that add_callback registers callback for each type in A | B union."""

def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None:
pass

registry.add_callback(None, union_callback)

# Callback should be registered for both event types
assert BeforeModelCallEvent in registry._registered_callbacks
assert AfterModelCallEvent in registry._registered_callbacks
assert union_callback in registry._registered_callbacks[BeforeModelCallEvent]
assert union_callback in registry._registered_callbacks[AfterModelCallEvent]


def test_hook_registry_add_callback_infers_union_types_union_syntax(registry):
"""Test that add_callback registers callback for each type in Union[A, B]."""

def union_callback(event: Union[BeforeModelCallEvent, AfterModelCallEvent]) -> None: # noqa: UP007
pass

registry.add_callback(None, union_callback)

# Callback should be registered for both event types
assert BeforeModelCallEvent in registry._registered_callbacks
assert AfterModelCallEvent in registry._registered_callbacks
assert union_callback in registry._registered_callbacks[BeforeModelCallEvent]
assert union_callback in registry._registered_callbacks[AfterModelCallEvent]


def test_hook_registry_add_callback_union_with_none_raises_error(registry):
"""Test that add_callback raises error when union contains None."""

def callback_with_none(event: BeforeModelCallEvent | None) -> None:
pass

with pytest.raises(ValueError, match="None is not a valid event type"):
registry.add_callback(None, callback_with_none)


def test_hook_registry_add_callback_union_with_invalid_type_raises_error(registry):
"""Test that add_callback raises error when union contains non-BaseHookEvent type."""

def callback_with_invalid_type(event: BeforeModelCallEvent | str) -> None:
pass

with pytest.raises(ValueError, match="Invalid type in union"):
registry.add_callback(None, callback_with_invalid_type)


def test_hook_registry_add_callback_union_multiple_types(registry):
"""Test that add_callback handles union with more than two types."""

def multi_union_callback(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent) -> None:
pass

registry.add_callback(None, multi_union_callback)

# Callback should be registered for all three event types
assert BeforeModelCallEvent in registry._registered_callbacks
assert AfterModelCallEvent in registry._registered_callbacks
assert BeforeInvocationEvent in registry._registered_callbacks
assert multi_union_callback in registry._registered_callbacks[BeforeModelCallEvent]
assert multi_union_callback in registry._registered_callbacks[AfterModelCallEvent]
assert multi_union_callback in registry._registered_callbacks[BeforeInvocationEvent]


# ========== Tests for list of types support ==========


def test_hook_registry_add_callback_with_list_of_types(registry):
"""Test that add_callback registers callback for each type in a list."""

def my_callback(event) -> None:
pass

registry.add_callback([BeforeModelCallEvent, AfterModelCallEvent], my_callback)

# Callback should be registered for both event types
assert BeforeModelCallEvent in registry._registered_callbacks
assert AfterModelCallEvent in registry._registered_callbacks
assert my_callback in registry._registered_callbacks[BeforeModelCallEvent]
assert my_callback in registry._registered_callbacks[AfterModelCallEvent]


def test_hook_registry_add_callback_with_list_deduplicates(registry):
"""Test that add_callback deduplicates event types in a list."""

def my_callback(event) -> None:
pass

# Same type appears multiple times
registry.add_callback([BeforeModelCallEvent, BeforeModelCallEvent, AfterModelCallEvent], my_callback)

# Callback should be registered only once per event type
assert len(registry._registered_callbacks[BeforeModelCallEvent]) == 1
assert len(registry._registered_callbacks[AfterModelCallEvent]) == 1


def test_hook_registry_add_callback_with_list_validates_types(registry):
"""Test that add_callback validates all types in a list are BaseHookEvent subclasses."""

def my_callback(event) -> None:
pass

with pytest.raises(ValueError, match="Invalid event type"):
registry.add_callback([BeforeModelCallEvent, str], my_callback)


def test_hook_registry_add_callback_with_empty_list_raises_error(registry):
"""Test that add_callback raises error when given an empty list."""

def my_callback(event) -> None:
pass

with pytest.raises(ValueError, match="event_type list cannot be empty"):
registry.add_callback([], my_callback)


@pytest.mark.asyncio
async def test_hook_registry_union_callback_invoked_for_each_type(registry, agent):
"""Test that a union-registered callback is invoked correctly for each event type."""
call_count = {"before": 0, "after": 0}

def union_callback(event: BeforeModelCallEvent | AfterModelCallEvent) -> None:
if isinstance(event, BeforeModelCallEvent):
call_count["before"] += 1
elif isinstance(event, AfterModelCallEvent):
call_count["after"] += 1

registry.add_callback(None, union_callback)

# Invoke BeforeModelCallEvent
before_event = BeforeModelCallEvent(agent=agent)
await registry.invoke_callbacks_async(before_event)
assert call_count["before"] == 1

# Invoke AfterModelCallEvent
after_event = AfterModelCallEvent(agent=agent)
await registry.invoke_callbacks_async(after_event)
assert call_count["after"] == 1
Loading