diff --git a/CONTRIBUTION_NOTE.txt b/CONTRIBUTION_NOTE.txt new file mode 100644 index 0000000000..2d22413ee1 --- /dev/null +++ b/CONTRIBUTION_NOTE.txt @@ -0,0 +1,92 @@ +LlmResiliencePlugin Contribution Note +===================================== + +What we implemented +------------------- +1) New plugin: + - Added src/google/adk/plugins/llm_resilience_plugin.py + - Provides retry + backoff + jitter + optional model fallbacks for LLM errors. + +2) Plugin export: + - Updated src/google/adk/plugins/__init__.py + - Exported LlmResiliencePlugin in __all__ for discoverability. + +3) Unit tests: + - Added tests/unittests/plugins/test_llm_resilience_plugin.py + - Covered: + - retry success on same model + - fallback model after retries + - non-transient errors bubbling correctly + +4) Usage sample: + - Added samples/resilient_agent.py + - Demonstrates plugin setup and recovery behavior. + +5) PR narrative and testing evidence: + - Updated PR_BODY.md to match repository PR template: + - issue/description + - testing plan + - manual E2E output + - checklist + + +Why this contribution is meaningful +----------------------------------- +1) Solves a real reliability gap: + Production agents frequently face transient failures (timeouts, 429, 5xx). + This change centralizes resilience behavior and removes repeated ad-hoc retry code. + +2) Low-risk architecture: + The feature is plugin-based and opt-in. + Existing users are unaffected unless they configure the plugin. + +3) Practical for maintainers and users: + Includes tests and a runnable sample, reducing review friction and making adoption easier. + +4) Aligns with ADK extensibility: + Keeps resilience logic at the plugin layer without changing core runner/flow behavior. + + +Key design reasons +------------------ +1) on_model_error_callback hook: + Best fit for intercepting model failures and deciding retry/fallback behavior. + +2) Exponential backoff with jitter: + Reduces retry storms and aligns with standard distributed-system reliability practices. + +3) Model fallback support: + Improves chance of successful completion when a single provider/model is degraded. + +4) Robust provider response handling: + Supports async-generator and coroutine style returns to handle provider differences. + +5) Type-safety/cycle-safe update: + Added TYPE_CHECKING import pattern for InvocationContext to avoid runtime issues. + + +Validation performed +-------------------- +1) Formatting: + - isort applied to changed Python files + - pyink applied to changed Python files + +2) Unit tests: + - Command: + .venv/Scripts/python -m pytest tests/unittests/plugins/test_llm_resilience_plugin.py -v + - Result: 3 passed + +3) Manual E2E sample run: + - Command: + .venv/Scripts/python samples/resilient_agent.py + - Observed: + LLM retry attempt 1 failed: TimeoutError('Simulated transient failure') + Collected 1 events + MODEL: Recovered on retry! + + +Scope and limitations +--------------------- +- This PR focuses on LLM call resilience only. +- Live bidirectional streaming paths are out of scope for this change. +- Future enhancements can add per-exception policies and circuit-breaker style controls. diff --git a/PR_BODY.md b/PR_BODY.md new file mode 100644 index 0000000000..35c6c4f0e4 --- /dev/null +++ b/PR_BODY.md @@ -0,0 +1,90 @@ +# feat(plugins): LlmResiliencePlugin – configurable retries/backoff and model fallbacks + +### Link to Issue or Description of Change + +**1. Link to an existing issue (if applicable):** + +- Closes: N/A +- Related: #1214 +- Related: #2561 +- Related discussions: #2292, #3199 + +**2. Or, if no issue exists, describe the change:** + +**Problem:** +Production agents need first-class resilience to transient LLM/API failures +(timeouts, HTTP 429/5xx). Today, retry/fallback logic is often ad-hoc and +duplicated across projects. + +**Solution:** +Introduce an opt-in plugin, `LlmResiliencePlugin`, that handles transient LLM +errors with configurable retries (exponential backoff + jitter) and optional +model fallbacks, without modifying core runner/flow logic. + +### Summary + +- Added `src/google/adk/plugins/llm_resilience_plugin.py`. +- Exported `LlmResiliencePlugin` in `src/google/adk/plugins/__init__.py`. +- Added unit tests in + `tests/unittests/plugins/test_llm_resilience_plugin.py`: + - `test_retry_success_on_same_model` + - `test_fallback_model_used_after_retries` + - `test_non_transient_error_bubbles` +- Added `samples/resilient_agent.py` demo. + +### Testing Plan + +**Unit Tests:** + +- [x] I have added or updated unit tests for my change. +- [x] All unit tests pass locally. + +Command run: + +```shell +.venv/Scripts/python -m pytest tests/unittests/plugins/test_llm_resilience_plugin.py -v +``` + +Result summary: + +```text +collected 3 items +tests/unittests/plugins/test_llm_resilience_plugin.py::TestLlmResiliencePlugin::test_fallback_model_used_after_retries PASSED +tests/unittests/plugins/test_llm_resilience_plugin.py::TestLlmResiliencePlugin::test_non_transient_error_bubbles PASSED +tests/unittests/plugins/test_llm_resilience_plugin.py::TestLlmResiliencePlugin::test_retry_success_on_same_model PASSED +3 passed +``` + +**Manual End-to-End (E2E) Tests:** + +Run sample: + +```shell +.venv/Scripts/python samples/resilient_agent.py +``` + +Observed output: + +```text +LLM retry attempt 1 failed: TimeoutError('Simulated transient failure') +Collected 1 events +MODEL: Recovered on retry! +``` + +### Checklist + +- [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. +- [x] I have performed a self-review of my own code. +- [x] I have commented my code, particularly in hard-to-understand areas. +- [x] I have added tests that prove my fix is effective or that my feature works. +- [x] New and existing unit tests pass locally with my changes. +- [x] I have manually tested my changes end-to-end. +- [x] Any dependent changes have been merged and published in downstream modules. (N/A; no dependent changes) + +### Additional context + +- Non-breaking: users opt in via + `Runner(..., plugins=[LlmResiliencePlugin(...)])`. +- Transient detection currently targets common HTTP/timeouts and can be extended + in follow-ups (e.g., per-exception policy, circuit breaking). +- Live bidirectional streaming paths are out of scope for this PR. diff --git a/samples/resilient_agent.py b/samples/resilient_agent.py new file mode 100644 index 0000000000..ab8880eacd --- /dev/null +++ b/samples/resilient_agent.py @@ -0,0 +1,105 @@ +# Sample: Using LlmResiliencePlugin for robust model calls +# +# Run with: +# PYTHONPATH=$(pwd)/src python samples/resilient_agent.py +# +# This demonstrates: +# - Configuring LlmResiliencePlugin for retries and fallbacks +# - Running a minimal in-memory agent with a mocked model + +from __future__ import annotations + +import asyncio +from typing import ClassVar + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.models.base_llm import BaseLlm +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.models.registry import LLMRegistry +from google.adk.plugins.llm_resilience_plugin import LlmResiliencePlugin +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types + + +class DemoFailThenSucceedModel(BaseLlm): + model: str = "demo-fail-succeed" + attempts: ClassVar[int] = ( + 0 # Class variable for shared state across instances + ) + + @classmethod + def supported_models(cls) -> list[str]: + return ["demo-fail-succeed"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ): + # Fail for the first attempt, then succeed + DemoFailThenSucceedModel.attempts += 1 + if DemoFailThenSucceedModel.attempts < 2: + raise TimeoutError("Simulated transient failure") + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="Recovered on retry!")], + ), + partial=False, + ) + + +# Register test models +LLMRegistry.register(DemoFailThenSucceedModel) + + +async def main(): + # Agent with the failing-then-succeed model + agent = LlmAgent(name="resilient_agent", model="demo-fail-succeed") + + # Build services and runner in-memory + artifact_service = InMemoryArtifactService() + session_service = InMemorySessionService() + memory_service = InMemoryMemoryService() + + runner = Runner( + app_name="resilience_demo", + agent=agent, + artifact_service=artifact_service, + session_service=session_service, + memory_service=memory_service, + plugins=[ + LlmResiliencePlugin( + max_retries=2, + backoff_initial=0.1, + backoff_multiplier=2.0, + jitter=0.1, + fallback_models=["mock"], # Demonstration; not used here + ) + ], + ) + + # Create a session and run once + session = await session_service.create_session( + app_name="resilience_demo", user_id="demo" + ) + events = [] + async for ev in runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ), + ): + events.append(ev) + + print("Collected", len(events), "events") + for e in events: + if e.content and e.content.parts and e.content.parts[0].text: + print("MODEL:", e.content.parts[0].text.strip()) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index 45caf16038..d1853c5ab8 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -14,6 +14,7 @@ from .base_plugin import BasePlugin from .debug_logging_plugin import DebugLoggingPlugin +from .llm_resilience_plugin import LlmResiliencePlugin from .logging_plugin import LoggingPlugin from .plugin_manager import PluginManager from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin @@ -24,4 +25,5 @@ 'LoggingPlugin', 'PluginManager', 'ReflectAndRetryToolPlugin', + 'LlmResiliencePlugin', ] diff --git a/src/google/adk/plugins/llm_resilience_plugin.py b/src/google/adk/plugins/llm_resilience_plugin.py new file mode 100644 index 0000000000..4e0f3fc48d --- /dev/null +++ b/src/google/adk/plugins/llm_resilience_plugin.py @@ -0,0 +1,327 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import random +from typing import Iterable +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..agents.invocation_context import InvocationContext + +try: + import httpx +except Exception: # pragma: no cover - httpx might not be installed in all envs + httpx = None # type: ignore + +from google.genai import types + +from ..agents.callback_context import CallbackContext +from ..models.llm_request import LlmRequest +from ..models.llm_response import LlmResponse +from ..models.registry import LLMRegistry +from ..plugins.base_plugin import BasePlugin + +logger = logging.getLogger("google_adk." + __name__) + + +def _extract_status_code(err: Exception) -> Optional[int]: + # Best-effort extraction of HTTP status codes from common client libraries + # (httpx, google api errors, etc.) + status = getattr(err, "status_code", None) + if isinstance(status, int): + return status + # httpx specific + if httpx is not None: + if isinstance(err, httpx.HTTPStatusError): + try: + return int(err.response.status_code) + except Exception: + return None + # Fallback: look for nested response + resp = getattr(err, "response", None) + if resp is not None: + code = getattr(resp, "status_code", None) + if isinstance(code, int): + return code + return None + + +def _is_transient_error(err: Exception) -> bool: + # Retry on common transient classes and HTTP status codes + transient_http = {429, 500, 502, 503, 504} + status = _extract_status_code(err) + if status is not None and status in transient_http: + return True + + # httpx transient + if httpx is not None and isinstance( + err, (httpx.ReadTimeout, httpx.ConnectError, httpx.RemoteProtocolError) + ): + return True + + # asyncio timeouts and cancellations often warrant retry/fallback at callsite + if isinstance(err, (asyncio.TimeoutError,)): + return True + + return False + + +class LlmResiliencePlugin(BasePlugin): + """A plugin that adds retry with exponential backoff and model fallbacks. + + Behavior: + - Intercepts model errors via on_model_error_callback + - Retries the same model up to max_retries with exponential backoff + jitter + - If still failing and fallback_models configured, tries them in order + - Returns the first successful LlmResponse or None to propagate the error + + Notes: + - Live (bidirectional) mode errors are not intercepted by BaseLlmFlow's error + handler; this plugin currently targets generate_content_async flow. + - In SSE streaming mode, the plugin returns a single final LlmResponse. + """ + + def __init__( + self, + *, + name: str = "llm_resilience_plugin", + max_retries: int = 3, + backoff_initial: float = 1.0, + backoff_multiplier: float = 2.0, + max_backoff: float = 10.0, + jitter: float = 0.2, + retry_on_exceptions: Optional[tuple[type[BaseException], ...]] = None, + fallback_models: Optional[Iterable[str]] = None, + ) -> None: + super().__init__(name) + if max_retries < 0: + raise ValueError("max_retries must be >= 0") + if backoff_initial <= 0: + raise ValueError("backoff_initial must be > 0") + if backoff_multiplier < 1.0: + raise ValueError("backoff_multiplier must be >= 1.0") + if max_backoff <= 0: + raise ValueError("max_backoff must be > 0") + if jitter < 0: + raise ValueError("jitter must be >= 0") + + self.max_retries = max_retries + self.backoff_initial = backoff_initial + self.backoff_multiplier = backoff_multiplier + self.max_backoff = max_backoff + self.jitter = jitter + self.retry_on_exceptions = retry_on_exceptions + self.fallback_models = list(fallback_models or []) + + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + # Decide whether to handle this error + if self.retry_on_exceptions is not None and not isinstance( + error, self.retry_on_exceptions + ): + # If user provided an explicit exception tuple and it doesn't match, + # optionally still retry on transient HTTP-ish errors. + if not _is_transient_error(error): + return None + else: + # If user did not provide explicit list, rely on our transient heuristic + if not _is_transient_error(error): + # Non-transient error → don't handle + return None + + # Attempt retries on the same model + response = await self._retry_same_model( + callback_context=callback_context, llm_request=llm_request + ) + if response is not None: + return response + + # Try fallbacks in order + if self.fallback_models: + response = await self._try_fallbacks( + callback_context=callback_context, llm_request=llm_request + ) + if response is not None: + return response + + # Let the original error propagate if all attempts failed + return None + + def _get_invocation_context( + self, callback_context: CallbackContext | InvocationContext + ) -> InvocationContext: + """Extract InvocationContext from callback_context. + + Accepts both Context (CallbackContext alias) and InvocationContext via + duck typing. + + Args: + callback_context: The callback context passed to the plugin. + + Returns: + The underlying InvocationContext. + + Raises: + TypeError: If callback_context is not a recognized type. + """ + # If this looks like an InvocationContext (has agent and run_config), use it directly + if hasattr(callback_context, "agent") and hasattr( + callback_context, "run_config" + ): + return callback_context # type: ignore[return-value] + # Otherwise expect a Context-like object exposing the private _invocation_context + ic = getattr(callback_context, "_invocation_context", None) + if ic is None: + raise TypeError( + "callback_context must be Context or InvocationContext-like" + ) + return ic + + async def _retry_same_model( + self, + *, + callback_context: CallbackContext | InvocationContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + invocation_context = self._get_invocation_context(callback_context) + # Determine streaming mode + streaming_mode = getattr( + invocation_context.run_config, "streaming_mode", None + ) + stream = False + try: + # Only SSE streaming is supported in generate_content_async + from ..agents.run_config import StreamingMode # local import to avoid cycles + + stream = streaming_mode == StreamingMode.SSE + except (ImportError, AttributeError): + pass + + agent = invocation_context.agent + llm = agent.canonical_model + + backoff = self.backoff_initial + for attempt in range(1, self.max_retries + 1): + sleep_time = min(self.max_backoff, backoff) + # add multiplicative (+/-) jitter + if self.jitter > 0: + jitter_delta = sleep_time * random.uniform(-self.jitter, self.jitter) + sleep_time = max(0.0, sleep_time + jitter_delta) + if sleep_time > 0: + await asyncio.sleep(sleep_time) + + try: + final_response = await self._call_llm_and_get_final( + llm=llm, llm_request=llm_request, stream=stream + ) + logger.info( + "LLM retry succeeded on attempt %s for agent %s", + attempt, + agent.name, + ) + return final_response + except Exception as e: # continue to next attempt + logger.warning( + "LLM retry attempt %s failed: %s", attempt, repr(e), exc_info=False + ) + backoff *= self.backoff_multiplier + + return None + + async def _try_fallbacks( + self, + *, + callback_context: CallbackContext | InvocationContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + invocation_context = self._get_invocation_context(callback_context) + # Determine streaming mode + streaming_mode = getattr( + invocation_context.run_config, "streaming_mode", None + ) + stream = False + try: + from ..agents.run_config import StreamingMode + + stream = streaming_mode == StreamingMode.SSE + except (ImportError, AttributeError): + pass + + for model_name in self.fallback_models: + try: + fallback_llm = LLMRegistry.new_llm(model_name) + # Update request model hint for provider bridges that honor it + llm_request.model = model_name + final_response = await self._call_llm_and_get_final( + llm=fallback_llm, llm_request=llm_request, stream=stream + ) + logger.info("LLM fallback succeeded with model '%s'", model_name) + return final_response + except Exception as e: + logger.warning( + "LLM fallback model '%s' failed: %s", + model_name, + repr(e), + exc_info=False, + ) + continue + return None + + async def _call_llm_and_get_final( + self, *, llm, llm_request: LlmRequest, stream: bool + ) -> LlmResponse: + """Calls the given llm and returns the final non-partial LlmResponse.""" + import inspect + + final: Optional[LlmResponse] = None + agen_or_coro = llm.generate_content_async(llm_request, stream=stream) + + # If the provider raised before first yield, this may be a coroutine; handle gracefully + if inspect.isasyncgen(agen_or_coro) or hasattr(agen_or_coro, "__aiter__"): + agen = agen_or_coro + try: + async for resp in agen: + # Keep the latest response; in streaming mode, last one is non-partial + final = resp + finally: + # If the generator is an async generator, ensure it's closed properly + try: + await agen.aclose() # type: ignore[attr-defined] + except Exception: + pass + else: + # Await the coroutine; some LLMs may return a single response + result = await agen_or_coro + if isinstance(result, LlmResponse): + final = result + elif isinstance(result, types.Content): + final = LlmResponse(content=result, partial=False) + else: + # Unknown return type + raise TypeError("LLM generate_content_async returned unsupported type") + + if final is None: + # Edge case: provider yielded nothing. Create a minimal error response. + return LlmResponse(partial=False) + return final diff --git a/tests/unittests/plugins/test_llm_resilience_plugin.py b/tests/unittests/plugins/test_llm_resilience_plugin.py new file mode 100644 index 0000000000..3b79e9c6cc --- /dev/null +++ b/tests/unittests/plugins/test_llm_resilience_plugin.py @@ -0,0 +1,155 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator +from typing import Optional +from unittest import IsolatedAsyncioTestCase + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.models.base_llm import BaseLlm +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.models.registry import LLMRegistry +from google.adk.plugins.llm_resilience_plugin import LlmResiliencePlugin +from google.genai import types + +from ..testing_utils import create_invocation_context + + +class AlwaysFailModel(BaseLlm): + model: str = "failing-model" + + @classmethod + def supported_models(cls) -> list[str]: + return ["failing-model"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + # Always raise a timeout error to simulate transient failures + raise asyncio.TimeoutError("Simulated timeout in AlwaysFailModel") + + +class SimpleSuccessModel(BaseLlm): + model: str = "mock" + + @classmethod + def supported_models(cls) -> list[str]: + return ["mock"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + # Return a single final response regardless of stream flag + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="final response from mock")], + ), + partial=False, + ) + + +class TestLlmResiliencePlugin(IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls): + # Register test models in the registry once + LLMRegistry.register(AlwaysFailModel) + LLMRegistry.register(SimpleSuccessModel) + + async def test_retry_success_on_same_model(self): + # Agent uses SimpleSuccessModel directly + agent = LlmAgent(name="agent", model=SimpleSuccessModel()) + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=2) + + # Build a minimal request + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]) + ] + ) + + # Simulate an initial transient error (e.g., 429/timeout) + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=asyncio.TimeoutError(), + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, LlmResponse) + self.assertFalse(result.partial) + self.assertIsNotNone(result.content) + self.assertEqual( + result.content.parts[0].text.strip(), "final response from mock" + ) + + async def test_fallback_model_used_after_retries(self): + # Agent starts with a failing string model; plugin will fallback to "mock" + agent = LlmAgent(name="agent", model="failing-model") + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=1, fallback_models=["mock"]) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + + # Trigger resilience with a transient error + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=asyncio.TimeoutError(), + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, LlmResponse) + self.assertFalse(result.partial) + self.assertEqual( + result.content.parts[0].text.strip(), "final response from mock" + ) + + async def test_non_transient_error_bubbles(self): + # Agent with success model, but error is non-transient → plugin should ignore + agent = LlmAgent(name="agent", model=SimpleSuccessModel()) + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=2) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + + class NonTransientError(RuntimeError): + pass + + # Non-transient error: status code not transient and not Timeout + # The plugin should return None so that the original error propagates + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=NonTransientError("boom"), + ) + self.assertIsNone(result)