From 5c22eb59b45356eb4593a24c409e3fd8661fb60f Mon Sep 17 00:00:00 2001 From: truppy Date: Thu, 19 Feb 2026 14:08:56 +0530 Subject: [PATCH 1/5] feat(plugins): add LlmResiliencePlugin with retries and model fallbacks Adds a new plugin for handling transient LLM errors with: - Configurable retries with exponential backoff + jitter - Transient error detection (HTTP 429/500/502/503/504, httpx timeouts) - Optional model fallbacks when primary model fails - Support for both async generator and coroutine LLM providers Usage: from google.adk_community.plugins import LlmResiliencePlugin runner = Runner( ..., plugins=[LlmResiliencePlugin(max_retries=3, fallback_models=['gemini-1.5-flash'])] ) --- src/google/adk_community/__init__.py | 1 + src/google/adk_community/plugins/__init__.py | 21 + .../plugins/llm_resilience_plugin.py | 359 ++++++++++++++++++ tests/unittests/plugins/__init__.py | 13 + .../plugins/test_llm_resilience_plugin.py | 185 +++++++++ 5 files changed, 579 insertions(+) create mode 100644 src/google/adk_community/plugins/__init__.py create mode 100644 src/google/adk_community/plugins/llm_resilience_plugin.py create mode 100644 tests/unittests/plugins/__init__.py create mode 100644 tests/unittests/plugins/test_llm_resilience_plugin.py diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..8f52d3a 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from . import memory +from . import plugins from . import sessions from . import version __version__ = version.__version__ diff --git a/src/google/adk_community/plugins/__init__.py b/src/google/adk_community/plugins/__init__.py new file mode 100644 index 0000000..a08f9fb --- /dev/null +++ b/src/google/adk_community/plugins/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Community plugins for Google ADK.""" + +from .llm_resilience_plugin import LlmResiliencePlugin + +__all__ = [ + "LlmResiliencePlugin", +] diff --git a/src/google/adk_community/plugins/llm_resilience_plugin.py b/src/google/adk_community/plugins/llm_resilience_plugin.py new file mode 100644 index 0000000..8ca9508 --- /dev/null +++ b/src/google/adk_community/plugins/llm_resilience_plugin.py @@ -0,0 +1,359 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LlmResiliencePlugin - retry with exponential backoff and model fallbacks.""" + +from __future__ import annotations + +import asyncio +import logging +import random +from typing import Iterable +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from google.adk.agents.invocation_context import InvocationContext + +try: + import httpx +except Exception: # pragma: no cover - httpx might not be installed in all envs + httpx = None # type: ignore + +from google.genai import types + +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.models.registry import LLMRegistry +from google.adk.plugins.base_plugin import BasePlugin + +logger = logging.getLogger("google_adk_community." + __name__) + + +def _extract_status_code(err: Exception) -> Optional[int]: + """Best-effort extraction of HTTP status codes from common client libraries.""" + status = getattr(err, "status_code", None) + if isinstance(status, int): + return status + # httpx specific + if httpx is not None: + if isinstance(err, httpx.HTTPStatusError): + try: + return int(err.response.status_code) + except Exception: + return None + # Fallback: look for nested response + resp = getattr(err, "response", None) + if resp is not None: + code = getattr(resp, "status_code", None) + if isinstance(code, int): + return code + return None + + +def _is_transient_error(err: Exception) -> bool: + """Check if an error is transient and should trigger retry.""" + # Retry on common transient classes and HTTP status codes + transient_http = {429, 500, 502, 503, 504} + status = _extract_status_code(err) + if status is not None and status in transient_http: + return True + + # httpx transient + if httpx is not None and isinstance( + err, (httpx.ReadTimeout, httpx.ConnectError, httpx.RemoteProtocolError) + ): + return True + + # asyncio timeouts and cancellations often warrant retry/fallback at callsite + if isinstance(err, (asyncio.TimeoutError,)): + return True + + return False + + +class LlmResiliencePlugin(BasePlugin): + """A plugin that adds retry with exponential backoff and model fallbacks. + + Behavior: + - Intercepts model errors via on_model_error_callback + - Retries the same model up to max_retries with exponential backoff + jitter + - If still failing and fallback_models configured, tries them in order + - Returns the first successful LlmResponse or None to propagate the error + + Notes: + - Live (bidirectional) mode errors are not intercepted by BaseLlmFlow's error + handler; this plugin currently targets generate_content_async flow. + - In SSE streaming mode, the plugin returns a single final LlmResponse. + + Example: + >>> from google.adk.runners import Runner + >>> from google.adk_community.plugins import LlmResiliencePlugin + >>> + >>> runner = Runner( + ... app_name="my_app", + ... agent=my_agent, + ... plugins=[ + ... LlmResiliencePlugin( + ... max_retries=3, + ... backoff_initial=1.0, + ... fallback_models=["gemini-1.5-flash"], + ... ) + ... ], + ... ) + """ + + def __init__( + self, + *, + name: str = "llm_resilience_plugin", + max_retries: int = 3, + backoff_initial: float = 1.0, + backoff_multiplier: float = 2.0, + max_backoff: float = 10.0, + jitter: float = 0.2, + retry_on_exceptions: Optional[tuple[type[BaseException], ...]] = None, + fallback_models: Optional[Iterable[str]] = None, + ) -> None: + """Initialize the LlmResiliencePlugin. + + Args: + name: Plugin name identifier. + max_retries: Maximum number of retry attempts on the same model. + backoff_initial: Initial backoff delay in seconds. + backoff_multiplier: Multiplier for exponential backoff. + max_backoff: Maximum backoff delay in seconds. + jitter: Jitter factor (0.0 to 1.0) to add randomness to backoff. + retry_on_exceptions: Optional tuple of exception types to retry on. + If None, uses built-in transient error detection. + fallback_models: Optional list of model names to try if primary fails. + """ + super().__init__(name) + if max_retries < 0: + raise ValueError("max_retries must be >= 0") + if backoff_initial <= 0: + raise ValueError("backoff_initial must be > 0") + if backoff_multiplier < 1.0: + raise ValueError("backoff_multiplier must be >= 1.0") + if max_backoff <= 0: + raise ValueError("max_backoff must be > 0") + if jitter < 0: + raise ValueError("jitter must be >= 0") + + self.max_retries = max_retries + self.backoff_initial = backoff_initial + self.backoff_multiplier = backoff_multiplier + self.max_backoff = max_backoff + self.jitter = jitter + self.retry_on_exceptions = retry_on_exceptions + self.fallback_models = list(fallback_models or []) + + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + """Handle model errors with retry and fallback logic.""" + # Decide whether to handle this error + if self.retry_on_exceptions is not None and not isinstance( + error, self.retry_on_exceptions + ): + # If user provided an explicit exception tuple and it doesn't match, + # optionally still retry on transient HTTP-ish errors. + if not _is_transient_error(error): + return None + else: + # If user did not provide explicit list, rely on our transient heuristic + if not _is_transient_error(error): + # Non-transient error → don't handle + return None + + # Attempt retries on the same model + response = await self._retry_same_model( + callback_context=callback_context, llm_request=llm_request + ) + if response is not None: + return response + + # Try fallbacks in order + if self.fallback_models: + response = await self._try_fallbacks( + callback_context=callback_context, llm_request=llm_request + ) + if response is not None: + return response + + # Let the original error propagate if all attempts failed + return None + + def _get_invocation_context( + self, callback_context: CallbackContext | InvocationContext + ) -> InvocationContext: + """Extract InvocationContext from callback_context. + + Accepts both Context (CallbackContext alias) and InvocationContext via + duck typing. + + Args: + callback_context: The callback context passed to the plugin. + + Returns: + The underlying InvocationContext. + + Raises: + TypeError: If callback_context is not a recognized type. + """ + # If this looks like an InvocationContext (has agent and run_config), use it directly + if hasattr(callback_context, "agent") and hasattr( + callback_context, "run_config" + ): + return callback_context # type: ignore[return-value] + # Otherwise expect a Context-like object exposing the private _invocation_context + ic = getattr(callback_context, "_invocation_context", None) + if ic is None: + raise TypeError( + "callback_context must be Context or InvocationContext-like" + ) + return ic + + async def _retry_same_model( + self, + *, + callback_context: CallbackContext | InvocationContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + invocation_context = self._get_invocation_context(callback_context) + # Determine streaming mode + streaming_mode = getattr( + invocation_context.run_config, "streaming_mode", None + ) + stream = False + try: + # Only SSE streaming is supported in generate_content_async + from google.adk.agents.run_config import StreamingMode + + stream = streaming_mode == StreamingMode.SSE + except (ImportError, AttributeError): + pass + + agent = invocation_context.agent + llm = agent.canonical_model + + backoff = self.backoff_initial + for attempt in range(1, self.max_retries + 1): + sleep_time = min(self.max_backoff, backoff) + # add multiplicative (+/-) jitter + if self.jitter > 0: + jitter_delta = sleep_time * random.uniform(-self.jitter, self.jitter) + sleep_time = max(0.0, sleep_time + jitter_delta) + if sleep_time > 0: + await asyncio.sleep(sleep_time) + + try: + final_response = await self._call_llm_and_get_final( + llm=llm, llm_request=llm_request, stream=stream + ) + logger.info( + "LLM retry succeeded on attempt %s for agent %s", + attempt, + agent.name, + ) + return final_response + except Exception as e: # continue to next attempt + logger.warning( + "LLM retry attempt %s failed: %s", attempt, repr(e), exc_info=False + ) + backoff *= self.backoff_multiplier + + return None + + async def _try_fallbacks( + self, + *, + callback_context: CallbackContext | InvocationContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + invocation_context = self._get_invocation_context(callback_context) + # Determine streaming mode + streaming_mode = getattr( + invocation_context.run_config, "streaming_mode", None + ) + stream = False + try: + from google.adk.agents.run_config import StreamingMode + + stream = streaming_mode == StreamingMode.SSE + except (ImportError, AttributeError): + pass + + for model_name in self.fallback_models: + try: + fallback_llm = LLMRegistry.new_llm(model_name) + # Update request model hint for provider bridges that honor it + llm_request.model = model_name + final_response = await self._call_llm_and_get_final( + llm=fallback_llm, llm_request=llm_request, stream=stream + ) + logger.info("LLM fallback succeeded with model '%s'", model_name) + return final_response + except Exception as e: + logger.warning( + "LLM fallback model '%s' failed: %s", + model_name, + repr(e), + exc_info=False, + ) + continue + return None + + async def _call_llm_and_get_final( + self, *, llm, llm_request: LlmRequest, stream: bool + ) -> LlmResponse: + """Calls the given llm and returns the final non-partial LlmResponse.""" + import inspect + + final: Optional[LlmResponse] = None + agen_or_coro = llm.generate_content_async(llm_request, stream=stream) + + # If the provider raised before first yield, this may be a coroutine; handle gracefully + if inspect.isasyncgen(agen_or_coro) or hasattr(agen_or_coro, "__aiter__"): + agen = agen_or_coro + try: + async for resp in agen: + # Keep the latest response; in streaming mode, last one is non-partial + final = resp + finally: + # If the generator is an async generator, ensure it's closed properly + try: + await agen.aclose() # type: ignore[attr-defined] + except Exception: + pass + else: + # Await the coroutine; some LLMs may return a single response + result = await agen_or_coro + if isinstance(result, LlmResponse): + final = result + elif isinstance(result, types.Content): + final = LlmResponse(content=result, partial=False) + else: + # Unknown return type + raise TypeError("LLM generate_content_async returned unsupported type") + + if final is None: + # Edge case: provider yielded nothing. Create a minimal error response. + return LlmResponse(partial=False) + return final diff --git a/tests/unittests/plugins/__init__.py b/tests/unittests/plugins/__init__.py new file mode 100644 index 0000000..0a2669d --- /dev/null +++ b/tests/unittests/plugins/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/plugins/test_llm_resilience_plugin.py b/tests/unittests/plugins/test_llm_resilience_plugin.py new file mode 100644 index 0000000..e3de7d7 --- /dev/null +++ b/tests/unittests/plugins/test_llm_resilience_plugin.py @@ -0,0 +1,185 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for LlmResiliencePlugin.""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator +from unittest import IsolatedAsyncioTestCase + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.models.base_llm import BaseLlm +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.models.registry import LLMRegistry +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk_community.plugins.llm_resilience_plugin import LlmResiliencePlugin +from google.genai import types + + +class AlwaysFailModel(BaseLlm): + model: str = "failing-model" + + @classmethod + def supported_models(cls) -> list[str]: + return ["failing-model"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + # Always raise a timeout error to simulate transient failures + raise asyncio.TimeoutError("Simulated timeout in AlwaysFailModel") + yield # Make this a generator + + +class SimpleSuccessModel(BaseLlm): + model: str = "mock" + + @classmethod + def supported_models(cls) -> list[str]: + return ["mock"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + # Return a single final response regardless of stream flag + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="final response from mock")], + ), + partial=False, + ) + + +async def create_invocation_context(agent: LlmAgent) -> InvocationContext: + """Helper to create an InvocationContext for testing.""" + invocation_id = "test_id" + artifact_service = InMemoryArtifactService() + session_service = InMemorySessionService() + memory_service = InMemoryMemoryService() + invocation_context = InvocationContext( + artifact_service=artifact_service, + session_service=session_service, + memory_service=memory_service, + plugin_manager=PluginManager(plugins=[]), + invocation_id=invocation_id, + agent=agent, + session=await session_service.create_session( + app_name="test_app", user_id="test_user" + ), + user_content=types.Content( + role="user", parts=[types.Part.from_text(text="")] + ), + run_config=RunConfig(), + ) + return invocation_context + + +class TestLlmResiliencePlugin(IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls): + # Register test models in the registry once + LLMRegistry.register(AlwaysFailModel) + LLMRegistry.register(SimpleSuccessModel) + + async def test_retry_success_on_same_model(self): + # Agent uses SimpleSuccessModel directly + agent = LlmAgent(name="agent", model=SimpleSuccessModel()) + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=2) + + # Build a minimal request + llm_request = LlmRequest( + contents=[ + types.Content(role="user", parts=[types.Part.from_text(text="hi")]) + ] + ) + + # Simulate an initial transient error (e.g., 429/timeout) + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=asyncio.TimeoutError(), + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, LlmResponse) + self.assertFalse(result.partial) + self.assertIsNotNone(result.content) + self.assertEqual( + result.content.parts[0].text.strip(), "final response from mock" + ) + + async def test_fallback_model_used_after_retries(self): + # Agent starts with a failing string model; plugin will fallback to "mock" + agent = LlmAgent(name="agent", model="failing-model") + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=1, fallback_models=["mock"]) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + + # Trigger resilience with a transient error + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=asyncio.TimeoutError(), + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, LlmResponse) + self.assertFalse(result.partial) + self.assertEqual( + result.content.parts[0].text.strip(), "final response from mock" + ) + + async def test_non_transient_error_bubbles(self): + # Agent with success model, but error is non-transient → plugin should ignore + agent = LlmAgent(name="agent", model=SimpleSuccessModel()) + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin(max_retries=2) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + + class NonTransientError(RuntimeError): + pass + + # Non-transient error: status code not transient and not Timeout + # The plugin should return None so that the original error propagates + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=NonTransientError("boom"), + ) + self.assertIsNone(result) From d9008a20be8cddb90f8caf414091f67ab781be92 Mon Sep 17 00:00:00 2001 From: truppy Date: Thu, 19 Feb 2026 15:19:04 +0530 Subject: [PATCH 2/5] fix: correct retry_on_exceptions logic to properly handle custom exceptions The previous logic required _is_transient_error() to be true in all cases, effectively ignoring the retry_on_exceptions parameter. Now the plugin will retry if the error is either in retry_on_exceptions OR is a transient error. Added test case to verify custom exception types trigger retry correctly. --- .../plugins/llm_resilience_plugin.py | 21 +++++------ .../plugins/test_llm_resilience_plugin.py | 35 +++++++++++++++++++ 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/google/adk_community/plugins/llm_resilience_plugin.py b/src/google/adk_community/plugins/llm_resilience_plugin.py index 8ca9508..df94a82 100644 --- a/src/google/adk_community/plugins/llm_resilience_plugin.py +++ b/src/google/adk_community/plugins/llm_resilience_plugin.py @@ -168,19 +168,14 @@ async def on_model_error_callback( error: Exception, ) -> Optional[LlmResponse]: """Handle model errors with retry and fallback logic.""" - # Decide whether to handle this error - if self.retry_on_exceptions is not None and not isinstance( - error, self.retry_on_exceptions - ): - # If user provided an explicit exception tuple and it doesn't match, - # optionally still retry on transient HTTP-ish errors. - if not _is_transient_error(error): - return None - else: - # If user did not provide explicit list, rely on our transient heuristic - if not _is_transient_error(error): - # Non-transient error → don't handle - return None + # Decide whether to handle this error: + # Retry if error is in retry_on_exceptions OR is a transient error + if self.retry_on_exceptions and isinstance(error, self.retry_on_exceptions): + # User explicitly wants to retry on this exception type. + pass + elif not _is_transient_error(error): + # Not an explicit exception and not a transient error, so don't handle. + return None # Attempt retries on the same model response = await self._retry_same_model( diff --git a/tests/unittests/plugins/test_llm_resilience_plugin.py b/tests/unittests/plugins/test_llm_resilience_plugin.py index e3de7d7..b45937d 100644 --- a/tests/unittests/plugins/test_llm_resilience_plugin.py +++ b/tests/unittests/plugins/test_llm_resilience_plugin.py @@ -183,3 +183,38 @@ class NonTransientError(RuntimeError): error=NonTransientError("boom"), ) self.assertIsNone(result) + + async def test_custom_retry_on_exceptions(self): + """Test that custom exception types in retry_on_exceptions trigger retry.""" + agent = LlmAgent(name="agent", model=SimpleSuccessModel()) + invocation_context = await create_invocation_context(agent) + + class CustomError(Exception): + pass + + # Plugin configured to retry on CustomError (which is NOT a transient error) + plugin = LlmResiliencePlugin( + max_retries=2, + retry_on_exceptions=(CustomError,), + ) + + llm_request = LlmRequest( + contents=[ + types.Content( + role="user", parts=[types.Part.from_text(text="hello")] + ) + ] + ) + + # CustomError should trigger retry even though it's not transient + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=CustomError("custom failure"), + ) + + self.assertIsNotNone(result) + self.assertIsInstance(result, LlmResponse) + self.assertEqual( + result.content.parts[0].text.strip(), "final response from mock" + ) From 2078ed58e7778a56e9e023340ebaaab0131b67f1 Mon Sep 17 00:00:00 2001 From: truppy Date: Thu, 19 Feb 2026 15:21:41 +0530 Subject: [PATCH 3/5] fix: narrow exception handler in _extract_status_code Catch only AttributeError, ValueError, TypeError instead of broad Exception. --- src/google/adk_community/plugins/llm_resilience_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk_community/plugins/llm_resilience_plugin.py b/src/google/adk_community/plugins/llm_resilience_plugin.py index df94a82..9a8b8d3 100644 --- a/src/google/adk_community/plugins/llm_resilience_plugin.py +++ b/src/google/adk_community/plugins/llm_resilience_plugin.py @@ -52,7 +52,7 @@ def _extract_status_code(err: Exception) -> Optional[int]: if isinstance(err, httpx.HTTPStatusError): try: return int(err.response.status_code) - except Exception: + except (AttributeError, ValueError, TypeError): return None # Fallback: look for nested response resp = getattr(err, "response", None) From b4554c9e9d28d14b8146a3190a98ea80bed5fa68 Mon Sep 17 00:00:00 2001 From: truppy Date: Thu, 19 Feb 2026 15:25:49 +0530 Subject: [PATCH 4/5] refactor: extract streaming mode detection to _is_sse_streaming helper Removes duplicated logic in _retry_same_model and _try_fallbacks (DRY). --- .../plugins/llm_resilience_plugin.py | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/src/google/adk_community/plugins/llm_resilience_plugin.py b/src/google/adk_community/plugins/llm_resilience_plugin.py index 9a8b8d3..6cdd2e7 100644 --- a/src/google/adk_community/plugins/llm_resilience_plugin.py +++ b/src/google/adk_community/plugins/llm_resilience_plugin.py @@ -225,25 +225,33 @@ def _get_invocation_context( ) return ic - async def _retry_same_model( - self, - *, - callback_context: CallbackContext | InvocationContext, - llm_request: LlmRequest, - ) -> Optional[LlmResponse]: - invocation_context = self._get_invocation_context(callback_context) - # Determine streaming mode + def _is_sse_streaming(self, invocation_context: InvocationContext) -> bool: + """Check if SSE streaming mode is enabled. + + Args: + invocation_context: The invocation context to check. + + Returns: + True if SSE streaming is enabled, False otherwise. + """ streaming_mode = getattr( invocation_context.run_config, "streaming_mode", None ) - stream = False try: - # Only SSE streaming is supported in generate_content_async from google.adk.agents.run_config import StreamingMode - stream = streaming_mode == StreamingMode.SSE + return streaming_mode == StreamingMode.SSE except (ImportError, AttributeError): - pass + return False + + async def _retry_same_model( + self, + *, + callback_context: CallbackContext | InvocationContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + invocation_context = self._get_invocation_context(callback_context) + stream = self._is_sse_streaming(invocation_context) agent = invocation_context.agent llm = agent.canonical_model @@ -283,17 +291,7 @@ async def _try_fallbacks( llm_request: LlmRequest, ) -> Optional[LlmResponse]: invocation_context = self._get_invocation_context(callback_context) - # Determine streaming mode - streaming_mode = getattr( - invocation_context.run_config, "streaming_mode", None - ) - stream = False - try: - from google.adk.agents.run_config import StreamingMode - - stream = streaming_mode == StreamingMode.SSE - except (ImportError, AttributeError): - pass + stream = self._is_sse_streaming(invocation_context) for model_name in self.fallback_models: try: From f83739d0b855c8dcb65f3c75e50c4ea03acaf3d7 Mon Sep 17 00:00:00 2001 From: truppy Date: Thu, 19 Feb 2026 15:31:45 +0530 Subject: [PATCH 5/5] test: add comprehensive retry_on_exceptions test with fail-then-succeed model Add test_retry_on_custom_exception_with_fail_then_succeed_model that verifies the actual retry mechanism works end-to-end with a custom exception type. Uses a CustomErrorModel that fails on first call and succeeds on retry. --- .../plugins/test_llm_resilience_plugin.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/unittests/plugins/test_llm_resilience_plugin.py b/tests/unittests/plugins/test_llm_resilience_plugin.py index b45937d..e9aac8c 100644 --- a/tests/unittests/plugins/test_llm_resilience_plugin.py +++ b/tests/unittests/plugins/test_llm_resilience_plugin.py @@ -218,3 +218,56 @@ class CustomError(Exception): self.assertEqual( result.content.parts[0].text.strip(), "final response from mock" ) + + async def test_retry_on_custom_exception_with_fail_then_succeed_model(self): + """Test retry with a model that fails once then succeeds on custom exception.""" + + class MyCustomError(Exception): + pass + + class CustomErrorModel(BaseLlm): + model: str = "custom-error-model" + call_count: int = 0 + + @classmethod + def supported_models(cls) -> list[str]: + return ["custom-error-model"] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + CustomErrorModel.call_count += 1 + if CustomErrorModel.call_count == 1: + raise MyCustomError("Custom error!") + + yield LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="Success!")], + ), + partial=False, + ) + + # Set call_count=1 to simulate the initial call already happened + # (which raised the error that triggered on_model_error_callback) + CustomErrorModel.call_count = 1 + LLMRegistry.register(CustomErrorModel) + + agent = LlmAgent(name="agent", model="custom-error-model") + invocation_context = await create_invocation_context(agent) + plugin = LlmResiliencePlugin( + max_retries=1, + retry_on_exceptions=(MyCustomError,), + ) + llm_request = LlmRequest(contents=[]) + + # The plugin should catch MyCustomError and retry. + result = await plugin.on_model_error_callback( + callback_context=invocation_context, + llm_request=llm_request, + error=MyCustomError(), + ) + + self.assertIsNotNone(result) + self.assertEqual(result.content.parts[0].text, "Success!") + self.assertEqual(CustomErrorModel.call_count, 2)