diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index e38942b2c..4986acf1f 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -59,6 +59,20 @@ jobs: uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Install system audio dependencies (Linux) + if: matrix.os-name == 'linux' + run: | + sudo apt-get update + sudo apt-get install -y portaudio19-dev libasound2-dev + - name: Install system audio dependencies (macOS) + if: matrix.os-name == 'macOS' + run: | + brew install portaudio + - name: Install system audio dependencies (Windows) + if: matrix.os-name == 'windows' + run: | + # Windows typically has audio libraries available by default + echo "Windows audio dependencies handled by PyAudio wheels" - name: Install dependencies run: | pip install --no-cache-dir hatch @@ -89,6 +103,11 @@ jobs: python-version: '3.10' cache: 'pip' + - name: Install system audio dependencies (Linux) + run: | + sudo apt-get update + sudo apt-get install -y portaudio19-dev libasound2-dev + - name: Install dependencies run: | pip install --no-cache-dir hatch @@ -97,3 +116,4 @@ jobs: id: lint run: hatch fmt --linter --check continue-on-error: false + diff --git a/.gitignore b/.gitignore index e92a233f8..8b0fd989c 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ dist repl_state .kiro uv.lock +.audio_cache diff --git a/README.md b/README.md index 3ff0ec2e4..e7d1b2a7e 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,74 @@ agent("What is the square root of 1764") It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools). +### Bidirectional Streaming + +> **⚠️ Experimental Feature**: Bidirectional streaming is currently in experimental status. APIs may change in future releases as we refine the feature based on user feedback and evolving model capabilities. + +Build real-time voice and audio conversations with persistent streaming connections. Unlike traditional request-response patterns, bidirectional streaming maintains long-running conversations where users can interrupt, provide continuous input, and receive real-time audio responses. Get started with your first BidiAgent by following the [Quickstart](https://strandsagents.com/latest/documentation/docs/user-guide/concepts/experimental/bidirectional-streaming/quickstart) guide. + +**Supported Model Providers:** +- Amazon Nova Sonic (`amazon.nova-sonic-v1:0`) +- Google Gemini Live (`gemini-2.5-flash-native-audio-preview-09-2025`) +- OpenAI Realtime API (`gpt-realtime`) + +**Quick Example:** + +```python +import asyncio +from strands.experimental.bidi import BidiAgent +from strands.experimental.bidi.models import BidiNovaSonicModel +from strands.experimental.bidi.io import BidiAudioIO, BidiTextIO +from strands.experimental.bidi.tools import stop_conversation +from strands_tools import calculator + +async def main(): + # Create bidirectional agent with audio model + model = BidiNovaSonicModel() + agent = BidiAgent(model=model, tools=[calculator, stop_conversation]) + + # Setup audio and text I/O + audio_io = BidiAudioIO() + text_io = BidiTextIO() + + # Run with real-time audio streaming + # Say "stop conversation" to gracefully end the conversation + await agent.run( + inputs=[audio_io.input()], + outputs=[audio_io.output(), text_io.output()] + ) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +**Configuration Options:** + +```python +# Configure audio settings +model = BidiNovaSonicModel( + provider_config={ + "audio": { + "input_rate": 16000, + "output_rate": 16000, + "voice": "matthew" + }, + "inference": { + "max_tokens": 2048, + "temperature": 0.7 + } + } +) + +# Configure I/O devices +audio_io = BidiAudioIO( + input_device_index=0, # Specific microphone + output_device_index=1, # Specific speaker + input_buffer_size=10, + output_buffer_size=10 +) +``` + ## Documentation For detailed guidance & examples, explore our documentation: diff --git a/pyproject.toml b/pyproject.toml index b542c7481..f5738a68b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,18 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] + +bidi = [ + "aws_sdk_bedrock_runtime; python_version>='3.12'", + "prompt_toolkit>=3.0.0,<4.0.0", + "pyaudio>=0.2.13,<1.0.0", + "smithy-aws-core>=0.0.1; python_version>='3.12'", +] +bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] +bidi-openai = ["websockets>=15.0.0,<16.0.0"] + all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", @@ -104,7 +115,7 @@ features = ["all"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.13.0,<0.14.0", - # Include required pacakge dependencies for mypy + # Include required package dependencies for mypy "strands-agents @ {root:uri}", ] @@ -118,7 +129,7 @@ format-fix = [ ] lint-check = [ "ruff check", - "mypy -p src" + "mypy ./src" ] lint-fix = [ "ruff check --fix" @@ -192,11 +203,16 @@ warn_no_return = true warn_unreachable = true follow_untyped_imports = true ignore_missing_imports = false +exclude = ["src/strands/experimental/bidi"] +[[tool.mypy.overrides]] +module = ["strands.experimental.bidi.*"] +follow_imports = "skip" [tool.ruff] line-length = 120 include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"] +exclude = ["src/strands/experimental/bidi/**/*.py", "tests/strands/experimental/bidi/**/*.py", "tests_integ/bidi/**/*.py"] [tool.ruff.lint] select = [ @@ -219,6 +235,7 @@ convention = "google" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" +addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi" [tool.coverage.run] @@ -227,6 +244,7 @@ source = ["src"] context = "thread" parallel = true concurrency = ["thread", "multiprocessing"] +omit = ["src/strands/experimental/bidi/*"] [tool.coverage.report] show_missing = true @@ -256,3 +274,48 @@ style = [ ["text", ""], ["disabled", "fg:#858585 italic"] ] + +# ========================= +# Bidi development configs +# ========================= + +[tool.hatch.envs.bidi] +dev-mode = true +features = ["dev", "bidi-all"] +installer = "uv" + +[tool.hatch.envs.bidi.scripts] +prepare = [ + "hatch run bidi-lint:format-fix", + "hatch run bidi-lint:quality-fix", + "hatch run bidi-lint:type-check", + "hatch run bidi-test:test-cov", +] + +[tools.hatch.envs.bidi-lint] +template = "bidi" + +[tool.hatch.envs.bidi-lint.scripts] +format-check = "format-fix --check" +format-fix = "ruff format {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py" +quality-check = "ruff check {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py" +quality-fix = "quality-check --fix" +type-check = "mypy {args} --python-version 3.12 ./src/strands/experimental/bidi/**/*.py" + +[tool.hatch.envs.bidi-test] +template = "bidi" + +[tool.hatch.envs.bidi-test.scripts] +test = "pytest {args} tests/strands/experimental/bidi" +test-cov = """ +test \ + --cov=strands.experimental.bidi \ + --cov-config= \ + --cov-branch \ + --cov-report=term-missing \ + --cov-report=xml:build/coverage/bidi-coverage.xml \ + --cov-report=html:build/coverage/bidi-html +""" + +[[tool.hatch.envs.bidi-test.matrix]] +python = ["3.13", "3.12"] diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py new file mode 100644 index 000000000..57986062e --- /dev/null +++ b/src/strands/experimental/bidi/__init__.py @@ -0,0 +1,78 @@ +"""Bidirectional streaming package.""" + +import sys + +if sys.version_info < (3, 12): + raise ImportError("bidi only supported for >= Python 3.12") + +# Main components - Primary user interface +# Re-export standard agent events for tool handling +from ...types._events import ( + ToolResultEvent, + ToolStreamEvent, + ToolUseStreamEvent, +) +from .agent.agent import BidiAgent + +# IO channels - Hardware abstraction +from .io.audio import BidiAudioIO + +# Model interface (for custom implementations) +from .models.model import BidiModel +from .models.nova_sonic import BidiNovaSonicModel + +# Built-in tools +from .tools import stop_conversation + +# Event types - For type hints and event handling +from .types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, + ModalityUsage, +) + +__all__ = [ + # Main interface + "BidiAgent", + # IO channels + "BidiAudioIO", + # Model providers + "BidiNovaSonicModel", + # Built-in tools + "stop_conversation", + # Input Event types + "BidiTextInputEvent", + "BidiAudioInputEvent", + "BidiImageInputEvent", + "BidiInputEvent", + # Output Event types + "BidiConnectionStartEvent", + "BidiConnectionCloseEvent", + "BidiResponseStartEvent", + "BidiResponseCompleteEvent", + "BidiAudioStreamEvent", + "BidiTranscriptStreamEvent", + "BidiInterruptionEvent", + "BidiUsageEvent", + "ModalityUsage", + "BidiErrorEvent", + "BidiOutputEvent", + # Tool Event types (reused from standard agent) + "ToolUseStreamEvent", + "ToolResultEvent", + "ToolStreamEvent", + # Model interface + "BidiModel", +] diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py new file mode 100644 index 000000000..6cee3264d --- /dev/null +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -0,0 +1,29 @@ +"""Utilities for async operations.""" + +from typing import Awaitable, Callable + +from ._task_pool import _TaskPool + +__all__ = ["_TaskPool"] + + +async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: + """Call all stops in sequence and aggregate errors. + + A failure in one stop call will not block subsequent stop calls. + + Args: + funcs: Stop functions to call in sequence. + + Raises: + ExceptionGroup: If any stop function raises an exception. + """ + exceptions = [] + for func in funcs: + try: + await func() + except Exception as exception: + exceptions.append(exception) + + if exceptions: + raise ExceptionGroup("failed stop sequence", exceptions) diff --git a/src/strands/experimental/bidi/_async/_task_pool.py b/src/strands/experimental/bidi/_async/_task_pool.py new file mode 100644 index 000000000..83146fd5f --- /dev/null +++ b/src/strands/experimental/bidi/_async/_task_pool.py @@ -0,0 +1,43 @@ +"""Manage pool of active async tasks. + +This is particularly useful for cancelling multiple tasks at once. +""" + +import asyncio +from typing import Any, Coroutine + + +class _TaskPool: + """Manage pool of active async tasks.""" + + def __init__(self) -> None: + """Setup task container.""" + self._tasks: set[asyncio.Task] = set() + + def __len__(self) -> int: + """Number of active tasks.""" + return len(self._tasks) + + def create(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: + """Create async task. + + Adds a clean up callback to run after task completes. + + Returns: + The created task. + """ + task = asyncio.create_task(coro) + task.add_done_callback(lambda task: self._tasks.remove(task)) + + self._tasks.add(task) + return task + + async def cancel(self) -> None: + """Cancel all active tasks in pool.""" + for task in self._tasks: + task.cancel() + + try: + await asyncio.gather(*self._tasks) + except asyncio.CancelledError: + pass diff --git a/src/strands/experimental/bidi/agent/__init__.py b/src/strands/experimental/bidi/agent/__init__.py new file mode 100644 index 000000000..564973099 --- /dev/null +++ b/src/strands/experimental/bidi/agent/__init__.py @@ -0,0 +1,5 @@ +"""Bidirectional agent for real-time streaming conversations.""" + +from .agent import BidiAgent + +__all__ = ["BidiAgent"] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py new file mode 100644 index 000000000..360dfe707 --- /dev/null +++ b/src/strands/experimental/bidi/agent/agent.py @@ -0,0 +1,398 @@ +"""Bidirectional Agent for real-time streaming conversations. + +Provides real-time audio and text interaction through persistent streaming connections. +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive +continuous responses including audio output. + +Key capabilities: + +- Persistent conversation connections with concurrent processing +- Real-time audio input/output streaming +- Automatic interruption detection and tool execution +- Event-driven communication with model providers +""" + +import asyncio +import logging +from typing import TYPE_CHECKING, Any, AsyncGenerator + +from .... import _identifier +from ....agent.state import AgentState +from ....hooks import HookProvider, HookRegistry +from ....interrupt import _InterruptState +from ....tools._caller import _ToolCaller +from ....tools.executors import ConcurrentToolExecutor +from ....tools.executors._executor import ToolExecutor +from ....tools.registry import ToolRegistry +from ....tools.watcher import ToolWatcher +from ....types.content import Messages +from ....types.tools import AgentTool +from ...hooks.events import BidiAgentInitializedEvent +from ...tools import ToolProvider +from .._async import stop_all +from ..models.model import BidiModel +from ..models.nova_sonic import BidiNovaSonicModel +from ..types.agent import BidiAgentInput +from ..types.events import ( + BidiAudioInputEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiOutputEvent, + BidiTextInputEvent, +) +from ..types.io import BidiInput, BidiOutput +from .loop import _BidiAgentLoop + +if TYPE_CHECKING: + from ....session.session_manager import SessionManager + +logger = logging.getLogger(__name__) + +_DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" + + +class BidiAgent: + """Agent for bidirectional streaming conversations. + + Enables real-time audio and text interaction with AI models through persistent + connections. Supports concurrent tool execution and interruption handling. + """ + + def __init__( + self, + model: BidiModel | str | None = None, + tools: list[str | AgentTool | ToolProvider] | None = None, + system_prompt: str | None = None, + messages: Messages | None = None, + record_direct_tool_call: bool = True, + load_tools_from_directory: bool = False, + agent_id: str | None = None, + name: str | None = None, + description: str | None = None, + hooks: list[HookProvider] | None = None, + state: AgentState | dict | None = None, + session_manager: "SessionManager | None" = None, + tool_executor: ToolExecutor | None = None, + **kwargs: Any, + ): + """Initialize bidirectional agent. + + Args: + model: BidiModel instance, string model_id, or None for default detection. + tools: Optional list of tools with flexible format support. + system_prompt: Optional system prompt for conversations. + messages: Optional conversation history to initialize with. + record_direct_tool_call: Whether to record direct tool calls in message history. + load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. + agent_id: Optional ID for the agent, useful for connection management and multi-agent scenarios. + name: Name of the Agent. + description: Description of what the Agent does. + hooks: Optional list of hook providers to register for lifecycle events. + state: Stateful information for the agent. Can be either an AgentState object, or a json serializable dict. + session_manager: Manager for handling agent sessions including conversation history and state. + If provided, enables session-based persistence and state management. + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + **kwargs: Additional configuration for future extensibility. + + Raises: + ValueError: If model configuration is invalid or state is invalid type. + TypeError: If model type is unsupported. + """ + self.model = ( + BidiNovaSonicModel() + if not model + else BidiNovaSonicModel(model_id=model) + if isinstance(model, str) + else model + ) + self.system_prompt = system_prompt + self.messages = messages or [] + + # Agent identification + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) + self.name = name or _DEFAULT_AGENT_NAME + self.description = description + + # Tool execution configuration + self.record_direct_tool_call = record_direct_tool_call + self.load_tools_from_directory = load_tools_from_directory + + # Initialize tool registry + self.tool_registry = ToolRegistry() + + if tools is not None: + self.tool_registry.process_tools(tools) + + self.tool_registry.initialize_tools(self.load_tools_from_directory) + + # Initialize tool watcher if directory loading is enabled + if self.load_tools_from_directory: + self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) + + # Initialize agent state management + if state is not None: + if isinstance(state, dict): + self.state = AgentState(state) + elif isinstance(state, AgentState): + self.state = state + else: + raise ValueError("state must be an AgentState object or a dict") + else: + self.state = AgentState() + + # Initialize other components + self._tool_caller = _ToolCaller(self) + + # Initialize tool executor + self.tool_executor = tool_executor or ConcurrentToolExecutor() + + # Initialize hooks registry + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + + # Initialize session management functionality + self._session_manager = session_manager + if self._session_manager: + self.hooks.add_hook(self._session_manager) + + self._loop = _BidiAgentLoop(self) + + # Emit initialization event + self.hooks.invoke_callbacks(BidiAgentInitializedEvent(agent=self)) + + # TODO: Determine if full support is required + self._interrupt_state = _InterruptState() + + self._started = False + + @property + def tool(self) -> _ToolCaller: + """Call tool as a function. + + Returns: + ToolCaller for method-style tool execution. + + Example: + ``` + agent = BidiAgent(model=model, tools=[calculator]) + agent.tool.calculator(expression="2+2") + ``` + """ + return self._tool_caller + + @property + def tool_names(self) -> list[str]: + """Get a list of all registered tool names. + + Returns: + Names of all tools available to this agent. + """ + all_tools = self.tool_registry.get_all_tools_config() + return list(all_tools.keys()) + + async def start(self, invocation_state: dict[str, Any] | None = None) -> None: + """Start a persistent bidirectional conversation connection. + + Initializes the streaming connection and starts background tasks for processing + model events, tool execution, and connection management. + + Args: + invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. + + Raises: + RuntimeError: + If agent already started. + + Example: + ```python + await agent.start(invocation_state={ + "user_id": "user_123", + "session_id": "session_456", + "database": db_connection, + }) + ``` + """ + if self._started: + raise RuntimeError("agent already started | call stop before starting again") + + logger.debug("agent starting") + await self._loop.start(invocation_state) + self._started = True + + async def send(self, input_data: BidiAgentInput | dict[str, Any]) -> None: + """Send input to the model (text, audio, image, or event dict). + + Unified method for sending text, audio, and image input to the model during + an active conversation session. Accepts TypedEvent instances or plain dicts + (e.g., from WebSocket clients) which are automatically reconstructed. + + Args: + input_data: Can be: + + - str: Text message from user + - BidiInputEvent: TypedEvent + - dict: Event dictionary (will be reconstructed to TypedEvent) + + Raises: + RuntimeError: If start has not been called. + ValueError: If invalid input type. + + Example: + await agent.send("Hello") + await agent.send(BidiAudioInputEvent(audio="base64...", format="pcm", ...)) + await agent.send({"type": "bidirectional_text_input", "text": "Hello", "role": "user"}) + """ + if not self._started: + raise RuntimeError("agent not started | call start before sending") + + input_event: BidiInputEvent + + if isinstance(input_data, str): + input_event = BidiTextInputEvent(text=input_data) + + elif isinstance(input_data, BidiInputEvent): + input_event = input_data + + elif isinstance(input_data, dict) and "type" in input_data: + input_type = input_data["type"] + input_data = {key: value for key, value in input_data.items() if key != "type"} + if input_type == "bidi_text_input": + input_event = BidiTextInputEvent(**input_data) + elif input_type == "bidi_audio_input": + input_event = BidiAudioInputEvent(**input_data) + elif input_type == "bidi_image_input": + input_event = BidiImageInputEvent(**input_data) + else: + raise ValueError(f"input_type=<{input_type}> | input type not supported") + + else: + raise ValueError("invalid input | must be str, BidiInputEvent, or event dict") + + await self._loop.send(input_event) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive events from the model including audio, text, and tool calls. + + Yields: + Model output events processed by background tasks including audio output, + text responses, tool calls, and connection updates. + + Raises: + RuntimeError: If start has not been called. + """ + if not self._started: + raise RuntimeError("agent not started | call start before receiving") + + async for event in self._loop.receive(): + yield event + + async def stop(self) -> None: + """End the conversation connection and cleanup all resources. + + Terminates the streaming connection, cancels background tasks, and + closes the connection to the model provider. + """ + self._started = False + await self._loop.stop() + + async def __aenter__(self, invocation_state: dict[str, Any] | None = None) -> "BidiAgent": + """Async context manager entry point. + + Automatically starts the bidirectional connection when entering the context. + + Args: + invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. + + Returns: + Self for use in the context. + """ + logger.debug("context_manager= | starting agent") + await self.start(invocation_state) + return self + + async def __aexit__(self, *_: Any) -> None: + """Async context manager exit point. + + Automatically ends the connection and cleans up resources including + when exiting the context, regardless of whether an exception occurred. + """ + logger.debug("context_manager= | stopping agent") + await self.stop() + + async def run( + self, inputs: list[BidiInput], outputs: list[BidiOutput], invocation_state: dict[str, Any] | None = None + ) -> None: + """Run the agent using provided IO channels for bidirectional communication. + + Args: + inputs: Input callables to read data from a source + outputs: Output callables to receive events from the agent + invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. + + Example: + ```python + # Using model defaults: + model = BidiNovaSonicModel() + audio_io = BidiAudioIO() + text_io = BidiTextIO() + agent = BidiAgent(model=model, tools=[calculator]) + await agent.run( + inputs=[audio_io.input()], + outputs=[audio_io.output(), text_io.output()], + invocation_state={"user_id": "user_123"} + ) + + # Using custom audio config: + model = BidiNovaSonicModel( + provider_config={"audio": {"input_rate": 48000, "output_rate": 24000}} + ) + audio_io = BidiAudioIO() + agent = BidiAgent(model=model, tools=[calculator]) + await agent.run( + inputs=[audio_io.input()], + outputs=[audio_io.output()], + ) + ``` + """ + + async def run_inputs() -> None: + async def task(input_: BidiInput) -> None: + while True: + event = await input_() + await self.send(event) + + await asyncio.gather(*[task(input_) for input_ in inputs]) + + async def run_outputs(inputs_task: asyncio.Task) -> None: + async for event in self.receive(): + await asyncio.gather(*[output(event) for output in outputs]) + + inputs_task.cancel() + + try: + await self.start(invocation_state) + + input_starts = [input_.start for input_ in inputs if isinstance(input_, BidiInput)] + output_starts = [output.start for output in outputs if isinstance(output, BidiOutput)] + for start in [*input_starts, *output_starts]: + await start(self) + + async with asyncio.TaskGroup() as task_group: + inputs_task = task_group.create_task(run_inputs()) + task_group.create_task(run_outputs(inputs_task)) + + finally: + input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] + output_stops = [output.stop for output in outputs if isinstance(output, BidiOutput)] + + await stop_all(*input_stops, *output_stops, self.stop) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py new file mode 100644 index 000000000..13b7033a4 --- /dev/null +++ b/src/strands/experimental/bidi/agent/loop.py @@ -0,0 +1,315 @@ +"""Agent loop. + +The agent loop handles the events received from the model and executes tools when given a tool use request. +""" + +import asyncio +import logging +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast + +from ....types._events import ToolInterruptEvent, ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse +from ...hooks.events import ( + BidiAfterConnectionRestartEvent, + BidiAfterInvocationEvent, + BidiBeforeConnectionRestartEvent, + BidiBeforeInvocationEvent, + BidiMessageAddedEvent, +) +from ...hooks.events import ( + BidiInterruptionEvent as BidiInterruptionHookEvent, +) +from .._async import _TaskPool, stop_all +from ..models import BidiModelTimeoutError +from ..types.events import ( + BidiConnectionCloseEvent, + BidiConnectionRestartEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) + +if TYPE_CHECKING: + from .agent import BidiAgent + +logger = logging.getLogger(__name__) + + +class _BidiAgentLoop: + """Agent loop. + + Attributes: + _agent: BidiAgent instance to loop. + _started: Flag if agent loop has started. + _task_pool: Track active async tasks created in loop. + _event_queue: Queue model and tool call events for receiver. + _invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. + _send_gate: Gate the sending of events to the model. + Blocks when agent is reseting the model connection after timeout. + _message_lock: Lock to ensure that paired messages are added to history in sequence without interference. + For example, tool use and tool result messages must be added adjacent to each other. + """ + + def __init__(self, agent: "BidiAgent") -> None: + """Initialize members of the agent loop. + + Note, before receiving events from the loop, the user must call `start`. + + Args: + agent: Bidirectional agent to loop over. + """ + self._agent = agent + self._started = False + self._task_pool = _TaskPool() + self._event_queue: asyncio.Queue + self._invocation_state: dict[str, Any] + + self._send_gate = asyncio.Event() + self._message_lock = asyncio.Lock() + + async def start(self, invocation_state: dict[str, Any] | None = None) -> None: + """Start the agent loop. + + The agent model is started as part of this call. + + Args: + invocation_state: Optional context to pass to tools during execution. + This allows passing custom data (user_id, session_id, database connections, etc.) + that tools can access via their invocation_state parameter. + + Raises: + RuntimeError: If loop already started. + """ + if self._started: + raise RuntimeError("loop already started | call stop before starting again") + + logger.debug("agent loop starting") + await self._agent.hooks.invoke_callbacks_async(BidiBeforeInvocationEvent(agent=self._agent)) + + await self._agent.model.start( + system_prompt=self._agent.system_prompt, + tools=self._agent.tool_registry.get_all_tool_specs(), + messages=self._agent.messages, + ) + + self._event_queue = asyncio.Queue(maxsize=1) + + self._task_pool = _TaskPool() + self._task_pool.create(self._run_model()) + + self._invocation_state = invocation_state or {} + self._send_gate.set() + self._started = True + + async def stop(self) -> None: + """Stop the agent loop.""" + logger.debug("agent loop stopping") + + self._started = False + self._send_gate.clear() + self._invocation_state = {} + + async def stop_tasks() -> None: + await self._task_pool.cancel() + + async def stop_model() -> None: + await self._agent.model.stop() + + try: + await stop_all(stop_tasks, stop_model) + finally: + await self._agent.hooks.invoke_callbacks_async(BidiAfterInvocationEvent(agent=self._agent)) + + async def send(self, event: BidiInputEvent | ToolResultEvent) -> None: + """Send model event. + + Additionally, add text input to messages array. + + Args: + event: User input event or tool result. + + Raises: + RuntimeError: If start has not been called. + """ + if not self._started: + raise RuntimeError("loop not started | call start before sending") + + if not self._send_gate.is_set(): + logger.debug("waiting for model send signal") + await self._send_gate.wait() + + if isinstance(event, BidiTextInputEvent): + message: Message = {"role": "user", "content": [{"text": event.text}]} + await self._add_messages(message) + + await self._agent.model.send(event) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive model and tool call events. + + Returns: + Model and tool call events. + + Raises: + RuntimeError: If start has not been called. + """ + if not self._started: + raise RuntimeError("loop not started | call start before receiving") + + while True: + event = await self._event_queue.get() + if isinstance(event, BidiModelTimeoutError): + logger.debug("model timeout error received") + yield BidiConnectionRestartEvent(event) + await self._restart_connection(event) + continue + + if isinstance(event, Exception): + raise event + + # Check for graceful shutdown event + if isinstance(event, BidiConnectionCloseEvent) and event.reason == "user_request": + yield event + break + + yield event + + async def _restart_connection(self, timeout_error: BidiModelTimeoutError) -> None: + """Restart the model connection after timeout. + + Args: + timeout_error: Timeout error reported by the model. + """ + logger.debug("reseting model connection") + + self._send_gate.clear() + + await self._agent.hooks.invoke_callbacks_async(BidiBeforeConnectionRestartEvent(self._agent, timeout_error)) + + restart_exception = None + try: + await self._agent.model.stop() + await self._agent.model.start( + self._agent.system_prompt, + self._agent.tool_registry.get_all_tool_specs(), + self._agent.messages, + **timeout_error.restart_config, + ) + self._task_pool.create(self._run_model()) + except Exception as exception: + restart_exception = exception + finally: + await self._agent.hooks.invoke_callbacks_async( + BidiAfterConnectionRestartEvent(self._agent, restart_exception) + ) + + self._send_gate.set() + + async def _run_model(self) -> None: + """Task for running the model. + + Events are streamed through the event queue. + """ + logger.debug("model task starting") + + try: + async for event in self._agent.model.receive(): + await self._event_queue.put(event) + + if isinstance(event, BidiTranscriptStreamEvent): + if event["is_final"]: + message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} + await self._add_messages(message) + + elif isinstance(event, ToolUseStreamEvent): + tool_use = event["current_tool_use"] + self._task_pool.create(self._run_tool(tool_use)) + + elif isinstance(event, BidiInterruptionEvent): + await self._agent.hooks.invoke_callbacks_async( + BidiInterruptionHookEvent( + agent=self._agent, + reason=event["reason"], + interrupted_response_id=event.get("interrupted_response_id"), + ) + ) + + except Exception as error: + await self._event_queue.put(error) + + async def _run_tool(self, tool_use: ToolUse) -> None: + """Task for running tool requested by the model using the tool executor. + + Args: + tool_use: Tool use request from model. + """ + logger.debug("tool_name=<%s> | tool execution starting", tool_use["name"]) + + tool_results: list[ToolResult] = [] + + invocation_state: dict[str, Any] = { + **self._invocation_state, + "agent": self._agent, + "model": self._agent.model, + "messages": self._agent.messages, + "system_prompt": self._agent.system_prompt, + } + + try: + tool_events = self._agent.tool_executor._stream( + self._agent, + tool_use, + tool_results, + invocation_state, + structured_output_context=None, + ) + + async for tool_event in tool_events: + if isinstance(tool_event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + interrupt_names = [interrupt.name for interrupt in tool_event.interrupts] + raise RuntimeError(f"interrupts={interrupt_names} | tool interrupts are not supported in bidi") + + await self._event_queue.put(tool_event) + + # Normal flow for all tools (including stop_conversation) + tool_result_event = cast(ToolResultEvent, tool_event) + + tool_use_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + tool_result_message: Message = {"role": "user", "content": [{"toolResult": tool_result_event.tool_result}]} + await self._add_messages(tool_use_message, tool_result_message) + + await self._event_queue.put(ToolResultMessageEvent(tool_result_message)) + + # Check for stop_conversation before sending to model + if tool_use["name"] == "stop_conversation": + logger.info("tool_name=<%s> | conversation stop requested, skipping model send", tool_use["name"]) + connection_id = getattr(self._agent.model, "_connection_id", "unknown") + await self._event_queue.put( + BidiConnectionCloseEvent(connection_id=connection_id, reason="user_request") + ) + return # Skip the model send + + # Send result to model (all tools except stop_conversation) + await self.send(tool_result_event) + + except Exception as error: + await self._event_queue.put(error) + + async def _add_messages(self, *messages: Message) -> None: + """Add messages to history in sequence without interference. + + Args: + *messages: List of messages to add into history. + """ + async with self._message_lock: + for message in messages: + self._agent.messages.append(message) + await self._agent.hooks.invoke_callbacks_async( + BidiMessageAddedEvent(agent=self._agent, message=message) + ) diff --git a/src/strands/experimental/bidi/io/__init__.py b/src/strands/experimental/bidi/io/__init__.py new file mode 100644 index 000000000..d099cba2f --- /dev/null +++ b/src/strands/experimental/bidi/io/__init__.py @@ -0,0 +1,6 @@ +"""IO channel implementations for bidirectional streaming.""" + +from .audio import BidiAudioIO +from .text import BidiTextIO + +__all__ = ["BidiAudioIO", "BidiTextIO"] diff --git a/src/strands/experimental/bidi/io/audio.py b/src/strands/experimental/bidi/io/audio.py new file mode 100644 index 000000000..5eff829e9 --- /dev/null +++ b/src/strands/experimental/bidi/io/audio.py @@ -0,0 +1,294 @@ +"""Send and receive audio data from devices. + +Reads user audio from input device and sends agent audio to output device using PyAudio. If a user interrupts the agent, +the output buffer is cleared to stop playback. + +Audio configuration is provided by the model via agent.model.config["audio"]. +""" + +import asyncio +import base64 +import logging +import queue +from typing import TYPE_CHECKING, Any + +import pyaudio + +from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent +from ..types.io import BidiInput, BidiOutput + +if TYPE_CHECKING: + from ..agent.agent import BidiAgent + +logger = logging.getLogger(__name__) + + +class _BidiAudioBuffer: + """Buffer chunks of audio data between agent and PyAudio.""" + + _buffer: queue.Queue + _data: bytearray + + def __init__(self, size: int | None = None): + """Initialize buffer settings. + + Args: + size: Size of the buffer (default: unbounded). + """ + self._size = size or 0 + + def start(self) -> None: + """Setup buffer.""" + self._buffer = queue.Queue(self._size) + self._data = bytearray() + + def stop(self) -> None: + """Tear down buffer.""" + if hasattr(self, "_data"): + self._data.clear() + if hasattr(self, "_buffer"): + # Unblocking waited get calls by putting an empty chunk + # Note, Queue.shutdown exists but is a 3.13+ only feature + # We simulate shutdown with the below logic + self._buffer.put_nowait(b"") + self._buffer = queue.Queue(self._size) + + def put(self, chunk: bytes) -> None: + """Put data chunk into buffer. + + If full, removes the oldest chunk. + """ + if self._buffer.full(): + logger.debug("buffer is full | removing oldest chunk") + try: + self._buffer.get_nowait() + except queue.Empty: + logger.debug("buffer already empty") + pass + + self._buffer.put_nowait(chunk) + + def get(self, byte_count: int | None = None) -> bytes: + """Get the number of bytes specified from the buffer. + + Args: + byte_count: Number of bytes to get from buffer. + + - If the number of bytes specified is not available, the return is padded with silence. + - If the number of bytes is not specified, get the first chunk put in the buffer. + + Returns: + Specified number of bytes. + """ + if not byte_count: + self._data.extend(self._buffer.get()) + byte_count = len(self._data) + + while len(self._data) < byte_count: + try: + self._data.extend(self._buffer.get_nowait()) + except queue.Empty: + break + + padding_bytes = b"\x00" * max(byte_count - len(self._data), 0) + self._data.extend(padding_bytes) + + data = self._data[:byte_count] + del self._data[:byte_count] + + return bytes(data) + + def clear(self) -> None: + """Clear the buffer.""" + while True: + try: + self._buffer.get_nowait() + except queue.Empty: + break + + +class _BidiAudioInput(BidiInput): + """Handle audio input from user. + + Attributes: + _audio: PyAudio instance for audio system access. + _stream: Audio input stream. + _buffer: Buffer for sharing audio data between agent and PyAudio. + """ + + _audio: pyaudio.PyAudio + _stream: pyaudio.Stream + + _BUFFER_SIZE = None + _DEVICE_INDEX = None + _FRAMES_PER_BUFFER = 512 + + def __init__(self, config: dict[str, Any]) -> None: + """Extract configs.""" + self._buffer_size = config.get("input_buffer_size", _BidiAudioInput._BUFFER_SIZE) + self._device_index = config.get("input_device_index", _BidiAudioInput._DEVICE_INDEX) + self._frames_per_buffer = config.get("input_frames_per_buffer", _BidiAudioInput._FRAMES_PER_BUFFER) + + self._buffer = _BidiAudioBuffer(self._buffer_size) + + async def start(self, agent: "BidiAgent") -> None: + """Start input stream. + + Args: + agent: The BidiAgent instance, providing access to model configuration. + """ + logger.debug("starting audio input stream") + + self._channels = agent.model.config["audio"]["channels"] + self._format = agent.model.config["audio"]["format"] + self._rate = agent.model.config["audio"]["input_rate"] + + self._buffer.start() + self._audio = pyaudio.PyAudio() + self._stream = self._audio.open( + channels=self._channels, + format=pyaudio.paInt16, + frames_per_buffer=self._frames_per_buffer, + input=True, + input_device_index=self._device_index, + rate=self._rate, + stream_callback=self._callback, + ) + + logger.debug("audio input stream started") + + async def stop(self) -> None: + """Stop input stream.""" + logger.debug("stopping audio input stream") + + if hasattr(self, "_stream"): + self._stream.close() + if hasattr(self, "_audio"): + self._audio.terminate() + if hasattr(self, "_buffer"): + self._buffer.stop() + + logger.debug("audio input stream stopped") + + async def __call__(self) -> BidiAudioInputEvent: + """Read audio from input stream.""" + data = await asyncio.to_thread(self._buffer.get) + + return BidiAudioInputEvent( + audio=base64.b64encode(data).decode("utf-8"), + channels=self._channels, + format=self._format, + sample_rate=self._rate, + ) + + def _callback(self, in_data: bytes, *_: Any) -> tuple[None, Any]: + """Callback to receive audio data from PyAudio.""" + self._buffer.put(in_data) + return (None, pyaudio.paContinue) + + +class _BidiAudioOutput(BidiOutput): + """Handle audio output from bidi agent. + + Attributes: + _audio: PyAudio instance for audio system access. + _stream: Audio output stream. + _buffer: Buffer for sharing audio data between agent and PyAudio. + """ + + _audio: pyaudio.PyAudio + _stream: pyaudio.Stream + + _BUFFER_SIZE = None + _DEVICE_INDEX = None + _FRAMES_PER_BUFFER = 512 + + def __init__(self, config: dict[str, Any]) -> None: + """Extract configs.""" + self._buffer_size = config.get("output_buffer_size", _BidiAudioOutput._BUFFER_SIZE) + self._device_index = config.get("output_device_index", _BidiAudioOutput._DEVICE_INDEX) + self._frames_per_buffer = config.get("output_frames_per_buffer", _BidiAudioOutput._FRAMES_PER_BUFFER) + + self._buffer = _BidiAudioBuffer(self._buffer_size) + + async def start(self, agent: "BidiAgent") -> None: + """Start output stream. + + Args: + agent: The BidiAgent instance, providing access to model configuration. + """ + logger.debug("starting audio output stream") + + self._channels = agent.model.config["audio"]["channels"] + self._rate = agent.model.config["audio"]["output_rate"] + + self._buffer.start() + self._audio = pyaudio.PyAudio() + self._stream = self._audio.open( + channels=self._channels, + format=pyaudio.paInt16, + frames_per_buffer=self._frames_per_buffer, + output=True, + output_device_index=self._device_index, + rate=self._rate, + stream_callback=self._callback, + ) + + logger.debug("audio output stream started") + + async def stop(self) -> None: + """Stop output stream.""" + logger.debug("stopping audio output stream") + + if hasattr(self, "_stream"): + self._stream.close() + if hasattr(self, "_audio"): + self._audio.terminate() + if hasattr(self, "_buffer"): + self._buffer.stop() + + logger.debug("audio output stream stopped") + + async def __call__(self, event: BidiOutputEvent) -> None: + """Send audio to output stream.""" + if isinstance(event, BidiAudioStreamEvent): + data = base64.b64decode(event["audio"]) + self._buffer.put(data) + logger.debug("audio_bytes=<%d> | audio chunk buffered for playback", len(data)) + + elif isinstance(event, BidiInterruptionEvent): + logger.debug("reason=<%s> | clearing audio buffer due to interruption", event["reason"]) + self._buffer.clear() + + def _callback(self, _in_data: None, frame_count: int, *_: Any) -> tuple[bytes, Any]: + """Callback to send audio data to PyAudio.""" + byte_count = frame_count * pyaudio.get_sample_size(pyaudio.paInt16) + data = self._buffer.get(byte_count) + return (data, pyaudio.paContinue) + + +class BidiAudioIO: + """Send and receive audio data from devices.""" + + def __init__(self, **config: Any) -> None: + """Initialize audio devices. + + Args: + **config: Optional device configuration: + + - input_buffer_size (int): Maximum input buffer size (default: None) + - input_device_index (int): Specific input device (default: None = system default) + - input_frames_per_buffer (int): Input buffer size (default: 512) + - output_buffer_size (int): Maximum output buffer size (default: None) + - output_device_index (int): Specific output device (default: None = system default) + - output_frames_per_buffer (int): Output buffer size (default: 512) + """ + self._config = config + + def input(self) -> _BidiAudioInput: + """Return audio processing BidiInput.""" + return _BidiAudioInput(self._config) + + def output(self) -> _BidiAudioOutput: + """Return audio processing BidiOutput.""" + return _BidiAudioOutput(self._config) diff --git a/src/strands/experimental/bidi/io/text.py b/src/strands/experimental/bidi/io/text.py new file mode 100644 index 000000000..f575c5606 --- /dev/null +++ b/src/strands/experimental/bidi/io/text.py @@ -0,0 +1,87 @@ +"""Handle text input and output to and from bidi agent.""" + +import logging +from typing import Any + +from prompt_toolkit import PromptSession + +from ..types.events import ( + BidiConnectionCloseEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) +from ..types.io import BidiInput, BidiOutput + +logger = logging.getLogger(__name__) + + +class _BidiTextInput(BidiInput): + """Handle text input from user.""" + + def __init__(self, config: dict[str, Any]) -> None: + """Extract configs and setup prompt session.""" + prompt = config.get("input_prompt", "") + self._session: PromptSession = PromptSession(prompt) + + async def __call__(self) -> BidiTextInputEvent: + """Read user input from stdin.""" + text = await self._session.prompt_async() + return BidiTextInputEvent(text.strip(), role="user") + + +class _BidiTextOutput(BidiOutput): + """Handle text output from bidi agent.""" + + async def __call__(self, event: BidiOutputEvent) -> None: + """Print text events to stdout.""" + if isinstance(event, BidiInterruptionEvent): + logger.debug("reason=<%s> | text output interrupted", event["reason"]) + print("interrupted") + + elif isinstance(event, BidiConnectionCloseEvent): + if event.reason == "user_request": + print("user requested connection close using the stop_conversation tool.") + logger.debug("connection_id=<%s> | user requested connection close", event.connection_id) + elif isinstance(event, BidiTranscriptStreamEvent): + text = event["text"] + is_final = event["is_final"] + role = event["role"] + + logger.debug( + "role=<%s>, is_final=<%s>, text_length=<%d> | text transcript received", + role, + is_final, + len(text), + ) + + if not is_final: + text = f"Preview: {text}" + + print(text) + + +class BidiTextIO: + """Handle text input and output to and from bidi agent. + + Accepts input from stdin and outputs to stdout. + """ + + def __init__(self, **config: Any) -> None: + """Initialize I/O. + + Args: + **config: Optional I/O configurations. + + - input_prompt (str): Input prompt to display on screen (default: blank) + """ + self._config = config + + def input(self) -> _BidiTextInput: + """Return text processing BidiInput.""" + return _BidiTextInput(self._config) + + def output(self) -> _BidiTextOutput: + """Return text processing BidiOutput.""" + return _BidiTextOutput() diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py new file mode 100644 index 000000000..cc62c9987 --- /dev/null +++ b/src/strands/experimental/bidi/models/__init__.py @@ -0,0 +1,10 @@ +"""Bidirectional model interfaces and implementations.""" + +from .model import BidiModel, BidiModelTimeoutError +from .nova_sonic import BidiNovaSonicModel + +__all__ = [ + "BidiModel", + "BidiModelTimeoutError", + "BidiNovaSonicModel", +] diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py new file mode 100644 index 000000000..88d7f5a0c --- /dev/null +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -0,0 +1,527 @@ +"""Gemini Live API bidirectional model provider using official Google GenAI SDK. + +Implements the BidiModel interface for Google's Gemini Live API using the +official Google GenAI SDK for simplified and robust WebSocket communication. + +Key improvements over custom WebSocket implementation: + +- Uses official google-genai SDK with native Live API support +- Simplified session management with client.aio.live.connect() +- Built-in tool integration and event handling +- Automatic WebSocket connection management and error handling +- Native support for audio/text streaming and interruption +""" + +import base64 +import logging +import uuid +from typing import Any, AsyncGenerator, cast + +from google import genai +from google.genai import types as genai_types +from google.genai.types import LiveConnectConfigOrDict, LiveServerMessage + +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all +from ..types.events import ( + AudioChannel, + AudioSampleRate, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, + ModalityUsage, +) +from ..types.model import AudioConfig +from .model import BidiModel, BidiModelTimeoutError + +logger = logging.getLogger(__name__) + +# Audio format constants +GEMINI_INPUT_SAMPLE_RATE: AudioSampleRate = 16000 +GEMINI_OUTPUT_SAMPLE_RATE: AudioSampleRate = 24000 +GEMINI_CHANNELS: AudioChannel = 1 + + +class BidiGeminiLiveModel(BidiModel): + """Gemini Live API implementation using official Google GenAI SDK. + + Combines model configuration and connection state in a single class. + Provides a clean interface to Gemini Live API using the official SDK, + eliminating custom WebSocket handling and providing robust error handling. + """ + + def __init__( + self, + model_id: str = "gemini-2.5-flash-native-audio-preview-09-2025", + provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, + **kwargs: Any, + ): + """Initialize Gemini Live API bidirectional model. + + Args: + model_id: Model identifier (default: gemini-2.5-flash-native-audio-preview-09-2025) + provider_config: Model behavior (audio, inference) + client_config: Authentication (api_key, http_options) + **kwargs: Reserved for future parameters. + + """ + # Store model ID + self.model_id = model_id + + # Resolve client config with defaults + self._client_config = self._resolve_client_config(client_config or {}) + + # Resolve provider config with defaults + self.config = self._resolve_provider_config(provider_config or {}) + + # Store API key for later use + self.api_key = self._client_config.get("api_key") + + # Create Gemini client + self._client = genai.Client(**self._client_config) + + # Connection state (initialized in start()) + self._live_session: Any = None + self._live_session_context_manager: Any = None + self._live_session_handle: str | None = None + self._connection_id: str | None = None + + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve client config (sets default http_options if not provided).""" + resolved = config.copy() + + # Set default http_options if not provided + if "http_options" not in resolved: + resolved["http_options"] = {"api_version": "v1alpha"} + + return resolved + + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + default_audio: AudioConfig = { + "input_rate": GEMINI_INPUT_SAMPLE_RATE, + "output_rate": GEMINI_OUTPUT_SAMPLE_RATE, + "channels": GEMINI_CHANNELS, + "format": "pcm", + } + default_inference = { + "response_modalities": ["AUDIO"], + "outputAudioTranscription": {}, + "inputAudioTranscription": {}, + } + + resolved = { + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": { + **default_inference, + **config.get("inference", {}), + }, + } + return resolved + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish bidirectional connection with Gemini Live API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") + + self._connection_id = str(uuid.uuid4()) + + # Build live config + live_config = self._build_live_config(system_prompt, tools, **kwargs) + + # Create the context manager and session + self._live_session_context_manager = self._client.aio.live.connect( + model=self.model_id, config=cast(LiveConnectConfigOrDict, live_config) + ) + self._live_session = await self._live_session_context_manager.__aenter__() + + # Gemini itself restores message history when resuming from session + if messages and "live_session_handle" not in kwargs: + await self._send_message_history(messages) + + async def _send_message_history(self, messages: Messages) -> None: + """Send conversation history to Gemini Live API. + + Sends each message as a separate turn with the correct role to maintain + proper conversation context. Follows the same pattern as the non-bidirectional + Gemini model implementation. + """ + if not messages: + return + + # Convert each message to Gemini format and send separately + for message in messages: + content_parts = [] + for content_block in message["content"]: + if "text" in content_block: + content_parts.append(genai_types.Part(text=content_block["text"])) + + if content_parts: + # Map role correctly - Gemini uses "user" and "model" roles + # "assistant" role from Messages format maps to "model" in Gemini + role = "model" if message["role"] == "assistant" else message["role"] + content = genai_types.Content(role=role, parts=content_parts) + await self._live_session.send_client_content(turns=content) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive Gemini Live API events and convert to provider-agnostic format.""" + if not self._connection_id: + raise RuntimeError("model not started | call start before receiving") + + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + # Wrap in while loop to restart after turn_complete (SDK limitation workaround) + while True: + async for message in self._live_session.receive(): + for event in self._convert_gemini_live_event(message): + yield event + + def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOutputEvent]: + """Convert Gemini Live API events to provider-agnostic format. + + Handles different types of content: + + - inputTranscription: User's speech transcribed to text + - outputTranscription: Model's audio transcribed to text + - modelTurn text: Text response from the model + - usageMetadata: Token usage information + + Returns: + List of event dicts (empty list if no events to emit). + + Raises: + BidiModelTimeoutError: If gemini responds with go away message. + """ + if message.go_away: + raise BidiModelTimeoutError( + message.go_away.model_dump_json(), live_session_handle=self._live_session_handle + ) + + if message.session_resumption_update: + resumption_update = message.session_resumption_update + if resumption_update.resumable and resumption_update.new_handle: + self._live_session_handle = resumption_update.new_handle + logger.debug("session_handle=<%s> | updating gemini session handle", self._live_session_handle) + return [] + + # Handle interruption first (from server_content) + if message.server_content and message.server_content.interrupted: + return [BidiInterruptionEvent(reason="user_speech")] + + # Handle input transcription (user's speech) - emit as transcript event + if message.server_content and message.server_content.input_transcription: + input_transcript = message.server_content.input_transcription + # Check if the transcription object has text content + if hasattr(input_transcript, "text") and input_transcript.text: + transcription_text = input_transcript.text + logger.debug("text_length=<%d> | gemini input transcription detected", len(transcription_text)) + return [ + BidiTranscriptStreamEvent( + delta={"text": transcription_text}, + text=transcription_text, + role="user", + # TODO: https://github.com/googleapis/python-genai/issues/1504 + is_final=bool(input_transcript.finished), + current_transcript=transcription_text, + ) + ] + + # Handle output transcription (model's audio) - emit as transcript event + if message.server_content and message.server_content.output_transcription: + output_transcript = message.server_content.output_transcription + # Check if the transcription object has text content + if hasattr(output_transcript, "text") and output_transcript.text: + transcription_text = output_transcript.text + logger.debug("text_length=<%d> | gemini output transcription detected", len(transcription_text)) + return [ + BidiTranscriptStreamEvent( + delta={"text": transcription_text}, + text=transcription_text, + role="assistant", + # TODO: https://github.com/googleapis/python-genai/issues/1504 + is_final=bool(output_transcript.finished), + current_transcript=transcription_text, + ) + ] + + # Handle audio output using SDK's built-in data property + # Check this BEFORE text to avoid triggering warning on mixed content + if message.data: + # Convert bytes to base64 string for JSON serializability + audio_b64 = base64.b64encode(message.data).decode("utf-8") + return [ + BidiAudioStreamEvent( + audio=audio_b64, + format="pcm", + sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), + channels=cast(AudioChannel, self.config["audio"]["channels"]), + ) + ] + + # Handle text output from model_turn (avoids warning by checking parts directly) + if message.server_content and message.server_content.model_turn: + model_turn = message.server_content.model_turn + if model_turn.parts: + # Concatenate all text parts (Gemini may send multiple parts) + text_parts = [] + for part in model_turn.parts: + # Check if part has text attribute and it's not empty + if hasattr(part, "text") and part.text: + text_parts.append(part.text) + + if text_parts: + full_text = " ".join(text_parts) + return [ + BidiTranscriptStreamEvent( + delta={"text": full_text}, + text=full_text, + role="assistant", + is_final=True, + current_transcript=full_text, + ) + ] + + # Handle tool calls - return list to support multiple tool calls + if message.tool_call and message.tool_call.function_calls: + tool_events: list[BidiOutputEvent] = [] + for func_call in message.tool_call.function_calls: + tool_use_event: ToolUse = { + "toolUseId": cast(str, func_call.id), + "name": cast(str, func_call.name), + "input": func_call.args or {}, + } + # Create ToolUseStreamEvent for consistency with standard agent + tool_events.append( + ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) + ) + return tool_events + + # Handle usage metadata + if hasattr(message, "usage_metadata") and message.usage_metadata: + usage = message.usage_metadata + + # Build modality details from token details + modality_details = [] + + # Process prompt tokens details + if usage.prompt_tokens_details: + for detail in usage.prompt_tokens_details: + if detail.modality and detail.token_count: + modality_details.append( + { + "modality": str(detail.modality).lower(), + "input_tokens": detail.token_count, + "output_tokens": 0, + } + ) + + # Process response tokens details + if usage.response_tokens_details: + for detail in usage.response_tokens_details: + if detail.modality and detail.token_count: + # Find or create modality entry + modality_str = str(detail.modality).lower() + existing = next((m for m in modality_details if m["modality"] == modality_str), None) + if existing: + existing["output_tokens"] = detail.token_count + else: + modality_details.append( + {"modality": modality_str, "input_tokens": 0, "output_tokens": detail.token_count} + ) + + return [ + BidiUsageEvent( + input_tokens=usage.prompt_token_count or 0, + output_tokens=usage.response_token_count or 0, + total_tokens=usage.total_token_count or 0, + modality_details=cast(list[ModalityUsage], modality_details) if modality_details else None, + cache_read_input_tokens=usage.cached_content_token_count + if usage.cached_content_token_count + else None, + ) + ] + + # Silently ignore setup_complete and generation_complete messages + return [] + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Unified send method for all content types. Sends the given inputs to Google Live API. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). + + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before sending/receiving") + + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, BidiImageInputEvent): + await self._send_image_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio content using Gemini Live API. + + Gemini Live expects continuous audio streaming via send_realtime_input. + This automatically triggers VAD and can interrupt ongoing responses. + """ + # Decode base64 audio to bytes for SDK + audio_bytes = base64.b64decode(audio_input.audio) + + # Create audio blob for the SDK + mime_type = f"audio/pcm;rate={self.config['audio']['input_rate']}" + audio_blob = genai_types.Blob(data=audio_bytes, mime_type=mime_type) + + # Send real-time audio input - this automatically handles VAD and interruption + await self._live_session.send_realtime_input(audio=audio_blob) + + async def _send_image_content(self, image_input: BidiImageInputEvent) -> None: + """Internal: Send image content using Gemini Live API. + + Sends image frames following the same pattern as the GitHub example. + Images are sent as base64-encoded data with MIME type. + """ + # Image is already base64 encoded in the event + msg = {"mime_type": image_input.mime_type, "data": image_input.image} + + # Send using the same method as the GitHub example + await self._live_session.send(input=msg) + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Gemini Live API.""" + # Create content with text + content = genai_types.Content(role="user", parts=[genai_types.Part(text=text)]) + + # Send as client content + await self._live_session.send_client_content(turns=content) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Gemini Live API.""" + tool_use_id = tool_result.get("toolUseId") + content = tool_result.get("content", []) + + # Validate all content types are supported + for block in content: + if "text" not in block and "json" not in block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by Gemini Live API" + ) + + # Optimize for single content item - unwrap the array + if len(content) == 1: + result_data = cast(dict[str, Any], content[0]) + else: + # Multiple items - send as array + result_data = {"result": content} + + # Create function response + func_response = genai_types.FunctionResponse( + id=tool_use_id, + name=tool_use_id, # Gemini uses name as identifier + response=result_data, + ) + + # Send tool response + await self._live_session.send_tool_response(function_responses=[func_response]) + + async def stop(self) -> None: + """Close Gemini Live API connection.""" + + async def stop_session() -> None: + if not self._live_session_context_manager: + return + + await self._live_session_context_manager.__aexit__(None, None, None) + + async def stop_connection() -> None: + self._connection_id = None + + await stop_all(stop_session, stop_connection) + + def _build_live_config( + self, system_prompt: str | None = None, tools: list[ToolSpec] | None = None, **kwargs: Any + ) -> dict[str, Any]: + """Build LiveConnectConfig for the official SDK. + + Simply passes through all config parameters from provider_config, allowing users + to configure any Gemini Live API parameter directly. + """ + config_dict: dict[str, Any] = self.config["inference"].copy() + + config_dict["session_resumption"] = {"handle": kwargs.get("live_session_handle")} + + # Add system instruction if provided + if system_prompt: + config_dict["system_instruction"] = system_prompt + + # Add tools if provided + if tools: + config_dict["tools"] = self._format_tools_for_live_api(tools) + + if "voice" in self.config["audio"]: + config_dict.setdefault("speech_config", {}).setdefault("voice_config", {}).setdefault( + "prebuilt_voice_config", {} + )["voice_name"] = self.config["audio"]["voice"] + + return config_dict + + def _format_tools_for_live_api(self, tool_specs: list[ToolSpec]) -> list[genai_types.Tool]: + """Format tool specs for Gemini Live API.""" + if not tool_specs: + return [] + + return [ + genai_types.Tool( + function_declarations=[ + genai_types.FunctionDeclaration( + description=tool_spec["description"], + name=tool_spec["name"], + parameters_json_schema=tool_spec["inputSchema"]["json"], + ) + for tool_spec in tool_specs + ], + ), + ] diff --git a/src/strands/experimental/bidi/models/model.py b/src/strands/experimental/bidi/models/model.py new file mode 100644 index 000000000..f5e34aa50 --- /dev/null +++ b/src/strands/experimental/bidi/models/model.py @@ -0,0 +1,134 @@ +"""Bidirectional streaming model interface. + +Defines the abstract interface for models that support real-time bidirectional +communication with persistent connections. Unlike traditional request-response +models, bidirectional models maintain an open connection for streaming audio, +text, and tool interactions. + +Features: + +- Persistent connection management with connect/close lifecycle +- Real-time bidirectional communication (send and receive simultaneously) +- Provider-agnostic event normalization +- Support for audio, text, image, and tool result streaming +""" + +import logging +from typing import Any, AsyncIterable, Protocol + +from ....types._events import ToolResultEvent +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.events import ( + BidiInputEvent, + BidiOutputEvent, +) + +logger = logging.getLogger(__name__) + + +class BidiModel(Protocol): + """Protocol for bidirectional streaming models. + + This interface defines the contract for models that support persistent streaming + connections with real-time audio and text communication. Implementations handle + provider-specific protocols while exposing a standardized event-based API. + + Attributes: + config: Configuration dictionary with provider-specific settings. + """ + + config: dict[str, Any] + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish a persistent streaming connection with the model. + + Opens a bidirectional connection that remains active for real-time communication. + The connection supports concurrent sending and receiving of events until explicitly + closed. Must be called before any send() or receive() operations. + + Args: + system_prompt: System instructions to configure model behavior. + tools: Tool specifications that the model can invoke during the conversation. + messages: Initial conversation history to provide context. + **kwargs: Provider-specific configuration options. + """ + ... + + async def stop(self) -> None: + """Close the streaming connection and release resources. + + Terminates the active bidirectional connection and cleans up any associated + resources such as network connections, buffers, or background tasks. After + calling close(), the model instance cannot be used until start() is called again. + """ + ... + + def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive streaming events from the model. + + Continuously yields events from the model as they arrive over the connection. + Events are normalized to a provider-agnostic format for uniform processing. + This method should be called in a loop or async task to process model responses. + + The stream continues until the connection is closed or an error occurs. + + Yields: + BidiOutputEvent: Standardized event objects containing audio output, + transcripts, tool calls, or control signals. + """ + ... + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Send content to the model over the active connection. + + Transmits user input or tool results to the model during an active streaming + session. Supports multiple content types including text, audio, images, and + tool execution results. Can be called multiple times during a conversation. + + Args: + content: The content to send. Must be one of: + + - BidiTextInputEvent: Text message from the user + - BidiAudioInputEvent: Audio data for speech input + - BidiImageInputEvent: Image data for visual understanding + - ToolResultEvent: Result from a tool execution + + Example: + ``` + await model.send(BidiTextInputEvent(text="Hello", role="user")) + await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) + await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) + await model.send(ToolResultEvent(tool_result)) + ``` + """ + ... + + +class BidiModelTimeoutError(Exception): + """Model timeout error. + + Bidirectional models are often configured with a connection time limit. Nova sonic for example keeps the connection + open for 8 minutes max. Upon receiving a timeout, the agent loop is configured to restart the model connection so as + to create a seamless, uninterrupted experience for the user. + """ + + def __init__(self, message: str, **restart_config: Any) -> None: + """Initialize error. + + Args: + message: Timeout message from model. + **restart_config: Configure restart specific behaviors in the call to model start. + """ + super().__init__(self, message) + + self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py new file mode 100644 index 000000000..6a2477e22 --- /dev/null +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -0,0 +1,758 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +Implements the BidiModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. + +Nova Sonic specifics: + +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import Any, AsyncGenerator, cast + +import boto3 +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import ( + BidirectionalInputPayloadPart, + InvokeModelWithBidirectionalStreamInputChunk, + ModelTimeoutException, + ValidationException, +) +from smithy_aws_core.identity.static import StaticCredentialsResolver +from smithy_core.aio.eventstream import DuplexEventStream +from smithy_core.shapes import ShapeID + +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all +from ..types.events import ( + AudioChannel, + AudioSampleRate, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, +) +from ..types.model import AudioConfig +from .model import BidiModel, BidiModelTimeoutError + +logger = logging.getLogger(__name__) + +_NOVA_INFERENCE_CONFIG_KEYS = { + "max_tokens": "maxTokens", + "temperature": "temperature", + "top_p": "topP", +} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64", +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH", +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + + +class BidiNovaSonicModel(BidiModel): + """Nova Sonic implementation for bidirectional streaming. + + Combines model configuration and connection state in a single class. + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidiModel interface. + + Attributes: + _stream: open bedrock stream to nova sonic. + """ + + _stream: DuplexEventStream + + def __init__( + self, + model_id: str = "amazon.nova-sonic-v1:0", + provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Model identifier (default: amazon.nova-sonic-v1:0) + provider_config: Model behavior (audio, inference settings) + client_config: AWS authentication (boto_session OR region, not both) + **kwargs: Reserved for future parameters. + """ + # Store model ID + self.model_id = model_id + + # Resolve client config with defaults + self._client_config = self._resolve_client_config(client_config or {}) + + # Resolve provider config with defaults + self.config = self._resolve_provider_config(provider_config or {}) + + # Store session and region for later use + self._session = self._client_config["boto_session"] + self.region = self._client_config["region"] + + # Track API-provided identifiers + self._connection_id: str | None = None + self._audio_content_name: str | None = None + self._current_completion_id: str | None = None + + # Indicates if model is done generating transcript + self._generation_stage: str | None = None + + # Ensure certain events are sent in sequence when required + self._send_lock = asyncio.Lock() + + logger.debug("model_id=<%s> | nova sonic model initialized", model_id) + + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve AWS client config (creates boto session if needed).""" + if "boto_session" in config and "region" in config: + raise ValueError("Cannot specify both 'boto_session' and 'region' in client_config") + + resolved = config.copy() + + # Create boto session if not provided + if "boto_session" not in resolved: + resolved["boto_session"] = boto3.Session() + + # Resolve region from session or use default + if "region" not in resolved: + resolved["region"] = resolved["boto_session"].region_name or "us-east-1" + + return resolved + + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + default_audio: AudioConfig = { + "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), + "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), + "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), + "format": "pcm", + "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), + } + + resolved = { + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": config.get("inference", {}), + } + return resolved + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish bidirectional connection to Nova Sonic. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + + Raises: + RuntimeError: If user calls start again without first stopping. + """ + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") + + logger.debug("nova connection starting") + + self._connection_id = str(uuid.uuid4()) + + # Get credentials from boto3 session (full credential chain) + credentials = self._session.get_credentials() + + if not credentials: + raise ValueError( + "no AWS credentials found. configure credentials via environment variables, " + "credential files, IAM roles, or SSO." + ) + + # Use static resolver with credentials configured as properties + resolver = StaticCredentialsResolver() + + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=resolver, + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={ShapeID("aws.auth#sigv4"): SigV4AuthScheme(service="bedrock")}, + # Configure static credentials as properties + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token, + ) + + self.client = BedrockRuntimeClient(config=config) + logger.debug("region=<%s> | nova sonic client initialized", self.region) + + client = BedrockRuntimeClient(config=config) + self._stream = await client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + logger.debug("region=<%s> | nova sonic client initialized", self.region) + + init_events = self._build_initialization_events(system_prompt, tools, messages) + logger.debug("event_count=<%d> | sending nova sonic initialization events", len(init_events)) + await self._send_nova_events(init_events) + + logger.info("connection_id=<%s> | nova sonic connection established", self._connection_id) + + def _build_initialization_events( + self, system_prompt: str | None, tools: list[ToolSpec] | None, messages: Messages | None + ) -> list[str]: + """Build the sequence of initialization events.""" + tools = tools or [] + events = [ + self._get_connection_start_event(), + self._get_prompt_start_event(tools), + *self._get_system_prompt_events(system_prompt), + ] + + # Add conversation history if provided + if messages: + events.extend(self._get_message_history_events(messages)) + logger.debug("message_count=<%d> | conversation history added to initialization", len(messages)) + + return events + + def _log_event_type(self, nova_event: dict[str, Any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if "usageEvent" in nova_event: + logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) + elif "textOutput" in nova_event: + logger.debug("nova text output received") + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", + tool_use["toolName"], + tool_use["toolUseId"], + ) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive Nova Sonic events and convert to provider-agnostic format. + + Raises: + RuntimeError: If start has not been called. + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before receiving") + + logger.debug("nova event stream starting") + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + _, output = await self._stream.await_output() + while True: + try: + event_data = await output.receive() + + except ValidationException as error: + if "InternalErrorCode=531" in error.message: + # nova also times out if user is silent for 175 seconds + raise BidiModelTimeoutError(error.message) from error + raise + + except ModelTimeoutException as error: + raise BidiModelTimeoutError(error.message) from error + + if not event_data: + continue + + nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] + self._log_event_type(nova_event) + + model_event = self._convert_nova_event(nova_event) + if model_event: + yield model_event + + async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: + """Unified send method for all content types. Sends the given content to Nova Sonic. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Input event. + + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before sending") + + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") + + async def _start_audio_connection(self) -> None: + """Internal: Start audio input connection (call once before sending audio chunks).""" + logger.debug("nova audio connection starting") + self._audio_content_name = str(uuid.uuid4()) + + # Build audio input configuration from config + audio_input_config = { + "mediaType": "audio/lpcm", + "sampleRateHertz": self.config["audio"]["input_rate"], + "sampleSizeBits": 16, + "channelCount": self.config["audio"]["channels"], + "audioType": "SPEECH", + "encoding": "base64", + } + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": self._audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": audio_input_config, + } + } + } + ) + + await self._send_nova_events([audio_content_start]) + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio using Nova Sonic protocol-specific format.""" + # Start audio connection if not already active + if not self._audio_content_name: + await self._start_audio_connection() + + # Audio is already base64 encoded in the event + # Send audio input event + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self._connection_id, + "contentName": self._audio_content_name, + "content": audio_input.audio, + } + } + } + ) + + await self._send_nova_events([audio_event]) + + async def _end_audio_input(self) -> None: + """Internal: End current audio input connection to trigger Nova Sonic processing.""" + if not self._audio_content_name: + return + + logger.debug("nova audio connection ending") + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self._connection_id, "contentName": self._audio_content_name}}} + ) + + await self._send_nova_events([audio_content_end]) + self._audio_content_name = None + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Nova Sonic format.""" + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name), + ] + await self._send_nova_events(events) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Nova Sonic toolResult format.""" + tool_use_id = tool_result["toolUseId"] + + logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) + + # Validate content types and preserve structure + content = tool_result.get("content", []) + + # Validate all content types are supported + for block in content: + if "text" not in block and "json" not in block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by Nova Sonic" + ) + + # Optimize for single content item - unwrap the array + if len(content) == 1: + result_data = cast(dict[str, Any], content[0]) + else: + # Multiple items - send as array + result_data = {"content": content} + + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result_data), + self._get_content_end_event(content_name), + ] + await self._send_nova_events(events) + + async def stop(self) -> None: + """Close Nova Sonic connection with proper cleanup sequence.""" + logger.debug("nova connection cleanup starting") + + async def stop_events() -> None: + if not self._connection_id: + return + + await self._end_audio_input() + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + await self._send_nova_events(cleanup_events) + + async def stop_stream() -> None: + if not hasattr(self, "_stream"): + return + + await self._stream.close() + + async def stop_connection() -> None: + self._connection_id = None + + await stop_all(stop_events, stop_stream, stop_connection) + + logger.debug("nova connection closed") + + def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | None: + """Convert Nova Sonic events to TypedEvent format.""" + # Handle completion start - track completionId + if "completionStart" in nova_event: + completion_data = nova_event["completionStart"] + self._current_completion_id = completion_data.get("completionId") + logger.debug("completion_id=<%s> | nova completion started", self._current_completion_id) + return None + + # Handle completion end + if "completionEnd" in nova_event: + completion_data = nova_event["completionEnd"] + completion_id = completion_data.get("completionId", self._current_completion_id) + stop_reason = completion_data.get("stopReason", "END_TURN") + + event = BidiResponseCompleteEvent( + response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing + stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete", + ) + + # Clear completion tracking + self._current_completion_id = None + return event + + # Handle audio output + if "audioOutput" in nova_event: + # Audio is already base64 string from Nova Sonic + audio_content = nova_event["audioOutput"]["content"] + return BidiAudioStreamEvent( + audio=audio_content, + format="pcm", + sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), + channels=cast(AudioChannel, self.config["audio"]["channels"]), + ) + + # Handle text output (transcripts) + elif "textOutput" in nova_event: + text_output = nova_event["textOutput"] + text_content = text_output["content"] + # Check for Nova Sonic interruption pattern + if '{ "interrupted" : true }' in text_content: + logger.debug("nova interruption detected in text output") + return BidiInterruptionEvent(reason="user_speech") + + return BidiTranscriptStreamEvent( + delta={"text": text_content}, + text=text_content, + role=text_output["role"].lower(), + is_final=self._generation_stage == "FINAL", + current_transcript=text_content, + ) + + # Handle tool use + if "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]), + } + # Return ToolUseStreamEvent - cast to dict for type compatibility + return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) + + # Handle interruption + if nova_event.get("stopReason") == "INTERRUPTED": + logger.debug("nova interruption detected via stop reason") + return BidiInterruptionEvent(reason="user_speech") + + # Handle usage events - convert to multimodal usage format + if "usageEvent" in nova_event: + usage_data = nova_event["usageEvent"] + total_input = usage_data.get("totalInputTokens", 0) + total_output = usage_data.get("totalOutputTokens", 0) + + return BidiUsageEvent( + input_tokens=total_input, + output_tokens=total_output, + total_tokens=usage_data.get("totalTokens", total_input + total_output), + ) + + # Handle content start events (emit response start) + if "contentStart" in nova_event: + content_data = nova_event["contentStart"] + if content_data["type"] == "TEXT": + self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] + + # Emit response start event using API-provided completionId + # completionId should already be tracked from completionStart event + return BidiResponseStartEvent( + response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing + ) + + if "contentEnd" in nova_event: + self._generation_stage = None + + # Ignore all other events + return None + + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" + inference_config = {_NOVA_INFERENCE_CONFIG_KEYS[key]: value for key, value in self.config["inference"].items()} + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": inference_config}}}) + + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + # Build audio output configuration from config + audio_output_config = { + "mediaType": "audio/lpcm", + "sampleRateHertz": self.config["audio"]["output_rate"], + "sampleSizeBits": 16, + "channelCount": self.config["audio"]["channels"], + "voiceId": self.config["audio"].get("voice", "matthew"), + "encoding": "base64", + "audioType": "SPEECH", + } + + prompt_start_event: dict[str, Any] = { + "event": { + "promptStart": { + "promptName": self._connection_id, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": audio_output_config, + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict[str, Any]]: + """Build tool configuration from tool specs.""" + tool_config: list[dict[str, Any]] = [] + for tool in tools: + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str | None) -> list[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt or ""), + self._get_content_end_event(content_name), + ] + + def _get_message_history_events(self, messages: Messages) -> list[str]: + """Generate conversation history events from agent messages. + + Converts agent message history to Nova Sonic format following the + contentStart/textInput/contentEnd pattern for each message. + + Args: + messages: List of conversation messages with role and content. + + Returns: + List of JSON event strings for Nova Sonic. + """ + events = [] + + for message in messages: + role = message["role"].upper() # Convert to ASSISTANT or USER + content_blocks = message.get("content", []) + + # Extract text content from content blocks + text_parts = [] + for block in content_blocks: + if "text" in block: + text_parts.append(block["text"]) + + # Combine all text parts + if text_parts: + combined_text = "\n".join(text_parts) + content_name = str(uuid.uuid4()) + + # Add contentStart, textInput, and contentEnd events + events.extend( + [ + self._get_text_content_start_event(content_name, role), + self._get_text_input_event(content_name, combined_text), + self._get_content_end_event(content_name), + ] + ) + + return events + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } + } + } + ) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, + } + } + } + ) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps( + {"event": {"textInput": {"promptName": self._connection_id, "contentName": content_name, "content": text}}} + ) + + def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> str: + """Generate tool result event.""" + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self._connection_id, + "contentName": content_name, + "content": json.dumps(result), + } + } + } + ) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({"event": {"contentEnd": {"promptName": self._connection_id, "contentName": content_name}}}) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({"event": {"promptEnd": {"promptName": self._connection_id}}}) + + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" + return json.dumps({"event": {"connectionEnd": {}}}) + + async def _send_nova_events(self, events: list[str]) -> None: + """Send event JSON string to Nova Sonic stream. + + A lock is used to send events in sequence when required (e.g., tool result start, content, and end). + + Args: + events: Jsonified events. + """ + async with self._send_lock: + for event in events: + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=bytes_data) + ) + await self._stream.input_stream.send(chunk) + logger.debug("nova sonic event sent successfully") diff --git a/src/strands/experimental/bidi/models/openai_realtime.py b/src/strands/experimental/bidi/models/openai_realtime.py new file mode 100644 index 000000000..9196a39d5 --- /dev/null +++ b/src/strands/experimental/bidi/models/openai_realtime.py @@ -0,0 +1,793 @@ +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import json +import logging +import os +import time +import uuid +from typing import Any, AsyncGenerator, Literal, cast + +import websockets +from websockets import ClientConnection + +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all +from ..types.events import ( + AudioSampleRate, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, + ModalityUsage, + Role, + StopReason, +) +from ..types.model import AudioConfig +from .model import BidiModel, BidiModelTimeoutError + +logger = logging.getLogger(__name__) + +# Test idle_timeout_ms + +# OpenAI Realtime API configuration +OPENAI_MAX_TIMEOUT_S = 3000 # 50 minutes +"""Max timeout before closing connection. + +OpenAI documents a 60 minute limit on realtime sessions +([docs](https://platform.openai.com/docs/guides/realtime-conversations#session-lifecycle-events)). However, OpenAI does +not emit any warnings when approaching the limit. As a workaround, we configure a max timeout client side to gracefully +handle the connection closure. We set the max to 50 minutes to provide enough buffer before hitting the real limit. +""" +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" +DEFAULT_SAMPLE_RATE = 24000 + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, + "transcription": {"model": "gpt-4o-transcribe"}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + }, + }, + "output": {"format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, "voice": "alloy"}, + }, +} + + +class BidiOpenAIRealtimeModel(BidiModel): + """OpenAI Realtime API implementation for bidirectional streaming. + + Combines model configuration and connection state in a single class. + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + _websocket: ClientConnection + _start_time: int + + def __init__( + self, + model_id: str = DEFAULT_MODEL, + provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialize OpenAI Realtime bidirectional model. + + Args: + model_id: Model identifier (default: gpt-realtime) + provider_config: Model behavior (audio, instructions, turn_detection, etc.) + client_config: Authentication (api_key, organization, project) + Falls back to OPENAI_API_KEY, OPENAI_ORGANIZATION, OPENAI_PROJECT env vars + **kwargs: Reserved for future parameters. + + """ + # Store model ID + self.model_id = model_id + + # Resolve client config with defaults and env vars + self._client_config = self._resolve_client_config(client_config or {}) + + # Resolve provider config with defaults + self.config = self._resolve_provider_config(provider_config or {}) + + # Store client config values for later use + self.api_key = self._client_config["api_key"] + self.organization = self._client_config.get("organization") + self.project = self._client_config.get("project") + self.timeout_s = self._client_config["timeout_s"] + + if self.timeout_s > OPENAI_MAX_TIMEOUT_S: + raise ValueError( + f"timeout_s=<{self.timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" + ) + + # Connection state (initialized in start()) + self._connection_id: str | None = None + + self._function_call_buffer: dict[str, Any] = {} + + logger.debug("model=<%s> | openai realtime model initialized", model_id) + + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve client config with env var fallback (config takes precedence).""" + resolved = config.copy() + + if "api_key" not in resolved: + resolved["api_key"] = os.getenv("OPENAI_API_KEY") + + if not resolved.get("api_key"): + raise ValueError( + "OpenAI API key is required. Provide via client_config={'api_key': '...'} " + "or set OPENAI_API_KEY environment variable." + ) + if "organization" not in resolved: + env_org = os.getenv("OPENAI_ORGANIZATION") + if env_org: + resolved["organization"] = env_org + + if "project" not in resolved: + env_project = os.getenv("OPENAI_PROJECT") + if env_project: + resolved["project"] = env_project + + if "timeout_s" not in resolved: + resolved["timeout_s"] = OPENAI_MAX_TIMEOUT_S + + return resolved + + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + default_audio: AudioConfig = { + "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), + "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), + "channels": 1, + "format": "pcm", + "voice": "alloy", + } + + resolved = { + "audio": { + **default_audio, + **config.get("audio", {}), + }, + "inference": config.get("inference", {}), + } + return resolved + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish bidirectional connection to OpenAI Realtime API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") + + logger.debug("openai realtime connection starting") + + # Initialize connection state + self._connection_id = str(uuid.uuid4()) + self._start_time = int(time.time()) + + self._function_call_buffer = {} + + # Establish WebSocket connection + url = f"{OPENAI_REALTIME_URL}?model={self.model_id}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if self.organization: + headers.append(("OpenAI-Organization", self.organization)) + if self.project: + headers.append(("OpenAI-Project", self.project)) + + self._websocket = await websockets.connect(url, additional_headers=headers) + logger.debug("connection_id=<%s> | websocket connected successfully", self._connection_id) + + # Configure session + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + # Add conversation history if provided + if messages: + await self._add_conversation_history(messages) + + def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: + """Create standardized transcript event. + + Args: + text: The transcript text + role: The role (will be normalized to lowercase) + is_final: Whether this is the final transcript + """ + # Normalize role to lowercase and ensure it's either "user" or "assistant" + normalized_role = role.lower() if isinstance(role, str) else "assistant" + if normalized_role not in ["user", "assistant"]: + normalized_role = "assistant" + + return BidiTranscriptStreamEvent( + delta={"text": text}, + text=text, + role=cast(Role, normalized_role), + is_final=is_final, + current_transcript=text if is_final else None, + ) + + def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: + """Create standardized interruption event for voice activity.""" + # Only speech_started triggers interruption + if activity_type == "speech_started": + return BidiInterruptionEvent(reason="user_speech") + # Other voice activity events are logged but don't create events + return None + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict[str, Any]: + """Build session configuration for OpenAI Realtime API.""" + config: dict[str, Any] = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + # Apply user-provided session configuration + supported_params = { + "max_output_tokens", + "output_modalities", + "tool_choice", + } + for key, value in self.config["inference"].items(): + if key in supported_params: + config[key] = value + else: + logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) + + audio_config = self.config["audio"] + + if "voice" in audio_config: + config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"] + + if "input_rate" in audio_config: + config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[ + "input_rate" + ] + + if "output_rate" in audio_config: + config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[ + "output_rate" + ] + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI Realtime API format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = ( + json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + ) + else: + schema = input_schema + + # OpenAI Realtime API expects flat structure, not nested under "function" + openai_tool = { + "type": "function", + "name": tool["name"], + "description": tool["description"], + "parameters": schema, + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session. + + Converts agent message history to OpenAI Realtime API format using + conversation.item.create events for each message. + + Note: OpenAI Realtime API has a 32-character limit on call_id, so we truncate + UUIDs consistently to ensure tool calls and their results match. + + Args: + messages: List of conversation messages with role and content. + """ + # Track tool call IDs to ensure consistency between calls and results + call_id_map: dict[str, str] = {} + + # First pass: collect all tool call IDs + for message in messages: + for block in message.get("content", []): + if "toolUse" in block: + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + call_id = original_id[:32] + call_id_map[original_id] = call_id + + # Second pass: send messages + for message in messages: + role = message["role"] + content_blocks = message.get("content", []) + + # Build content array for OpenAI format + openai_content = [] + + for block in content_blocks: + if "text" in block: + # Text content - use appropriate type based on role + # User messages use "input_text", assistant messages use "output_text" + if role == "user": + openai_content.append({"type": "input_text", "text": block["text"]}) + else: # assistant + openai_content.append({"type": "output_text", "text": block["text"]}) + elif "toolUse" in block: + # Tool use - create as function_call item + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + # Use pre-mapped call_id + call_id = call_id_map[original_id] + + tool_item = { + "type": "conversation.item.create", + "item": { + "type": "function_call", + "call_id": call_id, + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + }, + } + await self._send_event(tool_item) + continue # Tool use is sent separately, not in message content + elif "toolResult" in block: + # Tool result - create as function_call_output item + tool_result = block["toolResult"] + original_id = tool_result["toolUseId"] + + # Validate content types and serialize, preserving structure + result_output = "" + if "content" in tool_result: + # First validate all content types are supported + for result_block in tool_result["content"]: + if "text" not in result_block and "json" not in result_block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | " + f"Content type not supported by OpenAI Realtime API" + ) + + # Preserve structure by JSON-dumping the entire content array + result_output = json.dumps(tool_result["content"]) + + # Use mapped call_id if available, otherwise skip orphaned result + if original_id not in call_id_map: + continue # Skip this tool result since we don't have the call + + call_id = call_id_map[original_id] + + result_item = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result_output, + }, + } + await self._send_event(result_item) + continue # Tool result is sent separately, not in message content + + # Only create message item if there's text content + if openai_content: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": role, "content": openai_content}, + } + await self._send_event(conversation_item) + + logger.debug("message_count=<%d> | conversation history added to openai session", len(messages)) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive OpenAI events and convert to Strands TypedEvent format.""" + if not self._connection_id: + raise RuntimeError("model not started | call start before sending/receiving") + + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + while True: + duration = time.time() - self._start_time + if duration >= self.timeout_s: + raise BidiModelTimeoutError(f"timeout_s=<{self.timeout_s}>") + + try: + message = await asyncio.wait_for(self._websocket.recv(), timeout=10) + except asyncio.TimeoutError: + continue + + openai_event = json.loads(message) + + for event in self._convert_openai_event(openai_event) or []: + yield event + + def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutputEvent] | None: + """Convert OpenAI events to Strands TypedEvent format.""" + event_type = openai_event.get("type") + + # Turn start - response begins + if event_type == "response.created": + response = openai_event.get("response", {}) + response_id = response.get("id", str(uuid.uuid4())) + return [BidiResponseStartEvent(response_id=response_id)] + + # Audio output + elif event_type == "response.output_audio.delta": + # Audio is already base64 string from OpenAI + # Use the resolved output sample rate from our merged configuration + sample_rate = self.config["audio"]["output_rate"] + + # Channels from config is guaranteed to be 1 or 2 + channels = cast(Literal[1, 2], self.config["audio"]["channels"]) + return [ + BidiAudioStreamEvent( + audio=openai_event["delta"], + format="pcm", + sample_rate=sample_rate, + channels=channels, + ) + ] + + # Assistant text output events - combine multiple similar events + elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: + role = openai_event.get("role", "assistant") + return [ + self._create_text_event( + openai_event["delta"], role.lower() if isinstance(role, str) else "assistant", is_final=False + ) + ] + + elif event_type in ["response.output_audio_transcript.done"]: + role = openai_event.get("role", "assistant").lower() + return [self._create_text_event(openai_event["transcript"], role)] + + elif event_type in ["response.output_text.done"]: + role = openai_event.get("role", "assistant").lower() + return [self._create_text_event(openai_event["text"], role)] + + # User transcription events - combine multiple similar events + elif event_type in [ + "conversation.item.input_audio_transcription.delta", + "conversation.item.input_audio_transcription.completed", + ]: + text_key = "delta" if "delta" in event_type else "transcript" + text = openai_event.get(text_key, "") + role = openai_event.get("role", "user") + is_final = "completed" in event_type + return ( + [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] + if text.strip() + else None + ) + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + role = segment_data.get("role", "user") + return ( + [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] + if text.strip() + else None + ) + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("error=<%s> | openai transcription failed", error_info.get("message", "unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + # Return ToolUseStreamEvent for consistency with standard agent + return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=dict(tool_use))] + except (json.JSONDecodeError, KeyError) as e: + logger.warning("call_id=<%s>, error=<%s> | error parsing function arguments", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection - speech_started triggers interruption + elif event_type == "input_audio_buffer.speech_started": + # This is the primary interruption signal - handle it first + return [BidiInterruptionEvent(reason="user_speech")] + + # Response cancelled - handle interruption + elif event_type == "response.cancelled": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + logger.debug("response_id=<%s> | openai response cancelled", response_id) + return [BidiResponseCompleteEvent(response_id=response_id, stop_reason="interrupted")] + + # Turn complete and usage - response finished + elif event_type == "response.done": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + status = response.get("status", "completed") + usage = response.get("usage") + + # Map OpenAI status to our stop_reason + stop_reason_map = { + "completed": "complete", + "cancelled": "interrupted", + "failed": "error", + "incomplete": "interrupted", + } + + # Build list of events to return + events: list[Any] = [] + + # Always add response complete event + events.append( + BidiResponseCompleteEvent( + response_id=response_id, + stop_reason=cast(StopReason, stop_reason_map.get(status, "complete")), + ), + ) + + # Add usage event if available + if usage: + input_details = usage.get("input_token_details", {}) + output_details = usage.get("output_token_details", {}) + + # Build modality details + modality_details = [] + + # Text modality + text_input = input_details.get("text_tokens", 0) + text_output = output_details.get("text_tokens", 0) + if text_input > 0 or text_output > 0: + modality_details.append( + {"modality": "text", "input_tokens": text_input, "output_tokens": text_output} + ) + + # Audio modality + audio_input = input_details.get("audio_tokens", 0) + audio_output = output_details.get("audio_tokens", 0) + if audio_input > 0 or audio_output > 0: + modality_details.append( + {"modality": "audio", "input_tokens": audio_input, "output_tokens": audio_output} + ) + + # Image modality + image_input = input_details.get("image_tokens", 0) + if image_input > 0: + modality_details.append({"modality": "image", "input_tokens": image_input, "output_tokens": 0}) + + # Cached tokens + cached_tokens = input_details.get("cached_tokens", 0) + + # Add usage event + events.append( + BidiUsageEvent( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + modality_details=cast(list[ModalityUsage], modality_details) if modality_details else None, + cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None, + ) + ) + + # Return list of events + return events + + # Lifecycle events (log only) - combine multiple similar events + elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: + item = openai_event.get("item", {}) + action = "retrieved" if "retrieve" in event_type else "added" + logger.debug("action=<%s>, item_id=<%s> | openai conversation item event", action, item.get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("item_id=<%s> | openai conversation item done", openai_event.get("item", {}).get("id")) + return None + + # Response output events - combine similar events + elif event_type in [ + "response.output_item.added", + "response.output_item.done", + "response.content_part.added", + "response.content_part.done", + ]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug( + "event_type=<%s>, item_id=<%s> | openai output event", + event_type, + item_data.get("id") if item_data else "unknown", + ) + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = { + "call_id": call_id, + "name": function_name, + "arguments": "", + } + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + # Session/buffer events - combine simple log-only events + elif event_type in [ + "input_audio_buffer.committed", + "input_audio_buffer.cleared", + "session.created", + "session.updated", + ]: + logger.debug("event_type=<%s> | openai event received", event_type) + return None + + elif event_type == "error": + error_data = openai_event.get("error", {}) + error_code = error_data.get("code", "") + + # Suppress expected errors that don't affect session state + if error_code == "response_cancel_not_active": + # This happens when trying to cancel a response that's not active + # It's safe to ignore as the session remains functional + logger.debug("openai response cancel attempted when no response active") + return None + + # Log other errors + logger.error("error=<%s> | openai realtime error", error_data) + return None + + else: + logger.debug("event_type=<%s> | unhandled openai event type", event_type) + return None + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Unified send method for all content types. Sends the given content to OpenAI. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). + + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before sending") + + # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio content to OpenAI for processing.""" + # Audio is already base64 encoded in the event + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content to OpenAI for processing.""" + item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to OpenAI.""" + await self._send_event({"type": "response.cancel"}) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result back to OpenAI.""" + tool_use_id = tool_result.get("toolUseId") + + logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) + + # Validate content types and serialize, preserving structure + result_output = "" + if "content" in tool_result: + # First validate all content types are supported + for block in tool_result["content"]: + if "text" not in block and "json" not in block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by OpenAI Realtime API" + ) + + # Preserve structure by JSON-dumping the entire content array + result_output = json.dumps(tool_result["content"]) + + item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_output} + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def stop(self) -> None: + """Close session and cleanup resources.""" + logger.debug("openai realtime connection cleanup starting") + + async def stop_websocket() -> None: + if not hasattr(self, "_websocket"): + return + + await self._websocket.close() + + async def stop_connection() -> None: + self._connection_id = None + + await stop_all(stop_websocket, stop_connection) + + logger.debug("openai realtime connection closed") + + async def _send_event(self, event: dict[str, Any]) -> None: + """Send event to OpenAI via WebSocket.""" + message = json.dumps(event) + await self._websocket.send(message) + logger.debug("event_type=<%s> | openai event sent", event.get("type")) diff --git a/src/strands/experimental/bidi/tools/__init__.py b/src/strands/experimental/bidi/tools/__init__.py new file mode 100644 index 000000000..c665dc65a --- /dev/null +++ b/src/strands/experimental/bidi/tools/__init__.py @@ -0,0 +1,5 @@ +"""Built-in tools for bidirectional agents.""" + +from .stop_conversation import stop_conversation + +__all__ = ["stop_conversation"] diff --git a/src/strands/experimental/bidi/tools/stop_conversation.py b/src/strands/experimental/bidi/tools/stop_conversation.py new file mode 100644 index 000000000..9c7e1c6cd --- /dev/null +++ b/src/strands/experimental/bidi/tools/stop_conversation.py @@ -0,0 +1,16 @@ +"""Tool to gracefully stop a bidirectional connection.""" + +from ....tools.decorator import tool + + +@tool +def stop_conversation() -> str: + """Stop the bidirectional conversation gracefully. + + Use ONLY when user says "stop conversation" exactly. + Do NOT use for: "stop", "goodbye", "bye", "exit", "quit", "end" or other farewells or phrases. + + Returns: + Success message confirming the conversation will end + """ + return "Ending conversation" diff --git a/src/strands/experimental/bidi/types/__init__.py b/src/strands/experimental/bidi/types/__init__.py new file mode 100644 index 000000000..903a54508 --- /dev/null +++ b/src/strands/experimental/bidi/types/__init__.py @@ -0,0 +1,46 @@ +"""Type definitions for bidirectional streaming.""" + +from .agent import BidiAgentInput +from .events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionRestartEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, + ModalityUsage, +) +from .io import BidiInput, BidiOutput + +__all__ = [ + "BidiInput", + "BidiOutput", + "BidiAgentInput", + # Input Events + "BidiTextInputEvent", + "BidiAudioInputEvent", + "BidiImageInputEvent", + "BidiInputEvent", + # Output Events + "BidiConnectionStartEvent", + "BidiConnectionRestartEvent", + "BidiConnectionCloseEvent", + "BidiResponseStartEvent", + "BidiResponseCompleteEvent", + "BidiAudioStreamEvent", + "BidiTranscriptStreamEvent", + "BidiInterruptionEvent", + "BidiUsageEvent", + "ModalityUsage", + "BidiErrorEvent", + "BidiOutputEvent", +] diff --git a/src/strands/experimental/bidi/types/agent.py b/src/strands/experimental/bidi/types/agent.py new file mode 100644 index 000000000..8d1e9aab7 --- /dev/null +++ b/src/strands/experimental/bidi/types/agent.py @@ -0,0 +1,10 @@ +"""Agent-related type definitions for bidirectional streaming. + +This module defines the types used for BidiAgent. +""" + +from typing import TypeAlias + +from .events import BidiAudioInputEvent, BidiImageInputEvent, BidiTextInputEvent + +BidiAgentInput: TypeAlias = str | BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py new file mode 100644 index 000000000..9d44fc660 --- /dev/null +++ b/src/strands/experimental/bidi/types/events.py @@ -0,0 +1,612 @@ +"""Bidirectional streaming types for real-time audio/text conversations. + +Type definitions for bidirectional streaming that extends Strands' existing streaming +capabilities with real-time audio and persistent connection support. + +Key features: + +- Audio input/output events with standardized formats +- Interruption detection and handling +- Connection lifecycle management +- Provider-agnostic event types +- Type-safe discriminated unions with TypedEvent +- JSON-serializable events (audio/images stored as base64 strings) + +Audio format normalization: + +- Supports PCM, WAV, Opus, and MP3 formats +- Standardizes sample rates (16kHz, 24kHz, 48kHz) +- Normalizes channel configurations (mono/stereo) +- Abstracts provider-specific encodings +- Audio data stored as base64-encoded strings for JSON compatibility +""" + +from typing import TYPE_CHECKING, Any, Literal, cast + +from ....types._events import ModelStreamEvent, ToolUseStreamEvent, TypedEvent +from ....types.streaming import ContentBlockDelta + +if TYPE_CHECKING: + from ..models.model import BidiModelTimeoutError + +AudioChannel = Literal[1, 2] +"""Number of audio channels. + +- Mono: 1 +- Stereo: 2 +""" +AudioFormat = Literal["pcm", "wav", "opus", "mp3"] +"""Audio encoding format.""" +AudioSampleRate = Literal[16000, 24000, 48000] +"""Audio sample rate in Hz.""" + +Role = Literal["user", "assistant"] +"""Role of a message sender. + +- "user": Messages from the user to the assistant. +- "assistant": Messages from the assistant to the user. +""" + +StopReason = Literal["complete", "error", "interrupted", "tool_use"] +"""Reason for the model ending its response generation. + +- "complete": Model completed its response. +- "error": Model encountered an error. +- "interrupted": Model was interrupted by the user. +- "tool_use": Model is requesting a tool use. +""" + +# ============================================================================ +# Input Events (sent via agent.send()) +# ============================================================================ + + +class BidiTextInputEvent(TypedEvent): + """Text input event for sending text to the model. + + Used for sending text content through the send() method. + + Parameters: + text: The text content to send to the model. + role: The role of the message sender (default: "user"). + """ + + def __init__(self, text: str, role: Role = "user"): + """Initialize text input event.""" + super().__init__( + { + "type": "bidi_text_input", + "text": text, + "role": role, + } + ) + + @property + def text(self) -> str: + """The text content to send to the model.""" + return cast(str, self["text"]) + + @property + def role(self) -> Role: + """The role of the message sender.""" + return cast(Role, self["role"]) + + +class BidiAudioInputEvent(TypedEvent): + """Audio input event for sending audio to the model. + + Used for sending audio data through the send() method. + + Parameters: + audio: Base64-encoded audio string to send to model. + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sample_rate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + """ + + def __init__( + self, + audio: str, + format: AudioFormat | str, + sample_rate: AudioSampleRate, + channels: AudioChannel, + ): + """Initialize audio input event.""" + super().__init__( + { + "type": "bidi_audio_input", + "audio": audio, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + ) + + @property + def audio(self) -> str: + """Base64-encoded audio string.""" + return cast(str, self["audio"]) + + @property + def format(self) -> AudioFormat: + """Audio encoding format.""" + return cast(AudioFormat, self["format"]) + + @property + def sample_rate(self) -> AudioSampleRate: + """Number of audio samples per second in Hz.""" + return cast(AudioSampleRate, self["sample_rate"]) + + @property + def channels(self) -> AudioChannel: + """Number of audio channels (1=mono, 2=stereo).""" + return cast(AudioChannel, self["channels"]) + + +class BidiImageInputEvent(TypedEvent): + """Image input event for sending images/video frames to the model. + + Used for sending image data through the send() method. + + Parameters: + image: Base64-encoded image string. + mime_type: MIME type (e.g., "image/jpeg", "image/png"). + """ + + def __init__( + self, + image: str, + mime_type: str, + ): + """Initialize image input event.""" + super().__init__( + { + "type": "bidi_image_input", + "image": image, + "mime_type": mime_type, + } + ) + + @property + def image(self) -> str: + """Base64-encoded image string.""" + return cast(str, self["image"]) + + @property + def mime_type(self) -> str: + """MIME type of the image (e.g., "image/jpeg", "image/png").""" + return cast(str, self["mime_type"]) + + +# ============================================================================ +# Output Events (received via agent.receive()) +# ============================================================================ + + +class BidiConnectionStartEvent(TypedEvent): + """Streaming connection established and ready for interaction. + + Parameters: + connection_id: Unique identifier for this streaming connection. + model: Model identifier (e.g., "gpt-realtime", "gemini-2.0-flash-live"). + """ + + def __init__(self, connection_id: str, model: str): + """Initialize connection start event.""" + super().__init__( + { + "type": "bidi_connection_start", + "connection_id": connection_id, + "model": model, + } + ) + + @property + def connection_id(self) -> str: + """Unique identifier for this streaming connection.""" + return cast(str, self["connection_id"]) + + @property + def model(self) -> str: + """Model identifier (e.g., 'gpt-realtime', 'gemini-2.0-flash-live').""" + return cast(str, self["model"]) + + +class BidiConnectionRestartEvent(TypedEvent): + """Agent is restarting the model connection after timeout.""" + + def __init__(self, timeout_error: "BidiModelTimeoutError"): + """Initialize. + + Args: + timeout_error: Timeout error reported by the model. + """ + super().__init__( + { + "type": "bidi_connection_restart", + "timeout_error": timeout_error, + } + ) + + @property + def timeout_error(self) -> "BidiModelTimeoutError": + """Model timeout error.""" + return cast("BidiModelTimeoutError", self["timeout_error"]) + + +class BidiResponseStartEvent(TypedEvent): + """Model starts generating a response. + + Parameters: + response_id: Unique identifier for this response (used in response.complete). + """ + + def __init__(self, response_id: str): + """Initialize response start event.""" + super().__init__({"type": "bidi_response_start", "response_id": response_id}) + + @property + def response_id(self) -> str: + """Unique identifier for this response.""" + return cast(str, self["response_id"]) + + +class BidiAudioStreamEvent(TypedEvent): + """Streaming audio output from the model. + + Parameters: + audio: Base64-encoded audio string. + format: Audio encoding format. + sample_rate: Number of audio samples per second in Hz. + channels: Number of audio channels (1=mono, 2=stereo). + """ + + def __init__( + self, + audio: str, + format: AudioFormat, + sample_rate: AudioSampleRate, + channels: AudioChannel, + ): + """Initialize audio stream event.""" + super().__init__( + { + "type": "bidi_audio_stream", + "audio": audio, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + ) + + @property + def audio(self) -> str: + """Base64-encoded audio string.""" + return cast(str, self["audio"]) + + @property + def format(self) -> AudioFormat: + """Audio encoding format.""" + return cast(AudioFormat, self["format"]) + + @property + def sample_rate(self) -> AudioSampleRate: + """Number of audio samples per second in Hz.""" + return cast(AudioSampleRate, self["sample_rate"]) + + @property + def channels(self) -> AudioChannel: + """Number of audio channels (1=mono, 2=stereo).""" + return cast(AudioChannel, self["channels"]) + + +class BidiTranscriptStreamEvent(ModelStreamEvent): + """Audio transcription streaming (user or assistant speech). + + Supports incremental transcript updates for providers that send partial + transcripts before the final version. + + Parameters: + delta: The incremental transcript change (ContentBlockDelta). + text: The delta text (same as delta content for convenience). + role: Who is speaking ("user" or "assistant"). + is_final: Whether this is the final/complete transcript. + current_transcript: The accumulated transcript text so far (None for first delta). + """ + + def __init__( + self, + delta: ContentBlockDelta, + text: str, + role: Role, + is_final: bool, + current_transcript: str | None = None, + ): + """Initialize transcript stream event.""" + super().__init__( + { + "type": "bidi_transcript_stream", + "delta": delta, + "text": text, + "role": role, + "is_final": is_final, + "current_transcript": current_transcript, + } + ) + + @property + def delta(self) -> ContentBlockDelta: + """The incremental transcript change.""" + return cast(ContentBlockDelta, self["delta"]) + + @property + def text(self) -> str: + """The text content to send to the model.""" + return cast(str, self["text"]) + + @property + def role(self) -> Role: + """The role of the message sender.""" + return cast(Role, self["role"]) + + @property + def is_final(self) -> bool: + """Whether this is the final/complete transcript.""" + return cast(bool, self["is_final"]) + + @property + def current_transcript(self) -> str | None: + """The accumulated transcript text so far.""" + return cast(str | None, self.get("current_transcript")) + + +class BidiInterruptionEvent(TypedEvent): + """Model generation was interrupted. + + Parameters: + reason: Why the interruption occurred. + """ + + def __init__(self, reason: Literal["user_speech", "error"]): + """Initialize interruption event.""" + super().__init__( + { + "type": "bidi_interruption", + "reason": reason, + } + ) + + @property + def reason(self) -> str: + """Why the interruption occurred.""" + return cast(str, self["reason"]) + + +class BidiResponseCompleteEvent(TypedEvent): + """Model finished generating response. + + Parameters: + response_id: ID of the response that completed (matches response.start). + stop_reason: Why the response ended. + """ + + def __init__( + self, + response_id: str, + stop_reason: StopReason, + ): + """Initialize response complete event.""" + super().__init__( + { + "type": "bidi_response_complete", + "response_id": response_id, + "stop_reason": stop_reason, + } + ) + + @property + def response_id(self) -> str: + """Unique identifier for this response.""" + return cast(str, self["response_id"]) + + @property + def stop_reason(self) -> StopReason: + """Why the response ended.""" + return cast(StopReason, self["stop_reason"]) + + +class ModalityUsage(dict): + """Token usage for a specific modality. + + Attributes: + modality: Type of content. + input_tokens: Tokens used for this modality's input. + output_tokens: Tokens used for this modality's output. + """ + + modality: Literal["text", "audio", "image", "cached"] + input_tokens: int + output_tokens: int + + +class BidiUsageEvent(TypedEvent): + """Token usage event with modality breakdown for bidirectional streaming. + + Tracks token consumption across different modalities (audio, text, images) + during bidirectional streaming sessions. + + Parameters: + input_tokens: Total tokens used for all input modalities. + output_tokens: Total tokens used for all output modalities. + total_tokens: Sum of input and output tokens. + modality_details: Optional list of token usage per modality. + cache_read_input_tokens: Optional tokens read from cache. + cache_write_input_tokens: Optional tokens written to cache. + """ + + def __init__( + self, + input_tokens: int, + output_tokens: int, + total_tokens: int, + modality_details: list[ModalityUsage] | None = None, + cache_read_input_tokens: int | None = None, + cache_write_input_tokens: int | None = None, + ): + """Initialize usage event.""" + data: dict[str, Any] = { + "type": "bidi_usage", + "inputTokens": input_tokens, + "outputTokens": output_tokens, + "totalTokens": total_tokens, + } + if modality_details is not None: + data["modality_details"] = modality_details + if cache_read_input_tokens is not None: + data["cacheReadInputTokens"] = cache_read_input_tokens + if cache_write_input_tokens is not None: + data["cacheWriteInputTokens"] = cache_write_input_tokens + super().__init__(data) + + @property + def input_tokens(self) -> int: + """Total tokens used for all input modalities.""" + return cast(int, self["inputTokens"]) + + @property + def output_tokens(self) -> int: + """Total tokens used for all output modalities.""" + return cast(int, self["outputTokens"]) + + @property + def total_tokens(self) -> int: + """Sum of input and output tokens.""" + return cast(int, self["totalTokens"]) + + @property + def modality_details(self) -> list[ModalityUsage]: + """Optional list of token usage per modality.""" + return cast(list[ModalityUsage], self.get("modality_details", [])) + + @property + def cache_read_input_tokens(self) -> int | None: + """Optional tokens read from cache.""" + return cast(int | None, self.get("cacheReadInputTokens")) + + @property + def cache_write_input_tokens(self) -> int | None: + """Optional tokens written to cache.""" + return cast(int | None, self.get("cacheWriteInputTokens")) + + +class BidiConnectionCloseEvent(TypedEvent): + """Streaming connection closed. + + Parameters: + connection_id: Unique identifier for this streaming connection (matches BidiConnectionStartEvent). + reason: Why the connection was closed. + """ + + def __init__( + self, + connection_id: str, + reason: Literal["client_disconnect", "timeout", "error", "complete", "user_request"], + ): + """Initialize connection close event.""" + super().__init__( + { + "type": "bidi_connection_close", + "connection_id": connection_id, + "reason": reason, + } + ) + + @property + def connection_id(self) -> str: + """Unique identifier for this streaming connection.""" + return cast(str, self["connection_id"]) + + @property + def reason(self) -> str: + """Why the interruption occurred.""" + return cast(str, self["reason"]) + + +class BidiErrorEvent(TypedEvent): + """Error occurred during the session. + + Stores the full Exception object as an instance attribute for debugging while + keeping the event dict JSON-serializable. The exception can be accessed via + the `error` property for re-raising or type-based error handling. + + Parameters: + error: The exception that occurred. + details: Optional additional error information. + """ + + def __init__( + self, + error: Exception, + details: dict[str, Any] | None = None, + ): + """Initialize error event.""" + # Store serializable data in dict (for JSON serialization) + super().__init__( + { + "type": "bidi_error", + "message": str(error), + "code": type(error).__name__, + "details": details, + } + ) + # Store exception as instance attribute (not serialized) + self._error = error + + @property + def error(self) -> Exception: + """The original exception that occurred. + + Can be used for re-raising or type-based error handling. + """ + return self._error + + @property + def code(self) -> str: + """Error code derived from exception class name.""" + return cast(str, self["code"]) + + @property + def message(self) -> str: + """Human-readable error message from the exception.""" + return cast(str, self["message"]) + + @property + def details(self) -> dict[str, Any] | None: + """Additional error context beyond the exception itself.""" + return cast(dict[str, Any] | None, self.get("details")) + + +# ============================================================================ +# Type Unions +# ============================================================================ + +# Note: ToolResultEvent is imported from strands.types._events and used alongside +# BidiInputEvent in send() methods for sending tool results back to the model. + +BidiInputEvent = BidiTextInputEvent | BidiAudioInputEvent | BidiImageInputEvent +"""Union of different bidi input event types.""" + +BidiOutputEvent = ( + BidiConnectionStartEvent + | BidiConnectionRestartEvent + | BidiResponseStartEvent + | BidiAudioStreamEvent + | BidiTranscriptStreamEvent + | BidiInterruptionEvent + | BidiResponseCompleteEvent + | BidiUsageEvent + | BidiConnectionCloseEvent + | BidiErrorEvent + | ToolUseStreamEvent +) +"""Union of different bidi output event types.""" diff --git a/src/strands/experimental/bidi/types/io.py b/src/strands/experimental/bidi/types/io.py new file mode 100644 index 000000000..bdb7d9c9d --- /dev/null +++ b/src/strands/experimental/bidi/types/io.py @@ -0,0 +1,63 @@ +"""Protocol for bidirectional streaming IO channels. + +Defines callable protocols for input and output channels that can be used +with BidiAgent. This approach provides better typing and flexibility +by separating input and output concerns into independent callables. +""" + +from typing import TYPE_CHECKING, Awaitable, Protocol, runtime_checkable + +from ..types.events import BidiInputEvent, BidiOutputEvent + +if TYPE_CHECKING: + from ..agent.agent import BidiAgent + + +@runtime_checkable +class BidiInput(Protocol): + """Protocol for bidirectional input callables. + + Input callables read data from a source (microphone, camera, websocket, etc.) + and return events to be sent to the agent. + """ + + async def start(self, agent: "BidiAgent") -> None: + """Start input.""" + return + + async def stop(self) -> None: + """Stop input.""" + return + + def __call__(self) -> Awaitable[BidiInputEvent]: + """Read input data from the source. + + Returns: + Awaitable that resolves to an input event (audio, text, image, etc.) + """ + ... + + +@runtime_checkable +class BidiOutput(Protocol): + """Protocol for bidirectional output callables. + + Output callables receive events from the agent and handle them appropriately + (play audio, display text, send over websocket, etc.). + """ + + async def start(self, agent: "BidiAgent") -> None: + """Start output.""" + return + + async def stop(self) -> None: + """Stop output.""" + return + + def __call__(self, event: BidiOutputEvent) -> Awaitable[None]: + """Process output events from the agent. + + Args: + event: Output event from the agent (audio, text, tool calls, etc.) + """ + ... diff --git a/src/strands/experimental/bidi/types/model.py b/src/strands/experimental/bidi/types/model.py new file mode 100644 index 000000000..de41de1a9 --- /dev/null +++ b/src/strands/experimental/bidi/types/model.py @@ -0,0 +1,36 @@ +"""Model-related type definitions for bidirectional streaming. + +Defines types and configurations that are central to model providers, +including audio configuration that models use to specify their audio +processing requirements. +""" + +from typing import TypedDict + +from .events import AudioChannel, AudioFormat, AudioSampleRate + + +class AudioConfig(TypedDict, total=False): + """Audio configuration for bidirectional streaming models. + + Defines standard audio parameters that model providers use to specify + their audio processing requirements. All fields are optional to support + models that may not use audio or only need specific parameters. + + Model providers build this configuration by merging user-provided values + with their own defaults. The resulting configuration is then used by + audio I/O implementations to configure hardware appropriately. + + Attributes: + input_rate: Input sample rate in Hz (e.g., 16000, 24000, 48000) + output_rate: Output sample rate in Hz (e.g., 16000, 24000, 48000) + channels: Number of audio channels (1=mono, 2=stereo) + format: Audio encoding format + voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") + """ + + input_rate: AudioSampleRate + output_rate: AudioSampleRate + channels: AudioChannel + format: AudioFormat + voice: str diff --git a/src/strands/experimental/hooks/__init__.py b/src/strands/experimental/hooks/__init__.py index 098d4cf0d..c76b57ea4 100644 --- a/src/strands/experimental/hooks/__init__.py +++ b/src/strands/experimental/hooks/__init__.py @@ -5,6 +5,13 @@ AfterToolInvocationEvent, BeforeModelInvocationEvent, BeforeToolInvocationEvent, + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, ) __all__ = [ @@ -12,4 +19,12 @@ "AfterToolInvocationEvent", "BeforeModelInvocationEvent", "AfterModelInvocationEvent", + # BidiAgent hooks + "BidiAgentInitializedEvent", + "BidiBeforeInvocationEvent", + "BidiAfterInvocationEvent", + "BidiMessageAddedEvent", + "BidiBeforeToolCallEvent", + "BidiAfterToolCallEvent", + "BidiInterruptionEvent", ] diff --git a/src/strands/experimental/hooks/events.py b/src/strands/experimental/hooks/events.py index d711dd7ed..8a8d80629 100644 --- a/src/strands/experimental/hooks/events.py +++ b/src/strands/experimental/hooks/events.py @@ -1,16 +1,24 @@ -"""Experimental hook events emitted as part of invoking Agents. +"""Experimental hook events emitted as part of invoking Agents and BidiAgents. -This module defines the events that are emitted as Agents run through the lifecycle of a request. +This module defines the events that are emitted as Agents and BidiAgents run through the lifecycle of a request. """ import warnings -from typing import TypeAlias +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, TypeAlias from ...hooks.events import AfterModelCallEvent, AfterToolCallEvent, BeforeModelCallEvent, BeforeToolCallEvent +from ...hooks.registry import BaseHookEvent +from ...types.content import Message +from ...types.tools import AgentTool, ToolResult, ToolUse + +if TYPE_CHECKING: + from ..bidi.agent.agent import BidiAgent + from ..bidi.models import BidiModelTimeoutError warnings.warn( - "These events have been moved to production with updated names. Use BeforeModelCallEvent, " - "AfterModelCallEvent, BeforeToolCallEvent, and AfterToolCallEvent from strands.hooks instead.", + "BeforeModelCallEvent, AfterModelCallEvent, BeforeToolCallEvent, and AfterToolCallEvent are no longer experimental." + "Import from strands.hooks instead.", DeprecationWarning, stacklevel=2, ) @@ -19,3 +27,191 @@ AfterToolInvocationEvent: TypeAlias = AfterToolCallEvent BeforeModelInvocationEvent: TypeAlias = BeforeModelCallEvent AfterModelInvocationEvent: TypeAlias = AfterModelCallEvent + + +# BidiAgent Hook Events + + +@dataclass +class BidiHookEvent(BaseHookEvent): + """Base class for BidiAgent hook events. + + Attributes: + agent: The BidiAgent instance that triggered this event. + """ + + agent: "BidiAgent" + + +@dataclass +class BidiAgentInitializedEvent(BidiHookEvent): + """Event triggered when a BidiAgent has finished initialization. + + This event is fired after the BidiAgent has been fully constructed and all + built-in components have been initialized. Hook providers can use this + event to perform setup tasks that require a fully initialized agent. + """ + + pass + + +@dataclass +class BidiBeforeInvocationEvent(BidiHookEvent): + """Event triggered when BidiAgent starts a streaming session. + + This event is fired before the BidiAgent begins a streaming session, + before any model connection or audio processing occurs. Hook providers can + use this event to perform session-level setup, logging, or validation. + + This event is triggered at the beginning of agent.start(). + """ + + pass + + +@dataclass +class BidiAfterInvocationEvent(BidiHookEvent): + """Event triggered when BidiAgent ends a streaming session. + + This event is fired after the BidiAgent has completed a streaming session, + regardless of whether it completed successfully or encountered an error. + Hook providers can use this event for cleanup, logging, or state persistence. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + This event is triggered at the end of agent.stop(). + """ + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BidiMessageAddedEvent(BidiHookEvent): + """Event triggered when BidiAgent adds a message to the conversation. + + This event is fired whenever the BidiAgent adds a new message to its internal + message history, including user messages (from transcripts), assistant responses, + and tool results. Hook providers can use this event for logging, monitoring, or + implementing custom message processing logic. + + Note: This event is only triggered for messages added by the framework + itself, not for messages manually added by tools or external code. + + Attributes: + message: The message that was added to the conversation history. + """ + + message: Message + + +@dataclass +class BidiBeforeToolCallEvent(BidiHookEvent): + """Event triggered before BidiAgent executes a tool. + + This event is fired just before the BidiAgent executes a tool during a streaming + session, allowing hook providers to inspect, modify, or replace the tool that + will be executed. The selected_tool can be modified by hook callbacks to change + which tool gets executed. + + Attributes: + selected_tool: The tool that will be invoked. Can be modified by hooks + to change which tool gets executed. This may be None if tool lookup failed. + tool_use: The tool parameters that will be passed to selected_tool. + invocation_state: Keyword arguments that will be passed to the tool. + cancel_tool: A user defined message that when set, will cancel the tool call. + The message will be placed into a tool result with an error status. If set to `True`, Strands will cancel + the tool call and use a default cancel message. + """ + + selected_tool: AgentTool | None + tool_use: ToolUse + invocation_state: dict[str, Any] + cancel_tool: bool | str = False + + def _can_write(self, name: str) -> bool: + return name in ["cancel_tool", "selected_tool", "tool_use"] + + +@dataclass +class BidiAfterToolCallEvent(BidiHookEvent): + """Event triggered after BidiAgent executes a tool. + + This event is fired after the BidiAgent has finished executing a tool during + a streaming session, regardless of whether the execution was successful or + resulted in an error. Hook providers can use this event for cleanup, logging, + or post-processing. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + Attributes: + selected_tool: The tool that was invoked. It may be None if tool lookup failed. + tool_use: The tool parameters that were passed to the tool invoked. + invocation_state: Keyword arguments that were passed to the tool. + result: The result of the tool invocation. Either a ToolResult on success + or an Exception if the tool execution failed. + exception: Exception if the tool execution failed, None if successful. + cancel_message: The cancellation message if the user cancelled the tool call. + """ + + selected_tool: AgentTool | None + tool_use: ToolUse + invocation_state: dict[str, Any] + result: ToolResult + exception: Exception | None = None + cancel_message: str | None = None + + def _can_write(self, name: str) -> bool: + return name == "result" + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BidiInterruptionEvent(BidiHookEvent): + """Event triggered when model generation is interrupted. + + This event is fired when the user interrupts the assistant (e.g., by speaking + during the assistant's response) or when an error causes interruption. This is + specific to bidirectional streaming and doesn't exist in standard agents. + + Hook providers can use this event to log interruptions, implement custom + interruption handling, or trigger cleanup logic. + + Attributes: + reason: The reason for the interruption ("user_speech" or "error"). + interrupted_response_id: Optional ID of the response that was interrupted. + """ + + reason: Literal["user_speech", "error"] + interrupted_response_id: str | None = None + + +@dataclass +class BidiBeforeConnectionRestartEvent(BidiHookEvent): + """Event emitted before agent attempts to restart model connection after timeout. + + Attributes: + timeout_error: Timeout error reported by the model. + """ + + timeout_error: "BidiModelTimeoutError" + + +@dataclass +class BidiAfterConnectionRestartEvent(BidiHookEvent): + """Event emitted after agent attempts to restart model connection after timeout. + + Attribtues: + exception: Populated if exception was raised during connection restart. + None value means the restart was successful. + """ + + exception: Exception | None = None diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index a042452d3..ad4733a35 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from ..agent.agent import Agent + from ..experimental.bidi.agent.agent import BidiAgent from ..multiagent.base import MultiAgentBase logger = logging.getLogger(__name__) @@ -226,3 +227,87 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non else: logger.debug("session_id=<%s> | restoring multi-agent state", self.session_id) source.deserialize_state(state) + + def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: + """Initialize a bidirectional agent with a session. + + Args: + agent: BidiAgent to initialize from the session + **kwargs: Additional keyword arguments for future extensibility. + """ + if agent.agent_id in self._latest_agent_message: + raise SessionException("The `agent_id` of an agent must be unique in a session.") + self._latest_agent_message[agent.agent_id] = None + + session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id) + + if session_agent is None: + logger.debug( + "agent_id=<%s> | session_id=<%s> | creating bidi agent", + agent.agent_id, + self.session_id, + ) + + session_agent = SessionAgent.from_bidi_agent(agent) + self.session_repository.create_agent(self.session_id, session_agent) + # Initialize messages with sequential indices + session_message = None + for i, message in enumerate(agent.messages): + session_message = SessionMessage.from_message(message, i) + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + self._latest_agent_message[agent.agent_id] = session_message + else: + logger.debug( + "agent_id=<%s> | session_id=<%s> | restoring bidi agent", + agent.agent_id, + self.session_id, + ) + agent.state = AgentState(session_agent.state) + + session_agent.initialize_bidi_internal_state(agent) + + # BidiAgent has no conversation_manager, so no prepend_messages or removed_message_count + session_messages = self.session_repository.list_messages( + session_id=self.session_id, + agent_id=agent.agent_id, + offset=0, + ) + if len(session_messages) > 0: + self._latest_agent_message[agent.agent_id] = session_messages[-1] + + # Restore the agents messages array + agent.messages = [session_message.to_message() for session_message in session_messages] + + # Fix broken session histories: https://github.com/strands-agents/sdk-python/issues/859 + agent.messages = self._fix_broken_tool_use(agent.messages) + + def append_bidi_message(self, message: Message, agent: "BidiAgent", **kwargs: Any) -> None: + """Append a message to the bidirectional agent's session. + + Args: + message: Message to add to the agent in the session + agent: BidiAgent to append the message to + **kwargs: Additional keyword arguments for future extensibility. + """ + # Calculate the next index (0 if this is the first message, otherwise increment the previous index) + latest_agent_message = self._latest_agent_message[agent.agent_id] + if latest_agent_message: + next_index = latest_agent_message.message_id + 1 + else: + next_index = 0 + + session_message = SessionMessage.from_message(message, next_index) + self._latest_agent_message[agent.agent_id] = session_message + self.session_repository.create_message(self.session_id, agent.agent_id, session_message) + + def sync_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: + """Serialize and update the bidirectional agent into the session repository. + + Args: + agent: BidiAgent to sync to the session. + **kwargs: Additional keyword arguments for future extensibility. + """ + self.session_repository.update_agent( + self.session_id, + SessionAgent.from_bidi_agent(agent), + ) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index fb9132828..ba4356089 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -4,6 +4,11 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from ..experimental.hooks.events import ( + BidiAfterInvocationEvent, + BidiAgentInitializedEvent, + BidiMessageAddedEvent, +) from ..experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, @@ -15,6 +20,7 @@ if TYPE_CHECKING: from ..agent.agent import Agent + from ..experimental.bidi.agent.agent import BidiAgent from ..multiagent.base import MultiAgentBase logger = logging.getLogger(__name__) @@ -47,6 +53,12 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) + # Register BidiAgent hooks + registry.add_callback(BidiAgentInitializedEvent, lambda event: self.initialize_bidi_agent(event.agent)) + registry.add_callback(BidiMessageAddedEvent, lambda event: self.append_bidi_message(event.message, event.agent)) + registry.add_callback(BidiMessageAddedEvent, lambda event: self.sync_bidi_agent(event.agent)) + registry.add_callback(BidiAfterInvocationEvent, lambda event: self.sync_bidi_agent(event.agent)) + @abstractmethod def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: """Redact the message most recently appended to the agent in the session. @@ -114,3 +126,43 @@ def initialize_multi_agent(self, source: "MultiAgentBase", **kwargs: Any) -> Non "(initialize_multi_agent). Provide an implementation or use a " "SessionManager with session_type=SessionType.MULTI_AGENT." ) + + def initialize_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: + """Initialize a bidirectional agent with a session. + + Args: + agent: BidiAgent to initialize + **kwargs: Additional keyword arguments for future extensibility. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support bidirectional agent persistence " + "(initialize_bidi_agent). Provide an implementation or use a " + "SessionManager with bidirectional agent support." + ) + + def append_bidi_message(self, message: Message, agent: "BidiAgent", **kwargs: Any) -> None: + """Append a message to the bidirectional agent's session. + + Args: + message: Message to add to the agent in the session + agent: BidiAgent to append the message to + **kwargs: Additional keyword arguments for future extensibility. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support bidirectional agent persistence " + "(append_bidi_message). Provide an implementation or use a " + "SessionManager with bidirectional agent support." + ) + + def sync_bidi_agent(self, agent: "BidiAgent", **kwargs: Any) -> None: + """Serialize and sync the bidirectional agent with the session storage. + + Args: + agent: BidiAgent who should be synchronized with the session storage + **kwargs: Additional keyword arguments for future extensibility. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support bidirectional agent persistence " + "(sync_bidi_agent). Provide an implementation or use a " + "SessionManager with bidirectional agent support." + ) diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index fc7a3efb9..3ab576947 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -19,12 +19,13 @@ if TYPE_CHECKING: from ..agent import Agent + from ..experimental.bidi.agent import BidiAgent class _ToolCaller: """Call tool as a function.""" - def __init__(self, agent: "Agent") -> None: + def __init__(self, agent: "Agent | BidiAgent") -> None: """Initialize instance. Args: @@ -104,7 +105,11 @@ async def acall() -> ToolResult: return tool_result tool_result = run_async(acall) - self._agent.conversation_manager.apply_management(self._agent) + + # Apply conversation management if agent supports it (traditional agents) + if hasattr(self._agent, "conversation_manager"): + self._agent.conversation_manager.apply_management(self._agent) + return tool_result return caller diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index fe4fa135c..a4f9e7e1f 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -11,16 +11,19 @@ from opentelemetry import trace as trace_api +from ...experimental.hooks.events import BidiAfterToolCallEvent, BidiBeforeToolCallEvent from ...hooks import AfterToolCallEvent, BeforeToolCallEvent from ...telemetry.metrics import Trace from ...telemetry.tracer import get_tracer, serialize from ...types._events import ToolCancelEvent, ToolInterruptEvent, ToolResultEvent, ToolStreamEvent, TypedEvent from ...types.content import Message +from ...types.interrupt import Interrupt from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse from ..structured_output._structured_output_context import StructuredOutputContext if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ...experimental.bidi import BidiAgent logger = logging.getLogger(__name__) @@ -28,9 +31,61 @@ class ToolExecutor(abc.ABC): """Abstract base class for tool executors.""" + @staticmethod + def _is_agent(agent: "Agent | BidiAgent") -> bool: + """Check if the agent is an Agent instance, otherwise we assume BidiAgent. + + Note, we use a runtime import to avoid a circular dependency error. + """ + from ...agent import Agent + + return isinstance(agent, Agent) + + @staticmethod + async def _invoke_before_tool_call_hook( + agent: "Agent | BidiAgent", + tool_func: Any, + tool_use: ToolUse, + invocation_state: dict[str, Any], + ) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]: + """Invoke the appropriate before tool call hook based on agent type.""" + event_cls = BeforeToolCallEvent if ToolExecutor._is_agent(agent) else BidiBeforeToolCallEvent + return await agent.hooks.invoke_callbacks_async( + event_cls( + agent=agent, + selected_tool=tool_func, + tool_use=tool_use, + invocation_state=invocation_state, + ) + ) + + @staticmethod + async def _invoke_after_tool_call_hook( + agent: "Agent | BidiAgent", + selected_tool: Any, + tool_use: ToolUse, + invocation_state: dict[str, Any], + result: ToolResult, + exception: Exception | None = None, + cancel_message: str | None = None, + ) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]: + """Invoke the appropriate after tool call hook based on agent type.""" + event_cls = AfterToolCallEvent if ToolExecutor._is_agent(agent) else BidiAfterToolCallEvent + return await agent.hooks.invoke_callbacks_async( + event_cls( + agent=agent, + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + exception=exception, + cancel_message=cancel_message, + ) + ) + @staticmethod async def _stream( - agent: "Agent", + agent: "Agent | BidiAgent", tool_use: ToolUse, tool_results: list[ToolResult], invocation_state: dict[str, Any], @@ -48,7 +103,7 @@ async def _stream( - Interrupt handling for human-in-the-loop workflows Args: - agent: The agent for which the tool is being executed. + agent: The agent (Agent or BidiAgent) for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. invocation_state: Context for the tool invocation. @@ -86,13 +141,8 @@ async def _stream( } ) - before_event, interrupts = await agent.hooks.invoke_callbacks_async( - BeforeToolCallEvent( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, - ) + before_event, interrupts = await ToolExecutor._invoke_before_tool_call_hook( + agent, tool_func, tool_use, invocation_state ) if interrupts: @@ -110,15 +160,9 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - after_event, _ = await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, - tool_use=tool_use, - invocation_state=invocation_state, - selected_tool=None, - result=cancel_result, - cancel_message=cancel_message, - ) + + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, None, tool_use, invocation_state, cancel_result, cancel_message=cancel_message ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) @@ -148,14 +192,9 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - after_event, _ = await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - ) + + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) @@ -185,14 +224,8 @@ async def _stream( result = cast(ToolResult, event) - after_event, _ = await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - ) + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, result ) yield ToolResultEvent(after_event.result) @@ -205,22 +238,16 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event, _ = await agent.hooks.invoke_callbacks_async( - AfterToolCallEvent( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=error_result, - exception=e, - ) + + after_event, _ = await ToolExecutor._invoke_after_tool_call_hook( + agent, selected_tool, tool_use, invocation_state, error_result, exception=e ) yield ToolResultEvent(after_event.result) tool_results.append(after_event.result) @staticmethod async def _stream_with_trace( - agent: "Agent", + agent: "Agent | BidiAgent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -232,7 +259,7 @@ async def _stream_with_trace( """Execute tool with tracing and metrics collection. Args: - agent: The agent for which the tool is being executed. + agent: The agent (Agent or BidiAgent) for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -271,7 +298,8 @@ async def _stream_with_trace( tool_success = result.get("status") == "success" tool_duration = time.time() - tool_start_time message = Message(role="user", content=[{"toolResult": result}]) - agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) + if ToolExecutor._is_agent(agent): + agent.event_loop_metrics.add_tool_usage(tool_use, tool_duration, tool_trace, tool_success, message) cycle_trace.add_child(tool_trace) tracer.end_tool_call_span(tool_call_span, result) @@ -280,7 +308,7 @@ async def _stream_with_trace( # pragma: no cover def _execute( self, - agent: "Agent", + agent: "Agent | BidiAgent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -291,7 +319,7 @@ def _execute( """Execute the given tools according to this executor's strategy. Args: - agent: The agent for which tools are being executed. + agent: The agent (Agent or BidiAgent) for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 216eee379..da5c1ff10 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -21,7 +22,7 @@ class ConcurrentToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent", + agent: "Agent | BidiAgent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -32,7 +33,7 @@ async def _execute( """Execute tools concurrently. Args: - agent: The agent for which tools are being executed. + agent: The agent (Agent or BidiAgent) for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -78,7 +79,7 @@ async def _execute( async def _task( self, - agent: "Agent", + agent: "Agent | BidiAgent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -93,7 +94,7 @@ async def _task( """Execute a single tool and put results in the task queue. Args: - agent: The agent executing the tool. + agent: The agent (Agent or BidiAgent) executing the tool. tool_use: Tool use metadata and inputs. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index f78e60872..6163fc195 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent + from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -20,7 +21,7 @@ class SequentialToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent", + agent: "Agent | BidiAgent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -33,7 +34,7 @@ async def _execute( Breaks early if an interrupt is raised by the user. Args: - agent: The agent for which tools are being executed. + agent: The agent (Agent or BidiAgent) for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 558d3e298..efe0894ea 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -145,7 +145,7 @@ class ToolUseStreamEvent(ModelStreamEvent): def __init__(self, delta: ContentBlockDelta, current_tool_use: dict[str, Any]) -> None: """Initialize with delta and current tool use state.""" - super().__init__({"delta": delta, "current_tool_use": current_tool_use}) + super().__init__({"type": "tool_use_stream", "delta": delta, "current_tool_use": current_tool_use}) class TextStreamEvent(ModelStreamEvent): @@ -281,12 +281,12 @@ def __init__(self, tool_result: ToolResult) -> None: Args: tool_result: Final result from the tool execution """ - super().__init__({"tool_result": tool_result}) + super().__init__({"type": "tool_result", "tool_result": tool_result}) @property def tool_use_id(self) -> str: """The toolUseId associated with this result.""" - return cast(str, cast(ToolResult, self.get("tool_result")).get("toolUseId")) + return cast(ToolResult, self.get("tool_result")).get("toolUseId") @property def tool_result(self) -> ToolResult: @@ -309,12 +309,12 @@ def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None: tool_use: The tool invocation producing the stream tool_stream_data: The yielded event from the tool execution """ - super().__init__({"tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}}) + super().__init__({"type": "tool_stream", "tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}}) @property def tool_use_id(self) -> str: """The toolUseId associated with this stream.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId")) + return cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId") class ToolCancelEvent(TypedEvent): @@ -332,7 +332,7 @@ def __init__(self, tool_use: ToolUse, message: str) -> None: @property def tool_use_id(self) -> str: """The id of the tool cancelled.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId")) + return cast(ToolUse, cast(dict, self.get("tool_cancel_event")).get("tool_use")).get("toolUseId") @property def message(self) -> str: @@ -350,7 +350,7 @@ def __init__(self, tool_use: ToolUse, interrupts: list[Interrupt]) -> None: @property def tool_use_id(self) -> str: """The id of the tool interrupted.""" - return cast(str, cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId")) + return cast(ToolUse, cast(dict, self.get("tool_interrupt_event")).get("tool_use")).get("toolUseId") @property def interrupts(self) -> list[Interrupt]: diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 8b78ab448..5da3dcde8 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from ..agent.agent import Agent + from ..experimental.bidi.agent.agent import BidiAgent class SessionType(str, Enum): @@ -136,6 +137,31 @@ def from_agent(cls, agent: "Agent") -> "SessionAgent": }, ) + @classmethod + def from_bidi_agent(cls, agent: "BidiAgent") -> "SessionAgent": + """Convert a BidiAgent to a SessionAgent. + + Args: + agent: BidiAgent to convert + + Returns: + SessionAgent with empty conversation_manager_state (BidiAgent doesn't use conversation manager) + """ + if agent.agent_id is None: + raise ValueError("agent_id needs to be defined.") + + # BidiAgent doesn't have _interrupt_state yet, so we use empty dict for internal state + internal_state = {} + if hasattr(agent, "_interrupt_state"): + internal_state["interrupt_state"] = agent._interrupt_state.to_dict() + + return cls( + agent_id=agent.agent_id, + conversation_manager_state={}, # BidiAgent has no conversation_manager + state=agent.state.get(), + _internal_state=internal_state, + ) + @classmethod def from_dict(cls, env: dict[str, Any]) -> "SessionAgent": """Initialize a SessionAgent from a dictionary, ignoring keys that are not class parameters.""" @@ -150,6 +176,17 @@ def initialize_internal_state(self, agent: "Agent") -> None: if "interrupt_state" in self._internal_state: agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) + def initialize_bidi_internal_state(self, agent: "BidiAgent") -> None: + """Initialize internal state of BidiAgent. + + Args: + agent: BidiAgent to initialize internal state for + """ + # BidiAgent doesn't have _interrupt_state yet, so we skip interrupt state restoration + # When BidiAgent adds _interrupt_state support, this will automatically work + if "interrupt_state" in self._internal_state and hasattr(agent, "_interrupt_state"): + agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) + @dataclass class Session: diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 8343647b2..8f4dba6b1 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -8,16 +8,13 @@ import uuid from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union from typing_extensions import NotRequired, TypedDict from .interrupt import _Interruptible from .media import DocumentContent, ImageContent -if TYPE_CHECKING: - from .. import Agent - JSONSchema = dict """Type alias for JSON Schema dictionaries.""" @@ -136,7 +133,7 @@ class ToolContext(_Interruptible): Attributes: tool_use: The complete ToolUse object containing tool invocation details. - agent: The Agent instance executing this tool, providing access to conversation history, + agent: The Agent or BidiAgent instance executing this tool, providing access to conversation history, model configuration, and other agent state. invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(), agent.invoke_async(), etc.). @@ -147,7 +144,7 @@ class ToolContext(_Interruptible): """ tool_use: ToolUse - agent: "Agent" + agent: Any # Agent or BidiAgent - using Any for backwards compatibility invocation_state: dict[str, Any] def _interrupt_id(self, name: str) -> str: diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 4fef595f8..7b189a5c6 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -138,6 +138,7 @@ async def test_stream_e2e_success(alist): "arg1": 1013, "current_tool_use": {"input": {}, "name": "normal_tool", "toolUseId": "123"}, "delta": {"toolUse": {"input": "{}"}}, + "type": "tool_use_stream", }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "tool_use"}}}, @@ -195,6 +196,7 @@ async def test_stream_e2e_success(alist): "model": ANY, "system_prompt": None, "tool_config": tool_config, + "type": "tool_use_stream", }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "tool_use"}}}, @@ -252,6 +254,7 @@ async def test_stream_e2e_success(alist): "model": ANY, "system_prompt": None, "tool_config": tool_config, + "type": "tool_use_stream", }, {"event": {"contentBlockStop": {}}}, {"event": {"messageStop": {"stopReason": "tool_use"}}}, @@ -268,13 +271,15 @@ async def test_stream_e2e_success(alist): "tool_stream_event": { "data": {"tool_streaming": True}, "tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, - } + }, + "type": "tool_stream", }, { "tool_stream_event": { "data": "Final result", "tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, - } + }, + "type": "tool_stream", }, { "message": { @@ -573,6 +578,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": ""}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -582,6 +588,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"na'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": '{"na'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -591,6 +598,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'me"'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": 'me"'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -600,6 +608,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ': "J'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": ': "J'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -609,6 +618,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": 'ohn"'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": 'ohn"'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -618,6 +628,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": ', "age": 3'}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": ', "age": 3'}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", @@ -627,6 +638,7 @@ class Person(BaseModel): }, {"event": {"contentBlockDelta": {"delta": {"toolUse": {"input": "1}"}}, "contentBlockIndex": 0}}}, { + "type": "tool_use_stream", "delta": {"toolUse": {"input": "1}"}}, "current_tool_use": { "toolUseId": "tooluse_efwXnrK_S6qTyxzcq1IUMQ", diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ea6b09b75..f133400a8 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -684,6 +684,7 @@ def test_agent__call__callback(mock_model, agent, callback_handler, agenerator): unittest.mock.call(event={"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "test"}}}}), unittest.mock.call(event={"contentBlockDelta": {"delta": {"toolUse": {"input": '{"value"}'}}}}), unittest.mock.call( + type="tool_use_stream", agent=agent, current_tool_use={"toolUseId": "123", "name": "test", "input": {}}, delta={"toolUse": {"input": '{"value"}'}}, diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 0a323b30d..52980729c 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,6 +6,7 @@ import strands import strands.telemetry +from strands import Agent from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -133,6 +134,7 @@ def tool_executor(): @pytest.fixture def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry, tool_executor): mock = unittest.mock.Mock(name="agent") + mock.__class__ = Agent mock.config.cache_points = [] mock.model = model mock.system_prompt = system_prompt diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 3f5a6c998..02be400b1 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -133,11 +133,12 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) @pytest.mark.parametrize( - ("event", "state", "exp_updated_state", "callback_args"), + ("event", "event_type", "state", "exp_updated_state", "callback_args"), [ # Tool Use - Existing input ( {"delta": {"toolUse": {"input": '"value"}'}}}, + {"type": "tool_use_stream"}, {"current_tool_use": {"input": '{"key": '}}, {"current_tool_use": {"input": '{"key": "value"}'}}, {"current_tool_use": {"input": '{"key": "value"}'}}, @@ -145,6 +146,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) # Tool Use - New input ( {"delta": {"toolUse": {"input": '{"key": '}}}, + {"type": "tool_use_stream"}, {"current_tool_use": {}}, {"current_tool_use": {"input": '{"key": '}}, {"current_tool_use": {"input": '{"key": '}}, @@ -152,6 +154,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) # Text ( {"delta": {"text": " world"}}, + {}, {"text": "hello"}, {"text": "hello world"}, {"data": " world"}, @@ -159,6 +162,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) # Reasoning - Text - Existing ( {"delta": {"reasoningContent": {"text": "king"}}}, + {}, {"reasoningText": "thin"}, {"reasoningText": "thinking"}, {"reasoningText": "king", "reasoning": True}, @@ -167,12 +171,14 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ( {"delta": {"reasoningContent": {"text": "thin"}}}, {}, + {}, {"reasoningText": "thin"}, {"reasoningText": "thin", "reasoning": True}, ), # Reasoning - Signature - Existing ( {"delta": {"reasoningContent": {"signature": "ue"}}}, + {}, {"signature": "val"}, {"signature": "value"}, {"reasoning_signature": "ue", "reasoning": True}, @@ -181,6 +187,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) ( {"delta": {"reasoningContent": {"signature": "val"}}}, {}, + {}, {"signature": "val"}, {"reasoning_signature": "val", "reasoning": True}, ), @@ -188,12 +195,14 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) pytest.param( {"delta": {"reasoningContent": {"redactedContent": b"encoded"}}}, {}, + {}, {"redactedContent": b"encoded"}, {"reasoningRedactedContent": b"encoded", "reasoning": True}, ), # Reasoning - redactedContent - Existing pytest.param( {"delta": {"reasoningContent": {"redactedContent": b"data"}}}, + {}, {"redactedContent": b"encoded_"}, {"redactedContent": b"encoded_data"}, {"reasoningRedactedContent": b"data", "reasoning": True}, @@ -204,6 +213,7 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, {}, + {}, ), # Empty ( @@ -211,11 +221,12 @@ def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use) {}, {}, {}, + {}, ), ], ) -def test_handle_content_block_delta(event: ContentBlockDeltaEvent, state, exp_updated_state, callback_args): - exp_callback_event = {**callback_args, "delta": event["delta"]} if callback_args else {} +def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, state, exp_updated_state, callback_args): + exp_callback_event = {**event_type, **callback_args, "delta": event["delta"]} if callback_args else {} tru_updated_state, tru_callback_event = strands.event_loop.streaming.handle_content_block_delta(event, state) @@ -526,6 +537,7 @@ def test_extract_usage_metrics_empty_metadata(): "input": '{"key": "value"}', }, }, + "type": "tool_use_stream", }, { "event": { diff --git a/tests/strands/experimental/__init__.py b/tests/strands/experimental/__init__.py index e69de29bb..ac8db1d74 100644 --- a/tests/strands/experimental/__init__.py +++ b/tests/strands/experimental/__init__.py @@ -0,0 +1 @@ +"""Experimental features tests.""" diff --git a/tests/strands/experimental/bidi/__init__.py b/tests/strands/experimental/bidi/__init__.py new file mode 100644 index 000000000..ea37091cc --- /dev/null +++ b/tests/strands/experimental/bidi/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming tests.""" diff --git a/tests/strands/experimental/bidi/_async/__init__.py b/tests/strands/experimental/bidi/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py new file mode 100644 index 000000000..f8df25e14 --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -0,0 +1,36 @@ +from unittest.mock import AsyncMock + +import pytest + +from strands.experimental.bidi._async import stop_all + + +@pytest.mark.asyncio +async def test_stop_exception(): + func1 = AsyncMock() + func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) + func3 = AsyncMock() + + with pytest.raises(ExceptionGroup) as exc_info: + await stop_all(func1, func2, func3) + + func1.assert_called_once() + func2.assert_called_once() + func3.assert_called_once() + + assert len(exc_info.value.exceptions) == 1 + with pytest.raises(ValueError, match=r"stop 2 failed"): + raise exc_info.value.exceptions[0] + + +@pytest.mark.asyncio +async def test_stop_success(): + func1 = AsyncMock() + func2 = AsyncMock() + func3 = AsyncMock() + + await stop_all(func1, func2, func3) + + func1.assert_called_once() + func2.assert_called_once() + func3.assert_called_once() diff --git a/tests/strands/experimental/bidi/_async/test_task_pool.py b/tests/strands/experimental/bidi/_async/test_task_pool.py new file mode 100644 index 000000000..35f817954 --- /dev/null +++ b/tests/strands/experimental/bidi/_async/test_task_pool.py @@ -0,0 +1,54 @@ +import asyncio + +import pytest + +from strands.experimental.bidi._async._task_pool import _TaskPool + + +@pytest.fixture +def task_pool() -> _TaskPool: + return _TaskPool() + + +def test_len(task_pool): + tru_len = len(task_pool) + exp_len = 0 + assert tru_len == exp_len + + +@pytest.mark.asyncio +async def test_create(task_pool: _TaskPool) -> None: + event = asyncio.Event() + + async def coro(): + await event.wait() + + task = task_pool.create(coro()) + + tru_len = len(task_pool) + exp_len = 1 + assert tru_len == exp_len + + event.set() + await task + + tru_len = len(task_pool) + exp_len = 0 + assert tru_len == exp_len + + +@pytest.mark.asyncio +async def test_cancel(task_pool: _TaskPool) -> None: + event = asyncio.Event() + + async def coro(): + await event.wait() + + task = task_pool.create(coro()) + await task_pool.cancel() + + tru_len = len(task_pool) + exp_len = 0 + assert tru_len == exp_len + + assert task.done() diff --git a/tests/strands/experimental/bidi/agent/__init__.py b/tests/strands/experimental/bidi/agent/__init__.py new file mode 100644 index 000000000..3359c6565 --- /dev/null +++ b/tests/strands/experimental/bidi/agent/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming agent tests.""" \ No newline at end of file diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py new file mode 100644 index 000000000..19d3525d7 --- /dev/null +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -0,0 +1,343 @@ +"""Unit tests for BidiAgent.""" + +import unittest.mock +import asyncio +import pytest +from uuid import uuid4 + +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel +from strands.experimental.bidi.types.events import ( + BidiTextInputEvent, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiTranscriptStreamEvent, + BidiConnectionStartEvent, + BidiConnectionCloseEvent, +) + +class MockBidiModel: + """Mock bidirectional model for testing.""" + + def __init__(self, config=None, model_id="mock-model"): + self.config = config or {"audio": {"input_rate": 16000, "output_rate": 24000, "channels": 1}} + self.model_id = model_id + self._connection_id = None + self._started = False + self._events_to_yield = [] + + async def start(self, system_prompt=None, tools=None, messages=None, **kwargs): + if self._started: + raise RuntimeError("model already started | call stop before starting again") + self._connection_id = str(uuid4()) + self._started = True + + async def stop(self): + if self._started: + self._started = False + self._connection_id = None + + async def send(self, content): + if not self._started: + raise RuntimeError("model not started | call start before sending/receiving") + # Mock implementation - in real tests, this would trigger events + + async def receive(self): + """Async generator yielding mock events.""" + if not self._started: + raise RuntimeError("model not started | call start before sending/receiving") + + # Yield connection start event + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + # Yield any configured events + for event in self._events_to_yield: + yield event + + # Yield connection end event + yield BidiConnectionCloseEvent(connection_id=self._connection_id, reason="complete") + + def set_events(self, events): + """Helper to set events this mock model will yield.""" + self._events_to_yield = events + +@pytest.fixture +def mock_model(): + """Create a mock BidiModel instance.""" + return MockBidiModel() + +@pytest.fixture +def mock_tool_registry(): + """Mock tool registry with some basic tools.""" + registry = unittest.mock.Mock() + registry.get_all_tool_specs.return_value = [ + { + "name": "calculator", + "description": "Perform calculations", + "inputSchema": {"json": {"type": "object", "properties": {}}} + } + ] + registry.get_all_tools_config.return_value = {"calculator": {}} + return registry + + +@pytest.fixture +def mock_tool_caller(): + """Mock tool caller for testing tool execution.""" + caller = unittest.mock.AsyncMock() + caller.call_tool = unittest.mock.AsyncMock() + return caller + + +@pytest.fixture +def agent(mock_model, mock_tool_registry, mock_tool_caller): + """Create a BidiAgent instance for testing.""" + with unittest.mock.patch("strands.experimental.bidi.agent.agent.ToolRegistry") as mock_registry_class: + mock_registry_class.return_value = mock_tool_registry + + with unittest.mock.patch("strands.experimental.bidi.agent.agent._ToolCaller") as mock_caller_class: + mock_caller_class.return_value = mock_tool_caller + + # Don't pass tools to avoid real tool loading + agent = BidiAgent(model=mock_model) + return agent + +def test_bidi_agent_init_with_various_configurations(): + """Test agent initialization with various configurations.""" + # Test default initialization + mock_model = MockBidiModel() + agent = BidiAgent(model=mock_model) + + assert agent.model == mock_model + assert agent.system_prompt is None + assert not agent._started + assert agent.model._connection_id is None + + # Test with configuration + system_prompt = "You are a helpful assistant." + agent_with_config = BidiAgent( + model=mock_model, + system_prompt=system_prompt, + agent_id="test_agent" + ) + + assert agent_with_config.system_prompt == system_prompt + assert agent_with_config.agent_id == "test_agent" + + # Test with string model ID + model_id = "amazon.nova-sonic-v1:0" + agent_with_string = BidiAgent(model=model_id) + + assert isinstance(agent_with_string.model, BidiNovaSonicModel) + assert agent_with_string.model.model_id == model_id + + # Test model config access + config = agent.model.config + assert config["audio"]["input_rate"] == 16000 + assert config["audio"]["output_rate"] == 24000 + assert config["audio"]["channels"] == 1 + +@pytest.mark.asyncio +async def test_bidi_agent_start_stop_lifecycle(agent): + """Test agent start/stop lifecycle and state management.""" + # Initial state + assert not agent._started + assert agent.model._connection_id is None + + # Start agent + await agent.start() + assert agent._started + assert agent.model._connection_id is not None + connection_id = agent.model._connection_id + + # Double start should error + with pytest.raises(RuntimeError, match="agent already started"): + await agent.start() + + # Stop agent + await agent.stop() + assert not agent._started + assert agent.model._connection_id is None + + # Multiple stops should be safe + await agent.stop() + await agent.stop() + + # Restart should work with new connection ID + await agent.start() + assert agent._started + assert agent.model._connection_id != connection_id + +@pytest.mark.asyncio +async def test_bidi_agent_send_with_input_types(agent): + """Test sending various input types through agent.send().""" + await agent.start() + + # Test text input with TypedEvent + text_input = BidiTextInputEvent(text="Hello", role="user") + await agent.send(text_input) + assert len(agent.messages) == 1 + assert agent.messages[0]["content"][0]["text"] == "Hello" + + # Test string input (shorthand) + await agent.send("World") + assert len(agent.messages) == 2 + assert agent.messages[1]["content"][0]["text"] == "World" + + # Test audio input (doesn't add to messages) + audio_input = BidiAudioInputEvent( + audio="dGVzdA==", # base64 "test" + format="pcm", + sample_rate=16000, + channels=1 + ) + await agent.send(audio_input) + assert len(agent.messages) == 2 # Still 2, audio doesn't add + + # Test concurrent sends + sends = [ + agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) + for i in range(3) + ] + await asyncio.gather(*sends) + assert len(agent.messages) == 5 # 2 + 3 new messages + +@pytest.mark.asyncio +async def test_bidi_agent_receive_events_from_model(agent): + """Test receiving events from model.""" + # Configure mock model to yield events + events = [ + BidiAudioStreamEvent( + audio="dGVzdA==", + format="pcm", + sample_rate=24000, + channels=1 + ), + BidiTranscriptStreamEvent( + text="Hello world", + role="assistant", + is_final=True, + delta={"text": "Hello world"}, + current_transcript="Hello world" + ) + ] + agent.model.set_events(events) + + await agent.start() + + received_events = [] + async for event in agent.receive(): + received_events.append(event) + if len(received_events) >= 4: # Stop after getting expected events + break + + # Verify event types and order + assert len(received_events) >= 3 + assert isinstance(received_events[0], BidiConnectionStartEvent) + assert isinstance(received_events[1], BidiAudioStreamEvent) + assert isinstance(received_events[2], BidiTranscriptStreamEvent) + + # Test empty events + agent.model.set_events([]) + await agent.stop() + await agent.start() + + empty_events = [] + async for event in agent.receive(): + empty_events.append(event) + if len(empty_events) >= 2: + break + + assert len(empty_events) >= 1 + assert isinstance(empty_events[0], BidiConnectionStartEvent) + +def test_bidi_agent_tool_integration(agent, mock_tool_registry): + """Test agent tool integration and properties.""" + # Test tool property access + assert hasattr(agent, 'tool') + assert agent.tool is not None + assert agent.tool == agent._tool_caller + + # Test tool names property + mock_tool_registry.get_all_tools_config.return_value = { + "calculator": {}, + "weather": {} + } + + tool_names = agent.tool_names + assert isinstance(tool_names, list) + assert len(tool_names) == 2 + assert "calculator" in tool_names + assert "weather" in tool_names + +@pytest.mark.asyncio +async def test_bidi_agent_send_receive_error_before_start(agent): + """Test error handling in various scenarios.""" + # Test send before start + with pytest.raises(RuntimeError, match="call start before"): + await agent.send(BidiTextInputEvent(text="Hello", role="user")) + + # Test receive before start + with pytest.raises(RuntimeError, match="call start before"): + async for event in agent.receive(): + pass + + # Test send after stop + await agent.start() + await agent.stop() + with pytest.raises(RuntimeError, match="call start before"): + await agent.send(BidiTextInputEvent(text="Hello", role="user")) + + # Test receive after stop + with pytest.raises(RuntimeError, match="call start before"): + async for event in agent.receive(): + pass + + +@pytest.mark.asyncio +async def test_bidi_agent_start_receive_propagates_model_errors(): + """Test that model errors are properly propagated.""" + # Test model start error + mock_model = MockBidiModel() + mock_model.start = unittest.mock.AsyncMock(side_effect=Exception("Connection failed")) + error_agent = BidiAgent(model=mock_model) + + with pytest.raises(Exception, match="Connection failed"): + await error_agent.start() + + # Test model receive error + mock_model2 = MockBidiModel() + agent2 = BidiAgent(model=mock_model2) + await agent2.start() + + async def failing_receive(): + yield BidiConnectionStartEvent(connection_id="test", model="test-model") + raise Exception("Receive failed") + + agent2.model.receive = failing_receive + with pytest.raises(Exception, match="Receive failed"): + async for event in agent2.receive(): + pass + +@pytest.mark.asyncio +async def test_bidi_agent_state_consistency(agent): + """Test that agent state remains consistent across operations.""" + # Initial state + assert not agent._started + assert agent.model._connection_id is None + + # Start + await agent.start() + assert agent._started + assert agent.model._connection_id is not None + connection_id = agent.model._connection_id + + # Send operations shouldn't change connection state + await agent.send(BidiTextInputEvent(text="Hello", role="user")) + assert agent._started + assert agent.model._connection_id == connection_id + + # Stop + await agent.stop() + assert not agent._started + assert agent.model._connection_id is None \ No newline at end of file diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py new file mode 100644 index 000000000..d19cada60 --- /dev/null +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -0,0 +1,107 @@ +import unittest.mock + +import pytest +import pytest_asyncio + +from strands import tool +from strands.experimental.bidi.agent.loop import _BidiAgentLoop +from strands.experimental.bidi.models import BidiModelTimeoutError +from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent +from strands.hooks import HookRegistry +from strands.tools.executors import SequentialToolExecutor +from strands.tools.registry import ToolRegistry +from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + async def func(): + return "12:00" + + return func + + +@pytest.fixture +def agent(time_tool): + mock = unittest.mock.Mock() + mock.hooks = HookRegistry() + mock.messages = [] + mock.model = unittest.mock.AsyncMock() + mock.tool_executor = SequentialToolExecutor() + mock.tool_registry = ToolRegistry() + mock.tool_registry.process_tools([time_tool]) + + return mock + + +@pytest_asyncio.fixture +async def loop(agent): + return _BidiAgentLoop(agent) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerator): + timeout_error = BidiModelTimeoutError("test timeout", test_restart_config=1) + text_event = BidiTextInputEvent(text="test after restart") + + agent.model.receive = unittest.mock.Mock(side_effect=[timeout_error, agenerator([text_event])]) + + await loop.start() + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 2: + break + + exp_events = [ + BidiConnectionRestartEvent(timeout_error), + text_event, + ] + assert tru_events == exp_events + + agent.model.stop.assert_called_once() + assert agent.model.start.call_count == 2 + agent.model.start.assert_called_with( + agent.system_prompt, + agent.tool_registry.get_all_tool_specs(), + agent.messages, + test_restart_config=1, + ) + + +@pytest.mark.asyncio +async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): + + tool_use = {"toolUseId": "t1", "name": "time_tool", "input": {}} + tool_result = {"toolUseId": "t1", "status": "success", "content": [{"text": "12:00"}]} + + tool_use_event = ToolUseStreamEvent(current_tool_use=tool_use, delta="") + tool_result_event = ToolResultEvent(tool_result) + + agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) + + await loop.start() + + tru_events = [] + async for event in loop.receive(): + tru_events.append(event) + if len(tru_events) >= 3: + break + + exp_events = [ + tool_use_event, + tool_result_event, + ToolResultMessageEvent({"role": "user", "content": [{"toolResult": tool_result}]}), + ] + assert tru_events == exp_events + + tru_messages = agent.messages + exp_messages = [ + {"role": "assistant", "content": [{"toolUse": tool_use}]}, + {"role": "user", "content": [{"toolResult": tool_result}]}, + ] + assert tru_messages == exp_messages + + agent.model.send.assert_called_with(tool_result_event) diff --git a/tests/strands/experimental/bidi/io/__init__.py b/tests/strands/experimental/bidi/io/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py new file mode 100644 index 000000000..459faa78a --- /dev/null +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -0,0 +1,175 @@ +import base64 +import unittest.mock + +import pyaudio +import pytest +import pytest_asyncio + +from strands.experimental.bidi.io.audio import BidiAudioIO, _BidiAudioBuffer +from strands.experimental.bidi.types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent + + +@pytest.fixture +def audio_buffer(): + buffer = _BidiAudioBuffer(size=1) + buffer.start() + yield buffer + buffer.stop() + + +@pytest.fixture +def agent(): + mock = unittest.mock.MagicMock() + mock.model.config = { + "audio": { + "input_rate": 24000, + "output_rate": 16000, + "channels": 2, + "format": "test-format", + "voice": "test-voice", + }, + } + return mock + + +@pytest.fixture +def py_audio(): + with unittest.mock.patch("strands.experimental.bidi.io.audio.pyaudio.PyAudio") as mock: + yield mock.return_value + + +@pytest.fixture +def config(): + return { + "input_buffer_size": 1, + "input_device_index": 1, + "input_frames_per_buffer": 1024, + "output_buffer_size": 2, + "output_device_index": 2, + "output_frames_per_buffer": 2048, + } + +@pytest.fixture +def audio_io(py_audio, config): + _ = py_audio + return BidiAudioIO(**config) + + +@pytest_asyncio.fixture +async def audio_input(audio_io, agent): + input_ = audio_io.input() + await input_.start(agent) + yield input_ + await input_.stop() + + +@pytest_asyncio.fixture +async def audio_output(audio_io, agent): + output = audio_io.output() + await output.start(agent) + yield output + await output.stop() + + +def test_bidi_audio_buffer_put(audio_buffer): + audio_buffer.put(b"test-chunk") + + tru_chunk = audio_buffer.get() + exp_chunk = b"test-chunk" + assert tru_chunk == exp_chunk + + +def test_bidi_audio_buffer_put_full(audio_buffer): + audio_buffer.put(b"test-chunk-1") + audio_buffer.put(b"test-chunk-2") + + tru_chunk = audio_buffer.get() + exp_chunk = b"test-chunk-2" + assert tru_chunk == exp_chunk + + +def test_bidi_audio_buffer_get_padding(audio_buffer): + audio_buffer.put(b"test-chunk") + + tru_chunk = audio_buffer.get(11) + exp_chunk = b"test-chunk\x00" + assert tru_chunk == exp_chunk + + +def test_bidi_audio_buffer_clear(audio_buffer): + audio_buffer.put(b"test-chunk") + audio_buffer.clear() + + tru_byte = audio_buffer.get(1) + exp_byte = b"\x00" + assert tru_byte == exp_byte + + +@pytest.mark.asyncio +async def test_bidi_audio_io_input(audio_input): + audio_input._callback(b"test-audio") + + tru_event = await audio_input() + exp_event = BidiAudioInputEvent( + audio=base64.b64encode(b"test-audio").decode("utf-8"), + channels=2, + format="test-format", + sample_rate=24000, + ) + assert tru_event == exp_event + + +def test_bidi_audio_io_input_configs(py_audio, audio_input): + py_audio.open.assert_called_once_with( + channels=2, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + input_device_index=1, + rate=24000, + stream_callback=audio_input._callback, + ) + + +@pytest.mark.asyncio +async def test_bidi_audio_io_output(audio_output): + audio_event = BidiAudioStreamEvent( + audio=base64.b64encode(b"test-audio").decode("utf-8"), + channels=2, + format="test-format", + sample_rate=16000, + ) + await audio_output(audio_event) + + tru_data, _ = audio_output._callback(None, frame_count=4) + exp_data = b"test-aud" + assert tru_data == exp_data + + +@pytest.mark.asyncio +async def test_bidi_audio_io_output_interrupt(audio_output): + audio_event = BidiAudioStreamEvent( + audio=base64.b64encode(b"test-audio").decode("utf-8"), + channels=2, + format="test-format", + sample_rate=16000, + ) + await audio_output(audio_event) + interrupt_event = BidiInterruptionEvent(reason="user_speech") + await audio_output(interrupt_event) + + tru_data, _ = audio_output._callback(None, frame_count=1) + exp_data = b"\x00\x00" + assert tru_data == exp_data + + +def test_bidi_audio_io_output_configs(py_audio, audio_output): + py_audio.open.assert_called_once_with( + channels=2, + format=pyaudio.paInt16, + frames_per_buffer=2048, + output=True, + output_device_index=2, + rate=16000, + stream_callback=audio_output._callback, + ) diff --git a/tests/strands/experimental/bidi/io/test_text.py b/tests/strands/experimental/bidi/io/test_text.py new file mode 100644 index 000000000..5507a8c0f --- /dev/null +++ b/tests/strands/experimental/bidi/io/test_text.py @@ -0,0 +1,52 @@ +import unittest.mock + +import pytest + +from strands.experimental.bidi.io import BidiTextIO +from strands.experimental.bidi.types.events import BidiInterruptionEvent, BidiTextInputEvent, BidiTranscriptStreamEvent + + +@pytest.fixture +def prompt_session(): + with unittest.mock.patch("strands.experimental.bidi.io.text.PromptSession") as mock: + yield mock.return_value + + +@pytest.fixture +def text_io(): + return BidiTextIO() + + +@pytest.fixture +def text_input(text_io): + return text_io.input() + + +@pytest.fixture +def text_output(text_io): + return text_io.output() + + +@pytest.mark.asyncio +async def test_bidi_text_io_input(prompt_session, text_input): + prompt_session.prompt_async = unittest.mock.AsyncMock(return_value="test value") + + tru_event = await text_input() + exp_event = BidiTextInputEvent(text="test value", role="user") + assert tru_event == exp_event + + +@pytest.mark.parametrize( + ("event", "exp_print"), + [ + (BidiInterruptionEvent(reason="user_speech"), "interrupted"), + (BidiTranscriptStreamEvent(text="test text", delta="", is_final=False, role="user"), "Preview: test text"), + (BidiTranscriptStreamEvent(text="test text", delta="", is_final=True, role="user"), "test text"), + ] +) +@pytest.mark.asyncio +async def test_bidi_text_io_output(event, exp_print, text_output, capsys): + await text_output(event) + + tru_print = capsys.readouterr().out.strip() + assert tru_print == exp_print diff --git a/tests/strands/experimental/bidi/models/__init__.py b/tests/strands/experimental/bidi/models/__init__.py new file mode 100644 index 000000000..ea9fbb2d0 --- /dev/null +++ b/tests/strands/experimental/bidi/models/__init__.py @@ -0,0 +1 @@ +"""Bidirectional streaming model tests.""" diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py new file mode 100644 index 000000000..da516d4a0 --- /dev/null +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -0,0 +1,751 @@ +"""Unit tests for Gemini Live bidirectional streaming model. + +Tests the unified BidiGeminiLiveModel interface including: +- Model initialization and configuration +- Connection establishment and lifecycle +- Unified send() method with different content types +- Event receiving and conversion +""" + +import base64 +import unittest.mock + +import pytest +from google.genai import types as genai_types + +from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_genai_client(): + """Mock the Google GenAI client.""" + with unittest.mock.patch("strands.experimental.bidi.models.gemini_live.genai.Client") as mock_client_cls: + mock_client = mock_client_cls.return_value + mock_client.aio = unittest.mock.MagicMock() + + # Mock the live session + mock_live_session = unittest.mock.AsyncMock() + + # Mock the context manager + mock_live_session_cm = unittest.mock.MagicMock() + mock_live_session_cm.__aenter__ = unittest.mock.AsyncMock(return_value=mock_live_session) + mock_live_session_cm.__aexit__ = unittest.mock.AsyncMock(return_value=None) + + # Make connect return the context manager + mock_client.aio.live.connect = unittest.mock.MagicMock(return_value=mock_live_session_cm) + + yield mock_client, mock_live_session, mock_live_session_cm + + +@pytest.fixture +def model_id(): + return "models/gemini-2.0-flash-live-preview-04-09" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(mock_genai_client, model_id, api_key): + """Create a BidiGeminiLiveModel instance.""" + _ = mock_genai_client + return BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_model_initialization(mock_genai_client, model_id, api_key): + """Test model initialization with various configurations.""" + _ = mock_genai_client + + # Test default config + model_default = BidiGeminiLiveModel() + assert model_default.model_id == "gemini-2.5-flash-native-audio-preview-09-2025" + assert model_default.api_key is None + assert model_default._live_session is None + # Check default config includes transcription + assert model_default.config["inference"]["response_modalities"] == ["AUDIO"] + assert "outputAudioTranscription" in model_default.config["inference"] + assert "inputAudioTranscription" in model_default.config["inference"] + + # Test with API key + model_with_key = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) + assert model_with_key.model_id == model_id + assert model_with_key.api_key == api_key + + # Test with custom config (merges with defaults) + provider_config = {"inference": {"temperature": 0.7, "top_p": 0.9}} + model_custom = BidiGeminiLiveModel(model_id=model_id, provider_config=provider_config) + # Custom config should be merged with defaults + assert model_custom.config["inference"]["temperature"] == 0.7 + assert model_custom.config["inference"]["top_p"] == 0.9 + # Defaults should still be present + assert "response_modalities" in model_custom.config["inference"] + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connection_lifecycle(mock_genai_client, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" + mock_client, mock_live_session, mock_live_session_cm = mock_genai_client + + # Test basic connection + await model.start() + assert model._connection_id is not None + assert model._live_session == mock_live_session + mock_client.aio.live.connect.assert_called_once() + + # Test close + await model.stop() + mock_live_session_cm.__aexit__.assert_called_once() + + # Test connection with system prompt + await model.start(system_prompt=system_prompt) + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert config.get("system_instruction") == system_prompt + await model.stop() + + # Test connection with tools + await model.start(tools=[tool_spec]) + call_args = mock_client.aio.live.connect.call_args + config = call_args.kwargs.get("config", {}) + assert "tools" in config + assert len(config["tools"]) > 0 + await model.stop() + + # Test connection with messages + await model.start(messages=messages) + mock_live_session.send_client_content.assert_called() + await model.stop() + + +@pytest.mark.asyncio +async def test_connection_edge_cases(mock_genai_client, api_key, model_id): + """Test connection error handling and edge cases.""" + mock_client, _, mock_live_session_cm = mock_genai_client + + # Test connection error + model1 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) + mock_client.aio.live.connect.side_effect = Exception("Connection failed") + with pytest.raises(Exception, match=r"Connection failed"): + await model1.start() + + # Reset mock for next tests + mock_client.aio.live.connect.side_effect = None + + # Test double connection + model2 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) + await model2.start() + with pytest.raises(RuntimeError, match="call stop before starting again"): + await model2.start() + await model2.stop() + + # Test close when not connected + model3 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) + await model3.stop() # Should not raise + + # Test close error handling + model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) + await model4.start() + mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") + with pytest.raises(ExceptionGroup): + await model4.stop() + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(mock_genai_client, model): + """Test sending all content types through unified send() method.""" + _, mock_live_session, _ = mock_genai_client + await model.start() + + # Test text input + text_input = BidiTextInputEvent(text="Hello", role="user") + await model.send(text_input) + mock_live_session.send_client_content.assert_called_once() + call_args = mock_live_session.send_client_content.call_args + content = call_args.kwargs.get("turns") + assert content.role == "user" + assert content.parts[0].text == "Hello" + + # Test audio input (base64 encoded) + audio_b64 = base64.b64encode(b"audio_bytes").decode("utf-8") + audio_input = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=16000, + channels=1, + ) + await model.send(audio_input) + mock_live_session.send_realtime_input.assert_called_once() + + # Test image input (base64 encoded, no encoding parameter) + image_b64 = base64.b64encode(b"image_bytes").decode("utf-8") + image_input = BidiImageInputEvent( + image=image_b64, + mime_type="image/jpeg", + ) + await model.send(image_input) + mock_live_session.send.assert_called_once() + + # Test tool result + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(ToolResultEvent(tool_result)) + mock_live_session.send_tool_response.assert_called_once() + + await model.stop() + + +@pytest.mark.asyncio +async def test_send_edge_cases(mock_genai_client, model): + """Test send() edge cases and error handling.""" + _, mock_live_session, _ = mock_genai_client + + # Test send when inactive + text_input = BidiTextInputEvent(text="Hello", role="user") + with pytest.raises(RuntimeError, match=r"call start before sending"): + await model.send(text_input) + mock_live_session.send_client_content.assert_not_called() + + # Test unknown content type + await model.start() + unknown_content = {"unknown_field": "value"} + with pytest.raises(ValueError, match=r"content not supported"): + await model.send(unknown_content) + + await model.stop() + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): + """Test that receive() emits connection start and end events.""" + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive.return_value = agenerator([]) + + await model.start() + + async for event in model.receive(): + _ = event + break + + # Verify connection start and end + assert isinstance(event, BidiConnectionStartEvent) + assert event.get("type") == "bidi_connection_start" + assert event.connection_id == model._connection_id + + +@pytest.mark.asyncio +async def test_receive_timeout(mock_genai_client, model, agenerator): + mock_resumption_response = unittest.mock.Mock() + mock_resumption_response.go_away = None + mock_resumption_response.session_resumption_update = unittest.mock.Mock() + mock_resumption_response.session_resumption_update.resumable = True + mock_resumption_response.session_resumption_update.new_handle = "h1" + + mock_timeout_response = unittest.mock.Mock() + mock_timeout_response.go_away = unittest.mock.Mock() + mock_timeout_response.go_away.model_dump_json.return_value = "test timeout" + + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive = unittest.mock.Mock( + return_value=agenerator([mock_resumption_response, mock_timeout_response]) + ) + + await model.start() + + with pytest.raises(BidiModelTimeoutError, match=r"test timeout"): + async for _ in model.receive(): + pass + + tru_handle = model._live_session_handle + exp_handle = "h1" + assert tru_handle == exp_handle + + +@pytest.mark.asyncio +async def test_event_conversion(mock_genai_client, model): + """Test conversion of all Gemini Live event types to standard format.""" + _, _, _ = mock_genai_client + await model.start() + + # Test text output (converted to transcript via model_turn.parts) + mock_text = unittest.mock.Mock() + mock_text.data = None + mock_text.go_away = None + mock_text.session_resumption_update = None + mock_text.tool_call = None + + # Create proper server_content structure with model_turn + mock_server_content = unittest.mock.Mock() + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + + mock_model_turn = unittest.mock.Mock() + mock_part = unittest.mock.Mock() + mock_part.text = "Hello from Gemini" + mock_model_turn.parts = [mock_part] + mock_server_content.model_turn = mock_model_turn + + mock_text.server_content = mock_server_content + + text_events = model._convert_gemini_live_event(mock_text) + assert isinstance(text_events, list) + assert len(text_events) == 1 + text_event = text_events[0] + assert isinstance(text_event, BidiTranscriptStreamEvent) + assert text_event.get("type") == "bidi_transcript_stream" + assert text_event.text == "Hello from Gemini" + assert text_event.role == "assistant" + assert text_event.is_final is True + assert text_event.delta == {"text": "Hello from Gemini"} + assert text_event.current_transcript == "Hello from Gemini" + + # Test multiple text parts (should concatenate) + mock_multi_text = unittest.mock.Mock() + mock_multi_text.data = None + mock_multi_text.go_away = None + mock_multi_text.session_resumption_update = None + mock_multi_text.tool_call = None + + mock_server_content_multi = unittest.mock.Mock() + mock_server_content_multi.interrupted = False + mock_server_content_multi.input_transcription = None + mock_server_content_multi.output_transcription = None + + mock_model_turn_multi = unittest.mock.Mock() + mock_part1 = unittest.mock.Mock() + mock_part1.text = "Hello" + mock_part2 = unittest.mock.Mock() + mock_part2.text = "from Gemini" + mock_model_turn_multi.parts = [mock_part1, mock_part2] + mock_server_content_multi.model_turn = mock_model_turn_multi + + mock_multi_text.server_content = mock_server_content_multi + + multi_text_events = model._convert_gemini_live_event(mock_multi_text) + assert isinstance(multi_text_events, list) + assert len(multi_text_events) == 1 + multi_text_event = multi_text_events[0] + assert isinstance(multi_text_event, BidiTranscriptStreamEvent) + assert multi_text_event.text == "Hello from Gemini" # Concatenated with space + + # Test audio output (base64 encoded) + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_events = model._convert_gemini_live_event(mock_audio) + assert isinstance(audio_events, list) + assert len(audio_events) == 1 + audio_event = audio_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.get("type") == "bidi_audio_stream" + # Audio is now base64 encoded + expected_b64 = base64.b64encode(b"audio_data").decode("utf-8") + assert audio_event.audio == expected_b64 + assert audio_event.format == "pcm" + + # Test single tool call (returns list with one event) + mock_func_call = unittest.mock.Mock() + mock_func_call.id = "tool-123" + mock_func_call.name = "calculator" + mock_func_call.args = {"expression": "2+2"} + + mock_tool_call = unittest.mock.Mock() + mock_tool_call.function_calls = [mock_func_call] + + mock_tool = unittest.mock.Mock() + mock_tool.text = None + mock_tool.data = None + mock_tool.go_away = None + mock_tool.session_resumption_update = None + mock_tool.tool_call = mock_tool_call + mock_tool.server_content = None + + tool_events = model._convert_gemini_live_event(mock_tool) + # Should return a list of ToolUseStreamEvent + assert isinstance(tool_events, list) + assert len(tool_events) == 1 + tool_event = tool_events[0] + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in tool_event + assert "toolUse" in tool_event["delta"] + assert tool_event["delta"]["toolUse"]["toolUseId"] == "tool-123" + assert tool_event["delta"]["toolUse"]["name"] == "calculator" + + # Test multiple tool calls (returns list with multiple events) + mock_func_call_1 = unittest.mock.Mock() + mock_func_call_1.id = "tool-123" + mock_func_call_1.name = "calculator" + mock_func_call_1.args = {"expression": "2+2"} + + mock_func_call_2 = unittest.mock.Mock() + mock_func_call_2.id = "tool-456" + mock_func_call_2.name = "weather" + mock_func_call_2.args = {"location": "Seattle"} + + mock_tool_call_multi = unittest.mock.Mock() + mock_tool_call_multi.function_calls = [mock_func_call_1, mock_func_call_2] + + mock_tool_multi = unittest.mock.Mock() + mock_tool_multi.text = None + mock_tool_multi.data = None + mock_tool_multi.go_away = None + mock_tool_multi.session_resumption_update = None + mock_tool_multi.tool_call = mock_tool_call_multi + mock_tool_multi.server_content = None + + tool_events_multi = model._convert_gemini_live_event(mock_tool_multi) + # Should return a list with two ToolUseStreamEvent + assert isinstance(tool_events_multi, list) + assert len(tool_events_multi) == 2 + + # Verify first tool call + assert tool_events_multi[0]["delta"]["toolUse"]["toolUseId"] == "tool-123" + assert tool_events_multi[0]["delta"]["toolUse"]["name"] == "calculator" + assert tool_events_multi[0]["delta"]["toolUse"]["input"] == {"expression": "2+2"} + + # Verify second tool call + assert tool_events_multi[1]["delta"]["toolUse"]["toolUseId"] == "tool-456" + assert tool_events_multi[1]["delta"]["toolUse"]["name"] == "weather" + assert tool_events_multi[1]["delta"]["toolUse"]["input"] == {"location": "Seattle"} + + # Test interruption + mock_server_content = unittest.mock.Mock() + mock_server_content.interrupted = True + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + + mock_interrupt = unittest.mock.Mock() + mock_interrupt.text = None + mock_interrupt.data = None + mock_interrupt.go_away = None + mock_interrupt.session_resumption_update = None + mock_interrupt.tool_call = None + mock_interrupt.server_content = mock_server_content + + interrupt_events = model._convert_gemini_live_event(mock_interrupt) + assert isinstance(interrupt_events, list) + assert len(interrupt_events) == 1 + interrupt_event = interrupt_events[0] + assert isinstance(interrupt_event, BidiInterruptionEvent) + assert interrupt_event.get("type") == "bidi_interruption" + assert interrupt_event.reason == "user_speech" + + await model.stop() + + +# Audio Configuration Tests + + +def test_audio_config_defaults(mock_genai_client, model_id, api_key): + """Test default audio configuration.""" + _ = mock_genai_client + + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) + + assert model.config["audio"]["input_rate"] == 16000 + assert model.config["audio"]["output_rate"] == 24000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + assert "voice" not in model.config["audio"] # No default voice + + +def test_audio_config_partial_override(mock_genai_client, model_id, api_key): + """Test partial audio configuration override.""" + _ = mock_genai_client + + provider_config = {"audio": {"output_rate": 48000, "voice": "Puck"}} + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) + + # Overridden values + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["voice"] == "Puck" + + # Default values preserved + assert model.config["audio"]["input_rate"] == 16000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + + +def test_audio_config_full_override(mock_genai_client, model_id, api_key): + """Test full audio configuration override.""" + _ = mock_genai_client + + provider_config = { + "audio": { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "Aoede", + } + } + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) + + assert model.config["audio"]["input_rate"] == 48000 + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["channels"] == 2 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "Aoede" + + +# Helper Method Tests + + +def test_config_building(model, system_prompt, tool_spec): + """Test building live config with various options.""" + # Test basic config + config_basic = model._build_live_config() + assert isinstance(config_basic, dict) + + # Test with system prompt + config_prompt = model._build_live_config(system_prompt=system_prompt) + assert config_prompt["system_instruction"] == system_prompt + + # Test with tools + config_tools = model._build_live_config(tools=[tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 + + +def test_tool_formatting(model, tool_spec): + """Test tool formatting for Gemini Live API.""" + # Test with tools + formatted_tools = model._format_tools_for_live_api([tool_spec]) + assert len(formatted_tools) == 1 + assert isinstance(formatted_tools[0], genai_types.Tool) + + # Test empty list + formatted_empty = model._format_tools_for_live_api([]) + assert formatted_empty == [] + + + +# Tool Result Content Tests + + +@pytest.mark.asyncio +async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key): + """Test that audio events use configured sample rates and channels.""" + _, _, _ = mock_genai_client + + # Create model with custom audio configuration + provider_config = {"audio": {"output_rate": 48000, "channels": 2}} + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}, provider_config=provider_config) + await model.start() + + # Test audio output event uses custom configuration + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_events = model._convert_gemini_live_event(mock_audio) + assert len(audio_events) == 1 + audio_event = audio_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + # Should use configured rates, not constants + assert audio_event.sample_rate == 48000 # Custom config + assert audio_event.channels == 2 # Custom config + assert audio_event.format == "pcm" + + await model.stop() + + +@pytest.mark.asyncio +async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_key): + """Test that audio events use default sample rates when no custom config.""" + _, _, _ = mock_genai_client + + # Create model without custom audio configuration + model = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) + await model.start() + + # Test audio output event uses defaults + mock_audio = unittest.mock.Mock() + mock_audio.text = None + mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None + mock_audio.tool_call = None + mock_audio.server_content = None + + audio_events = model._convert_gemini_live_event(mock_audio) + assert len(audio_events) == 1 + audio_event = audio_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + # Should use default rates + assert audio_event.sample_rate == 24000 # Default output rate + assert audio_event.channels == 1 # Default channels + assert audio_event.format == "pcm" + + await model.stop() + + +# Tool Result Content Tests + + +@pytest.mark.asyncio +async def test_tool_result_single_content_unwrapped(mock_genai_client, model): + """Test that single content item is unwrapped (optimization).""" + _, mock_live_session, _ = mock_genai_client + await model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Single result"}], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the tool response was sent + mock_live_session.send_tool_response.assert_called_once() + call_args = mock_live_session.send_tool_response.call_args + function_responses = call_args.kwargs.get("function_responses", []) + + assert len(function_responses) == 1 + func_response = function_responses[0] + assert func_response.id == "tool-123" + # Single content should be unwrapped (not in array) + assert func_response.response == {"text": "Single result"} + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_multiple_content_as_array(mock_genai_client, model): + """Test that multiple content items are sent as array.""" + _, mock_live_session, _ = mock_genai_client + await model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-456", + "status": "success", + "content": [{"text": "Part 1"}, {"json": {"data": "value"}}], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the tool response was sent + mock_live_session.send_tool_response.assert_called_once() + call_args = mock_live_session.send_tool_response.call_args + function_responses = call_args.kwargs.get("function_responses", []) + + assert len(function_responses) == 1 + func_response = function_responses[0] + assert func_response.id == "tool-456" + # Multiple content should be in array format + assert "result" in func_response.response + assert isinstance(func_response.response["result"], list) + assert len(func_response.response["result"]) == 2 + assert func_response.response["result"][0] == {"text": "Part 1"} + assert func_response.response["result"][1] == {"json": {"data": "value"}} + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_unsupported_content_type(mock_genai_client, model): + """Test that unsupported content types raise ValueError.""" + _, _, _ = mock_genai_client + await model.start() + + # Test with image content (unsupported) + tool_result_image: ToolResult = { + "toolUseId": "tool-999", + "status": "success", + "content": [{"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Gemini Live API"): + await model.send(ToolResultEvent(tool_result_image)) + + # Test with document content (unsupported) + tool_result_doc: ToolResult = { + "toolUseId": "tool-888", + "status": "success", + "content": [{"document": {"format": "pdf", "source": {"bytes": b"doc_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Gemini Live API"): + await model.send(ToolResultEvent(tool_result_doc)) + + # Test with mixed content (one unsupported) + tool_result_mixed: ToolResult = { + "toolUseId": "tool-777", + "status": "success", + "content": [{"text": "Valid text"}, {"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Gemini Live API"): + await model.send(ToolResultEvent(tool_result_mixed)) + + await model.stop() + + +# Helper fixture for async generator +@pytest.fixture +def agenerator(): + """Helper to create async generators for testing.""" + + async def _agenerator(items): + for item in items: + yield item + + return _agenerator diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py new file mode 100644 index 000000000..04f8043be --- /dev/null +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -0,0 +1,763 @@ +"""Unit tests for Nova Sonic bidirectional model implementation. + +Tests the unified BidirectionalModel interface implementation for Amazon Nova Sonic, +covering connection lifecycle, event conversion, audio streaming, and tool execution. +""" + +import asyncio +import base64 +import json +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio +from aws_sdk_bedrock_runtime.models import ModelTimeoutException, ValidationException + +from strands.experimental.bidi.models.nova_sonic import ( + BidiNovaSonicModel, +) +from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, +) +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult + + +# Test fixtures +@pytest.fixture +def model_id(): + """Nova Sonic model identifier.""" + return "amazon.nova-sonic-v1:0" + + +@pytest.fixture +def region(): + """AWS region.""" + return "us-east-1" + + +@pytest.fixture +def mock_stream(): + """Mock Nova Sonic bidirectional stream.""" + stream = AsyncMock() + stream.input_stream = AsyncMock() + stream.input_stream.send = AsyncMock() + stream.input_stream.close = AsyncMock() + stream.await_output = AsyncMock() + return stream + + +@pytest.fixture +def mock_client(mock_stream): + """Mock Bedrock Runtime client.""" + with patch("strands.experimental.bidi.models.nova_sonic.BedrockRuntimeClient") as mock_cls: + mock_instance = AsyncMock() + mock_instance.invoke_model_with_bidirectional_stream = AsyncMock(return_value=mock_stream) + mock_cls.return_value = mock_instance + + yield mock_instance + + +@pytest_asyncio.fixture +def nova_model(model_id, region, mock_client): + """Create Nova Sonic model instance.""" + _ = mock_client + + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + yield model + + +# Initialization and Connection Tests + + +@pytest.mark.asyncio +async def test_model_initialization(model_id, region): + """Test model initialization with configuration.""" + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + + assert model.model_id == model_id + assert model.region == region + assert model._connection_id is None + + +# Audio Configuration Tests + + +@pytest.mark.asyncio +async def test_audio_config_defaults(model_id, region): + """Test default audio configuration.""" + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + + assert model.config["audio"]["input_rate"] == 16000 + assert model.config["audio"]["output_rate"] == 16000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "matthew" + + +@pytest.mark.asyncio +async def test_audio_config_partial_override(model_id, region): + """Test partial audio configuration override.""" + provider_config = {"audio": {"output_rate": 24000, "voice": "ruth"}} + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + + # Overridden values + assert model.config["audio"]["output_rate"] == 24000 + assert model.config["audio"]["voice"] == "ruth" + + # Default values preserved + assert model.config["audio"]["input_rate"] == 16000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + + +@pytest.mark.asyncio +async def test_audio_config_full_override(model_id, region): + """Test full audio configuration override.""" + provider_config = { + "audio": { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "stephen", + } + } + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + + assert model.config["audio"]["input_rate"] == 48000 + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["channels"] == 2 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "stephen" + + +@pytest.mark.asyncio +async def test_connection_lifecycle(nova_model, mock_client, mock_stream): + """Test complete connection lifecycle with various configurations.""" + + # Test basic connection + await nova_model.start(system_prompt="Test system prompt") + assert nova_model._stream == mock_stream + assert nova_model._connection_id is not None + assert mock_client.invoke_model_with_bidirectional_stream.called + + # Test close + await nova_model.stop() + assert mock_stream.close.called + + # Test connection with tools + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": {"json": json.dumps({"type": "object", "properties": {}})}, + } + ] + await nova_model.start(system_prompt="You are helpful", tools=tools) + # Verify initialization events were sent (connectionStart, promptStart, system prompt) + assert mock_stream.input_stream.send.call_count >= 3 + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_model_stop_alone(nova_model): + await nova_model.stop() # Should not raise + + +@pytest.mark.asyncio +async def test_connection_with_message_history(nova_model, mock_client, mock_stream): + """Test connection initialization with conversation history.""" + nova_model.client = mock_client + + # Create message history + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + {"role": "assistant", "content": [{"text": "I'll check the weather for you."}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "tool-123", "name": "get_weather", "input": {}}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "tool-123", "content": [{"text": "Sunny, 72°F"}]}}], + }, + {"role": "assistant", "content": [{"text": "It's sunny and 72 degrees."}]}, + ] + + # Start connection with message history + await nova_model.start(system_prompt="You are a helpful assistant", messages=messages) + + # Verify initialization events were sent + # Should include: sessionStart, promptStart, system prompt (3 events), + # and message history (only text messages: 3 messages * 3 events each = 9 events) + # Tool use/result messages are now skipped in history + # Total: 1 + 1 + 3 + 9 = 14 events minimum + assert mock_stream.input_stream.send.call_count >= 14 + + # Verify the events contain proper role information + sent_events = [call.args[0].value.bytes_.decode("utf-8") for call in mock_stream.input_stream.send.call_args_list] + + # Check that USER and ASSISTANT roles are present in contentStart events + user_events = [e for e in sent_events if '"role": "USER"' in e] + assistant_events = [e for e in sent_events if '"role": "ASSISTANT"' in e] + + # Only text messages are sent, so we expect 1 user message and 2 assistant messages + assert len(user_events) >= 1 + assert len(assistant_events) >= 2 + + await nova_model.stop() + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(nova_model, mock_stream): + """Test sending all content types through unified send() method.""" + await nova_model.start() + + # Test text content + text_event = BidiTextInputEvent(text="Hello, Nova!", role="user") + await nova_model.send(text_event) + # Should send contentStart, textInput, and contentEnd + assert mock_stream.input_stream.send.call_count >= 3 + + # Test audio content (base64 encoded) + audio_b64 = base64.b64encode(b"audio data").decode("utf-8") + audio_event = BidiAudioInputEvent(audio=audio_b64, format="pcm", sample_rate=16000, channels=1) + await nova_model.send(audio_event) + # Should start audio connection and send audio + assert nova_model._audio_content_name + assert mock_stream.input_stream.send.called + + # Test tool result with single content item (should be unwrapped) + tool_result_single: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Weather is sunny"}], + } + await nova_model.send(ToolResultEvent(tool_result_single)) + # Should send contentStart, toolResult, and contentEnd + assert mock_stream.input_stream.send.called + + # Test tool result with multiple content items (should send as array) + tool_result_multi: ToolResult = { + "toolUseId": "tool-456", + "status": "success", + "content": [{"text": "Part 1"}, {"json": {"data": "value"}}], + } + await nova_model.send(ToolResultEvent(tool_result_multi)) + assert mock_stream.input_stream.send.called + + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_send_edge_cases(nova_model): + """Test send() edge cases and error handling.""" + + # Test image content (not supported, base64 encoded, no encoding parameter) + await nova_model.start() + image_b64 = base64.b64encode(b"image data").decode("utf-8") + image_event = BidiImageInputEvent( + image=image_b64, + mime_type="image/jpeg", + ) + + with pytest.raises(ValueError, match=r"content not supported"): + await nova_model.send(image_event) + + await nova_model.stop() + + +# Receive and Event Conversion Tests + + +@pytest.mark.asyncio +async def test_event_conversion(nova_model): + """Test conversion of all Nova Sonic event types to standard format.""" + # Test audio output (now returns BidiAudioStreamEvent) + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + nova_event = {"audioOutput": {"content": audio_base64}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiAudioStreamEvent) + assert result.get("type") == "bidi_audio_stream" + # Audio is kept as base64 string + assert result.get("audio") == audio_base64 + assert result.get("format") == "pcm" + assert result.get("sample_rate") == 16000 + + # Test text output (now returns BidiTranscriptStreamEvent) + nova_event = {"textOutput": {"content": "Hello, world!", "role": "ASSISTANT"}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiTranscriptStreamEvent) + assert result.get("type") == "bidi_transcript_stream" + assert result.get("text") == "Hello, world!" + assert result.get("role") == "assistant" + assert result.delta == {"text": "Hello, world!"} + assert result.current_transcript == "Hello, world!" + + # Test tool use (now returns ToolUseStreamEvent from core strands) + tool_input = {"location": "Seattle"} + nova_event = {"toolUse": {"toolUseId": "tool-123", "toolName": "get_weather", "content": json.dumps(tool_input)}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in result + assert "toolUse" in result["delta"] + tool_use = result["delta"]["toolUse"] + assert tool_use["toolUseId"] == "tool-123" + assert tool_use["name"] == "get_weather" + assert tool_use["input"] == tool_input + + # Test interruption (now returns BidiInterruptionEvent) + nova_event = {"stopReason": "INTERRUPTED"} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiInterruptionEvent) + assert result.get("type") == "bidi_interruption" + assert result.get("reason") == "user_speech" + + # Test usage metrics (now returns BidiUsageEvent) + nova_event = { + "usageEvent": { + "totalTokens": 100, + "totalInputTokens": 40, + "totalOutputTokens": 60, + "details": {"total": {"output": {"speechTokens": 30}}}, + } + } + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiUsageEvent) + assert result.get("type") == "bidi_usage" + assert result.get("totalTokens") == 100 + assert result.get("inputTokens") == 40 + assert result.get("outputTokens") == 60 + + # Test content start tracks role and emits BidiResponseStartEvent + # TEXT type contentStart (matches API spec) + nova_event = { + "contentStart": { + "role": "ASSISTANT", + "type": "TEXT", + "additionalModelFields": '{"generationStage":"FINAL"}', + "contentId": "content-123", + } + } + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiResponseStartEvent) + assert result.get("type") == "bidi_response_start" + assert nova_model._generation_stage == "FINAL" + + # Test AUDIO type contentStart (no additionalModelFields) + nova_event = {"contentStart": {"role": "ASSISTANT", "type": "AUDIO", "contentId": "content-456"}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiResponseStartEvent) + + # Test TOOL type contentStart + nova_event = {"contentStart": {"role": "TOOL", "type": "TOOL", "contentId": "content-789"}} + result = nova_model._convert_nova_event(nova_event) + assert result is not None + assert isinstance(result, BidiResponseStartEvent) + + +# Audio Streaming Tests + + +@pytest.mark.asyncio +async def test_audio_connection_lifecycle(nova_model): + """Test audio connection start and end lifecycle.""" + + await nova_model.start() + + # Start audio connection + await nova_model._start_audio_connection() + assert nova_model._audio_content_name + + # End audio connection + await nova_model._end_audio_input() + assert not nova_model._audio_content_name + + await nova_model.stop() + + +# Helper Method Tests + + +@pytest.mark.asyncio +async def test_tool_configuration(nova_model): + """Test building tool configuration from tool specs.""" + tools = [ + { + "name": "get_weather", + "description": "Get weather information", + "inputSchema": {"json": json.dumps({"type": "object", "properties": {"location": {"type": "string"}}})}, + } + ] + + tool_config = nova_model._build_tool_configuration(tools) + + assert len(tool_config) == 1 + assert tool_config[0]["toolSpec"]["name"] == "get_weather" + assert tool_config[0]["toolSpec"]["description"] == "Get weather information" + assert "inputSchema" in tool_config[0]["toolSpec"] + + +@pytest.mark.asyncio +async def test_event_templates(nova_model): + """Test event template generation.""" + # Test connection start event + event_json = nova_model._get_connection_start_event() + event = json.loads(event_json) + assert "event" in event + assert "sessionStart" in event["event"] + assert "inferenceConfiguration" in event["event"]["sessionStart"] + + # Test prompt start event + nova_model._connection_id = "test-connection" + event_json = nova_model._get_prompt_start_event([]) + event = json.loads(event_json) + assert "event" in event + assert "promptStart" in event["event"] + assert event["event"]["promptStart"]["promptName"] == "test-connection" + + # Test text input event + content_name = "test-content" + event_json = nova_model._get_text_input_event(content_name, "Hello") + event = json.loads(event_json) + assert "event" in event + assert "textInput" in event["event"] + assert event["event"]["textInput"]["content"] == "Hello" + + # Test tool result event + result = {"result": "Success"} + event_json = nova_model._get_tool_result_event(content_name, result) + event = json.loads(event_json) + assert "event" in event + assert "toolResult" in event["event"] + assert json.loads(event["event"]["toolResult"]["content"]) == result + + +@pytest.mark.asyncio +async def test_message_history_conversion(nova_model): + """Test conversion of agent messages to Nova Sonic history events.""" + nova_model.connection_id = "test-connection" + + # Test with various message types + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi there!"}]}, + { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "tool-1", "name": "calculator", "input": {"expr": "2+2"}}}], + }, + {"role": "user", "content": [{"toolResult": {"toolUseId": "tool-1", "content": [{"text": "4"}]}}]}, + {"role": "assistant", "content": [{"text": "The answer is 4"}]}, + ] + + events = nova_model._get_message_history_events(messages) + + # Only text messages generate events (3 messages * 3 events each = 9 events) + # Tool use/result messages are now skipped in history + assert len(events) == 9 + + # Parse and verify events + parsed_events = [json.loads(e) for e in events] + + # Check first message (user) + assert "contentStart" in parsed_events[0]["event"] + assert parsed_events[0]["event"]["contentStart"]["role"] == "USER" + assert "textInput" in parsed_events[1]["event"] + assert parsed_events[1]["event"]["textInput"]["content"] == "Hello" + assert "contentEnd" in parsed_events[2]["event"] + + # Check second message (assistant) + assert "contentStart" in parsed_events[3]["event"] + assert parsed_events[3]["event"]["contentStart"]["role"] == "ASSISTANT" + assert "textInput" in parsed_events[4]["event"] + assert parsed_events[4]["event"]["textInput"]["content"] == "Hi there!" + + # Check third message (assistant - last text message) + assert "contentStart" in parsed_events[6]["event"] + assert parsed_events[6]["event"]["contentStart"]["role"] == "ASSISTANT" + assert "textInput" in parsed_events[7]["event"] + assert parsed_events[7]["event"]["textInput"]["content"] == "The answer is 4" + + +@pytest.mark.asyncio +async def test_message_history_empty_and_edge_cases(nova_model): + """Test message history conversion with empty and edge cases.""" + nova_model.connection_id = "test-connection" + + # Test with empty messages + events = nova_model._get_message_history_events([]) + assert len(events) == 0 + + # Test with message containing no text content + messages = [{"role": "user", "content": []}] + events = nova_model._get_message_history_events(messages) + assert len(events) == 0 # No events generated for empty content + + # Test with multiple text blocks in one message + messages = [{"role": "user", "content": [{"text": "First part"}, {"text": "Second part"}]}] + events = nova_model._get_message_history_events(messages) + assert len(events) == 3 # contentStart, textInput, contentEnd + parsed = json.loads(events[1]) + content = parsed["event"]["textInput"]["content"] + assert "First part" in content + assert "Second part" in content + + +# Error Handling Tests + + +@pytest.mark.asyncio +async def test_custom_audio_rates_in_events(model_id, region): + """Test that audio events use configured sample rates.""" + # Create model with custom audio configuration + provider_config = {"audio": {"output_rate": 48000, "channels": 2}} + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}, provider_config=provider_config) + + # Test audio output event uses custom configuration + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + nova_event = {"audioOutput": {"content": audio_base64}} + result = model._convert_nova_event(nova_event) + + assert result is not None + assert isinstance(result, BidiAudioStreamEvent) + # Should use configured rates, not constants + assert result.sample_rate == 48000 # Custom config + assert result.channels == 2 # Custom config + assert result.format == "pcm" + + +@pytest.mark.asyncio +async def test_default_audio_rates_in_events(model_id, region): + """Test that audio events use default sample rates when no custom config.""" + # Create model without custom audio configuration + model = BidiNovaSonicModel(model_id=model_id, client_config={"region": region}) + + # Test audio output event uses defaults + audio_bytes = b"test audio data" + audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") + nova_event = {"audioOutput": {"content": audio_base64}} + result = model._convert_nova_event(nova_event) + + assert result is not None + assert isinstance(result, BidiAudioStreamEvent) + # Should use default rates + assert result.sample_rate == 16000 # Default output rate + assert result.channels == 1 # Default channels + assert result.format == "pcm" + + +# Error Handling Tests +@pytest.mark.asyncio +async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): + mock_output = AsyncMock() + mock_output.receive.side_effect = ModelTimeoutException("Connection timeout") + mock_stream.await_output.return_value = (None, mock_output) + + await nova_model.start() + + with pytest.raises(BidiModelTimeoutError, match=r"Connection timeout"): + async for _ in nova_model.receive(): + pass + + +@pytest.mark.asyncio +async def test_bidi_nova_sonic_model_receive_timeout_validation(nova_model, mock_stream): + mock_output = AsyncMock() + mock_output.receive.side_effect = ValidationException("InternalErrorCode=531: Request timeout") + mock_stream.await_output.return_value = (None, mock_output) + + await nova_model.start() + + with pytest.raises(BidiModelTimeoutError, match=r"InternalErrorCode=531"): + async for _ in nova_model.receive(): + pass + + +@pytest.mark.asyncio +async def test_error_handling(nova_model, mock_stream): + """Test error handling in various scenarios.""" + + # Test response processor handles errors gracefully + async def mock_error(*args, **kwargs): + raise Exception("Test error") + + mock_stream.await_output.side_effect = mock_error + + await nova_model.start() + + # Wait a bit for response processor to handle error + await asyncio.sleep(0.1) + + # Should still be able to close cleanly + await nova_model.stop() + + +# Tool Result Content Tests + + +@pytest.mark.asyncio +async def test_tool_result_single_content_unwrapped(nova_model, mock_stream): + """Test that single content item is unwrapped (optimization).""" + await nova_model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Single result"}], + } + + await nova_model.send(ToolResultEvent(tool_result)) + + # Verify events were sent + assert mock_stream.input_stream.send.called + calls = mock_stream.input_stream.send.call_args_list + + # Find the toolResult event + tool_result_events = [] + for call in calls: + event_json = call.args[0].value.bytes_.decode("utf-8") + event = json.loads(event_json) + if "toolResult" in event.get("event", {}): + tool_result_events.append(event) + + assert len(tool_result_events) > 0 + tool_result_event = tool_result_events[0]["event"]["toolResult"] + + # Single content should be unwrapped (not in array) + content = json.loads(tool_result_event["content"]) + assert content == {"text": "Single result"} + + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_multiple_content_as_array(nova_model, mock_stream): + """Test that multiple content items are sent as array.""" + await nova_model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-456", + "status": "success", + "content": [{"text": "Part 1"}, {"json": {"data": "value"}}], + } + + await nova_model.send(ToolResultEvent(tool_result)) + + # Verify events were sent + assert mock_stream.input_stream.send.called + calls = mock_stream.input_stream.send.call_args_list + + # Find the toolResult event + tool_result_events = [] + for call in calls: + event_json = call.args[0].value.bytes_.decode("utf-8") + event = json.loads(event_json) + if "toolResult" in event.get("event", {}): + tool_result_events.append(event) + + assert len(tool_result_events) > 0 + tool_result_event = tool_result_events[0]["event"]["toolResult"] + + # Multiple content should be in array format + content = json.loads(tool_result_event["content"]) + assert "content" in content + assert isinstance(content["content"], list) + assert len(content["content"]) == 2 + assert content["content"][0] == {"text": "Part 1"} + assert content["content"][1] == {"json": {"data": "value"}} + + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_empty_content(nova_model, mock_stream): + """Test that empty content is handled gracefully.""" + await nova_model.start() + + tool_result: ToolResult = { + "toolUseId": "tool-789", + "status": "success", + "content": [], + } + + await nova_model.send(ToolResultEvent(tool_result)) + + # Verify events were sent + assert mock_stream.input_stream.send.called + calls = mock_stream.input_stream.send.call_args_list + + # Find the toolResult event + tool_result_events = [] + for call in calls: + event_json = call.args[0].value.bytes_.decode("utf-8") + event = json.loads(event_json) + if "toolResult" in event.get("event", {}): + tool_result_events.append(event) + + assert len(tool_result_events) > 0 + tool_result_event = tool_result_events[0]["event"]["toolResult"] + + # Empty content should result in empty array wrapped in content key + content = json.loads(tool_result_event["content"]) + assert content == {"content": []} + + await nova_model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_unsupported_content_type(nova_model): + """Test that unsupported content types raise ValueError.""" + await nova_model.start() + + # Test with image content (unsupported) + tool_result_image: ToolResult = { + "toolUseId": "tool-999", + "status": "success", + "content": [{"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Nova Sonic"): + await nova_model.send(ToolResultEvent(tool_result_image)) + + # Test with document content (unsupported) + tool_result_doc: ToolResult = { + "toolUseId": "tool-888", + "status": "success", + "content": [{"document": {"format": "pdf", "source": {"bytes": b"doc_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Nova Sonic"): + await nova_model.send(ToolResultEvent(tool_result_doc)) + + # Test with mixed content (one unsupported) + tool_result_mixed: ToolResult = { + "toolUseId": "tool-777", + "status": "success", + "content": [{"text": "Valid text"}, {"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by Nova Sonic"): + await nova_model.send(ToolResultEvent(tool_result_mixed)) + + await nova_model.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py new file mode 100644 index 000000000..5c9c0900d --- /dev/null +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -0,0 +1,918 @@ +"""Unit tests for OpenAI Realtime bidirectional streaming model. + +Tests the unified BidiOpenAIRealtimeModel interface including: +- Model initialization and configuration +- Connection establishment with WebSocket +- Unified send() method with different content types +- Event receiving and conversion +- Connection lifecycle management +""" + +import base64 +import json +import unittest.mock + +import pytest + +from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel +from strands.experimental.bidi.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiResponseCompleteEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, +) +from strands.types._events import ToolResultEvent +from strands.types.tools import ToolResult + + +@pytest.fixture +def mock_websocket(): + """Mock WebSocket connection.""" + mock_ws = unittest.mock.AsyncMock() + mock_ws.send = unittest.mock.AsyncMock() + mock_ws.close = unittest.mock.AsyncMock() + return mock_ws + + +@pytest.fixture +def mock_websockets_connect(mock_websocket): + """Mock websockets.connect function.""" + + async def async_connect(*args, **kwargs): + return mock_websocket + + with unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.websockets.connect") as mock_connect: + mock_connect.side_effect = async_connect + yield mock_connect, mock_websocket + + +@pytest.fixture +def model_name(): + return "gpt-realtime" + + +@pytest.fixture +def api_key(): + return "test-api-key" + + +@pytest.fixture +def model(mock_websockets_connect, api_key, model_name): + """Create an BidiOpenAIRealtimeModel instance.""" + return BidiOpenAIRealtimeModel(model=model_name, client_config={"api_key": api_key}) + + +@pytest.fixture +def tool_spec(): + return { + "description": "Calculate mathematical expressions", + "name": "calculator", + "inputSchema": {"json": {"type": "object", "properties": {"expression": {"type": "string"}}}}, + } + + +@pytest.fixture +def system_prompt(): + return "You are a helpful assistant" + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "Hello"}]}] + + +# Initialization Tests + + +def test_model_initialization(api_key, model_name, monkeypatch): + """Test model initialization with various configurations.""" + # Test default config + model_default = BidiOpenAIRealtimeModel(client_config={"api_key": "test-key"}) + assert model_default.model_id == "gpt-realtime" + assert model_default.api_key == "test-key" + + # Test with custom model + model_custom = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) + assert model_custom.model_id == model_name + assert model_custom.api_key == api_key + + # Test with organization and project via environment variables + monkeypatch.setenv("OPENAI_ORGANIZATION", "org-123") + monkeypatch.setenv("OPENAI_PROJECT", "proj-456") + model_env = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) + assert model_env.organization == "org-123" + assert model_env.project == "proj-456" + + # Test with env API key + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + model_env = BidiOpenAIRealtimeModel() + assert model_env.api_key == "env-key" + + +# Audio Configuration Tests + + +def test_audio_config_defaults(api_key, model_name): + """Test default audio configuration.""" + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) + + assert model.config["audio"]["input_rate"] == 24000 + assert model.config["audio"]["output_rate"] == 24000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "alloy" + + +def test_audio_config_partial_override(api_key, model_name): + """Test partial audio configuration override.""" + provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + + # Overridden values + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["voice"] == "echo" + + # Default values preserved + assert model.config["audio"]["input_rate"] == 24000 + assert model.config["audio"]["channels"] == 1 + assert model.config["audio"]["format"] == "pcm" + + +def test_audio_config_full_override(api_key, model_name): + """Test full audio configuration override.""" + provider_config = { + "audio": { + "input_rate": 48000, + "output_rate": 48000, + "channels": 2, + "format": "pcm", + "voice": "shimmer", + } + } + model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + + assert model.config["audio"]["input_rate"] == 48000 + assert model.config["audio"]["output_rate"] == 48000 + assert model.config["audio"]["channels"] == 2 + assert model.config["audio"]["format"] == "pcm" + assert model.config["audio"]["voice"] == "shimmer" + + +def test_audio_config_extracts_voice_from_provider_config(api_key, model_name): + """Test that voice is extracted from provider_config when config audio not provided.""" + provider_config = {"audio": {"voice": "fable"}} + + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) + + # Should extract voice from provider_config + assert model.config["audio"]["voice"] == "fable" + + +def test_init_without_api_key_raises(monkeypatch): + """Test that initialization without API key raises error.""" + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with pytest.raises(ValueError, match="OpenAI API key is required"): + BidiOpenAIRealtimeModel() + + +# Connection Tests + + +@pytest.mark.asyncio +async def test_connection_lifecycle(mock_websockets_connect, model, system_prompt, tool_spec, messages): + """Test complete connection lifecycle with various configurations.""" + mock_connect, mock_ws = mock_websockets_connect + + # Test basic connection + await model.start() + assert model._connection_id is not None + assert model._websocket == mock_ws + mock_connect.assert_called_once() + + # Test close + await model.stop() + mock_ws.close.assert_called_once() + + # Test connection with system prompt + await model.start(system_prompt=system_prompt) + calls = mock_ws.send.call_args_list + session_update = next( + (json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update"), None + ) + assert session_update is not None + assert system_prompt in session_update["session"]["instructions"] + await model.stop() + + # Test connection with tools + await model.start(tools=[tool_spec]) + calls = mock_ws.send.call_args_list + # Tools are sent in a separate session.update after initial connection + session_updates = [ + json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "session.update" + ] + assert len(session_updates) > 0 + # Check if any session update has tools + has_tools = any("tools" in update.get("session", {}) for update in session_updates) + assert has_tools + await model.stop() + + # Test connection with messages + await model.start(messages=messages) + calls = mock_ws.send.call_args_list + item_creates = [ + json.loads(call[0][0]) for call in calls if json.loads(call[0][0]).get("type") == "conversation.item.create" + ] + assert len(item_creates) > 0 + await model.stop() + + # Test connection with organization header (via environment) + # Note: This test needs to be in a separate test function to use monkeypatch properly + # Skipping inline environment test here - see test_connection_with_org_header + + +@pytest.mark.asyncio +async def test_connection_with_org_header(mock_websockets_connect, monkeypatch): + """Test connection with organization header from environment.""" + mock_connect, mock_ws = mock_websockets_connect + + monkeypatch.setenv("OPENAI_ORGANIZATION", "org-123") + model_org = BidiOpenAIRealtimeModel(client_config={"api_key": "test-key"}) + await model_org.start() + call_kwargs = mock_connect.call_args.kwargs + headers = call_kwargs.get("additional_headers", []) + org_header = [h for h in headers if h[0] == "OpenAI-Organization"] + assert len(org_header) == 1 + assert org_header[0][1] == "org-123" + await model_org.stop() + + +@pytest.mark.asyncio +async def test_connection_with_message_history(mock_websockets_connect, model): + """Test connection initialization with conversation history including tool calls.""" + _, mock_ws = mock_websockets_connect + + # Create message history with various content types + messages = [ + {"role": "user", "content": [{"text": "What's the weather?"}]}, + {"role": "assistant", "content": [{"text": "I'll check the weather for you."}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "call-123", "name": "get_weather", "input": {"location": "Seattle"}}} + ], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "call-123", "content": [{"text": "Sunny, 72°F"}]}}], + }, + {"role": "assistant", "content": [{"text": "It's sunny and 72 degrees."}]}, + ] + + # Start connection with message history + await model.start(messages=messages) + + # Get all sent events + calls = mock_ws.send.call_args_list + sent_events = [json.loads(call[0][0]) for call in calls] + + # Filter conversation.item.create events + item_creates = [e for e in sent_events if e.get("type") == "conversation.item.create"] + + # Should have 5 items: 2 messages, 1 function_call, 1 function_call_output, 1 message + assert len(item_creates) >= 5 + + # Verify message items + message_items = [e for e in item_creates if e.get("item", {}).get("type") == "message"] + assert len(message_items) >= 3 + + # Verify first user message + user_msg = message_items[0] + assert user_msg["item"]["role"] == "user" + assert user_msg["item"]["content"][0]["text"] == "What's the weather?" + + # Verify function call item + function_call_items = [e for e in item_creates if e.get("item", {}).get("type") == "function_call"] + assert len(function_call_items) >= 1 + func_call = function_call_items[0] + assert func_call["item"]["call_id"] == "call-123" + assert func_call["item"]["name"] == "get_weather" + assert json.loads(func_call["item"]["arguments"]) == {"location": "Seattle"} + + # Verify function call output item + function_output_items = [e for e in item_creates if e.get("item", {}).get("type") == "function_call_output"] + assert len(function_output_items) >= 1 + func_output = function_output_items[0] + assert func_output["item"]["call_id"] == "call-123" + # Content is now preserved as JSON array + output = json.loads(func_output["item"]["output"]) + assert output == [{"text": "Sunny, 72°F"}] + + await model.stop() + + +@pytest.mark.asyncio +async def test_connection_edge_cases(mock_websockets_connect, api_key, model_name): + """Test connection error handling and edge cases.""" + mock_connect, mock_ws = mock_websockets_connect + + # Test connection error + model1 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) + mock_connect.side_effect = Exception("Connection failed") + with pytest.raises(Exception, match="Connection failed"): + await model1.start() + + # Reset mock + async def async_connect(*args, **kwargs): + return mock_ws + + mock_connect.side_effect = async_connect + + # Test double connection + model2 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) + await model2.start() + with pytest.raises(RuntimeError, match=r"call stop before starting again"): + await model2.start() + await model2.stop() + + # Test close when not connected + model3 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) + await model3.stop() # Should not raise + + # Test close error + model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) + await model4.start() + mock_ws.close.side_effect = Exception("Close failed") + with pytest.raises(ExceptionGroup): + await model4.stop() + + +# Send Method Tests + + +@pytest.mark.asyncio +async def test_send_all_content_types(mock_websockets_connect, model): + """Test sending all content types through unified send() method.""" + _, mock_ws = mock_websockets_connect + await model.start() + + # Test text input + text_input = BidiTextInputEvent(text="Hello", role="user") + await model.send(text_input) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + response_create = [m for m in messages if m.get("type") == "response.create"] + assert len(item_create) > 0 + assert len(response_create) > 0 + + # Test audio input (base64 encoded) + audio_b64 = base64.b64encode(b"audio_bytes").decode("utf-8") + audio_input = BidiAudioInputEvent( + audio=audio_b64, + format="pcm", + sample_rate=24000, + channels=1, + ) + await model.send(audio_input) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + audio_append = [m for m in messages if m.get("type") == "input_audio_buffer.append"] + assert len(audio_append) > 0 + assert "audio" in audio_append[0] + # Audio should be passed through as base64 + assert audio_append[0]["audio"] == audio_b64 + + # Test tool result with text content + tool_result: ToolResult = { + "toolUseId": "tool-123", + "status": "success", + "content": [{"text": "Result: 42"}], + } + await model.send(ToolResultEvent(tool_result)) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + assert len(item_create) > 0 + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-123" + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"text": "Result: 42"}] + + # Test tool result with JSON content + tool_result_json: ToolResult = { + "toolUseId": "tool-456", + "status": "success", + "content": [{"json": {"result": 42, "status": "ok"}}], + } + await model.send(ToolResultEvent(tool_result_json)) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-456" + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"json": {"result": 42, "status": "ok"}}] + + # Test tool result with multiple content blocks + tool_result_multi: ToolResult = { + "toolUseId": "tool-789", + "status": "success", + "content": [{"text": "Part 1"}, {"json": {"data": "value"}}, {"text": "Part 2"}], + } + await model.send(ToolResultEvent(tool_result_multi)) + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "tool-789" + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"text": "Part 1"}, {"json": {"data": "value"}}, {"text": "Part 2"}] + + # Test tool result with image content (should raise error) + tool_result_image: ToolResult = { + "toolUseId": "tool-999", + "status": "success", + "content": [{"image": {"format": "jpeg", "source": {"bytes": b"image_data"}}}], + } + with pytest.raises(ValueError, match=r"Content type not supported by OpenAI Realtime API"): + await model.send(ToolResultEvent(tool_result_image)) + + # Test tool result with document content (should raise error) + tool_result_doc: ToolResult = { + "toolUseId": "tool-888", + "status": "success", + "content": [{"document": {"format": "pdf", "source": {"bytes": b"doc_data"}}}], + } + with pytest.raises(ValueError, match=r"Content type not supported by OpenAI Realtime API"): + await model.send(ToolResultEvent(tool_result_doc)) + + await model.stop() + + +@pytest.mark.asyncio +async def test_send_edge_cases(mock_websockets_connect, model): + """Test send() edge cases and error handling.""" + _, mock_ws = mock_websockets_connect + + # Test send when inactive + text_input = BidiTextInputEvent(text="Hello", role="user") + with pytest.raises(RuntimeError, match=r"call start before sending"): + await model.send(text_input) + mock_ws.send.assert_not_called() + + # Test image input (not supported, base64 encoded, no encoding parameter) + await model.start() + image_b64 = base64.b64encode(b"image_bytes").decode("utf-8") + image_input = BidiImageInputEvent( + image=image_b64, + mime_type="image/jpeg", + ) + with pytest.raises(ValueError, match=r"content not supported"): + await model.send(image_input) + + await model.stop() + + +# Receive Method Tests + + +@pytest.mark.asyncio +async def test_receive_lifecycle_events(mock_websocket, model): + audio_message = '{"type": "response.output_audio.delta", "delta": ""}' + mock_websocket.recv.return_value = audio_message + + await model.start() + model._connection_id = "c1" + + tru_events = [] + async for event in model.receive(): + tru_events.append(event) + if len(tru_events) >= 2: + break + + exp_events = [ + BidiConnectionStartEvent(connection_id="c1", model="gpt-realtime"), + BidiAudioStreamEvent( + audio="", + format="pcm", + sample_rate=24000, + channels=1, + ) + ] + assert tru_events == exp_events + + +@unittest.mock.patch("strands.experimental.bidi.models.openai_realtime.time.time") +@pytest.mark.asyncio +async def test_receive_timeout(mock_time, model): + mock_time.side_effect = [1, 2] + model.timeout_s = 1 + + await model.start() + + with pytest.raises(BidiModelTimeoutError, match=r"timeout_s=<1>"): + async for _ in model.receive(): + pass + + +@pytest.mark.asyncio +async def test_event_conversion(model): + """Test conversion of all OpenAI event types to standard format.""" + await model.start() + + # Test audio output (now returns list with BidiAudioStreamEvent) + audio_event = {"type": "response.output_audio.delta", "delta": base64.b64encode(b"audio_data").decode()} + converted = model._convert_openai_event(audio_event) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], BidiAudioStreamEvent) + assert converted[0].get("type") == "bidi_audio_stream" + assert converted[0].get("audio") == base64.b64encode(b"audio_data").decode() + assert converted[0].get("format") == "pcm" + + # Test text output (now returns list with BidiTranscriptStreamEvent) + text_event = {"type": "response.output_text.delta", "delta": "Hello from OpenAI"} + converted = model._convert_openai_event(text_event) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], BidiTranscriptStreamEvent) + assert converted[0].get("type") == "bidi_transcript_stream" + assert converted[0].get("text") == "Hello from OpenAI" + assert converted[0].get("role") == "assistant" + assert converted[0].delta == {"text": "Hello from OpenAI"} + assert converted[0].is_final is False # Delta events are not final + + # Test function call sequence + item_added = { + "type": "response.output_item.added", + "item": {"type": "function_call", "call_id": "call-123", "name": "calculator"}, + } + model._convert_openai_event(item_added) + + args_delta = { + "type": "response.function_call_arguments.delta", + "call_id": "call-123", + "delta": '{"expression": "2+2"}', + } + model._convert_openai_event(args_delta) + + args_done = {"type": "response.function_call_arguments.done", "call_id": "call-123"} + converted = model._convert_openai_event(args_done) + # Now returns list with ToolUseStreamEvent + assert isinstance(converted, list) + assert len(converted) == 1 + # ToolUseStreamEvent has delta and current_tool_use, not a "type" field + assert "delta" in converted[0] + assert "toolUse" in converted[0]["delta"] + tool_use = converted[0]["delta"]["toolUse"] + assert tool_use["toolUseId"] == "call-123" + assert tool_use["name"] == "calculator" + assert tool_use["input"]["expression"] == "2+2" + + # Test voice activity (now returns list with BidiInterruptionEvent for speech_started) + speech_started = {"type": "input_audio_buffer.speech_started"} + converted = model._convert_openai_event(speech_started) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], BidiInterruptionEvent) + assert converted[0].get("type") == "bidi_interruption" + assert converted[0].get("reason") == "user_speech" + + # Test response.cancelled event (should return ResponseCompleteEvent with interrupted reason) + response_cancelled = {"type": "response.cancelled", "response": {"id": "resp_123"}} + converted = model._convert_openai_event(response_cancelled) + assert isinstance(converted, list) + assert len(converted) == 1 + assert isinstance(converted[0], BidiResponseCompleteEvent) + assert converted[0].get("type") == "bidi_response_complete" + assert converted[0].get("response_id") == "resp_123" + assert converted[0].get("stop_reason") == "interrupted" + + # Test error handling - response_cancel_not_active should be suppressed + error_cancel_not_active = { + "type": "error", + "error": {"code": "response_cancel_not_active", "message": "No active response to cancel"}, + } + converted = model._convert_openai_event(error_cancel_not_active) + assert converted is None # Should be suppressed + + # Test error handling - other errors should be logged but return None + error_other = {"type": "error", "error": {"code": "some_other_error", "message": "Something went wrong"}} + converted = model._convert_openai_event(error_other) + assert converted is None + + await model.stop() + + +# Helper Method Tests + + +def test_config_building(model, system_prompt, tool_spec): + """Test building session config with various options.""" + # Test basic config + config_basic = model._build_session_config(None, None) + assert isinstance(config_basic, dict) + assert "instructions" in config_basic + assert "audio" in config_basic + + # Test with system prompt + config_prompt = model._build_session_config(system_prompt, None) + assert config_prompt["instructions"] == system_prompt + + # Test with tools + config_tools = model._build_session_config(None, [tool_spec]) + assert "tools" in config_tools + assert len(config_tools["tools"]) > 0 + + +def test_tool_conversion(model, tool_spec): + """Test tool conversion to OpenAI format.""" + # Test with tools + openai_tools = model._convert_tools_to_openai_format([tool_spec]) + assert len(openai_tools) == 1 + assert openai_tools[0]["type"] == "function" + assert openai_tools[0]["name"] == "calculator" + assert openai_tools[0]["description"] == "Calculate mathematical expressions" + + # Test empty list + openai_empty = model._convert_tools_to_openai_format([]) + assert openai_empty == [] + + +def test_helper_methods(model): + """Test various helper methods.""" + # Test _create_text_event (now returns BidiTranscriptStreamEvent) + text_event = model._create_text_event("Hello", "user") + assert isinstance(text_event, BidiTranscriptStreamEvent) + assert text_event.get("type") == "bidi_transcript_stream" + assert text_event.get("text") == "Hello" + assert text_event.get("role") == "user" + assert text_event.delta == {"text": "Hello"} + assert text_event.is_final is True # Done events are final + assert text_event.current_transcript == "Hello" + + # Test _create_voice_activity_event (now returns BidiInterruptionEvent for speech_started) + voice_event = model._create_voice_activity_event("speech_started") + assert isinstance(voice_event, BidiInterruptionEvent) + assert voice_event.get("type") == "bidi_interruption" + assert voice_event.get("reason") == "user_speech" + + # Other voice activities return None + assert model._create_voice_activity_event("speech_stopped") is None + + +@pytest.mark.asyncio +async def test_send_event_helper(mock_websockets_connect, model): + """Test _send_event helper method.""" + _, mock_ws = mock_websockets_connect + await model.start() + + test_event = {"type": "test.event", "data": "test"} + await model._send_event(test_event) + + calls = mock_ws.send.call_args_list + last_call = calls[-1] + sent_message = json.loads(last_call[0][0]) + assert sent_message == test_event + + await model.stop() + + +@pytest.mark.asyncio +async def test_custom_audio_sample_rate(mock_websockets_connect, api_key): + """Test that custom audio sample rate from provider_config is used in audio events.""" + _, mock_ws = mock_websockets_connect + + # Create model with custom sample rate + custom_sample_rate = 48000 + provider_config = {"audio": {"output_rate": custom_sample_rate}} + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}, provider_config=provider_config) + + await model.start() + + # Simulate receiving an audio delta event from OpenAI + openai_audio_event = {"type": "response.output_audio.delta", "delta": "base64audiodata"} + + # Convert the event + converted_events = model._convert_openai_event(openai_audio_event) + + # Verify the audio event uses the custom sample rate + assert converted_events is not None + assert len(converted_events) == 1 + audio_event = converted_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.sample_rate == custom_sample_rate + assert audio_event.format == "pcm" + assert audio_event.channels == 1 + + await model.stop() + + +@pytest.mark.asyncio +async def test_default_audio_sample_rate(mock_websockets_connect, api_key): + """Test that default audio sample rate is used when no custom config is provided.""" + _, mock_ws = mock_websockets_connect + + # Create model without custom audio config + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) + + await model.start() + + # Simulate receiving an audio delta event from OpenAI + openai_audio_event = {"type": "response.output_audio.delta", "delta": "base64audiodata"} + + # Convert the event + converted_events = model._convert_openai_event(openai_audio_event) + + # Verify the audio event uses the default sample rate (24000) + assert converted_events is not None + assert len(converted_events) == 1 + audio_event = converted_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.sample_rate == 24000 # Default from DEFAULT_SAMPLE_RATE + assert audio_event.format == "pcm" + assert audio_event.channels == 1 + + await model.stop() + + +@pytest.mark.asyncio +async def test_partial_audio_config(mock_websockets_connect, api_key): + """Test that partial audio config doesn't break and falls back to defaults.""" + _, mock_ws = mock_websockets_connect + + # Create model with partial audio config (missing format.rate) + provider_config = {"audio": {"output": {"voice": "alloy"}}} + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}, provider_config=provider_config) + + await model.start() + + # Simulate receiving an audio delta event from OpenAI + openai_audio_event = {"type": "response.output_audio.delta", "delta": "base64audiodata"} + + # Convert the event + converted_events = model._convert_openai_event(openai_audio_event) + + # Verify the audio event uses the default sample rate + assert converted_events is not None + assert len(converted_events) == 1 + audio_event = converted_events[0] + assert isinstance(audio_event, BidiAudioStreamEvent) + assert audio_event.sample_rate == 24000 # Falls back to default + assert audio_event.format == "pcm" + assert audio_event.channels == 1 + + await model.stop() + + +# Tool Result Content Tests + + +@pytest.mark.asyncio +async def test_tool_result_single_text_content(mock_websockets_connect, api_key): + """Test tool result with single text content block.""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-123", + "status": "success", + "content": [{"text": "Simple text result"}], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the sent event + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + + assert len(item_create) > 0 + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "call-123" + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"text": "Simple text result"}] + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_single_json_content(mock_websockets_connect, api_key): + """Test tool result with single JSON content block.""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-456", + "status": "success", + "content": [{"json": {"temperature": 72, "condition": "sunny"}}], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the sent event + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "call-456" + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [{"json": {"temperature": 72, "condition": "sunny"}}] + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_multiple_content_blocks(mock_websockets_connect, api_key): + """Test tool result with multiple content blocks (text and json).""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-789", + "status": "success", + "content": [ + {"text": "Weather data:"}, + {"json": {"temp": 72, "humidity": 65}}, + {"text": "Forecast: sunny"}, + ], + } + + await model.send(ToolResultEvent(tool_result)) + + # Verify the sent event + calls = mock_ws.send.call_args_list + messages = [json.loads(call[0][0]) for call in calls] + item_create = [m for m in messages if m.get("type") == "conversation.item.create"] + + item = item_create[-1].get("item", {}) + assert item.get("type") == "function_call_output" + assert item.get("call_id") == "call-789" + # Content is now preserved as JSON array + output = json.loads(item.get("output")) + assert output == [ + {"text": "Weather data:"}, + {"json": {"temp": 72, "humidity": 65}}, + {"text": "Forecast: sunny"}, + ] + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_image_content_raises_error(mock_websockets_connect, api_key): + """Test that tool result with image content raises ValueError.""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-999", + "status": "success", + "content": [{"image": {"format": "jpeg", "source": {"bytes": b"fake_image_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by OpenAI Realtime API"): + await model.send(ToolResultEvent(tool_result)) + + await model.stop() + + +@pytest.mark.asyncio +async def test_tool_result_document_content_raises_error(mock_websockets_connect, api_key): + """Test that tool result with document content raises ValueError.""" + _, mock_ws = mock_websockets_connect + model = BidiOpenAIRealtimeModel(client_config={"api_key": api_key}) + await model.start() + + tool_result: ToolResult = { + "toolUseId": "call-888", + "status": "success", + "content": [{"document": {"format": "pdf", "source": {"bytes": b"fake_pdf_data"}}}], + } + + with pytest.raises(ValueError, match=r"Content type not supported by OpenAI Realtime API"): + await model.send(ToolResultEvent(tool_result)) + + await model.stop() diff --git a/tests/strands/experimental/bidi/types/__init__.py b/tests/strands/experimental/bidi/types/__init__.py new file mode 100644 index 000000000..a1330e552 --- /dev/null +++ b/tests/strands/experimental/bidi/types/__init__.py @@ -0,0 +1 @@ +"""Tests for bidirectional streaming types.""" diff --git a/tests/strands/experimental/bidi/types/test_events.py b/tests/strands/experimental/bidi/types/test_events.py new file mode 100644 index 000000000..1e609bd36 --- /dev/null +++ b/tests/strands/experimental/bidi/types/test_events.py @@ -0,0 +1,163 @@ +"""Tests for bidirectional streaming event types. + +This module tests JSON serialization for all bidirectional streaming event types. +""" + +import base64 +import json + +import pytest + +from strands.experimental.bidi.types.events import ( + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiErrorEvent, + BidiImageInputEvent, + BidiInterruptionEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, +) + + +@pytest.mark.parametrize( + "event_class,kwargs,expected_type", + [ + # Input events + (BidiTextInputEvent, {"text": "Hello", "role": "user"}, "bidi_text_input"), + ( + BidiAudioInputEvent, + { + "audio": base64.b64encode(b"audio").decode("utf-8"), + "format": "pcm", + "sample_rate": 16000, + "channels": 1, + }, + "bidi_audio_input", + ), + ( + BidiImageInputEvent, + {"image": base64.b64encode(b"image").decode("utf-8"), "mime_type": "image/jpeg"}, + "bidi_image_input", + ), + # Output events + ( + BidiConnectionStartEvent, + {"connection_id": "c1", "model": "m1"}, + "bidi_connection_start", + ), + (BidiResponseStartEvent, {"response_id": "r1"}, "bidi_response_start"), + ( + BidiAudioStreamEvent, + { + "audio": base64.b64encode(b"audio").decode("utf-8"), + "format": "pcm", + "sample_rate": 24000, + "channels": 1, + }, + "bidi_audio_stream", + ), + ( + BidiTranscriptStreamEvent, + { + "delta": {"text": "Hello"}, + "text": "Hello", + "role": "assistant", + "is_final": True, + "current_transcript": "Hello", + }, + "bidi_transcript_stream", + ), + (BidiInterruptionEvent, {"reason": "user_speech"}, "bidi_interruption"), + ( + BidiResponseCompleteEvent, + {"response_id": "r1", "stop_reason": "complete"}, + "bidi_response_complete", + ), + ( + BidiUsageEvent, + {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, + "bidi_usage", + ), + ( + BidiConnectionCloseEvent, + {"connection_id": "c1", "reason": "complete"}, + "bidi_connection_close", + ), + (BidiErrorEvent, {"error": ValueError("test"), "details": None}, "bidi_error"), + ], +) +def test_event_json_serialization(event_class, kwargs, expected_type): + """Test that all event types are JSON serializable and deserializable.""" + # Create event + event = event_class(**kwargs) + + # Verify type field + assert event["type"] == expected_type + + # Serialize to JSON + json_str = json.dumps(event) + print("event_class:", event_class) + print(json_str) + # Deserialize back + data = json.loads(json_str) + + # Verify type preserved + assert data["type"] == expected_type + + # Verify all non-private keys preserved + for key in event.keys(): + if not key.startswith("_"): + assert key in data + + +def test_transcript_stream_event_delta_pattern(): + """Test that BidiTranscriptStreamEvent follows ModelStreamEvent delta pattern.""" + # Test partial transcript (delta) + partial_event = BidiTranscriptStreamEvent( + delta={"text": "Hello"}, + text="Hello", + role="user", + is_final=False, + current_transcript=None, + ) + + assert partial_event.text == "Hello" + assert partial_event.role == "user" + assert partial_event.is_final is False + assert partial_event.current_transcript is None + assert partial_event.delta == {"text": "Hello"} + + # Test final transcript with accumulated text + final_event = BidiTranscriptStreamEvent( + delta={"text": " world"}, + text=" world", + role="user", + is_final=True, + current_transcript="Hello world", + ) + + assert final_event.text == " world" + assert final_event.role == "user" + assert final_event.is_final is True + assert final_event.current_transcript == "Hello world" + assert final_event.delta == {"text": " world"} + + +def test_transcript_stream_event_extends_model_stream_event(): + """Test that BidiTranscriptStreamEvent is a ModelStreamEvent.""" + from strands.types._events import ModelStreamEvent + + event = BidiTranscriptStreamEvent( + delta={"text": "test"}, + text="test", + role="assistant", + is_final=True, + current_transcript="test", + ) + + assert isinstance(event, ModelStreamEvent) diff --git a/tests/strands/experimental/hooks/test_bidi_hook_events.py b/tests/strands/experimental/hooks/test_bidi_hook_events.py new file mode 100644 index 000000000..4d49243b2 --- /dev/null +++ b/tests/strands/experimental/hooks/test_bidi_hook_events.py @@ -0,0 +1,169 @@ +"""Unit tests for BidiAgent hook events.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.hooks import ( + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, +) +from strands.types.tools import ToolResult, ToolUse + + +@pytest.fixture +def agent(): + return Mock() + + +@pytest.fixture +def tool(): + tool = Mock() + tool.tool_name = "test_tool" + return tool + + +@pytest.fixture +def tool_use(): + return ToolUse(name="test_tool", toolUseId="123", input={"param": "value"}) + + +@pytest.fixture +def tool_invocation_state(): + return {"param": "value"} + + +@pytest.fixture +def tool_result(): + return ToolResult(content=[{"text": "result"}], status="success", toolUseId="123") + + +@pytest.fixture +def message(): + return {"role": "user", "content": [{"text": "Hello"}]} + + +@pytest.fixture +def initialized_event(agent): + return BidiAgentInitializedEvent(agent=agent) + + +@pytest.fixture +def before_invocation_event(agent): + return BidiBeforeInvocationEvent(agent=agent) + + +@pytest.fixture +def after_invocation_event(agent): + return BidiAfterInvocationEvent(agent=agent) + + +@pytest.fixture +def message_added_event(agent, message): + return BidiMessageAddedEvent(agent=agent, message=message) + + +@pytest.fixture +def before_tool_event(agent, tool, tool_use, tool_invocation_state): + return BidiBeforeToolCallEvent( + agent=agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=tool_invocation_state, + ) + + +@pytest.fixture +def after_tool_event(agent, tool, tool_use, tool_invocation_state, tool_result): + return BidiAfterToolCallEvent( + agent=agent, + selected_tool=tool, + tool_use=tool_use, + invocation_state=tool_invocation_state, + result=tool_result, + ) + + +@pytest.fixture +def interruption_event(agent): + return BidiInterruptionEvent(agent=agent, reason="user_speech") + + +def test_event_should_reverse_callbacks( + initialized_event, + before_invocation_event, + after_invocation_event, + message_added_event, + before_tool_event, + after_tool_event, + interruption_event, +): + """Verify which events use reverse callback ordering.""" + # note that we ignore E712 (explicit booleans) for consistency/readability purposes + + assert initialized_event.should_reverse_callbacks == False # noqa: E712 + assert message_added_event.should_reverse_callbacks == False # noqa: E712 + assert interruption_event.should_reverse_callbacks == False # noqa: E712 + + assert before_invocation_event.should_reverse_callbacks == False # noqa: E712 + assert after_invocation_event.should_reverse_callbacks == True # noqa: E712 + + assert before_tool_event.should_reverse_callbacks == False # noqa: E712 + assert after_tool_event.should_reverse_callbacks == True # noqa: E712 + + +def test_interruption_event_with_response_id(agent): + """Verify BidiInterruptionEvent can include response ID.""" + event = BidiInterruptionEvent(agent=agent, reason="error", interrupted_response_id="resp_123") + + assert event.reason == "error" + assert event.interrupted_response_id == "resp_123" + + +def test_message_added_event_cannot_write_properties(message_added_event): + """Verify BidiMessageAddedEvent properties are read-only.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + message_added_event.agent = Mock() + with pytest.raises(AttributeError, match="Property message is not writable"): + message_added_event.message = {} + + +def test_before_tool_call_event_can_write_properties(before_tool_event): + """Verify BidiBeforeToolCallEvent allows writing specific properties.""" + new_tool_use = ToolUse(name="new_tool", toolUseId="456", input={}) + before_tool_event.selected_tool = None # Should not raise + before_tool_event.tool_use = new_tool_use # Should not raise + before_tool_event.cancel_tool = "Cancelled by user" # Should not raise + + +def test_before_tool_call_event_cannot_write_properties(before_tool_event): + """Verify BidiBeforeToolCallEvent protects certain properties.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + before_tool_event.agent = Mock() + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + before_tool_event.invocation_state = {} + + +def test_after_tool_call_event_can_write_properties(after_tool_event): + """Verify BidiAfterToolCallEvent allows writing result property.""" + new_result = ToolResult(content=[{"text": "new result"}], status="success", toolUseId="456") + after_tool_event.result = new_result # Should not raise + + +def test_after_tool_call_event_cannot_write_properties(after_tool_event): + """Verify BidiAfterToolCallEvent protects certain properties.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + after_tool_event.agent = Mock() + with pytest.raises(AttributeError, match="Property selected_tool is not writable"): + after_tool_event.selected_tool = None + with pytest.raises(AttributeError, match="Property tool_use is not writable"): + after_tool_event.tool_use = ToolUse(name="new", toolUseId="456", input={}) + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + after_tool_event.invocation_state = {} + with pytest.raises(AttributeError, match="Property exception is not writable"): + after_tool_event.exception = Exception("test") diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index 6744aa00c..f4899f2ab 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -123,7 +123,7 @@ def test_deprecation_warning_on_import(captured_warnings): assert len(captured_warnings) == 1 assert issubclass(captured_warnings[0].category, DeprecationWarning) - assert "moved to production with updated names" in str(captured_warnings[0].message) + assert "are no longer experimental" in str(captured_warnings[0].message) def test_deprecation_warning_on_import_only_for_experimental(captured_warnings): diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 451d0dd09..0b5623ae0 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -7,6 +7,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.agent.state import AgentState from strands.interrupt import _InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock @@ -413,3 +414,141 @@ def test_fix_broken_tool_use_does_not_change_valid_message(session_manager): # Should remain unchanged since toolUse is in last message assert fixed_messages == messages + + +# ============================================================================ +# BidiAgent Session Tests +# ============================================================================ + + +@pytest.fixture +def mock_bidi_agent(): + """Create a mock BidiAgent for testing.""" + agent = Mock() + agent.agent_id = "bidi-agent-1" + agent.messages = [{"role": "user", "content": [{"text": "Hello from bidi!"}]}] + agent.state = AgentState({"key": "value"}) + # BidiAgent doesn't have _interrupt_state yet + return agent + + +def test_initialize_bidi_agent_creates_new(session_manager, mock_bidi_agent): + """Test initializing a new BidiAgent creates session data.""" + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Verify agent created in repository + agent_data = session_manager.session_repository.read_agent("test-session", "bidi-agent-1") + assert agent_data is not None + assert agent_data.agent_id == "bidi-agent-1" + assert agent_data.conversation_manager_state == {} # Empty for BidiAgent + assert agent_data.state == {"key": "value"} + + # Verify message created + messages = session_manager.session_repository.list_messages("test-session", "bidi-agent-1") + assert len(messages) == 1 + assert messages[0].message["role"] == "user" + + +def test_initialize_bidi_agent_restores_existing(session_manager, mock_bidi_agent): + """Test initializing BidiAgent restores from existing session.""" + # Create existing session data + session_agent = SessionAgent( + agent_id="bidi-agent-1", + state={"restored": "state"}, + conversation_manager_state={}, # Empty for BidiAgent + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + # Add messages + msg1 = SessionMessage.from_message({"role": "user", "content": [{"text": "Message 1"}]}, 0) + msg2 = SessionMessage.from_message({"role": "assistant", "content": [{"text": "Response 1"}]}, 1) + session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg1) + session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg2) + + # Initialize agent + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Verify state restored + assert mock_bidi_agent.state.get() == {"restored": "state"} + + # Verify messages restored + assert len(mock_bidi_agent.messages) == 2 + assert mock_bidi_agent.messages[0]["role"] == "user" + assert mock_bidi_agent.messages[1]["role"] == "assistant" + + +def test_append_bidi_message(session_manager, mock_bidi_agent): + """Test appending messages to BidiAgent session.""" + # Initialize agent first + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Append new message + new_message = {"role": "assistant", "content": [{"text": "Response"}]} + session_manager.append_bidi_message(new_message, mock_bidi_agent) + + # Verify message stored + messages = session_manager.session_repository.list_messages("test-session", "bidi-agent-1") + assert len(messages) == 2 # Initial + new + assert messages[1].message["role"] == "assistant" + + +def test_sync_bidi_agent(session_manager, mock_bidi_agent): + """Test syncing BidiAgent state to session.""" + # Initialize agent + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Update agent state + mock_bidi_agent.state = AgentState({"updated": "state"}) + + # Sync agent + session_manager.sync_bidi_agent(mock_bidi_agent) + + # Verify state updated in repository + agent_data = session_manager.session_repository.read_agent("test-session", "bidi-agent-1") + assert agent_data.state == {"updated": "state"} + + +def test_bidi_agent_no_conversation_manager(session_manager, mock_bidi_agent): + """Test that BidiAgent session doesn't use conversation_manager.""" + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Verify conversation_manager_state is empty + agent_data = session_manager.session_repository.read_agent("test-session", "bidi-agent-1") + assert agent_data.conversation_manager_state == {} + + +def test_bidi_agent_unique_id_constraint(session_manager, mock_bidi_agent): + """Test that BidiAgent agent_id must be unique in session.""" + # Initialize first agent + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Try to initialize another agent with same ID + agent2 = Mock() + agent2.agent_id = "bidi-agent-1" # Same ID + agent2.messages = [] + agent2.state = AgentState({}) + + with pytest.raises(SessionException, match="The `agent_id` of an agent must be unique in a session."): + session_manager.initialize_bidi_agent(agent2) + + +def test_bidi_agent_messages_with_offset_zero(session_manager, mock_bidi_agent): + """Test that BidiAgent uses offset=0 for message restoration (no conversation_manager).""" + # Create session with messages + session_agent = SessionAgent( + agent_id="bidi-agent-1", + state={}, + conversation_manager_state={}, + ) + session_manager.session_repository.create_agent("test-session", session_agent) + + # Add 5 messages + for i in range(5): + msg = SessionMessage.from_message({"role": "user", "content": [{"text": f"Message {i}"}]}, i) + session_manager.session_repository.create_message("test-session", "bidi-agent-1", msg) + + # Initialize agent + session_manager.initialize_bidi_agent(mock_bidi_agent) + + # Verify all messages restored (offset=0, no removed_message_count) + assert len(mock_bidi_agent.messages) == 5 diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index 5984e33ab..ad92ba603 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,6 +4,7 @@ import pytest import strands +from strands import Agent from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry from strands.interrupt import _InterruptState from strands.tools.registry import ToolRegistry @@ -102,6 +103,7 @@ def tool_registry(weather_tool, temperature_tool, exception_tool, thread_tool, i @pytest.fixture def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() + mock_agent.__class__ = Agent mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry mock_agent._interrupt_state = _InterruptState() diff --git a/tests_integ/bidi/__init__.py b/tests_integ/bidi/__init__.py new file mode 100644 index 000000000..05da9afcb --- /dev/null +++ b/tests_integ/bidi/__init__.py @@ -0,0 +1 @@ +"""Integration tests for bidirectional streaming agents.""" diff --git a/tests_integ/bidi/conftest.py b/tests_integ/bidi/conftest.py new file mode 100644 index 000000000..0d453818a --- /dev/null +++ b/tests_integ/bidi/conftest.py @@ -0,0 +1,28 @@ +"""Pytest fixtures for bidirectional streaming integration tests.""" + +import logging + +import pytest + +from .generators.audio import AudioGenerator + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def audio_generator(): + """Provide AudioGenerator instance for tests.""" + return AudioGenerator(region="us-east-1") + + +@pytest.fixture(autouse=True) +def setup_logging(): + """Configure logging for tests.""" + logging.basicConfig( + level=logging.DEBUG, + format="%(levelname)s | %(name)s | %(message)s", + ) + # Reduce noise from some loggers + logging.getLogger("boto3").setLevel(logging.WARNING) + logging.getLogger("botocore").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) diff --git a/tests_integ/bidi/context.py b/tests_integ/bidi/context.py new file mode 100644 index 000000000..f60379b60 --- /dev/null +++ b/tests_integ/bidi/context.py @@ -0,0 +1,369 @@ +"""Test context manager for bidirectional streaming tests. + +Provides a high-level interface for testing bidirectional streaming agents +with continuous background threads that mimic real-world usage patterns. +""" + +import asyncio +import base64 +import logging +import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from strands.experimental.bidi.agent.agent import BidiAgent + + from .generators.audio import AudioGenerator + +logger = logging.getLogger(__name__) + +# Constants for timing and buffering +QUEUE_POLL_TIMEOUT = 0.05 # 50ms - balance between responsiveness and CPU usage +SILENCE_INTERVAL = 0.05 # 50ms - send silence every 50ms when queue empty +AUDIO_CHUNK_DELAY = 0.01 # 10ms - small delay between audio chunks +WAIT_POLL_INTERVAL = 0.1 # 100ms - how often to check for response completion + + +class BidirectionalTestContext: + """Manages threads and generators for bidirectional streaming tests. + + Mimics real-world usage with continuous background threads: + - Audio input thread (microphone simulation with silence padding) + - Event collection thread (captures all model outputs) + + Generators feed data into threads via queues for natural conversation flow. + + Example: + async with BidirectionalTestContext(agent, audio_generator) as ctx: + await ctx.say("What is 5 plus 3?") + await ctx.wait_for_response() + assert "8" in " ".join(ctx.get_text_outputs()) + """ + + def __init__( + self, + agent: "BidiAgent", + audio_generator: "AudioGenerator | None" = None, + silence_chunk_size: int = 1024, + audio_chunk_size: int = 1024, + ): + """Initialize test context. + + Args: + agent: BidiAgent instance. + audio_generator: AudioGenerator for text-to-speech. + silence_chunk_size: Size of silence chunks in bytes. + audio_chunk_size: Size of audio chunks for streaming. + """ + self.agent = agent + self.audio_generator = audio_generator + self.silence_chunk_size = silence_chunk_size + self.audio_chunk_size = audio_chunk_size + + # Queue for thread communication + self.input_queue = asyncio.Queue() # Handles both audio and text input + + # Event storage (thread-safe) + self._event_queue = asyncio.Queue() # Events from collection thread + self.events = [] # Cached events for test access + self.last_event_time = None + + # Control flags + self.active = False + self.threads = [] + + async def __aenter__(self): + """Start context manager, agent session, and background threads.""" + # Start agent session + await self.agent.start() + logger.debug("Agent session started") + + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Stop context manager, cleanup threads, and end agent session.""" + # End agent session FIRST - this will cause receive() to exit cleanly + if self.agent._started: + await self.agent.stop() + logger.debug("Agent session stopped") + + # Then stop the context threads + await self.stop() + + return False + + async def start(self): + """Start all background threads.""" + self.active = True + self.last_event_time = time.monotonic() + + self.threads = [ + asyncio.create_task(self._input_thread()), + asyncio.create_task(self._event_collection_thread()), + ] + + logger.debug("Test context started with %d threads", len(self.threads)) + + async def stop(self): + """Stop all threads gracefully.""" + if not self.active: + logger.debug("stop() called but already stopped") + return + + logger.debug("stop() called - stopping threads") + self.active = False + + # Cancel all threads + for task in self.threads: + if not task.done(): + task.cancel() + + # Wait for cancellation + await asyncio.gather(*self.threads, return_exceptions=True) + + logger.debug("Test context stopped") + + # === User-facing methods === + + async def say(self, text: str): + """Convert text to audio and queue audio chunks to be sent to model. + + Args: + text: Text to convert to speech and send as audio. + + Raises: + ValueError: If audio generator is not available. + """ + if not self.audio_generator: + raise ValueError("Audio generator not available. Pass audio_generator to BidirectionalTestContext.") + + # Generate audio via Polly + audio_data = await self.audio_generator.generate_audio(text) + + # Split into chunks and queue each chunk + for i in range(0, len(audio_data), self.audio_chunk_size): + chunk = audio_data[i : i + self.audio_chunk_size] + chunk_event = self.audio_generator.create_audio_input_event(chunk) + await self.input_queue.put({"type": "audio_chunk", "data": chunk_event}) + + logger.debug("audio_bytes=<%d>, text_preview=<%s> | queued audio for text", len(audio_data), text[:50]) + + async def send(self, data: str | dict) -> None: + """Send data directly to model (text, image, etc.). + + Args: + data: Data to send to model. Can be: + - str: Text input + - dict: Custom event (e.g., image, audio) + """ + await self.input_queue.put({"type": "direct", "data": data}) + logger.debug("data_type=<%s> | queued direct send", type(data).__name__) + + async def wait_for_response( + self, + timeout: float = 15.0, + silence_threshold: float = 2.0, + min_events: int = 1, + ): + """Wait for model to finish responding. + + Uses silence detection (no events for silence_threshold seconds) + combined with minimum event count to determine response completion. + + Args: + timeout: Maximum time to wait in seconds. + silence_threshold: Seconds of silence to consider response complete. + min_events: Minimum events before silence detection activates. + """ + start_time = time.monotonic() + initial_event_count = len(self.get_events()) # Drain queue + + while time.monotonic() - start_time < timeout: + # Drain queue to get latest events + current_events = self.get_events() + + # Check if we have minimum events + if len(current_events) - initial_event_count >= min_events: + # Check silence + elapsed_since_event = time.monotonic() - self.last_event_time + if elapsed_since_event >= silence_threshold: + logger.debug( + "event_count=<%d>, silence_duration=<%.1f> | response complete", + len(current_events) - initial_event_count, + elapsed_since_event, + ) + return + + await asyncio.sleep(WAIT_POLL_INTERVAL) + + logger.warning("timeout=<%s> | response timeout", timeout) + + def get_events(self, event_type: str | None = None) -> list[dict]: + """Get collected events, optionally filtered by type. + + Drains the event queue and caches events for subsequent calls. + + Args: + event_type: Optional event type to filter by (e.g., "textOutput"). + + Returns: + List of events, filtered if event_type specified. + """ + # Drain queue into cache (non-blocking) + while not self._event_queue.empty(): + try: + event = self._event_queue.get_nowait() + self.events.append(event) + self.last_event_time = time.monotonic() + except asyncio.QueueEmpty: + break + + if event_type: + return [e for e in self.events if event_type in e] + return self.events.copy() + + def get_text_outputs(self) -> list[str]: + """Extract text outputs from collected events. + + Handles both new TypedEvent format and legacy event formats. + + Returns: + List of text content strings. + """ + texts = [] + for event in self.get_events(): # Drain queue first + # Handle new TypedEvent format (bidi_transcript_stream) + if event.get("type") == "bidi_transcript_stream": + text = event.get("text", "") + if text: + texts.append(text) + # Handle legacy textOutput events (Nova Sonic, OpenAI) + elif "textOutput" in event: + text = event["textOutput"].get("text", "") + if text: + texts.append(text) + # Handle legacy transcript events (Gemini Live) + elif "transcript" in event: + text = event["transcript"].get("text", "") + if text: + texts.append(text) + return texts + + def get_audio_outputs(self) -> list[bytes]: + """Extract audio outputs from collected events. + + Returns: + List of audio data bytes. + """ + # Drain queue first to get latest events + events = self.get_events() + audio_data = [] + for event in events: + # Handle new TypedEvent format (bidi_audio_stream) + if event.get("type") == "bidi_audio_stream": + audio_b64 = event.get("audio") + if audio_b64: + # Decode base64 to bytes + audio_data.append(base64.b64decode(audio_b64)) + # Handle legacy audioOutput events + elif "audioOutput" in event: + data = event["audioOutput"].get("audioData") + if data: + audio_data.append(data) + return audio_data + + def get_tool_uses(self) -> list[dict]: + """Extract tool use events from collected events. + + Returns: + List of tool use events. + """ + # Drain queue first to get latest events + events = self.get_events() + return [event["toolUse"] for event in events if "toolUse" in event] + + def has_interruption(self) -> bool: + """Check if any interruption was detected. + + Returns: + True if interruption detected in events. + """ + return any("interruptionDetected" in event for event in self.events) + + def clear_events(self): + """Clear collected events (useful for multi-turn tests).""" + self.events.clear() + logger.debug("Events cleared") + + # === Background threads === + + async def _input_thread(self): + """Continuously handle input to model. + + - Sends queued audio chunks immediately + - Sends silence chunks periodically when queue is empty (simulates microphone) + - Sends direct data to model + """ + try: + logger.debug("active=<%s> | input thread starting", self.active) + while self.active: + try: + # Check for queued input (non-blocking with short timeout) + input_item = await asyncio.wait_for(self.input_queue.get(), timeout=QUEUE_POLL_TIMEOUT) + + if input_item["type"] == "audio_chunk": + # Send pre-generated audio chunk + await self.agent.send(input_item["data"]) + await asyncio.sleep(AUDIO_CHUNK_DELAY) + + elif input_item["type"] == "direct": + # Send data directly to agent + await self.agent.send(input_item["data"]) + data_repr = ( + str(input_item["data"])[:50] + if isinstance(input_item["data"], str) + else type(input_item["data"]).__name__ + ) + logger.debug("data=<%s> | sent direct data", data_repr) + + except asyncio.TimeoutError: + # No input queued - send silence chunk to simulate continuous microphone input + if self.audio_generator: + silence = self._generate_silence_chunk() + await self.agent.send(silence) + await asyncio.sleep(SILENCE_INTERVAL) + + except asyncio.CancelledError: + logger.debug("Input thread cancelled") + raise # Re-raise to properly propagate cancellation + except Exception as e: + logger.exception("error=<%s> | input thread error", e) + finally: + logger.debug("active=<%s> | input thread stopped", self.active) + + async def _event_collection_thread(self): + """Continuously collect events from model.""" + try: + async for event in self.agent.receive(): + if not self.active: + break + + # Thread-safe: put in queue instead of direct append + await self._event_queue.put(event) + logger.debug("event_type=<%s> | event collected", event.get("type", "unknown")) + + except asyncio.CancelledError: + logger.debug("Event collection thread cancelled") + raise # Re-raise to properly propagate cancellation + except Exception as e: + logger.error("error=<%s> | event collection thread error", e) + + def _generate_silence_chunk(self) -> dict: + """Generate silence chunk for background audio. + + Returns: + BidiAudioInputEvent with silence data. + """ + silence = b"\x00" * self.silence_chunk_size + return self.audio_generator.create_audio_input_event(silence) diff --git a/tests_integ/bidi/generators/__init__.py b/tests_integ/bidi/generators/__init__.py new file mode 100644 index 000000000..1f13f0564 --- /dev/null +++ b/tests_integ/bidi/generators/__init__.py @@ -0,0 +1 @@ +"""Test data generators for bidirectional streaming integration tests.""" diff --git a/tests_integ/bidi/generators/audio.py b/tests_integ/bidi/generators/audio.py new file mode 100644 index 000000000..4598817fd --- /dev/null +++ b/tests_integ/bidi/generators/audio.py @@ -0,0 +1,159 @@ +"""Audio generation utilities using Amazon Polly for test audio input. + +Provides text-to-speech conversion for generating realistic audio test data +without requiring physical audio devices or pre-recorded files. +""" + +import base64 +import hashlib +import logging +from pathlib import Path +from typing import Literal + +import boto3 + +logger = logging.getLogger(__name__) + +# Audio format constants matching Nova Sonic requirements +NOVA_SONIC_SAMPLE_RATE = 16000 +NOVA_SONIC_CHANNELS = 1 +NOVA_SONIC_FORMAT = "pcm" + +# Polly configuration +POLLY_VOICE_ID = "Matthew" # US English male voice +POLLY_ENGINE = "neural" # Higher quality neural engine + +# Cache directory for generated audio +CACHE_DIR = Path(__file__).parent.parent / ".audio_cache" + + +class AudioGenerator: + """Generate test audio using Amazon Polly with caching.""" + + def __init__(self, region: str = "us-east-1"): + """Initialize audio generator with Polly client. + + Args: + region: AWS region for Polly service. + """ + self.polly_client = boto3.client("polly", region_name=region) + self._ensure_cache_dir() + + def _ensure_cache_dir(self) -> None: + """Create cache directory if it doesn't exist.""" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + + def _get_cache_key(self, text: str, voice_id: str) -> str: + """Generate cache key from text and voice.""" + content = f"{text}:{voice_id}".encode("utf-8") + return hashlib.md5(content).hexdigest() + + def _get_cache_path(self, cache_key: str) -> Path: + """Get cache file path for given key.""" + return CACHE_DIR / f"{cache_key}.pcm" + + async def generate_audio( + self, + text: str, + voice_id: str = POLLY_VOICE_ID, + use_cache: bool = True, + ) -> bytes: + """Generate audio from text using Polly with caching. + + Args: + text: Text to convert to speech. + voice_id: Polly voice ID to use. + use_cache: Whether to use cached audio if available. + + Returns: + Raw PCM audio bytes at 16kHz mono (Nova Sonic format). + """ + # Check cache first + if use_cache: + cache_key = self._get_cache_key(text, voice_id) + cache_path = self._get_cache_path(cache_key) + + if cache_path.exists(): + logger.debug("text_preview=<%s> | using cached audio", text[:50]) + return cache_path.read_bytes() + + # Generate audio with Polly + logger.debug("text_preview=<%s> | generating audio with polly", text[:50]) + + try: + response = self.polly_client.synthesize_speech( + Text=text, + OutputFormat="pcm", # Raw PCM format + VoiceId=voice_id, + Engine=POLLY_ENGINE, + SampleRate=str(NOVA_SONIC_SAMPLE_RATE), + ) + + # Read audio data + audio_data = response["AudioStream"].read() + + # Cache for future use + if use_cache: + cache_path.write_bytes(audio_data) + logger.debug("cache_path=<%s> | cached audio", cache_path) + + return audio_data + + except Exception as e: + logger.error("error=<%s> | polly audio generation failed", e) + raise + + def create_audio_input_event( + self, + audio_data: bytes, + format: Literal["pcm", "wav", "opus", "mp3"] = NOVA_SONIC_FORMAT, + sample_rate: int = NOVA_SONIC_SAMPLE_RATE, + channels: int = NOVA_SONIC_CHANNELS, + ) -> dict: + """Create BidiAudioInputEvent from raw audio data. + + Args: + audio_data: Raw audio bytes. + format: Audio format. + sample_rate: Sample rate in Hz. + channels: Number of audio channels. + + Returns: + BidiAudioInputEvent dict ready for agent.send(). + """ + # Convert bytes to base64 string for JSON compatibility + audio_b64 = base64.b64encode(audio_data).decode("utf-8") + + return { + "type": "bidi_audio_input", + "audio": audio_b64, + "format": format, + "sample_rate": sample_rate, + "channels": channels, + } + + def clear_cache(self) -> None: + """Clear all cached audio files.""" + if CACHE_DIR.exists(): + for cache_file in CACHE_DIR.glob("*.pcm"): + cache_file.unlink() + logger.info("Audio cache cleared") + + +# Convenience function for quick audio generation +async def generate_test_audio(text: str, use_cache: bool = True) -> dict: + """Generate test audio input event from text. + + Convenience function that creates an AudioGenerator and returns + a ready-to-use BidiAudioInputEvent. + + Args: + text: Text to convert to speech. + use_cache: Whether to use cached audio. + + Returns: + BidiAudioInputEvent dict ready for agent.send(). + """ + generator = AudioGenerator() + audio_data = await generator.generate_audio(text, use_cache=use_cache) + return generator.create_audio_input_event(audio_data) diff --git a/tests_integ/bidi/hook_utils.py b/tests_integ/bidi/hook_utils.py new file mode 100644 index 000000000..ea51a029e --- /dev/null +++ b/tests_integ/bidi/hook_utils.py @@ -0,0 +1,76 @@ +"""Shared utilities for testing BidiAgent hooks.""" + +from strands.experimental.hooks.events import ( + BidiAfterInvocationEvent, + BidiAfterToolCallEvent, + BidiAgentInitializedEvent, + BidiBeforeInvocationEvent, + BidiBeforeToolCallEvent, + BidiInterruptionEvent, + BidiMessageAddedEvent, +) +from strands.hooks import HookProvider + + +class HookEventCollector(HookProvider): + """Hook provider that collects all emitted events for testing.""" + + def __init__(self): + self.events = [] + + def register_hooks(self, registry): + registry.add_callback(BidiAgentInitializedEvent, self.on_initialized) + registry.add_callback(BidiBeforeInvocationEvent, self.on_before_invocation) + registry.add_callback(BidiAfterInvocationEvent, self.on_after_invocation) + registry.add_callback(BidiBeforeToolCallEvent, self.on_before_tool_call) + registry.add_callback(BidiAfterToolCallEvent, self.on_after_tool_call) + registry.add_callback(BidiMessageAddedEvent, self.on_message_added) + registry.add_callback(BidiInterruptionEvent, self.on_interruption) + + def on_initialized(self, event: BidiAgentInitializedEvent): + self.events.append(("initialized", event)) + + def on_before_invocation(self, event: BidiBeforeInvocationEvent): + self.events.append(("before_invocation", event)) + + def on_after_invocation(self, event: BidiAfterInvocationEvent): + self.events.append(("after_invocation", event)) + + def on_before_tool_call(self, event: BidiBeforeToolCallEvent): + self.events.append(("before_tool_call", event)) + + def on_after_tool_call(self, event: BidiAfterToolCallEvent): + self.events.append(("after_tool_call", event)) + + def on_message_added(self, event: BidiMessageAddedEvent): + self.events.append(("message_added", event)) + + def on_interruption(self, event: BidiInterruptionEvent): + self.events.append(("interruption", event)) + + def get_event_types(self): + """Get list of event type names in order.""" + return [event_type for event_type, _ in self.events] + + def get_events_by_type(self, event_type): + """Get all events of a specific type.""" + return [event for et, event in self.events if et == event_type] + + def get_tool_calls(self): + """Get list of tool names that were called.""" + before_calls = self.get_events_by_type("before_tool_call") + return [event.tool_use["name"] for event in before_calls] + + def verify_tool_execution(self): + """Verify that tool execution hooks were properly paired.""" + before_calls = self.get_events_by_type("before_tool_call") + after_calls = self.get_events_by_type("after_tool_call") + + assert len(before_calls) == len(after_calls), "Before and after tool call hooks should be paired" + + before_tools = [event.tool_use["name"] for event in before_calls] + after_tools = [event.tool_use["name"] for event in after_calls] + + assert before_tools == after_tools, "Tool call order should match between before and after hooks" + + return before_tools diff --git a/tests_integ/bidi/test_bidi_hooks.py b/tests_integ/bidi/test_bidi_hooks.py new file mode 100644 index 000000000..cb7def664 --- /dev/null +++ b/tests_integ/bidi/test_bidi_hooks.py @@ -0,0 +1,210 @@ +"""Integration tests for BidiAgent hooks with real model providers.""" + +import pytest + +from strands import tool +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.hooks.events import ( + BidiAfterInvocationEvent, + BidiBeforeInvocationEvent, +) +from strands.hooks import HookProvider + +from .hook_utils import HookEventCollector + + +@pytest.mark.asyncio +class TestBidiAgentHooksLifecycle: + """Test BidiAgent hook lifecycle events.""" + + async def test_agent_initialization_emits_hook(self): + """Verify agent initialization emits BidiAgentInitializedEvent.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + # Should have emitted initialized event + assert "initialized" in collector.get_event_types() + init_events = collector.get_events_by_type("initialized") + assert len(init_events) == 1 + assert init_events[0].agent == agent + + async def test_session_lifecycle_emits_hooks(self): + """Verify session start/stop emits before/after invocation events.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + # Start session + await agent.start() + + # Should have emitted before_invocation + assert "before_invocation" in collector.get_event_types() + + # Stop session + await agent.stop() + + # Should have emitted after_invocation + assert "after_invocation" in collector.get_event_types() + + # Verify order: initialized -> before_invocation -> after_invocation + event_types = collector.get_event_types() + assert event_types.index("initialized") < event_types.index("before_invocation") + assert event_types.index("before_invocation") < event_types.index("after_invocation") + + async def test_message_added_hook_on_text_input(self): + """Verify sending text emits BidiMessageAddedEvent.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + await agent.start() + + # Send text message + await agent.send("Hello, agent!") + + await agent.stop() + + # Should have emitted message_added event + message_events = collector.get_events_by_type("message_added") + assert len(message_events) >= 1 + + # Find the user message event + user_messages = [e for e in message_events if e.message["role"] == "user"] + assert len(user_messages) >= 1 + assert user_messages[0].message["content"][0]["text"] == "Hello, agent!" + + +@pytest.mark.asyncio +class TestBidiAgentHooksWithTools: + """Test BidiAgent hook events with tool execution.""" + + async def test_tool_call_hooks_emitted(self): + """Verify tool execution emits before/after tool call events.""" + + @tool + def test_calculator(expression: str) -> str: + """Calculate a math expression.""" + return f"Result: {eval(expression)}" + + collector = HookEventCollector() + agent = BidiAgent(tools=[test_calculator], hooks=[collector]) + + # Note: This test verifies hook infrastructure is in place + # Actual tool execution would require model interaction + # which is tested in full integration tests + + # Verify hooks are registered + assert agent.hooks.has_callbacks() + + # Verify tool is registered + assert "test_calculator" in agent.tool_names + + +@pytest.mark.asyncio +class TestBidiAgentHooksEventData: + """Test BidiAgent hook event data integrity.""" + + async def test_hook_events_contain_agent_reference(self): + """Verify all hook events contain correct agent reference.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + await agent.start() + await agent.send("Test message") + await agent.stop() + + # All events should reference the same agent + for _, event in collector.events: + assert hasattr(event, "agent") + assert event.agent == agent + + async def test_message_added_event_contains_message(self): + """Verify BidiMessageAddedEvent contains the actual message.""" + collector = HookEventCollector() + agent = BidiAgent(hooks=[collector]) + + await agent.start() + test_text = "Test message content" + await agent.send(test_text) + await agent.stop() + + # Find message_added events + message_events = collector.get_events_by_type("message_added") + assert len(message_events) >= 1 + + # Verify message content + user_messages = [e for e in message_events if e.message["role"] == "user"] + assert len(user_messages) >= 1 + assert user_messages[0].message["content"][0]["text"] == test_text + + +@pytest.mark.asyncio +class TestBidiAgentHooksOrdering: + """Test BidiAgent hook callback ordering.""" + + async def test_multiple_hooks_fire_in_order(self): + """Verify multiple hook providers fire in registration order.""" + call_order = [] + + class FirstHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiBeforeInvocationEvent, lambda e: call_order.append("first")) + + class SecondHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiBeforeInvocationEvent, lambda e: call_order.append("second")) + + class ThirdHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiBeforeInvocationEvent, lambda e: call_order.append("third")) + + agent = BidiAgent(hooks=[FirstHook(), SecondHook(), ThirdHook()]) + + await agent.start() + await agent.stop() + + # Verify order + assert call_order == ["first", "second", "third"] + + async def test_after_invocation_fires_in_reverse_order(self): + """Verify after invocation hooks fire in reverse order (cleanup).""" + call_order = [] + + class FirstHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiAfterInvocationEvent, lambda e: call_order.append("first")) + + class SecondHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiAfterInvocationEvent, lambda e: call_order.append("second")) + + class ThirdHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(BidiAfterInvocationEvent, lambda e: call_order.append("third")) + + agent = BidiAgent(hooks=[FirstHook(), SecondHook(), ThirdHook()]) + + await agent.start() + await agent.stop() + + # Verify reverse order for cleanup + assert call_order == ["third", "second", "first"] + + +@pytest.mark.asyncio +class TestBidiAgentHooksContextManager: + """Test BidiAgent hooks with async context manager.""" + + async def test_hooks_fire_with_context_manager(self): + """Verify hooks fire correctly when using async context manager.""" + collector = HookEventCollector() + + async with BidiAgent(hooks=[collector]) as agent: + await agent.send("Test message") + + # Verify lifecycle events + event_types = collector.get_event_types() + assert "initialized" in event_types + assert "before_invocation" in event_types + assert "after_invocation" in event_types + + # Verify order + assert event_types.index("before_invocation") < event_types.index("after_invocation") diff --git a/tests_integ/bidi/test_bidirectional_agent.py b/tests_integ/bidi/test_bidirectional_agent.py new file mode 100644 index 000000000..61cf78723 --- /dev/null +++ b/tests_integ/bidi/test_bidirectional_agent.py @@ -0,0 +1,246 @@ +"""Parameterized integration tests for bidirectional streaming. + +Tests fundamental functionality across multiple model providers (Nova Sonic, OpenAI, etc.) +including multi-turn conversations, audio I/O, text transcription, and tool execution. + +This demonstrates the provider-agnostic design of the bidirectional streaming system. +""" + +import asyncio +import logging +import os + +import pytest + +from strands import tool +from strands.experimental.bidi.agent.agent import BidiAgent +from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel +from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel + +from .context import BidirectionalTestContext +from .hook_utils import HookEventCollector + +logger = logging.getLogger(__name__) + + +# Simple calculator tool for testing +@tool +def calculator(operation: str, x: float, y: float) -> float: + """Perform basic arithmetic operations. + + Args: + operation: The operation to perform (add, subtract, multiply, divide) + x: First number + y: Second number + + Returns: + Result of the operation + """ + if operation == "add": + return x + y + elif operation == "subtract": + return x - y + elif operation == "multiply": + return x * y + elif operation == "divide": + if y == 0: + raise ValueError("Cannot divide by zero") + return x / y + else: + raise ValueError(f"Unknown operation: {operation}") + + +# Provider configurations +PROVIDER_CONFIGS = { + "nova_sonic": { + "model_class": BidiNovaSonicModel, + "model_kwargs": {"region": "us-east-1"}, + "silence_duration": 2.5, # Nova Sonic needs 2+ seconds of silence + "env_vars": ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], + "skip_reason": "AWS credentials not available", + }, + "openai": { + "model_class": BidiOpenAIRealtimeModel, + "model_kwargs": { + "model": "gpt-4o-realtime-preview-2024-12-17", + "session": { + "output_modalities": ["audio"], # OpenAI only supports audio OR text, not both + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700, + }, + }, + "output": {"format": {"type": "audio/pcm", "rate": 24000}, "voice": "alloy"}, + }, + }, + }, + "silence_duration": 1.0, # OpenAI has faster VAD + "env_vars": ["OPENAI_API_KEY"], + "skip_reason": "OPENAI_API_KEY not available", + }, + "gemini_live": { + "model_class": BidiGeminiLiveModel, + "model_kwargs": { + # Uses default model and config (audio output + transcription enabled) + }, + "silence_duration": 1.5, # Gemini has good VAD, similar to OpenAI + "env_vars": ["GOOGLE_AI_API_KEY"], + "skip_reason": "GOOGLE_AI_API_KEY not available", + }, +} + + +def check_provider_available(provider_name: str) -> tuple[bool, str]: + """Check if a provider's credentials are available. + + Args: + provider_name: Name of the provider to check. + + Returns: + Tuple of (is_available, skip_reason). + """ + config = PROVIDER_CONFIGS[provider_name] + env_vars = config["env_vars"] + + missing_vars = [var for var in env_vars if not os.getenv(var)] + + if missing_vars: + return False, f"{config['skip_reason']}: {', '.join(missing_vars)}" + + return True, "" + + +@pytest.fixture(params=list(PROVIDER_CONFIGS.keys())) +def provider_config(request): + """Provide configuration for each model provider. + + This fixture is parameterized to run tests against all available providers. + """ + provider_name = request.param + config = PROVIDER_CONFIGS[provider_name] + + # Check if provider is available + is_available, skip_reason = check_provider_available(provider_name) + if not is_available: + pytest.skip(skip_reason) + + return { + "name": provider_name, + **config, + } + + +@pytest.fixture +def hook_collector(): + """Provide a hook event collector for tracking all events.""" + return HookEventCollector() + + +@pytest.fixture +def agent_with_calculator(provider_config, hook_collector): + """Provide bidirectional agent with calculator tool for the given provider. + + Note: Session lifecycle (start/end) is handled by BidirectionalTestContext. + """ + model_class = provider_config["model_class"] + model_kwargs = provider_config["model_kwargs"] + + model = model_class(**model_kwargs) + return BidiAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant with access to a calculator tool. Keep responses brief.", + hooks=[hook_collector], + ) + + +@pytest.mark.asyncio +async def test_bidirectional_agent(agent_with_calculator, audio_generator, provider_config, hook_collector): + """Test multi-turn conversation with follow-up questions across providers. + + This test runs against all configured providers (Nova Sonic, OpenAI, etc.) + to validate provider-agnostic functionality. + + Validates: + - Session lifecycle (start/end via context manager) + - Audio input streaming + - Speech-to-text transcription + - Tool execution (calculator) with hook verification + - Multi-turn conversation flow + - Text-to-speech audio output + """ + provider_name = provider_config["name"] + silence_duration = provider_config["silence_duration"] + + logger.info("provider=<%s> | testing provider", provider_name) + + async with BidirectionalTestContext(agent_with_calculator, audio_generator) as ctx: + # Turn 1: Simple greeting to test basic audio I/O + await ctx.say("Hello, can you hear me?") + # Wait for silence to trigger provider's VAD/silence detection + await asyncio.sleep(silence_duration) + await ctx.wait_for_response() + + text_outputs_turn1 = ctx.get_text_outputs() + + # Validate turn 1 - just check we got a response + assert len(text_outputs_turn1) > 0, f"[{provider_name}] No text output received in turn 1" + + logger.info("provider=<%s> | turn 1 complete received response", provider_name) + logger.info("provider=<%s>, response=<%s> | turn 1 response", provider_name, text_outputs_turn1[0][:100]) + + # Turn 2: Follow-up to test multi-turn conversation + await ctx.say("What's your name?") + # Wait for silence to trigger provider's VAD/silence detection + await asyncio.sleep(silence_duration) + await ctx.wait_for_response() + + text_outputs_turn2 = ctx.get_text_outputs() + + # Validate turn 2 - check we got more responses + assert len(text_outputs_turn2) > len(text_outputs_turn1), f"[{provider_name}] No new text output in turn 2" + + logger.info("provider=<%s> | turn 2 complete multi-turn conversation works", provider_name) + logger.info("provider=<%s>, response_count=<%d> | total responses", provider_name, len(text_outputs_turn2)) + + # Validate full conversation + # Validate audio outputs + audio_outputs = ctx.get_audio_outputs() + assert len(audio_outputs) > 0, f"[{provider_name}] No audio output received" + total_audio_bytes = sum(len(audio) for audio in audio_outputs) + + # Verify tool execution hooks if tools were called + tool_calls = hook_collector.get_tool_calls() + if len(tool_calls) > 0: + logger.info("provider=<%s> | tool execution detected", provider_name) + # Verify hooks are properly paired + verified_tools = hook_collector.verify_tool_execution() + logger.info( + "provider=<%s>, tools_called=<%s> | tool execution hooks verified", + provider_name, + verified_tools, + ) + else: + logger.info("provider=<%s> | no tools were called during conversation", provider_name) + + # Summary + logger.info("=" * 60) + logger.info("provider=<%s> | multi-turn conversation test passed", provider_name) + logger.info("provider=<%s> | test summary", provider_name) + logger.info("event_count=<%d> | total events", len(ctx.get_events())) + logger.info("text_response_count=<%d> | text responses", len(text_outputs_turn2)) + logger.info( + "audio_chunk_count=<%d>, audio_bytes=<%d> | audio chunks", + len(audio_outputs), + total_audio_bytes, + ) + logger.info( + "tool_calls=<%d> | tool execution count", + len(tool_calls), + ) + logger.info("=" * 60) diff --git a/tests_integ/bidi/wrappers/__init__.py b/tests_integ/bidi/wrappers/__init__.py new file mode 100644 index 000000000..6b8a64984 --- /dev/null +++ b/tests_integ/bidi/wrappers/__init__.py @@ -0,0 +1,4 @@ +"""Wrappers for bidirectional streaming integration tests. + +Includes fault injection and other transparent wrappers around real implementations. +"""