diff --git a/contributing/samples/bedrock_wikipedia_agent/README.md b/contributing/samples/bedrock_wikipedia_agent/README.md new file mode 100644 index 0000000..7872294 --- /dev/null +++ b/contributing/samples/bedrock_wikipedia_agent/README.md @@ -0,0 +1,73 @@ +# Bedrock Wikipedia Agent + +A minimal research agent built with **Google ADK** and **Amazon Bedrock** that +answers questions by searching Wikipedia. + +## What it demonstrates + +- Instantiating `BedrockModel` and passing it to an ADK `Agent` +- Defining Python functions as ADK tools (`wikipedia_search`, `wikipedia_get_article`) +- Running an agent turn with `Runner.run_async` +- Streaming and non-streaming response modes + +## Prerequisites + +```bash +pip install "google-adk-community[bedrock]" wikipedia-api +``` + +AWS credentials must be available via the standard boto3 credential chain: + +| Method | How | +|---|---| +| Environment variables | `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY` + `AWS_REGION` | +| Credentials file | `~/.aws/credentials` | +| IAM role | EC2 instance profile, ECS task role, Lambda execution role | + +## Usage + +```bash +# Default question +python agent.py + +# Custom question +python agent.py "Who invented the World Wide Web?" + +# Streaming output +python agent.py --stream "What is quantum computing?" + +# Different model and region +python agent.py --model amazon.nova-pro-v1:0 --region us-west-2 "What is AWS Lambda?" +``` + +## Sample output + +``` +============================================================ +Model : us.anthropic.claude-haiku-4-5-20251001-v1:0 | Region : us-east-1 +Q: What is Python? +============================================================ + + [tool] wikipedia_search({'query': 'Python programming language'}) + [result] status=success + +A: +**Python** is a high-level, general-purpose programming language known for its emphasis on code readability through significant indentation. Here are the key characteristics: + +**Main Features:** +- **Design Philosophy**: Emphasizes code readability and clean syntax +- **Type System**: Dynamically type-checked +- **Memory Management**: Garbage-collected (automatic memory management) +- **Programming Paradigms**: Supports multiple styles including: + - Structured programming (particularly procedural) + - Object-oriented programming + - Functional programming + +**History:** +- Created by Guido van Rossum in the late 1980s as a successor to the ABC programming language +- Python 3.0, released in 2008, was a major revision + +Python has become one of the most popular programming languages due to its readability, versatility, and ease of learning, making it suitable for web development, data science, artificial intelligence, automation, and many other applications. + +**Source:** https://en.wikipedia.org/wiki/Python_(programming_language) +``` diff --git a/contributing/samples/bedrock_wikipedia_agent/agent.py b/contributing/samples/bedrock_wikipedia_agent/agent.py new file mode 100644 index 0000000..222cdac --- /dev/null +++ b/contributing/samples/bedrock_wikipedia_agent/agent.py @@ -0,0 +1,280 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wikipedia research agent powered by Amazon Bedrock. + +Demonstrates BedrockModel integration with Google ADK by building a simple +research assistant that answers questions using Wikipedia. + +Usage:: + + # Default question + python agent.py + + # Custom question + python agent.py "Who invented the World Wide Web?" + + # Streaming mode + python agent.py --stream "What is quantum computing?" + + # Use a different Bedrock model + python agent.py --model amazon.nova-pro-v1:0 "What is AWS Lambda?" + +Prerequisites:: + + pip install google-adk-community[bedrock] wikipedia-api + +AWS credentials must be configured via one of: + - Environment variables (AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY) + - AWS credentials file (~/.aws/credentials) + - IAM role (EC2 instance profile, ECS task role, Lambda execution role) +""" + +import argparse +import asyncio +from functools import lru_cache +import os + +import wikipediaapi +from google.adk.agents import Agent +from google.adk import Runner +from google.adk.sessions import InMemorySessionService +from google.genai import types + +from google.adk_community.models.bedrock_model import BedrockModel + +_APP_NAME = "bedrock_wikipedia_agent" + + +# --------------------------------------------------------------------------- +# Wikipedia tools +# --------------------------------------------------------------------------- + + +@lru_cache +def _get_wiki_client(language: str) -> wikipediaapi.Wikipedia: + return wikipediaapi.Wikipedia( + user_agent="google-adk-community-example/1.0", language=language + ) + + +def wikipedia_search(query: str, language: str = "en") -> dict: + """Search Wikipedia and return a summary for the best-matching article. + + Args: + query: The topic or question to search for on Wikipedia. + language: Wikipedia language code (default: ``"en"``). + + Returns: + A dict containing ``title``, ``snippet``, ``url``, and optionally + ``related`` articles. Returns a ``"no_results"`` status dict when no + matching article is found. + """ + wiki = _get_wiki_client(language) + page = wiki.page(query) + if not page.exists(): + return { + "status": "no_results", + "query": query, + "message": f"No Wikipedia article found for: {query}", + } + + summary = page.summary + snippet = summary[:500] + "..." if len(summary) > 500 else summary + + related = [] + for _, link_page in list(page.links.items())[:3]: + if link_page.exists(): + s = link_page.summary + related.append({ + "title": link_page.title, + "snippet": s[:150] + "..." if len(s) > 150 else s, + }) + + return { + "status": "success", + "title": page.title, + "snippet": snippet, + "url": page.fullurl, + "related": related, + } + + +def wikipedia_get_article( + title: str, + summary_only: bool = True, + max_length: int = 3000, + language: str = "en", +) -> dict: + """Retrieve content from a Wikipedia article by its exact title. + + Args: + title: Exact Wikipedia article title (e.g. ``"Python (programming + language)"``). + summary_only: When ``True`` (default), return only the introductory + summary. Set to ``False`` for the full article text. + max_length: Maximum character length of full-text content (default 3000). + language: Wikipedia language code (default: ``"en"``). + + Returns: + A dict containing ``title``, ``content``, ``url``, and ``categories``. + Returns a ``"not_found"`` status dict when the article does not exist. + """ + wiki = _get_wiki_client(language) + page = wiki.page(title) + if not page.exists(): + return { + "status": "not_found", + "title": title, + "message": f"Wikipedia article not found: {title}", + } + + if summary_only: + content = page.summary + else: + content = page.text[:max_length] + if len(page.text) > max_length: + content += "\n\n[... content truncated]" + + return { + "status": "success", + "title": page.title, + "content": content, + "url": page.fullurl, + "categories": list(page.categories.keys())[:5], + } + + +# --------------------------------------------------------------------------- +# Agent factory +# --------------------------------------------------------------------------- + + +def build_agent(model_id: str, region: str) -> Agent: + """Create a Wikipedia research Agent backed by Bedrock. + + Args: + model_id: Bedrock model ID or cross-region inference profile. + region: AWS region for the Bedrock API endpoint. + + Returns: + A configured ADK :class:`~google.adk.agents.Agent`. + """ + return Agent( + model=BedrockModel(model=model_id, region_name=region, max_tokens=2048), + name="wikipedia_research_agent", + description="Answers questions using Wikipedia via Amazon Bedrock.", + instruction=( + "You are a concise research assistant. " + "Use wikipedia_search to find relevant articles and " + "wikipedia_get_article to retrieve detail when needed. " + "Always cite the Wikipedia URL in your final answer." + ), + tools=[wikipedia_search, wikipedia_get_article], + ) + + +# --------------------------------------------------------------------------- +# Runner +# --------------------------------------------------------------------------- + + +async def ask(question: str, model_id: str, region: str, stream: bool) -> None: + """Send a single question to the agent and print the response. + + Args: + question: The user's question. + model_id: Bedrock model ID to use. + region: AWS region. + stream: When ``True``, stream partial text deltas to stdout. + """ + agent = build_agent(model_id, region) + session_service = InMemorySessionService() + runner = Runner( + agent=agent, + app_name=_APP_NAME, + session_service=session_service, + ) + session = await session_service.create_session( + app_name=_APP_NAME, user_id="user" + ) + + print(f"\n{'='*60}") + print(f"Model : {model_id} | Region : {region}") + print(f"Q: {question}") + print(f"{'='*60}\n") + + async for event in runner.run_async( + user_id="user", + session_id=session.id, + new_message=types.Content( + role="user", parts=[types.Part.from_text(text=question)] + ), + ): + if not event.content or not event.content.parts: + continue + for part in event.content.parts: + if part.function_call: + print( + f" [tool] {part.function_call.name}({part.function_call.args})" + ) + elif part.function_response: + status = (part.function_response.response or {}).get("status", "?") + print(f" [result] status={status}") + elif part.text: + if stream and not event.is_final_response(): + print(part.text, end="", flush=True) + elif event.is_final_response(): + if stream: + print() # newline after streamed output + print(f"\nA:\n{part.text}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Wikipedia research agent powered by Amazon Bedrock + ADK" + ) + parser.add_argument( + "question", + nargs="?", + default="What is Amazon Bedrock?", + help="Question to answer (default: 'What is Amazon Bedrock?')", + ) + parser.add_argument( + "--model", + default="us.anthropic.claude-haiku-4-5-20251001-v1:0", + help="Bedrock model ID (default: us.anthropic.claude-haiku-4-5-20251001-v1:0)", + ) + parser.add_argument( + "--region", + default=os.environ.get("AWS_REGION", "us-east-1"), + help="AWS region (default: AWS_REGION env var or us-east-1)", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream text output to stdout", + ) + args = parser.parse_args() + + asyncio.run(ask(args.question, args.model, args.region, args.stream)) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 11afcd8..b1bdbdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,9 +41,13 @@ changelog = "https://github.com/google/adk-python-community/blob/main/CHANGELOG. documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] +bedrock = [ + "boto3>=1.40.0", +] test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", + "boto3>=1.40.0", ] diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..9ac57e5 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -13,6 +13,8 @@ # limitations under the License. from . import memory +from . import models from . import sessions from . import version + __version__ = version.__version__ diff --git a/src/google/adk_community/models/README.md b/src/google/adk_community/models/README.md new file mode 100644 index 0000000..c969727 --- /dev/null +++ b/src/google/adk_community/models/README.md @@ -0,0 +1,137 @@ +# `google.adk_community.models` + +Community-contributed model integrations for [Google ADK](https://google.github.io/adk-docs/). + +## Available models + +| Class | Provider | Install extra | +|---|---|---| +| `BedrockModel` | Amazon Bedrock (Converse API) | `bedrock` | + +--- + +## `BedrockModel` + +Native Amazon Bedrock integration via the +[Converse API](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html). +Supports all Bedrock-hosted models and cross-region inference profiles. + +### Supported models + +Any model available via the Bedrock Converse API is supported, including cross-region +inference profiles (`us.*`, `eu.*`, `ap.*`). See the +[Amazon Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html) +for the full list of supported models. + +### Installation + +```bash +pip install "google-adk-community[bedrock]" +``` + +### Quick start + +```python +from google.adk.agents import Agent +from google.adk_community.models import BedrockModel + +agent = Agent( + model=BedrockModel(model="us.anthropic.claude-haiku-4-5-20251001-v1:0"), + name="my_agent", + instruction="You are a helpful assistant.", + tools=[...], +) +``` + +Because `BedrockModel` is registered with `LLMRegistry` on import, you can +also pass the model ID as a plain string after importing the module: + +```python +import google.adk_community.models # triggers LLMRegistry registration + +agent = Agent(model="us.anthropic.claude-haiku-4-5-20251001-v1:0", ...) +``` + +### AWS authentication + +`BedrockModel` resolves credentials through the standard +[boto3 credential chain](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html): + +1. **Environment variables** — `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_SESSION_TOKEN` +2. **AWS credentials file** — `~/.aws/credentials` +3. **IAM role** — EC2 instance profile, ECS task role, Lambda execution role + +The AWS region is resolved in this order: + +1. `region_name` constructor argument +2. `region_name` of the supplied `boto_session` +3. `AWS_REGION` environment variable +4. `AWS_DEFAULT_REGION` environment variable +5. Fallback: `us-east-1` + +#### Custom boto3 session (assumed role, named profile, …) + +```python +import boto3 +from google.adk_community.models import BedrockModel + +session = boto3.Session(profile_name="my-prod-profile") +# or an assumed-role session: +# sts = boto3.client("sts") +# creds = sts.assume_role(RoleArn="arn:aws:iam::123456789:role/MyRole", +# RoleSessionName="adk-session")["Credentials"] +# session = boto3.Session( +# aws_access_key_id=creds["AccessKeyId"], +# aws_secret_access_key=creds["SecretAccessKey"], +# aws_session_token=creds["SessionToken"], +# ) + +model = BedrockModel( + model="us.anthropic.claude-haiku-4-5-20251001-v1:0", + boto_session=session, +) +``` + +> **Note:** `boto_session` and `region_name` are mutually exclusive. +> Pass the region when constructing the `boto3.Session` instead. + +### Configuration reference + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `model` | `str` | `us.anthropic.claude-haiku-4-5-20251001-v1:0` | Bedrock model ID or inference profile | +| `region_name` | `str \| None` | `None` | AWS region (see resolution order above) | +| `max_tokens` | `int` | `4096` | Maximum tokens to generate | +| `guardrail_id` | `str \| None` | `None` | Bedrock Guardrail identifier | +| `guardrail_version` | `str \| None` | `None` | Bedrock Guardrail version (`"1"`, `"DRAFT"`, …) | +| `boto_session` | `boto3.Session \| None` | `None` | Pre-configured boto3 session | + +### Guardrails + +```python +model = BedrockModel( + model="us.anthropic.claude-haiku-4-5-20251001-v1:0", + guardrail_id="abc123def456", + guardrail_version="1", +) +``` + +Responses blocked by a guardrail are returned with +`finish_reason = FinishReason.SAFETY`. + +### Streaming + +Streaming is enabled by default when the ADK runner requests it. Each text +delta is yielded as a partial `LlmResponse` (`partial=True`), followed by a +single aggregated final response (`partial=False`). + +### Error handling + +`BedrockModel` propagates `botocore.exceptions.ClientError` unchanged so +callers can inspect `error.response["Error"]["Code"]`. Common codes: + +| Code | Meaning | +|---|---| +| `ThrottlingException` | Rate limit exceeded — add retry/back-off logic | +| `ValidationException` | Invalid request — often a context-window overflow | +| `AccessDeniedException` | IAM principal lacks model access — check the [Bedrock console](https://console.aws.amazon.com/bedrock/home#/modelaccess) | diff --git a/src/google/adk_community/models/__init__.py b/src/google/adk_community/models/__init__.py new file mode 100644 index 0000000..aa0b71a --- /dev/null +++ b/src/google/adk_community/models/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.models.registry import LLMRegistry + +from .bedrock_model import BedrockModel + +LLMRegistry.register(BedrockModel) + +__all__ = ["BedrockModel"] diff --git a/src/google/adk_community/models/bedrock_model.py b/src/google/adk_community/models/bedrock_model.py new file mode 100644 index 0000000..600deff --- /dev/null +++ b/src/google/adk_community/models/bedrock_model.py @@ -0,0 +1,754 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Amazon Bedrock model integration for Google ADK. + +Provides native integration with Amazon Bedrock via the Converse API, +supporting all Bedrock-hosted models including Amazon Nova, Anthropic Claude, +Meta Llama, Mistral, Cohere, and more. +""" + +from __future__ import annotations + +import asyncio +from functools import cached_property +import json +import logging +import os +from typing import Any +from typing import AsyncGenerator +from typing import NoReturn +from typing import Optional +from typing import TYPE_CHECKING + +from google.adk.models.base_llm import BaseLlm +from google.adk.models.llm_response import LlmResponse +from google.genai import types +from pydantic import Field +from pydantic import model_validator +from typing_extensions import override + +from google.adk_community import version as _community_version + +if TYPE_CHECKING: + import boto3 + from botocore.exceptions import ClientError + from google.adk.models.llm_request import LlmRequest + +__all__ = ["BedrockModel"] + +logger = logging.getLogger("google_adk." + __name__) + +DEFAULT_MAX_TOKENS = 4096 +DEFAULT_REGION = "us-east-1" + +# Bedrock ValidationException messages that signal context window overflow. +_CONTEXT_OVERFLOW_MESSAGES = ( + "Input is too long for requested model", + "input length and `max_tokens` exceed context limit", + "too many total text bytes", + "prompt is too long", +) + +# Bedrock stopReason -> ADK FinishReason +_FINISH_REASON_MAP: dict[str, types.FinishReason] = { + "end_turn": types.FinishReason.STOP, + "tool_use": types.FinishReason.STOP, + "stop_sequence": types.FinishReason.STOP, + "max_tokens": types.FinishReason.MAX_TOKENS, + "guardrail_intervened": types.FinishReason.SAFETY, +} + +# image/* MIME type -> Bedrock image format string +_IMAGE_FORMAT_MAP: dict[str, str] = { + "image/jpeg": "jpeg", + "image/jpg": "jpeg", + "image/png": "png", + "image/gif": "gif", + "image/webp": "webp", +} + + +# --------------------------------------------------------------------------- +# ADK -> Bedrock conversion helpers +# --------------------------------------------------------------------------- + + +def _part_to_bedrock_block(part: types.Part) -> dict[str, Any] | None: + """Convert a single ADK Part to a Bedrock content block. + + Args: + part: An ADK types.Part object. + + Returns: + A Bedrock-compatible content block dict, or None if the part type is + not supported. + """ + if part.text is not None: + return {"text": part.text} + + if part.function_call: + assert part.function_call.name + return { + "toolUse": { + "toolUseId": part.function_call.id or "", + "name": part.function_call.name, + "input": part.function_call.args or {}, + } + } + + if part.function_response: + response_data = part.function_response.response + # Normalise the tool result into a plain text content block. + if isinstance(response_data, dict): + if "result" in response_data: + content_text = str(response_data["result"]) + elif "content" in response_data: + items = response_data["content"] + # Handle list-of-dicts content produced by some tool wrappers. + if isinstance(items, list): + texts = [ + item.get("text", str(item)) + if isinstance(item, dict) + else str(item) + for item in items + ] + content_text = "\n".join(texts) + else: + content_text = str(items) + else: + content_text = json.dumps(response_data) + else: + content_text = str(response_data) + + return { + "toolResult": { + "toolUseId": part.function_response.id or "", + "content": [{"text": content_text}], + "status": "success", + } + } + + if part.inline_data and part.inline_data.data and part.inline_data.mime_type: + mime = part.inline_data.mime_type.lower().split(";")[0].strip() + img_format = _IMAGE_FORMAT_MAP.get(mime) + if img_format: + return { + "image": { + "format": img_format, + "source": {"bytes": part.inline_data.data}, + } + } + logger.warning( + "mime_type=<%s> | unsupported inline_data MIME type, skipping", mime + ) + return None + + logger.warning("Unsupported ADK Part type, skipping: %s", part) + return None + + +def _content_to_bedrock_message( + content: types.Content, +) -> dict[str, Any] | None: + """Convert an ADK Content object to a Bedrock Converse API message. + + Args: + content: An ADK types.Content object. + + Returns: + A Bedrock message dict, or None if the resulting content list is empty. + """ + role = "assistant" if content.role in ("model", "assistant") else "user" + bedrock_content: list[dict[str, Any]] = [] + for part in content.parts or []: + block = _part_to_bedrock_block(part) + if block is not None: + bedrock_content.append(block) + if not bedrock_content: + return None + return {"role": role, "content": bedrock_content} + + +def _bedrock_block_to_part(block: dict[str, Any]) -> types.Part | None: + """Convert a Bedrock response content block to an ADK Part. + + Args: + block: A Bedrock content block dict from a converse response. + + Returns: + An ADK types.Part, or None if the block type is not handled. + """ + if "text" in block: + return types.Part.from_text(text=block["text"]) + + if "toolUse" in block: + tool_use = block["toolUse"] + part = types.Part.from_function_call( + name=tool_use["name"], + args=tool_use.get("input", {}), + ) + part.function_call.id = tool_use.get("toolUseId", "") + return part + + return None + + +def _function_declaration_to_tool_spec( + func_decl: types.FunctionDeclaration, +) -> dict[str, Any]: + """Convert an ADK FunctionDeclaration to a Bedrock toolSpec dict. + + Args: + func_decl: An ADK types.FunctionDeclaration. + + Returns: + A Bedrock toolSpec dict suitable for the toolConfig.tools list. + """ + assert func_decl.name + + properties: dict[str, Any] = {} + required_params: list[str] = [] + + if func_decl.parameters_json_schema: + input_schema = func_decl.parameters_json_schema + else: + if func_decl.parameters and func_decl.parameters.properties: + for key, schema in func_decl.parameters.properties.items(): + prop = schema.model_dump(exclude_none=True) + # Normalise type enum to a lowercase string. + if "type" in prop: + t = prop["type"] + prop["type"] = ( + t.value.lower() if hasattr(t, "value") else str(t).lower() + ) + properties[key] = prop + if func_decl.parameters and func_decl.parameters.required: + required_params = func_decl.parameters.required + + input_schema: dict[str, Any] = { + "type": "object", + "properties": properties, + } + if required_params: + input_schema["required"] = required_params + + return { + "toolSpec": { + "name": func_decl.name, + "description": func_decl.description or "", + "inputSchema": {"json": input_schema}, + } + } + + +# --------------------------------------------------------------------------- +# BedrockModel +# --------------------------------------------------------------------------- + + +class BedrockModel(BaseLlm): + """Amazon Bedrock integration for Google ADK via the Converse API. + + Supports all models available on Amazon Bedrock including: + - Amazon Nova (``amazon.nova-pro-v1:0``, ``amazon.nova-lite-v1:0``, …) + - Anthropic Claude (``anthropic.claude-3-5-sonnet-20241022-v2:0``, …) + - Meta Llama (``meta.llama3-70b-instruct-v1:0``, …) + - Mistral AI, Cohere, AI21, DeepSeek, and more + - Cross-region inference profiles (``us.*``, ``eu.*``, ``ap.*`` prefixes) + + Example usage:: + + from google.adk.agents import Agent + from google.adk_community.models.bedrock_model import BedrockModel + + agent = Agent( + model=BedrockModel(model="us.anthropic.claude-haiku-4-5-20251001-v1:0"), + ... + ) + + AWS credentials are resolved via the standard boto3 credential chain: + environment variables, ``~/.aws/credentials``, or an IAM instance/task role. + + To use a custom boto3 session (e.g. assumed role):: + + import boto3 + session = boto3.Session(profile_name="my-profile") + model = BedrockModel( + model="us.anthropic.claude-haiku-4-5-20251001-v1:0", + boto_session=session, + ) + + Attributes: + model: Bedrock model ID or cross-region inference profile ID. + region_name: AWS region. Resolved from ``AWS_REGION`` / ``AWS_DEFAULT_REGION`` + environment variables, then falls back to ``us-east-1``. + max_tokens: Maximum tokens to generate. Defaults to 4096. + guardrail_id: Optional Bedrock Guardrail identifier. + guardrail_version: Optional Bedrock Guardrail version (e.g. ``"1"`` or + ``"DRAFT"``). + boto_session: Optional pre-configured :class:`boto3.Session`. Takes + precedence over ``region_name`` when both are supplied; however, + supplying both raises a ``ValueError``. + """ + + model: str = "us.anthropic.claude-haiku-4-5-20251001-v1:0" + region_name: Optional[str] = None + max_tokens: int = DEFAULT_MAX_TOKENS + guardrail_id: Optional[str] = None + guardrail_version: Optional[str] = None + # boto_session is excluded from pydantic serialisation because boto3.Session + # is not a pydantic-serialisable type. + boto_session: Optional[Any] = Field(default=None, exclude=True) + + @model_validator(mode="after") + def _validate_boto_session_and_region(self) -> "BedrockModel": + if self.boto_session is not None and self.region_name is not None: + raise ValueError( + "Cannot specify both `boto_session` and `region_name`. " + "Pass `region_name` when constructing the boto3.Session instead." + ) + return self + + @classmethod + @override + def supported_models(cls) -> list[str]: + """Return regex patterns that match Bedrock model IDs. + + Covers: + - Cross-region inference profiles: ``us.*``, ``eu.*``, ``ap.*`` + - Direct model IDs for all major providers available on Bedrock + """ + return [ + # Cross-region inference profiles + r"(us|eu|ap)\.(amazon|anthropic|meta|mistral|cohere|ai21|deepseek|writer)\..+", + # Direct model IDs + r"(amazon|anthropic|meta|mistral|cohere|ai21|deepseek|writer)\..+", + ] + + @override + async def generate_content_async( + self, + llm_request: "LlmRequest", + stream: bool = False, + ) -> AsyncGenerator[LlmResponse, None]: + """Generate content using the Amazon Bedrock Converse API. + + Args: + llm_request: The ADK LlmRequest containing messages, tools, and config. + stream: When ``True``, streams partial responses via ``converse_stream``. + + Yields: + :class:`~google.adk.models.llm_response.LlmResponse` objects. + In streaming mode multiple partial responses are yielded, followed by a + final aggregated response with ``partial=False``. + """ + request = self._build_request(llm_request) + logger.debug( + "model=<%s> | sending request to Bedrock", + llm_request.model or self.model, + ) + + if stream: + async for response in self._generate_streaming(request): + yield response + else: + yield await self._generate_non_streaming(request) + + # ------------------------------------------------------------------ + # Request building + # ------------------------------------------------------------------ + + def _build_request(self, llm_request: "LlmRequest") -> dict[str, Any]: + """Build a Bedrock Converse API request dict from an LlmRequest. + + Args: + llm_request: The ADK LlmRequest. + + Returns: + A dict ready to be unpacked into ``client.converse(**request)`` or + ``client.converse_stream(**request)``. + """ + # --- messages --- + messages: list[dict[str, Any]] = [] + for content in llm_request.contents or []: + msg = _content_to_bedrock_message(content) + if msg: + messages.append(msg) + + # --- system prompt --- + system_blocks: list[dict[str, Any]] = [] + if llm_request.config and llm_request.config.system_instruction: + system_blocks = [{"text": llm_request.config.system_instruction}] + + # --- tools --- + tool_config: dict[str, Any] | None = None + if ( + llm_request.config + and llm_request.config.tools + and llm_request.config.tools[0].function_declarations + ): + tools = [ + _function_declaration_to_tool_spec(fd) + for fd in llm_request.config.tools[0].function_declarations + ] + tool_config = {"tools": tools, "toolChoice": {"auto": {}}} + + # --- inference config --- + inference_config: dict[str, Any] = {"maxTokens": self.max_tokens} + if llm_request.config: + cfg = llm_request.config + if cfg.temperature is not None: + inference_config["temperature"] = cfg.temperature + if cfg.top_p is not None: + inference_config["topP"] = cfg.top_p + if cfg.stop_sequences: + inference_config["stopSequences"] = cfg.stop_sequences + if cfg.max_output_tokens is not None: + inference_config["maxTokens"] = cfg.max_output_tokens + + request: dict[str, Any] = { + "modelId": llm_request.model or self.model, + "messages": messages, + "system": system_blocks, + "inferenceConfig": inference_config, + } + if tool_config: + request["toolConfig"] = tool_config + if self.guardrail_id and self.guardrail_version: + request["guardrailConfig"] = { + "guardrailIdentifier": self.guardrail_id, + "guardrailVersion": self.guardrail_version, + "trace": "enabled", + } + + return request + + # ------------------------------------------------------------------ + # Non-streaming + # ------------------------------------------------------------------ + + async def _generate_non_streaming( + self, request: dict[str, Any] + ) -> LlmResponse: + """Call Bedrock ``converse`` and return a single LlmResponse. + + Args: + request: A Bedrock Converse API request dict. + + Returns: + An :class:`~google.adk.models.llm_response.LlmResponse`. + + Raises: + botocore.exceptions.ClientError: Re-raised with model/region context + appended to the exception message. + """ + from botocore.exceptions import ClientError + + try: + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, lambda: self._client.converse(**request) + ) + except ClientError as e: + self._handle_client_error(e) + return self._parse_converse_response(response) + + def _parse_converse_response(self, response: dict[str, Any]) -> LlmResponse: + """Convert a Bedrock ``converse`` response dict to an LlmResponse. + + Args: + response: The raw response dict returned by ``client.converse()``. + + Returns: + An :class:`~google.adk.models.llm_response.LlmResponse`. + """ + message = response["output"]["message"] + parts: list[types.Part] = [] + for block in message.get("content", []): + part = _bedrock_block_to_part(block) + if part: + parts.append(part) + + stop_reason = response.get("stopReason", "end_turn") + finish_reason = _FINISH_REASON_MAP.get(stop_reason, types.FinishReason.STOP) + + usage_metadata = None + if "usage" in response: + usage = response["usage"] + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=usage.get("inputTokens", 0), + candidates_token_count=usage.get("outputTokens", 0), + total_token_count=usage.get("totalTokens", 0), + ) + + return LlmResponse( + content=types.Content(role="model", parts=parts), + finish_reason=finish_reason, + usage_metadata=usage_metadata, + ) + + # ------------------------------------------------------------------ + # Streaming + # ------------------------------------------------------------------ + + async def _generate_streaming( + self, request: dict[str, Any] + ) -> AsyncGenerator[LlmResponse, None]: + """Call Bedrock ``converse_stream`` and yield partial + final LlmResponse. + + The synchronous boto3 streaming call is offloaded to a thread; events are + forwarded to the async caller via a queue, following the same thread/queue + bridge pattern used by the Strands Agents SDK. + + Args: + request: A Bedrock Converse API request dict. + + Yields: + Partial :class:`~google.adk.models.llm_response.LlmResponse` objects + (``partial=True``) for each text delta, followed by a single final + response (``partial=False``) with all accumulated content. + """ + loop = asyncio.get_running_loop() + queue: asyncio.Queue[dict[str, Any] | None | Exception] = asyncio.Queue() + + def _stream_in_thread() -> None: + from botocore.exceptions import ClientError + + try: + response = self._client.converse_stream(**request) + for chunk in response["stream"]: + loop.call_soon_threadsafe(queue.put_nowait, chunk) + except ClientError as e: + self._handle_client_error(e) + raise + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) + + task = asyncio.create_task(asyncio.to_thread(_stream_in_thread)) + + # --- accumulation state --- + text_buffer = "" + # block_index -> {"id": str, "name": str, "args": str} + function_calls: dict[int, dict[str, Any]] = {} + finish_reason: types.FinishReason | None = None + usage_metadata: types.GenerateContentResponseUsageMetadata | None = None + + while True: + chunk = await queue.get() + if chunk is None: + break + if isinstance(chunk, Exception): + raise chunk + + # messageStart — carries role, not needed for output. + if "messageStart" in chunk: + pass + + # contentBlockStart — toolUse blocks announce name/id here. + elif "contentBlockStart" in chunk: + cbs = chunk["contentBlockStart"] + block_index: int = cbs.get("contentBlockIndex", 0) + start = cbs.get("start", {}) + if "toolUse" in start: + function_calls[block_index] = { + "id": start["toolUse"]["toolUseId"], + "name": start["toolUse"]["name"], + "args": "", + } + + # contentBlockDelta — text or toolUse input fragments. + elif "contentBlockDelta" in chunk: + cbd = chunk["contentBlockDelta"] + block_index = cbd.get("contentBlockIndex", 0) + delta = cbd.get("delta", {}) + + if "text" in delta: + text = delta["text"] + text_buffer += text + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text=text)], + ), + partial=True, + ) + + elif "toolUse" in delta: + if block_index in function_calls: + function_calls[block_index]["args"] += delta["toolUse"].get( + "input", "" + ) + + # contentBlockStop — end of a content block; no action needed here. + elif "contentBlockStop" in chunk: + pass + + # messageStop — carries the final stop reason. + elif "messageStop" in chunk: + stop_reason = chunk["messageStop"].get("stopReason", "end_turn") + finish_reason = _FINISH_REASON_MAP.get( + stop_reason, types.FinishReason.STOP + ) + + # metadata — usage token counts. + elif "metadata" in chunk: + meta = chunk["metadata"] + if "usage" in meta: + usage = meta["usage"] + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=usage.get("inputTokens", 0), + candidates_token_count=usage.get("outputTokens", 0), + total_token_count=usage.get("totalTokens", 0), + ) + + await task + + # --- build final aggregated response --- + parts: list[types.Part] = [] + if text_buffer: + parts.append(types.Part.from_text(text=text_buffer)) + + for _, fc in sorted(function_calls.items()): + try: + args = json.loads(fc["args"]) if fc["args"] else {} + except json.JSONDecodeError: + logger.warning( + "tool_name=<%s> | failed to parse tool input JSON, using empty" + " dict", + fc["name"], + ) + args = {} + part = types.Part.from_function_call(name=fc["name"], args=args) + part.function_call.id = fc["id"] + parts.append(part) + + yield LlmResponse( + content=types.Content(role="model", parts=parts), + partial=False, + finish_reason=finish_reason, + usage_metadata=usage_metadata, + ) + + # ------------------------------------------------------------------ + # Error handling + # ------------------------------------------------------------------ + + def _handle_client_error(self, error: "ClientError") -> NoReturn: + """Enrich and re-raise a botocore ClientError with diagnostic context. + + Logs the model ID and region, then re-raises the original exception so + callers can inspect ``error.response["Error"]["Code"]`` as usual. + + Common error codes surfaced here: + + * ``ThrottlingException`` — the Bedrock service is rate-limiting the + request; callers should implement back-off / retry logic. + * ``ValidationException`` with a context-overflow message — the combined + prompt and ``max_tokens`` exceeds the model's context window. + * ``AccessDeniedException`` — the IAM principal does not have access to + the requested model; check model-access settings in the Bedrock console. + + Args: + error: The :class:`botocore.exceptions.ClientError` to handle. + + Raises: + botocore.exceptions.ClientError: Always re-raises *error*. + """ + code = error.response["Error"]["Code"] + message = str(error) + region = getattr( + getattr(self._client, "meta", None), "region_name", "unknown" + ) + + logger.error( + "model=<%s> region=<%s> error_code=<%s> | Bedrock ClientError: %s", + self.model, + region, + code, + message, + ) + + if code in ("ThrottlingException", "throttlingException"): + logger.warning( + "model=<%s> | request throttled; consider adding retry logic", + self.model, + ) + + if code == "ValidationException" and any( + msg in message for msg in _CONTEXT_OVERFLOW_MESSAGES + ): + logger.warning( + "model=<%s> | context window overflow — reduce prompt length or" + " max_tokens", + self.model, + ) + + if code == "AccessDeniedException" and "model" in message.lower(): + logger.warning( + "model=<%s> | access denied — enable model access at " + "https://console.aws.amazon.com/bedrock/home#/modelaccess", + self.model, + ) + + raise error + + # ------------------------------------------------------------------ + # boto3 client + # ------------------------------------------------------------------ + + @cached_property + def _client(self) -> Any: + """Create and cache a boto3 ``bedrock-runtime`` client. + + Resolves the AWS region in priority order: + 1. ``region_name`` attribute + 2. Region from ``boto_session`` (if provided) + 3. ``AWS_REGION`` environment variable + 4. ``AWS_DEFAULT_REGION`` environment variable + 5. Hard-coded fallback ``us-east-1`` + + Returns: + A boto3 ``bedrock-runtime`` client. + + Raises: + ImportError: If ``boto3`` is not installed. + """ + try: + import boto3 + except ImportError as e: + raise ImportError( + "BedrockModel requires the boto3 package.\n" + "Install it with: pip install google-adk-community[bedrock]\n" + "Or: pip install boto3" + ) from e + + from botocore.config import Config as BotocoreConfig + + session: boto3.Session = self.boto_session or boto3.Session() + region = ( + self.region_name + or session.region_name + or os.environ.get("AWS_REGION") + or os.environ.get("AWS_DEFAULT_REGION") + or DEFAULT_REGION + ) + user_agent = f"google-adk-community/{_community_version.__version__}" + client_config = BotocoreConfig(user_agent_extra=user_agent) + logger.debug("region=<%s> | creating bedrock-runtime client", region) + return session.client( + "bedrock-runtime", region_name=region, config=client_config + ) diff --git a/tests/unittests/models/__init__.py b/tests/unittests/models/__init__.py new file mode 100644 index 0000000..0a2669d --- /dev/null +++ b/tests/unittests/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/models/test_bedrock_model.py b/tests/unittests/models/test_bedrock_model.py new file mode 100644 index 0000000..026289d --- /dev/null +++ b/tests/unittests/models/test_bedrock_model.py @@ -0,0 +1,932 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for BedrockModel.""" + +import json +import os +import sys +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.genai import types +import pytest + +from google.adk_community.models.bedrock_model import _bedrock_block_to_part +from google.adk_community.models.bedrock_model import _content_to_bedrock_message +from google.adk_community.models.bedrock_model import _function_declaration_to_tool_spec +from google.adk_community.models.bedrock_model import _part_to_bedrock_block +from google.adk_community.models.bedrock_model import BedrockModel + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_llm_request( + contents=None, + system_instruction=None, + function_declarations=None, + temperature=None, + max_output_tokens=None, + top_p=None, + stop_sequences=None, + model=None, +): + """Build a minimal LlmRequest-like mock.""" + config = MagicMock() + config.system_instruction = system_instruction + config.temperature = temperature + config.top_p = top_p + config.stop_sequences = stop_sequences + config.max_output_tokens = max_output_tokens + config.response_schema = None + + if function_declarations: + tool = MagicMock() + tool.function_declarations = function_declarations + config.tools = [tool] + else: + config.tools = [] + + req = MagicMock() + req.contents = contents or [] + req.config = config + req.model = model + return req + + +def _make_model(model_id="us.anthropic.claude-haiku-4-5-20251001-v1:0", **kw): + """Return a BedrockModel with a mocked boto3 client.""" + m = BedrockModel(model=model_id, **kw) + return m + + +def _mock_client(model: BedrockModel) -> MagicMock: + """Patch the cached _client property and return the mock.""" + client = MagicMock() + # Override the cached_property by writing directly to __dict__ + model.__dict__["_client"] = client + return client + + +# --------------------------------------------------------------------------- +# supported_models +# --------------------------------------------------------------------------- + + +class TestSupportedModels: + + def test_cross_region_inference_profile(self): + import re + + patterns = BedrockModel.supported_models() + model_id = "us.anthropic.claude-haiku-4-5-20251001-v1:0" + assert any(re.fullmatch(p, model_id) for p in patterns) + + def test_direct_model_id(self): + import re + + patterns = BedrockModel.supported_models() + model_id = "anthropic.claude-3-5-sonnet-20241022-v2:0" + assert any(re.fullmatch(p, model_id) for p in patterns) + + def test_amazon_nova(self): + import re + + patterns = BedrockModel.supported_models() + assert any(re.fullmatch(p, "amazon.nova-pro-v1:0") for p in patterns) + + def test_meta_llama(self): + import re + + patterns = BedrockModel.supported_models() + assert any( + re.fullmatch(p, "meta.llama3-70b-instruct-v1:0") for p in patterns + ) + + def test_eu_inference_profile(self): + import re + + patterns = BedrockModel.supported_models() + assert any( + re.fullmatch(p, "eu.anthropic.claude-3-5-sonnet-20241022-v2:0") + for p in patterns + ) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +class TestValidation: + + def test_boto_session_and_region_raises(self): + mock_session = MagicMock() + with pytest.raises(ValueError, match="Cannot specify both"): + BedrockModel( + model="amazon.nova-pro-v1:0", + boto_session=mock_session, + region_name="us-east-1", + ) + + +# --------------------------------------------------------------------------- +# ADK -> Bedrock conversion helpers +# --------------------------------------------------------------------------- + + +class TestPartToBedrockBlock: + + def test_text_part(self): + part = types.Part.from_text(text="hello") + block = _part_to_bedrock_block(part) + assert block == {"text": "hello"} + + def test_function_call_part(self): + part = types.Part.from_function_call(name="my_tool", args={"key": "val"}) + part.function_call.id = "tool-123" + block = _part_to_bedrock_block(part) + assert block == { + "toolUse": { + "toolUseId": "tool-123", + "name": "my_tool", + "input": {"key": "val"}, + } + } + + def test_function_response_with_result(self): + part = types.Part.from_function_response( + name="my_tool", response={"result": "42"} + ) + part.function_response.id = "tool-123" + block = _part_to_bedrock_block(part) + assert block is not None + assert block["toolResult"]["toolUseId"] == "tool-123" + assert block["toolResult"]["status"] == "success" + assert block["toolResult"]["content"][0]["text"] == "42" + + def test_function_response_with_content_list(self): + part = types.Part.from_function_response( + name="my_tool", + response={"content": [{"text": "line1"}, {"text": "line2"}]}, + ) + part.function_response.id = "tool-456" + block = _part_to_bedrock_block(part) + assert "line1" in block["toolResult"]["content"][0]["text"] + assert "line2" in block["toolResult"]["content"][0]["text"] + + def test_image_part_jpeg(self): + part = types.Part( + inline_data=types.Blob(mime_type="image/jpeg", data=b"\xff\xd8") + ) + block = _part_to_bedrock_block(part) + assert block is not None + assert block["image"]["format"] == "jpeg" + assert block["image"]["source"]["bytes"] == b"\xff\xd8" + + def test_image_part_png(self): + part = types.Part( + inline_data=types.Blob(mime_type="image/png", data=b"\x89PNG") + ) + block = _part_to_bedrock_block(part) + assert block["image"]["format"] == "png" + + def test_unsupported_mime_returns_none(self): + part = types.Part( + inline_data=types.Blob(mime_type="audio/mp3", data=b"\x00\x01") + ) + block = _part_to_bedrock_block(part) + assert block is None + + +class TestContentToBedrockMessage: + + def test_user_role(self): + content = types.Content( + role="user", parts=[types.Part.from_text(text="hi")] + ) + msg = _content_to_bedrock_message(content) + assert msg["role"] == "user" + assert msg["content"] == [{"text": "hi"}] + + def test_model_role_mapped_to_assistant(self): + content = types.Content( + role="model", parts=[types.Part.from_text(text="reply")] + ) + msg = _content_to_bedrock_message(content) + assert msg["role"] == "assistant" + + def test_empty_parts_returns_none(self): + content = types.Content(role="user", parts=[]) + msg = _content_to_bedrock_message(content) + assert msg is None + + +class TestBedrockBlockToADKPart: + + def test_text_block(self): + part = _bedrock_block_to_part({"text": "hello"}) + assert part is not None + assert part.text == "hello" + + def test_tool_use_block(self): + block = { + "toolUse": { + "toolUseId": "t1", + "name": "search", + "input": {"query": "aws"}, + } + } + part = _bedrock_block_to_part(block) + assert part is not None + assert part.function_call.name == "search" + assert part.function_call.args == {"query": "aws"} + assert part.function_call.id == "t1" + + def test_unknown_block_returns_none(self): + part = _bedrock_block_to_part({"unknownField": "value"}) + assert part is None + + +class TestFunctionDeclarationToToolSpec: + + def test_basic_declaration(self): + fd = types.FunctionDeclaration( + name="get_weather", + description="Get the weather", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "location": types.Schema( + type=types.Type.STRING, description="City name" + ) + }, + required=["location"], + ), + ) + spec = _function_declaration_to_tool_spec(fd) + assert spec["toolSpec"]["name"] == "get_weather" + assert spec["toolSpec"]["description"] == "Get the weather" + schema = spec["toolSpec"]["inputSchema"]["json"] + assert schema["type"] == "object" + assert "location" in schema["properties"] + assert schema["required"] == ["location"] + + def test_no_parameters(self): + fd = types.FunctionDeclaration(name="ping", description="Ping the service") + spec = _function_declaration_to_tool_spec(fd) + assert spec["toolSpec"]["name"] == "ping" + assert spec["toolSpec"]["inputSchema"]["json"]["type"] == "object" + + +# --------------------------------------------------------------------------- +# _build_request +# --------------------------------------------------------------------------- + + +class TestBuildRequest: + + def test_basic_text_request(self): + model = _make_model() + req = _make_llm_request( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]) + ], + system_instruction="You are helpful.", + model="us.anthropic.claude-haiku-4-5-20251001-v1:0", + ) + payload = model._build_request(req) + assert payload["modelId"] == "us.anthropic.claude-haiku-4-5-20251001-v1:0" + assert payload["messages"][0]["role"] == "user" + assert payload["system"] == [{"text": "You are helpful."}] + assert payload["inferenceConfig"]["maxTokens"] == 4096 + + def test_inference_config_overrides(self): + model = _make_model() + req = _make_llm_request( + temperature=0.5, + top_p=0.9, + max_output_tokens=512, + stop_sequences=["STOP"], + ) + payload = model._build_request(req) + cfg = payload["inferenceConfig"] + assert cfg["temperature"] == 0.5 + assert cfg["topP"] == 0.9 + assert cfg["maxTokens"] == 512 + assert cfg["stopSequences"] == ["STOP"] + + def test_tool_config_included(self): + fd = types.FunctionDeclaration(name="echo", description="Echo input") + model = _make_model() + req = _make_llm_request(function_declarations=[fd]) + payload = model._build_request(req) + assert "toolConfig" in payload + tools = payload["toolConfig"]["tools"] + assert tools[0]["toolSpec"]["name"] == "echo" + + def test_guardrail_config_included(self): + model = _make_model(guardrail_id="abc123", guardrail_version="1") + req = _make_llm_request() + payload = model._build_request(req) + assert payload["guardrailConfig"]["guardrailIdentifier"] == "abc123" + assert payload["guardrailConfig"]["guardrailVersion"] == "1" + + def test_no_guardrail_when_not_set(self): + model = _make_model() + req = _make_llm_request() + payload = model._build_request(req) + assert "guardrailConfig" not in payload + + def test_model_id_from_request_overrides_default(self): + model = _make_model(model_id="amazon.nova-lite-v1:0") + req = _make_llm_request(model="amazon.nova-pro-v1:0") + payload = model._build_request(req) + assert payload["modelId"] == "amazon.nova-pro-v1:0" + + +# --------------------------------------------------------------------------- +# Non-streaming generation +# --------------------------------------------------------------------------- + + +class TestGenerateNonStreaming: + + @pytest.mark.asyncio + async def test_text_response(self): + model = _make_model() + client = _mock_client(model) + client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [{"text": "Hello, world!"}], + } + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + } + + req = _make_llm_request( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]) + ] + ) + responses = [] + async for r in model.generate_content_async(req, stream=False): + responses.append(r) + + assert len(responses) == 1 + resp = responses[0] + assert resp.content.parts[0].text == "Hello, world!" + assert resp.usage_metadata.total_token_count == 15 + + @pytest.mark.asyncio + async def test_tool_use_response(self): + model = _make_model() + client = _mock_client(model) + client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [{ + "toolUse": { + "toolUseId": "tool-1", + "name": "get_weather", + "input": {"location": "Seattle"}, + } + }], + } + }, + "stopReason": "tool_use", + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, + } + + req = _make_llm_request() + responses = [] + async for r in model.generate_content_async(req, stream=False): + responses.append(r) + + resp = responses[0] + fc = resp.content.parts[0].function_call + assert fc.name == "get_weather" + assert fc.args == {"location": "Seattle"} + assert fc.id == "tool-1" + + @pytest.mark.asyncio + async def test_max_tokens_finish_reason(self): + model = _make_model() + client = _mock_client(model) + client.converse.return_value = { + "output": { + "message": {"role": "assistant", "content": [{"text": "truncated"}]} + }, + "stopReason": "max_tokens", + "usage": {"inputTokens": 5, "outputTokens": 4096, "totalTokens": 4101}, + } + + req = _make_llm_request() + responses = [] + async for r in model.generate_content_async(req, stream=False): + responses.append(r) + + from google.genai import types as gtypes + + assert responses[0].finish_reason == gtypes.FinishReason.MAX_TOKENS + + +# --------------------------------------------------------------------------- +# Streaming generation +# --------------------------------------------------------------------------- + + +class TestGenerateStreaming: + + def _build_stream_chunks(self, text="Hello!", tool_use=None): + """Helper to build a realistic list of Bedrock stream chunks.""" + chunks = [{"messageStart": {"role": "assistant"}}] + if text: + chunks += [ + {"contentBlockStart": {"contentBlockIndex": 0, "start": {}}}, + { + "contentBlockDelta": { + "contentBlockIndex": 0, + "delta": {"text": text}, + } + }, + {"contentBlockStop": {"contentBlockIndex": 0}}, + ] + if tool_use: + chunks += [ + { + "contentBlockStart": { + "contentBlockIndex": 1, + "start": { + "toolUse": { + "toolUseId": tool_use["id"], + "name": tool_use["name"], + } + }, + } + }, + { + "contentBlockDelta": { + "contentBlockIndex": 1, + "delta": { + "toolUse": {"input": json.dumps(tool_use["input"])} + }, + } + }, + {"contentBlockStop": {"contentBlockIndex": 1}}, + ] + chunks += [ + {"messageStop": {"stopReason": "end_turn"}}, + { + "metadata": { + "usage": { + "inputTokens": 10, + "outputTokens": 8, + "totalTokens": 18, + } + } + }, + ] + return chunks + + @pytest.mark.asyncio + async def test_streaming_text(self): + model = _make_model() + client = _mock_client(model) + chunks = self._build_stream_chunks(text="Hi there!") + + def fake_converse_stream(**kwargs): + return {"stream": iter(chunks)} + + client.converse_stream.side_effect = fake_converse_stream + + req = _make_llm_request( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + responses = [] + async for r in model.generate_content_async(req, stream=True): + responses.append(r) + + # Partial chunks + final aggregated response + partials = [r for r in responses if r.partial] + finals = [r for r in responses if not r.partial] + assert len(partials) >= 1 + assert len(finals) == 1 + assert finals[0].content.parts[0].text == "Hi there!" + assert finals[0].usage_metadata.total_token_count == 18 + + @pytest.mark.asyncio + async def test_streaming_tool_use(self): + model = _make_model() + client = _mock_client(model) + chunks = self._build_stream_chunks( + text="", + tool_use={"id": "t1", "name": "search", "input": {"q": "bedrock"}}, + ) + + def fake_converse_stream(**kwargs): + return {"stream": iter(chunks)} + + client.converse_stream.side_effect = fake_converse_stream + + req = _make_llm_request() + responses = [] + async for r in model.generate_content_async(req, stream=True): + responses.append(r) + + final = [r for r in responses if not r.partial][0] + fc_parts = [p for p in final.content.parts if p.function_call] + assert len(fc_parts) == 1 + assert fc_parts[0].function_call.name == "search" + assert fc_parts[0].function_call.args == {"q": "bedrock"} + assert fc_parts[0].function_call.id == "t1" + + +# --------------------------------------------------------------------------- +# boto3 client — importerror, session, region resolution +# --------------------------------------------------------------------------- + + +class TestClient: + + def test_boto3_import_error(self): + model = _make_model() + # Ensure _client is not already cached. + assert "_client" not in model.__dict__ + with patch.dict("sys.modules", {"boto3": None}): + with pytest.raises(ImportError, match="BedrockModel requires"): + _ = model._client + + def test_custom_boto_session_used(self): + mock_session = MagicMock() + mock_session.region_name = "eu-central-1" + mock_session.client.return_value = MagicMock() + + model = BedrockModel( + model="amazon.nova-pro-v1:0", + boto_session=mock_session, + ) + _ = model._client + + mock_session.client.assert_called_once_with( + "bedrock-runtime", region_name="eu-central-1", config=ANY + ) + + def test_region_from_aws_region_env(self): + with patch.dict(os.environ, {"AWS_REGION": "ap-northeast-1"}, clear=False): + with patch("boto3.Session") as mock_cls: + mock_sess = MagicMock() + mock_sess.region_name = None + mock_cls.return_value = mock_sess + mock_sess.client.return_value = MagicMock() + + model = BedrockModel(model="amazon.nova-pro-v1:0") + _ = model._client + + mock_sess.client.assert_called_once_with( + "bedrock-runtime", region_name="ap-northeast-1", config=ANY + ) + + def test_region_fallback_to_default(self): + env_clean = { + k: v + for k, v in os.environ.items() + if k not in ("AWS_REGION", "AWS_DEFAULT_REGION") + } + with patch.dict(os.environ, env_clean, clear=True): + with patch("boto3.Session") as mock_cls: + mock_sess = MagicMock() + mock_sess.region_name = None + mock_cls.return_value = mock_sess + mock_sess.client.return_value = MagicMock() + + model = BedrockModel(model="amazon.nova-pro-v1:0") + _ = model._client + + mock_sess.client.assert_called_once_with( + "bedrock-runtime", region_name="us-east-1", config=ANY + ) + + def test_user_agent_set_on_client(self): + with patch("boto3.Session") as mock_cls: + mock_sess = MagicMock() + mock_sess.region_name = "us-east-1" + mock_cls.return_value = mock_sess + mock_sess.client.return_value = MagicMock() + + model = BedrockModel(model="amazon.nova-pro-v1:0") + _ = model._client + + _, kwargs = mock_sess.client.call_args + config = kwargs.get("config") + assert config is not None + assert "google-adk-community" in config.user_agent_extra + + def test_boto_session_excluded_from_serialisation(self): + mock_session = MagicMock() + model = BedrockModel( + model="amazon.nova-pro-v1:0", + boto_session=mock_session, + ) + dumped = model.model_dump() + assert "boto_session" not in dumped + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +def _make_client_error(code: str, message: str) -> "MagicMock": + """Build a mock botocore ClientError.""" + from botocore.exceptions import ClientError + + error_response = {"Error": {"Code": code, "Message": message}} + return ClientError(error_response, "converse") + + +class TestClientErrorHandling: + + @pytest.mark.asyncio + async def test_throttling_exception_propagates(self): + model = _make_model() + client = _mock_client(model) + from botocore.exceptions import ClientError + + client.converse.side_effect = _make_client_error( + "ThrottlingException", "Rate exceeded" + ) + + req = _make_llm_request() + with pytest.raises(ClientError) as exc_info: + async for _ in model.generate_content_async(req, stream=False): + pass + + assert exc_info.value.response["Error"]["Code"] == "ThrottlingException" + + @pytest.mark.asyncio + async def test_context_overflow_exception_propagates(self): + model = _make_model() + client = _mock_client(model) + from botocore.exceptions import ClientError + + client.converse.side_effect = _make_client_error( + "ValidationException", "Input is too long for requested model" + ) + + req = _make_llm_request() + with pytest.raises(ClientError) as exc_info: + async for _ in model.generate_content_async(req, stream=False): + pass + + assert exc_info.value.response["Error"]["Code"] == "ValidationException" + + @pytest.mark.asyncio + async def test_streaming_throttling_propagates(self): + model = _make_model() + client = _mock_client(model) + from botocore.exceptions import ClientError + + client.converse_stream.side_effect = _make_client_error( + "ThrottlingException", "Rate exceeded" + ) + + req = _make_llm_request() + with pytest.raises(ClientError): + async for _ in model.generate_content_async(req, stream=True): + pass + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + + def test_empty_contents_builds_empty_messages(self): + model = _make_model() + req = _make_llm_request(contents=[]) + payload = model._build_request(req) + assert payload["messages"] == [] + + def test_content_with_all_unsupported_parts_skipped(self): + # A Content where every Part is unsupported → message is None → not added. + model = _make_model() + content = types.Content(role="user", parts=[types.Part()]) + msg = _content_to_bedrock_message(content) + assert msg is None + + @pytest.mark.asyncio + async def test_guardrail_intervened_finish_reason_non_streaming(self): + model = _make_model() + client = _mock_client(model) + client.converse.return_value = { + "output": { + "message": {"role": "assistant", "content": [{"text": "blocked"}]} + }, + "stopReason": "guardrail_intervened", + "usage": {"inputTokens": 5, "outputTokens": 1, "totalTokens": 6}, + } + + responses = [] + async for r in model.generate_content_async( + _make_llm_request(), stream=False + ): + responses.append(r) + + assert responses[0].finish_reason == types.FinishReason.SAFETY + + def test_function_response_dict_fallback_json_dumps(self): + """Dict with neither 'result' nor 'content' key → json.dumps path.""" + part = types.Part.from_function_response( + name="my_tool", response={"foo": "bar", "baz": 123} + ) + part.function_response.id = "tool-789" + block = _part_to_bedrock_block(part) + assert block is not None + text = block["toolResult"]["content"][0]["text"] + assert json.loads(text) == {"foo": "bar", "baz": 123} + + def test_function_declaration_with_parameters_json_schema(self): + """When parameters_json_schema is set, it should be used directly.""" + schema = { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } + fd = types.FunctionDeclaration( + name="get_weather", + description="Get weather", + parameters_json_schema=schema, + ) + spec = _function_declaration_to_tool_spec(fd) + assert spec["toolSpec"]["inputSchema"]["json"] == schema + + def test_empty_text_part_preserved(self): + """Empty string text part should not be silently dropped.""" + part = types.Part.from_text(text="") + block = _part_to_bedrock_block(part) + assert block == {"text": ""} + + def test_region_name_directly_specified(self): + with patch("boto3.Session") as mock_cls: + mock_sess = MagicMock() + mock_cls.return_value = mock_sess + mock_sess.client.return_value = MagicMock() + + model = BedrockModel( + model="amazon.nova-pro-v1:0", region_name="ap-northeast-2" + ) + _ = model._client + + mock_sess.client.assert_called_once_with( + "bedrock-runtime", region_name="ap-northeast-2", config=ANY + ) + + +class TestStreamingEdgeCases: + + def _build_text_and_tool_chunks(self): + """Build chunks with both text and tool_use in a single response.""" + return [ + {"messageStart": {"role": "assistant"}}, + # Text block + {"contentBlockStart": {"contentBlockIndex": 0, "start": {}}}, + { + "contentBlockDelta": { + "contentBlockIndex": 0, + "delta": {"text": "Let me search."}, + } + }, + {"contentBlockStop": {"contentBlockIndex": 0}}, + # Tool use block + { + "contentBlockStart": { + "contentBlockIndex": 1, + "start": {"toolUse": {"toolUseId": "t1", "name": "search"}}, + } + }, + { + "contentBlockDelta": { + "contentBlockIndex": 1, + "delta": {"toolUse": {"input": '{"q": "adk"}'}}, + } + }, + {"contentBlockStop": {"contentBlockIndex": 1}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": { + "inputTokens": 15, + "outputTokens": 12, + "totalTokens": 27, + } + } + }, + ] + + @pytest.mark.asyncio + async def test_streaming_text_and_tool_combined(self): + """Streaming response with both text and tool_use in one turn.""" + model = _make_model() + client = _mock_client(model) + chunks = self._build_text_and_tool_chunks() + + def fake_converse_stream(**kwargs): + return {"stream": iter(chunks)} + + client.converse_stream.side_effect = fake_converse_stream + + req = _make_llm_request() + responses = [] + async for r in model.generate_content_async(req, stream=True): + responses.append(r) + + final = [r for r in responses if not r.partial][0] + assert len(final.content.parts) == 2 + assert final.content.parts[0].text == "Let me search." + assert final.content.parts[1].function_call.name == "search" + assert final.content.parts[1].function_call.args == {"q": "adk"} + assert final.content.parts[1].function_call.id == "t1" + + @pytest.mark.asyncio + async def test_streaming_malformed_tool_json_fallback(self): + """Malformed tool input JSON should fall back to empty dict.""" + model = _make_model() + client = _mock_client(model) + chunks = [ + {"messageStart": {"role": "assistant"}}, + { + "contentBlockStart": { + "contentBlockIndex": 0, + "start": {"toolUse": {"toolUseId": "t1", "name": "broken"}}, + } + }, + { + "contentBlockDelta": { + "contentBlockIndex": 0, + "delta": {"toolUse": {"input": "{invalid json!!"}}, + } + }, + {"contentBlockStop": {"contentBlockIndex": 0}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": { + "inputTokens": 5, + "outputTokens": 3, + "totalTokens": 8, + } + } + }, + ] + + def fake_converse_stream(**kwargs): + return {"stream": iter(chunks)} + + client.converse_stream.side_effect = fake_converse_stream + + req = _make_llm_request() + responses = [] + async for r in model.generate_content_async(req, stream=True): + responses.append(r) + + final = [r for r in responses if not r.partial][0] + assert final.content.parts[0].function_call.name == "broken" + assert final.content.parts[0].function_call.args == {} + + @pytest.mark.asyncio + async def test_streaming_non_client_error_propagates(self): + """Non-ClientError exceptions in stream thread should propagate.""" + model = _make_model() + client = _mock_client(model) + + client.converse_stream.side_effect = ConnectionError("Network unreachable") + + req = _make_llm_request() + with pytest.raises(ConnectionError, match="Network unreachable"): + async for _ in model.generate_content_async(req, stream=True): + pass