diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..8f52d3a 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from . import memory +from . import plugins from . import sessions from . import version __version__ = version.__version__ diff --git a/src/google/adk_community/plugins/__init__.py b/src/google/adk_community/plugins/__init__.py new file mode 100644 index 0000000..a08f9fb --- /dev/null +++ b/src/google/adk_community/plugins/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Community plugins for Google ADK.""" + +from .llm_resilience_plugin import LlmResiliencePlugin + +__all__ = [ + "LlmResiliencePlugin", +] diff --git a/src/google/adk_community/plugins/llm_resilience_plugin.py b/src/google/adk_community/plugins/llm_resilience_plugin.py new file mode 100644 index 0000000..6cdd2e7 --- /dev/null +++ b/src/google/adk_community/plugins/llm_resilience_plugin.py @@ -0,0 +1,352 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LlmResiliencePlugin - retry with exponential backoff and model fallbacks.""" + +from __future__ import annotations + +import asyncio +import logging +import random +from typing import Iterable +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from google.adk.agents.invocation_context import InvocationContext + +try: + import httpx +except Exception: # pragma: no cover - httpx might not be installed in all envs + httpx = None # type: ignore + +from google.genai import types + +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.models.registry import LLMRegistry +from google.adk.plugins.base_plugin import BasePlugin + +logger = logging.getLogger("google_adk_community." + __name__) + + +def _extract_status_code(err: Exception) -> Optional[int]: + """Best-effort extraction of HTTP status codes from common client libraries.""" + status = getattr(err, "status_code", None) + if isinstance(status, int): + return status + # httpx specific + if httpx is not None: + if isinstance(err, httpx.HTTPStatusError): + try: + return int(err.response.status_code) + except (AttributeError, ValueError, TypeError): + return None + # Fallback: look for nested response + resp = getattr(err, "response", None) + if resp is not None: + code = getattr(resp, "status_code", None) + if isinstance(code, int): + return code + return None + + +def _is_transient_error(err: Exception) -> bool: + """Check if an error is transient and should trigger retry.""" + # Retry on common transient classes and HTTP status codes + transient_http = {429, 500, 502, 503, 504} + status = _extract_status_code(err) + if status is not None and status in transient_http: + return True + + # httpx transient + if httpx is not None and isinstance( + err, (httpx.ReadTimeout, httpx.ConnectError, httpx.RemoteProtocolError) + ): + return True + + # asyncio timeouts and cancellations often warrant retry/fallback at callsite + if isinstance(err, (asyncio.TimeoutError,)): + return True + + return False + + +class LlmResiliencePlugin(BasePlugin): + """A plugin that adds retry with exponential backoff and model fallbacks. + + Behavior: + - Intercepts model errors via on_model_error_callback + - Retries the same model up to max_retries with exponential backoff + jitter + - If still failing and fallback_models configured, tries them in order + - Returns the first successful LlmResponse or None to propagate the error + + Notes: + - Live (bidirectional) mode errors are not intercepted by BaseLlmFlow's error + handler; this plugin currently targets generate_content_async flow. + - In SSE streaming mode, the plugin returns a single final LlmResponse. + + Example: + >>> from google.adk.runners import Runner + >>> from google.adk_community.plugins import LlmResiliencePlugin + >>> + >>> runner = Runner( + ... app_name="my_app", + ... agent=my_agent, + ... plugins=[ + ... LlmResiliencePlugin( + ... max_retries=3, + ... backoff_initial=1.0, + ... fallback_models=["gemini-1.5-flash"], + ... ) + ... ], + ... ) + """ + + def __init__( + self, + *, + name: str = "llm_resilience_plugin", + max_retries: int = 3, + backoff_initial: float = 1.0, + backoff_multiplier: float = 2.0, + max_backoff: float = 10.0, + jitter: float = 0.2, + retry_on_exceptions: Optional[tuple[type[BaseException], ...]] = None, + fallback_models: Optional[Iterable[str]] = None, + ) -> None: + """Initialize the LlmResiliencePlugin. + + Args: + name: Plugin name identifier. + max_retries: Maximum number of retry attempts on the same model. + backoff_initial: Initial backoff delay in seconds. + backoff_multiplier: Multiplier for exponential backoff. + max_backoff: Maximum backoff delay in seconds. + jitter: Jitter factor (0.0 to 1.0) to add randomness to backoff. + retry_on_exceptions: Optional tuple of exception types to retry on. + If None, uses built-in transient error detection. + fallback_models: Optional list of model names to try if primary fails. + """ + super().__init__(name) + if max_retries < 0: + raise ValueError("max_retries must be >= 0") + if backoff_initial <= 0: + raise ValueError("backoff_initial must be > 0") + if backoff_multiplier < 1.0: + raise ValueError("backoff_multiplier must be >= 1.0") + if max_backoff <= 0: + raise ValueError("max_backoff must be > 0") + if jitter < 0: + raise ValueError("jitter must be >= 0") + + self.max_retries = max_retries + self.backoff_initial = backoff_initial + self.backoff_multiplier = backoff_multiplier + self.max_backoff = max_backoff + self.jitter = jitter + self.retry_on_exceptions = retry_on_exceptions + self.fallback_models = list(fallback_models or []) + + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + """Handle model errors with retry and fallback logic.""" + # Decide whether to handle this error: + # Retry if error is in retry_on_exceptions OR is a transient error + if self.retry_on_exceptions and isinstance(error, self.retry_on_exceptions): + # User explicitly wants to retry on this exception type. + pass + elif not _is_transient_error(error): + # Not an explicit exception and not a transient error, so don't handle. + return None + + # Attempt retries on the same model + response = await self._retry_same_model( + callback_context=callback_context, llm_request=llm_request + ) + if response is not None: + return response + + # Try fallbacks in order + if self.fallback_models: + response = await self._try_fallbacks( + callback_context=callback_context, llm_request=llm_request + ) + if response is not None: + return response + + # Let the original error propagate if all attempts failed + return None + + def _get_invocation_context( + self, callback_context: CallbackContext | InvocationContext + ) -> InvocationContext: + """Extract InvocationContext from callback_context. + + Accepts both Context (CallbackContext alias) and InvocationContext via + duck typing. + + Args: + callback_context: The callback context passed to the plugin. + + Returns: + The underlying InvocationContext. + + Raises: + TypeError: If callback_context is not a recognized type. + """ + # If this looks like an InvocationContext (has agent and run_config), use it directly + if hasattr(callback_context, "agent") and hasattr( + callback_context, "run_config" + ): + return callback_context # type: ignore[return-value] + # Otherwise expect a Context-like object exposing the private _invocation_context + ic = getattr(callback_context, "_invocation_context", None) + if ic is None: + raise TypeError( + "callback_context must be Context or InvocationContext-like" + ) + return ic + + def _is_sse_streaming(self, invocation_context: InvocationContext) -> bool: + """Check if SSE streaming mode is enabled. + + Args: + invocation_context: The invocation context to check. + + Returns: + True if SSE streaming is enabled, False otherwise. + """ + streaming_mode = getattr( + invocation_context.run_config, "streaming_mode", None + ) + try: + from google.adk.agents.run_config import StreamingMode + + return streaming_mode == StreamingMode.SSE + except (ImportError, AttributeError): + return False + + async def _retry_same_model( + self, + *, + callback_context: CallbackContext | InvocationContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + invocation_context = self._get_invocation_context(callback_context) + stream = self._is_sse_streaming(invocation_context) + + agent = invocation_context.agent + llm = agent.canonical_model + + backoff = self.backoff_initial + for attempt in range(1, self.max_retries + 1): + sleep_time = min(self.max_backoff, backoff) + # add multiplicative (+/-) jitter + if self.jitter > 0: + jitter_delta = sleep_time * random.uniform(-self.jitter, self.jitter) + sleep_time = max(0.0, sleep_time + jitter_delta) + if sleep_time > 0: + await asyncio.sleep(sleep_time) + + try: + final_response = await self._call_llm_and_get_final( + llm=llm, llm_request=llm_request, stream=stream + ) + logger.info( + "LLM retry succeeded on attempt %s for agent %s", + attempt, + agent.name, + ) + return final_response + except Exception as e: # continue to next attempt + logger.warning( + "LLM retry attempt %s failed: %s", attempt, repr(e), exc_info=False + ) + backoff *= self.backoff_multiplier + + return None + + async def _try_fallbacks( + self, + *, + callback_context: CallbackContext | InvocationContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + invocation_context = self._get_invocation_context(callback_context) + stream = self._is_sse_streaming(invocation_context) + + for model_name in self.fallback_models: + try: + fallback_llm = LLMRegistry.new_llm(model_name) + # Update request model hint for provider bridges that honor it + llm_request.model = model_name + final_response = await self._call_llm_and_get_final( + llm=fallback_llm, llm_request=llm_request, stream=stream + ) + logger.info("LLM fallback succeeded with model '%s'", model_name) + return final_response + except Exception as e: + logger.warning( + "LLM fallback model '%s' failed: %s", + model_name, + repr(e), + exc_info=False, + ) + continue + return None + + async def _call_llm_and_get_final( + self, *, llm, llm_request: LlmRequest, stream: bool + ) -> LlmResponse: + """Calls the given llm and returns the final non-partial LlmResponse.""" + import inspect + + final: Optional[LlmResponse] = None + agen_or_coro = llm.generate_content_async(llm_request, stream=stream) + + # If the provider raised before first yield, this may be a coroutine; handle gracefully + if inspect.isasyncgen(agen_or_coro) or hasattr(agen_or_coro, "__aiter__"): + agen = agen_or_coro + try: + async for resp in agen: + # Keep the latest response; in streaming mode, last one is non-partial + final = resp + finally: + # If the generator is an async generator, ensure it's closed properly + try: + await agen.aclose() # type: ignore[attr-defined] + except Exception: + pass + else: + # Await the coroutine; some LLMs may return a single response + result = await agen_or_coro + if isinstance(result, LlmResponse): + final = result + elif isinstance(result, types.Content): + final = LlmResponse(content=result, partial=False) + else: + # Unknown return type + raise TypeError("LLM generate_content_async returned unsupported type") + + if final is None: + # Edge case: provider yielded nothing. Create a minimal error response. + return LlmResponse(partial=False) + return final diff --git a/tests/unittests/plugins/__init__.py b/tests/unittests/plugins/__init__.py new file mode 100644 index 0000000..0a2669d --- /dev/null +++ b/tests/unittests/plugins/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/plugins/test_llm_resilience_plugin.py b/tests/unittests/plugins/test_llm_resilience_plugin.py new file mode 100644 index 0000000..e9aac8c --- /dev/null +++ b/tests/unittests/plugins/test_llm_resilience_plugin.py @@ -0,0 +1,273 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LlmResiliencePlugin.""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator +from unittest import IsolatedAsyncioTestCase + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.models.base_llm import BaseLlm +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.models.registry import LLMRegistry +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk_community.plugins.llm_resilience_plugin import LlmResiliencePlugin +from google.genai import types + + +class AlwaysFailModel(BaseLlm): + model: str = "failing-model" + + @classmethod + def supported_models(cls) -> list[str]: + return ["failing-model"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + # Always raise a timeout error to simulate transient failures + raise asyncio.TimeoutError("Simulated timeout in AlwaysFailModel") + yield # Make this a generator + + +class SimpleSuccessModel(BaseLlm): + model: str = "mock" + + @classmethod + def supported_models(cls) -> list[str]: + return ["mock"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + # Return a single final response regardless of stream flag + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="final response from mock")], + ), + partial=False, + ) + + +async def create_invocation_context(agent: LlmAgent) -> InvocationContext: + """Helper to create an InvocationContext for testing.""" + invocation_id = "test_id" + artifact_service = InMemoryArtifactService() + session_service = InMemorySessionService() + memory_service = InMemoryMemoryService() + invocation_context = InvocationContext( + artifact_service=artifact_service, + session_service=session_service, + memory_service=memory_service, + plugin_manager=PluginManager(plugins=[]), + invocation_id=invocation_id, + agent=agent, + session=await session_service.create_session( + app_name="test_app", user_id="test_user" + ), + user_content=types.Content( + role="user", parts=[types.Part.from_text(text="")] + ), + run_config=RunConfig(), + ) + return invocation_context + + +class TestLlmResiliencePlugin(IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls): + # Register test models in the registry once + LLMRegistry.register(AlwaysFailModel) + LLMRegistry.register(SimpleSuccessModel) + + async def test_retry_success_on_same_model(self): + # Agent uses SimpleSuccessModel directly + agent = LlmAgent(name="agent", model=SimpleSuccessModel()) + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=2) + + # Build a minimal request + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]) + ] + ) + + # Simulate an initial transient error (e.g., 429/timeout) + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=asyncio.TimeoutError(), + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, LlmResponse) + self.assertFalse(result.partial) + self.assertIsNotNone(result.content) + self.assertEqual( + result.content.parts[0].text.strip(), "final response from mock" + ) + + async def test_fallback_model_used_after_retries(self): + # Agent starts with a failing string model; plugin will fallback to "mock" + agent = LlmAgent(name="agent", model="failing-model") + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=1, fallback_models=["mock"]) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + + # Trigger resilience with a transient error + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=asyncio.TimeoutError(), + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, LlmResponse) + self.assertFalse(result.partial) + self.assertEqual( + result.content.parts[0].text.strip(), "final response from mock" + ) + + async def test_non_transient_error_bubbles(self): + # Agent with success model, but error is non-transient → plugin should ignore + agent = LlmAgent(name="agent", model=SimpleSuccessModel()) + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=2) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + + class NonTransientError(RuntimeError): + pass + + # Non-transient error: status code not transient and not Timeout + # The plugin should return None so that the original error propagates + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=NonTransientError("boom"), + ) + self.assertIsNone(result) + + async def test_custom_retry_on_exceptions(self): + """Test that custom exception types in retry_on_exceptions trigger retry.""" + agent = LlmAgent(name="agent", model=SimpleSuccessModel()) + invocation_context = await create_invocation_context(agent) + + class CustomError(Exception): + pass + + # Plugin configured to retry on CustomError (which is NOT a transient error) + plugin = LlmResiliencePlugin( + max_retries=2, + retry_on_exceptions=(CustomError,), + ) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + + # CustomError should trigger retry even though it's not transient + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=CustomError("custom failure"), + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, LlmResponse) + self.assertEqual( + result.content.parts[0].text.strip(), "final response from mock" + ) + + async def test_retry_on_custom_exception_with_fail_then_succeed_model(self): + """Test retry with a model that fails once then succeeds on custom exception.""" + + class MyCustomError(Exception): + pass + + class CustomErrorModel(BaseLlm): + model: str = "custom-error-model" + call_count: int = 0 + + @classmethod + def supported_models(cls) -> list[str]: + return ["custom-error-model"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + CustomErrorModel.call_count += 1 + if CustomErrorModel.call_count == 1: + raise MyCustomError("Custom error!") + + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="Success!")], + ), + partial=False, + ) + + # Set call_count=1 to simulate the initial call already happened + # (which raised the error that triggered on_model_error_callback) + CustomErrorModel.call_count = 1 + LLMRegistry.register(CustomErrorModel) + + agent = LlmAgent(name="agent", model="custom-error-model") + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin( + max_retries=1, + retry_on_exceptions=(MyCustomError,), + ) + llm_request = LlmRequest(contents=[]) + + # The plugin should catch MyCustomError and retry. + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=MyCustomError(), + ) + + self.assertIsNotNone(result) + self.assertEqual(result.content.parts[0].text, "Success!") + self.assertEqual(CustomErrorModel.call_count, 2)