From e3c44bfd7f52e1810fe6dc724cd2d59aaf9b86cb Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 10 Apr 2026 17:10:30 -0400 Subject: [PATCH] feat: add support for all models to agent config --- src/strands/experimental/agent_config.py | 147 ++++- src/strands/models/bedrock.py | 26 + src/strands/models/llamacpp.py | 20 + src/strands/models/mistral.py | 22 + src/strands/models/model.py | 23 + src/strands/models/ollama.py | 21 + src/strands/models/sagemaker.py | 21 + .../strands/experimental/test_agent_config.py | 549 ++++++++++++++++++ 8 files changed, 824 insertions(+), 5 deletions(-) diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index e6fb94118..26510e6c6 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -9,9 +9,33 @@ agent = config_to_agent("config.json") # Add tools that need code-based instantiation agent.tool_registry.process_tools([ToolWithConfigArg(HttpsConnection("localhost"))]) + +The ``model`` field supports two formats: + +**String format (backward compatible — defaults to Bedrock):** + {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + +**Object format (supports all providers):** + { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 10000, + "client_args": {"api_key": "$ANTHROPIC_API_KEY"} + } + } + +Environment variable references (``$VAR`` or ``${VAR}``) in model config values are resolved +automatically before provider instantiation. + +Note: The following constructor parameters cannot be specified from JSON because they require +code-based instantiation: ``boto_session`` (Bedrock, SageMaker), ``client`` (OpenAI, Gemini), +``gemini_tools`` (Gemini). Use ``region_name`` / ``client_args`` as JSON-friendly alternatives. """ import json +import os +import re from pathlib import Path from typing import Any @@ -27,8 +51,25 @@ "properties": { "name": {"description": "Name of the agent", "type": ["string", "null"], "default": None}, "model": { - "description": "The model ID to use for this agent. If not specified, uses the default model.", - "type": ["string", "null"], + "description": ( + "The model to use for this agent. Can be a string (Bedrock model_id) " + "or an object with a 'provider' field for any supported provider." + ), + "oneOf": [ + {"type": "string"}, + {"type": "null"}, + { + "type": "object", + "properties": { + "provider": { + "description": "The model provider name", + "type": "string", + } + }, + "required": ["provider"], + "additionalProperties": True, + }, + ], "default": None, }, "prompt": { @@ -50,6 +91,87 @@ # Pre-compile validator for better performance _VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA) +# Pattern for matching environment variable references +_ENV_VAR_PATTERN = re.compile(r"^\$\{([^}]+)\}$|^\$([A-Za-z_][A-Za-z0-9_]*)$") + +# Provider name to model class name — resolved via strands.models lazy __getattr__ +PROVIDER_MAP: dict[str, str] = { + "bedrock": "BedrockModel", + "anthropic": "AnthropicModel", + "openai": "OpenAIModel", + "gemini": "GeminiModel", + "ollama": "OllamaModel", + "litellm": "LiteLLMModel", + "mistral": "MistralModel", + "llamaapi": "LlamaAPIModel", + "llamacpp": "LlamaCppModel", + "sagemaker": "SageMakerAIModel", + "writer": "WriterModel", + "openai_responses": "OpenAIResponsesModel", +} + + +def _resolve_env_vars(value: Any) -> Any: + """Recursively resolve environment variable references in config values. + + String values matching ``$VAR_NAME`` or ``${VAR_NAME}`` are replaced with the + corresponding environment variable value. Dicts and lists are traversed recursively. + + Args: + value: The value to resolve. Can be a string, dict, list, or any other type. + + Returns: + The resolved value with environment variable references replaced. + + Raises: + ValueError: If a referenced environment variable is not set. + """ + if isinstance(value, str): + match = _ENV_VAR_PATTERN.match(value) + if match: + var_name = match.group(1) or match.group(2) + env_value = os.environ.get(var_name) + if env_value is None: + raise ValueError(f"Environment variable '{var_name}' is not set") + return env_value + return value + if isinstance(value, dict): + return {k: _resolve_env_vars(v) for k, v in value.items()} + if isinstance(value, list): + return [_resolve_env_vars(item) for item in value] + return value + + +def _create_model_from_dict(model_config: dict[str, Any]) -> Any: + """Create a Model instance from a provider config dict. + + Routes the config to the appropriate model class based on the ``provider`` field, + then delegates to the class's ``from_dict`` method. All imports are lazy to avoid + requiring optional dependencies that are not installed. + + Args: + model_config: Dict containing at least a ``provider`` key and provider-specific params. + + Returns: + A configured Model instance for the specified provider. + + Raises: + ValueError: If the provider name is not recognized. + ImportError: If the provider's optional dependencies are not installed. + """ + config = model_config.copy() + provider = config.pop("provider") + + class_name = PROVIDER_MAP.get(provider) + if class_name is None: + supported = ", ".join(sorted(PROVIDER_MAP.keys())) + raise ValueError(f"Unknown model provider: '{provider}'. Supported providers: {supported}") + + from .. import models + + model_cls = getattr(models, class_name) + return model_cls.from_dict(config) + def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any: """Create an Agent from a configuration file or dictionary. @@ -83,6 +205,12 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A Create agent from dictionary: >>> config = {"model": "anthropic.claude-3-5-sonnet-20241022-v2:0", "tools": ["calculator"]} >>> agent = config_to_agent(config) + + Create agent with object model config: + >>> config = { + ... "model": {"provider": "openai", "model_id": "gpt-4o", "client_args": {"api_key": "$OPENAI_API_KEY"}} + ... } + >>> agent = config_to_agent(config) """ # Parse configuration if isinstance(config, str): @@ -114,11 +242,20 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A raise ValueError(f"Configuration validation error at {error_path}: {e.message}") from e # Prepare Agent constructor arguments - agent_kwargs = {} + agent_kwargs: dict[str, Any] = {} + + # Handle model field — string vs object format + model_value = config_dict.get("model") + if isinstance(model_value, dict): + # Object format: resolve env vars and create Model instance via factory + resolved_config = _resolve_env_vars(model_value) + agent_kwargs["model"] = _create_model_from_dict(resolved_config) + elif model_value is not None: + # String format (backward compat): pass directly as model_id to Agent + agent_kwargs["model"] = model_value - # Map configuration keys to Agent constructor parameters + # Map remaining configuration keys to Agent constructor parameters config_mapping = { - "model": "model", "prompt": "system_prompt", "tools": "tools", "name": "name", diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index bfb7b1ede..c3c246fc3 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -127,6 +127,32 @@ class BedrockConfig(TypedDict, total=False): temperature: float | None top_p: float | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "BedrockModel": + """Create a BedrockModel from a configuration dictionary. + + Handles extraction of ``region_name``, ``endpoint_url``, and conversion of + ``boto_client_config`` from a plain dict to ``botocore.config.Config``. + + Args: + config: Model configuration dictionary. + + Returns: + A configured BedrockModel instance. + """ + kwargs: dict[str, Any] = {} + + if "region_name" in config: + kwargs["region_name"] = config.pop("region_name") + if "endpoint_url" in config: + kwargs["endpoint_url"] = config.pop("endpoint_url") + if "boto_client_config" in config: + raw = config.pop("boto_client_config") + kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw + + kwargs.update(config) + return cls(**kwargs) + def __init__( self, *, diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index c52509816..36da4ca03 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -131,6 +131,26 @@ class LlamaCppConfig(TypedDict, total=False): model_id: str params: dict[str, Any] | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "LlamaCppModel": + """Create a LlamaCppModel from a configuration dictionary. + + Handles extraction of ``base_url`` and ``timeout`` as separate constructor parameters. + + Args: + config: Model configuration dictionary. + + Returns: + A configured LlamaCppModel instance. + """ + kwargs: dict[str, Any] = {} + if "base_url" in config: + kwargs["base_url"] = config.pop("base_url") + if "timeout" in config: + kwargs["timeout"] = config.pop("timeout") + kwargs.update(config) + return cls(**kwargs) + def __init__( self, base_url: str = "http://localhost:8080", diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index f44a11d30..9bdb8ced9 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -53,6 +53,28 @@ class MistralConfig(TypedDict, total=False): top_p: float | None stream: bool | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "MistralModel": + """Create a MistralModel from a configuration dictionary. + + Handles extraction of ``api_key`` and ``client_args`` as separate constructor parameters. + + Args: + config: Model configuration dictionary. + + Returns: + A configured MistralModel instance. + """ + api_key = config.pop("api_key", None) + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if api_key is not None: + kwargs["api_key"] = api_key + if client_args is not None: + kwargs["client_args"] = client_args + kwargs.update(config) + return cls(**kwargs) + def __init__( self, api_key: str | None = None, diff --git a/src/strands/models/model.py b/src/strands/models/model.py index f084d24d5..0c3a5c7b8 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -1,5 +1,7 @@ """Abstract base class for Agent model providers.""" +from __future__ import annotations + import abc import logging from collections.abc import AsyncGenerator, AsyncIterable @@ -51,6 +53,27 @@ def stateful(self) -> bool: """ return False + @classmethod + def from_dict(cls, config: dict[str, Any]) -> Model: + """Create a Model instance from a configuration dictionary. + + The default implementation extracts ``client_args`` (if present) and passes + all remaining keys as keyword arguments to the constructor. Subclasses with + non-standard constructor signatures should override this method. + + Args: + config: Provider-specific configuration dictionary. + + Returns: + A configured Model instance. + """ + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if client_args is not None: + kwargs["client_args"] = client_args + kwargs.update(config) + return cls(**kwargs) + @abc.abstractmethod # pragma: no cover def update_config(self, **model_config: Any) -> None: diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 97cb7948a..37b7090b1 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -56,6 +56,27 @@ class OllamaConfig(TypedDict, total=False): temperature: float | None top_p: float | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "OllamaModel": + """Create an OllamaModel from a configuration dictionary. + + Handles extraction of ``host`` as a positional argument and mapping of + ``client_args`` to the ``ollama_client_args`` constructor parameter. + + Args: + config: Model configuration dictionary. + + Returns: + A configured OllamaModel instance. + """ + host = config.pop("host", None) + client_args = config.pop("client_args", None) + kwargs: dict[str, Any] = {} + if client_args is not None: + kwargs["ollama_client_args"] = client_args + kwargs.update(config) + return cls(host, **kwargs) + def __init__( self, host: str | None, diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 775969290..424bac85f 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -133,6 +133,27 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): target_variant: str | None | None additional_args: dict[str, Any] | None + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "SageMakerAIModel": + """Create a SageMakerAIModel from a configuration dictionary. + + Handles extraction of ``endpoint_config``, ``payload_config``, and conversion of + ``boto_client_config`` from a plain dict to ``botocore.config.Config``. + + Args: + config: Model configuration dictionary. + + Returns: + A configured SageMakerAIModel instance. + """ + kwargs: dict[str, Any] = {} + kwargs["endpoint_config"] = config.pop("endpoint_config", {}) + kwargs["payload_config"] = config.pop("payload_config", {}) + if "boto_client_config" in config: + raw = config.pop("boto_client_config") + kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw + return cls(**kwargs) + def __init__( self, endpoint_config: SageMakerAIEndpointConfig, diff --git a/tests/strands/experimental/test_agent_config.py b/tests/strands/experimental/test_agent_config.py index e6188079b..e60a24b94 100644 --- a/tests/strands/experimental/test_agent_config.py +++ b/tests/strands/experimental/test_agent_config.py @@ -3,10 +3,21 @@ import json import os import tempfile +from typing import Any +from unittest.mock import MagicMock, patch import pytest from strands.experimental import config_to_agent +from strands.experimental.agent_config import ( + PROVIDER_MAP, + _create_model_from_dict, + _resolve_env_vars, +) + +# ============================================================================= +# Backward compatibility tests (existing) +# ============================================================================= def test_config_to_agent_with_dict(): @@ -170,3 +181,541 @@ def test_config_to_agent_with_tool(): config = {"model": "test-model", "tools": ["tests.fixtures.say_tool:say"]} agent = config_to_agent(config) assert "say" in agent.tool_names + + +# ============================================================================= +# Environment variable resolution tests +# ============================================================================= + + +class TestResolveEnvVars: + """Tests for the _resolve_env_vars utility function.""" + + def test_resolve_dollar_prefix(self): + """Test resolving $VAR_NAME format.""" + with patch.dict(os.environ, {"MY_API_KEY": "secret123"}): + assert _resolve_env_vars("$MY_API_KEY") == "secret123" + + def test_resolve_braced_format(self): + """Test resolving ${VAR_NAME} format.""" + with patch.dict(os.environ, {"MY_API_KEY": "secret456"}): + assert _resolve_env_vars("${MY_API_KEY}") == "secret456" + + def test_resolve_nested_dict(self): + """Test recursive resolution in nested dicts.""" + with patch.dict(os.environ, {"KEY1": "val1", "KEY2": "val2"}): + data = {"outer": {"inner": "$KEY1"}, "flat": "${KEY2}"} + result = _resolve_env_vars(data) + assert result == {"outer": {"inner": "val1"}, "flat": "val2"} + + def test_resolve_list(self): + """Test recursive resolution in lists.""" + with patch.dict(os.environ, {"KEY1": "val1", "KEY2": "val2"}): + data = ["$KEY1", "${KEY2}", "literal"] + result = _resolve_env_vars(data) + assert result == ["val1", "val2", "literal"] + + def test_missing_env_var_raises(self): + """Test that missing env vars raise ValueError.""" + with patch.dict(os.environ, {}, clear=True): + # Ensure the var is not set + os.environ.pop("NONEXISTENT_VAR", None) + with pytest.raises(ValueError, match="Environment variable 'NONEXISTENT_VAR' is not set"): + _resolve_env_vars("$NONEXISTENT_VAR") + + def test_missing_braced_env_var_raises(self): + """Test that missing braced env vars raise ValueError.""" + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("NONEXISTENT_VAR", None) + with pytest.raises(ValueError, match="Environment variable 'NONEXISTENT_VAR' is not set"): + _resolve_env_vars("${NONEXISTENT_VAR}") + + def test_non_env_string_unchanged(self): + """Test that regular strings are returned unchanged.""" + assert _resolve_env_vars("just-a-string") == "just-a-string" + + def test_non_string_values_unchanged(self): + """Test that non-string values pass through unchanged.""" + assert _resolve_env_vars(42) == 42 + assert _resolve_env_vars(True) is True + assert _resolve_env_vars(3.14) == 3.14 + assert _resolve_env_vars(None) is None + + def test_deeply_nested_resolution(self): + """Test env var resolution in deeply nested structures.""" + with patch.dict(os.environ, {"DEEP_VAL": "found"}): + data = {"a": {"b": {"c": [{"d": "$DEEP_VAL"}]}}} + result = _resolve_env_vars(data) + assert result == {"a": {"b": {"c": [{"d": "found"}]}}} + + +# ============================================================================= +# Schema validation tests — dual-format model field +# ============================================================================= + + +class TestSchemaValidation: + """Tests for the updated AGENT_CONFIG_SCHEMA that supports both string and object model formats.""" + + def test_string_model_valid(self): + """Test that string model format still passes validation.""" + config = {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + agent = config_to_agent(config) + assert agent.model.config["model_id"] == "us.anthropic.claude-sonnet-4-20250514-v1:0" + + def test_object_model_valid(self): + """Test that object model format passes schema validation.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 10000, + } + } + agent = config_to_agent(config) + assert agent.model is mock_model + + def test_object_model_missing_provider_raises(self): + """Test that object model without provider raises validation error.""" + config = {"model": {"model_id": "some-model"}} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + def test_object_model_allows_additional_properties(self): + """Test that object model format allows provider-specific properties.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "openai", + "model_id": "gpt-4o", + "client_args": {"api_key": "test"}, + "custom_field": "allowed", + } + } + # Should not raise + config_to_agent(config) + + def test_null_model_still_valid(self): + """Test that null model is still accepted for default behavior.""" + config = {"model": None} + agent = config_to_agent(config) + # Should use default model + assert agent is not None + + def test_model_wrong_type_raises(self): + """Test that model field with invalid type raises validation error.""" + config = {"model": 12345} + with pytest.raises(ValueError, match="Configuration validation error"): + config_to_agent(config) + + def test_object_model_from_file(self): + """Test object model format loaded from a JSON file.""" + mock_model = MagicMock() + config_data = { + "model": { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + } + } + temp_path = "" + try: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: + json.dump(config_data, f) + f.flush() + temp_path = f.name + + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + agent = config_to_agent(temp_path) + assert agent.model is mock_model + finally: + if os.path.exists(temp_path): + os.remove(temp_path) + + +# ============================================================================= +# Provider factory tests — all 12 providers +# ============================================================================= + + +class TestProviderMap: + """Test that all 12 providers are registered in PROVIDER_MAP.""" + + EXPECTED_PROVIDERS = [ + "bedrock", + "anthropic", + "openai", + "gemini", + "ollama", + "litellm", + "mistral", + "llamaapi", + "llamacpp", + "sagemaker", + "writer", + "openai_responses", + ] + + def test_all_providers_registered(self): + """Test that all 12 providers are in PROVIDER_MAP.""" + for provider in self.EXPECTED_PROVIDERS: + assert provider in PROVIDER_MAP, f"Provider '{provider}' not found in PROVIDER_MAP" + + def test_no_extra_providers(self): + """Test that only the expected 12 providers are registered.""" + assert set(PROVIDER_MAP.keys()) == set(self.EXPECTED_PROVIDERS) + + +class TestCreateModelFromConfig: + """Tests for _create_model_from_dict dispatching to cls.from_dict.""" + + def test_unknown_provider_raises(self): + """Test that an unknown provider name raises ValueError.""" + with pytest.raises(ValueError, match="Unknown model provider: 'nonexistent'"): + _create_model_from_dict({"provider": "nonexistent", "model_id": "x"}) + + def _patch_model_class(self, class_name): + """Patch a model class on the strands.models module and return the mock.""" + mock_cls = MagicMock() + mock_cls.from_dict.return_value = MagicMock() + return patch(f"strands.models.{class_name}", mock_cls, create=True), mock_cls + + def test_dispatches_to_from_dict(self): + """Test that _create_model_from_dict calls cls.from_dict on the resolved model class.""" + mock_model = MagicMock() + mock_cls = MagicMock() + mock_cls.from_dict.return_value = mock_model + + with patch("strands.models.AnthropicModel", mock_cls, create=True): + result = _create_model_from_dict( + { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 8192, + "client_args": {"api_key": "test-key"}, + } + ) + mock_cls.from_dict.assert_called_once() + call_config = mock_cls.from_dict.call_args[0][0] + assert call_config["model_id"] == "claude-sonnet-4-20250514" + assert call_config["max_tokens"] == 8192 + assert call_config["client_args"] == {"api_key": "test-key"} + assert "provider" not in call_config + assert result is mock_model + + def test_does_not_mutate_input(self): + """Test that _create_model_from_dict does not mutate the input dict.""" + original = {"provider": "anthropic", "model_id": "test"} + original_copy = original.copy() + + mock_cls = MagicMock() + mock_cls.from_dict.return_value = MagicMock() + with patch("strands.models.AnthropicModel", mock_cls, create=True): + _create_model_from_dict(original) + + assert original == original_copy + + @pytest.mark.parametrize( + "provider,class_name", + list(PROVIDER_MAP.items()), + ) + def test_all_providers_dispatch(self, provider, class_name): + """Test that each registered provider dispatches to the correct class.""" + patcher, mock_cls = self._patch_model_class(class_name) + with patcher: + _create_model_from_dict({"provider": provider, "model_id": "test"}) + mock_cls.from_dict.assert_called_once() + + +# ============================================================================= +# Model from_dict tests — provider-specific parameter handling +# ============================================================================= + + +class TestModelFromConfig: + """Tests for from_dict on model classes with non-standard constructors. + + Patches __init__ on each model class to capture the arguments passed by from_dict + without actually initializing the model (which would require real provider dependencies). + """ + + def test_bedrock_from_dict_boto_client_config_conversion(self): + """Test that BedrockModel.from_dict converts boto_client_config dict to BotocoreConfig.""" + from botocore.config import Config as BotocoreConfig + + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "region_name": "us-west-2", + "boto_client_config": {"read_timeout": 300}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["region_name"] == "us-west-2" + assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) + assert call_kwargs["model_id"] == "test-model" + + def test_bedrock_from_dict_without_boto_client_config(self): + """Test BedrockModel.from_dict without boto_client_config.""" + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "region_name": "us-east-1", + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["region_name"] == "us-east-1" + assert "boto_client_config" not in call_kwargs + + def test_bedrock_from_dict_endpoint_url(self): + """Test BedrockModel.from_dict with endpoint_url.""" + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + BedrockModel.from_dict( + { + "model_id": "test-model", + "endpoint_url": "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com", + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["endpoint_url"] == "https://vpce-1234.bedrock-runtime.us-west-2.vpce.amazonaws.com" + + def test_ollama_from_dict_host_and_client_args_mapping(self): + """Test that OllamaModel.from_dict routes host and maps client_args to ollama_client_args.""" + from strands.models.ollama import OllamaModel + + with patch.object(OllamaModel, "__init__", return_value=None) as mock_init: + OllamaModel.from_dict( + { + "model_id": "llama3", + "host": "http://localhost:11434", + "client_args": {"timeout": 30}, + } + ) + call_args = mock_init.call_args + assert call_args[0][0] == "http://localhost:11434" # host is positional + assert call_args[1]["ollama_client_args"] == {"timeout": 30} + assert call_args[1]["model_id"] == "llama3" + + def test_ollama_from_dict_default_host(self): + """Test OllamaModel.from_dict with no host specified defaults to None.""" + from strands.models.ollama import OllamaModel + + with patch.object(OllamaModel, "__init__", return_value=None) as mock_init: + OllamaModel.from_dict({"model_id": "llama3"}) + call_args = mock_init.call_args + assert call_args[0][0] is None # host defaults to None + + def test_mistral_from_dict_api_key_extraction(self): + """Test that MistralModel.from_dict extracts api_key separately.""" + from strands.models.mistral import MistralModel + + with patch.object(MistralModel, "__init__", return_value=None) as mock_init: + MistralModel.from_dict( + { + "model_id": "mistral-large-latest", + "api_key": "test-key", + "client_args": {"timeout": 60}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["api_key"] == "test-key" + assert call_kwargs["client_args"] == {"timeout": 60} + assert call_kwargs["model_id"] == "mistral-large-latest" + + def test_llamacpp_from_dict_base_url_and_timeout(self): + """Test that LlamaCppModel.from_dict extracts base_url and timeout.""" + from strands.models.llamacpp import LlamaCppModel + + with patch.object(LlamaCppModel, "__init__", return_value=None) as mock_init: + LlamaCppModel.from_dict( + { + "model_id": "default", + "base_url": "http://myhost:8080", + "timeout": 30.0, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["base_url"] == "http://myhost:8080" + assert call_kwargs["timeout"] == 30.0 + assert call_kwargs["model_id"] == "default" + + def test_sagemaker_from_dict_dict_params(self): + """Test that SageMakerAIModel.from_dict receives endpoint_config and payload_config as dicts.""" + from strands.models.sagemaker import SageMakerAIModel + + with patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: + SageMakerAIModel.from_dict( + { + "endpoint_config": {"endpoint_name": "my-ep", "region_name": "us-west-2"}, + "payload_config": {"max_tokens": 1024, "stream": True}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["endpoint_config"] == {"endpoint_name": "my-ep", "region_name": "us-west-2"} + assert call_kwargs["payload_config"] == {"max_tokens": 1024, "stream": True} + + def test_sagemaker_from_dict_boto_client_config_conversion(self): + """Test that SageMakerAIModel.from_dict converts boto_client_config dict to BotocoreConfig.""" + from botocore.config import Config as BotocoreConfig + + from strands.models.sagemaker import SageMakerAIModel + + with patch.object(SageMakerAIModel, "__init__", return_value=None) as mock_init: + SageMakerAIModel.from_dict( + { + "endpoint_config": {"endpoint_name": "my-ep"}, + "payload_config": {"max_tokens": 1024}, + "boto_client_config": {"read_timeout": 300}, + } + ) + call_kwargs = mock_init.call_args[1] + assert isinstance(call_kwargs["boto_client_config"], BotocoreConfig) + + def test_default_from_dict_client_args_pattern(self): + """Test the default from_dict (inherited) handles client_args + remaining kwargs.""" + from strands.models.bedrock import BedrockModel + + with patch.object(BedrockModel, "__init__", return_value=None) as mock_init: + # BedrockModel overrides from_dict, so use AnthropicModel which inherits the default + from strands.models.anthropic import AnthropicModel + + with patch.object(AnthropicModel, "__init__", return_value=None) as mock_init: + AnthropicModel.from_dict( + { + "model_id": "claude-sonnet-4-20250514", + "max_tokens": 4096, + "client_args": {"api_key": "test"}, + "params": {"temperature": 0.5}, + } + ) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["client_args"] == {"api_key": "test"} + assert call_kwargs["model_id"] == "claude-sonnet-4-20250514" + assert call_kwargs["max_tokens"] == 4096 + assert call_kwargs["params"] == {"temperature": 0.5} + + def test_default_from_dict_without_client_args(self): + """Test the default from_dict works without client_args.""" + from strands.models.anthropic import AnthropicModel + + with patch.object(AnthropicModel, "__init__", return_value=None) as mock_init: + AnthropicModel.from_dict({"model_id": "test-model", "max_tokens": 1024}) + call_kwargs = mock_init.call_args[1] + assert call_kwargs["model_id"] == "test-model" + assert call_kwargs["max_tokens"] == 1024 + assert "client_args" not in call_kwargs + + +# ============================================================================= +# Error handling tests +# ============================================================================= + + +class TestErrorHandling: + """Tests for error handling in model creation.""" + + def test_missing_optional_dependency(self): + """Test clear error when provider dependency is not installed.""" + mock_cls = MagicMock() + mock_cls.from_dict.side_effect = ImportError("No module named 'anthropic'") + + with patch("strands.models.AnthropicModel", mock_cls, create=True): + with pytest.raises(ImportError, match="anthropic"): + _create_model_from_dict( + { + "provider": "anthropic", + "model_id": "claude-sonnet-4-20250514", + } + ) + + def test_unknown_provider_error_message(self): + """Test that unknown provider gives helpful error message.""" + with pytest.raises(ValueError, match="Unknown model provider: 'my_custom_provider'"): + _create_model_from_dict({"provider": "my_custom_provider"}) + + +# ============================================================================= +# Integration: config_to_agent with object model +# ============================================================================= + + +class TestConfigToAgentObjectModel: + """Tests for config_to_agent using the object model format end-to-end.""" + + def test_object_model_creates_agent(self): + """Test that object model config creates an agent with the correct model.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": { + "provider": "openai", + "model_id": "gpt-4o", + }, + "prompt": "You are helpful", + } + agent = config_to_agent(config) + assert agent.model is mock_model + assert agent.system_prompt == "You are helpful" + + def test_object_model_env_var_resolution(self): + """Test that env vars are resolved in object model config before provider creation.""" + mock_model = MagicMock() + with patch.dict(os.environ, {"TEST_API_KEY": "resolved-key"}): + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ) as mock_create: + config = { + "model": { + "provider": "openai", + "model_id": "gpt-4o", + "client_args": {"api_key": "$TEST_API_KEY"}, + } + } + config_to_agent(config) + # Verify the env var was resolved before passing to the factory + call_args = mock_create.call_args[0][0] + assert call_args["client_args"]["api_key"] == "resolved-key" + + def test_string_model_backward_compat(self): + """Test that string model still works as Bedrock model_id.""" + config = {"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + agent = config_to_agent(config) + # String model is passed directly to Agent, which interprets it as Bedrock model_id + assert agent.model.config["model_id"] == "us.anthropic.claude-sonnet-4-20250514-v1:0" + + def test_object_model_with_kwargs_override(self): + """Test that kwargs can still override when using object model.""" + mock_model = MagicMock() + with patch( + "strands.experimental.agent_config._create_model_from_dict", + return_value=mock_model, + ): + config = { + "model": {"provider": "openai", "model_id": "gpt-4o"}, + "prompt": "Original prompt", + } + agent = config_to_agent(config, system_prompt="Override prompt") + assert agent.system_prompt == "Override prompt"