diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 48467ff4b74..949f9ab0ca0 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -1320,6 +1320,16 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: invocation_context.agent.name ) + callback_context = CallbackContext( + invocation_context, event_actions=model_response_event.actions + ) + if response := await invocation_context.plugin_manager.run_on_model_request_callback( + callback_context=callback_context, + llm_request=llm_request, + ): + yield response + return + # Calls the LLM. llm = self.__get_llm(invocation_context) diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 54bfab2ed28..3dbef0670aa 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -250,6 +250,26 @@ async def before_model_callback( """ pass + async def on_model_request_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Callback executed immediately before a request is sent to the model. + + This hook is fired after all `before_model_callback`s have completed and + the request has been finalized (e.g. labels injected). It is the correct + place to observe the exact `LlmRequest` that will be sent to the model. + + Args: + callback_context: The context for the current agent call. + llm_request: The final request object to be sent to the model. + + Returns: + An optional LlmResponse. If an LlmResponse is returned, it will be used + instead of calling the model. Returning `None` allows the model call + to proceed normally. + """ + pass + async def after_model_callback( self, *, callback_context: CallbackContext, llm_response: LlmResponse ) -> Optional[LlmResponse]: diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 36d92bf781d..e65a5f3a483 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -3726,13 +3726,13 @@ async def after_agent_callback( ) @_safe_callback - async def before_model_callback( + async def on_model_request_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest, ) -> None: - """Callback before LLM call. + """Callback immediately before LLM call. Logs the LLM request details including: 1. Prompt content diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index 5566349516a..7a157263693 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -49,6 +49,7 @@ "before_tool_callback", "after_tool_callback", "before_model_callback", + "on_model_request_callback", "after_model_callback", "on_tool_error_callback", "on_model_error_callback", @@ -245,6 +246,16 @@ async def run_before_model_callback( llm_request=llm_request, ) + async def run_on_model_request_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Runs the `on_model_request_callback` for all plugins.""" + return await self._run_callbacks( + "on_model_request_callback", + callback_context=callback_context, + llm_request=llm_request, + ) + async def run_after_model_callback( self, *, callback_context: CallbackContext, llm_response: LlmResponse ) -> Optional[LlmResponse]: diff --git a/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py b/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py index bc19c458f92..50832c3d2a2 100644 --- a/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py +++ b/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py @@ -63,6 +63,7 @@ class SpanCapturingPlugin(BasePlugin): def __init__(self): self.name = 'span_capturing_plugin' self.before_capture = _SpanCapture() + self.request_capture = _SpanCapture() self.after_capture = _SpanCapture() self.error_capture = _SpanCapture() @@ -80,6 +81,15 @@ async def before_model_callback( return self._short_circuit_response return None + async def on_model_request_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + self.request_capture.capture() + return None + async def after_model_callback( self, *, @@ -149,6 +159,11 @@ def test_before_and_after_callbacks_share_same_span(): f' before={plugin.before_capture.span_id:#x},' f' after={plugin.after_capture.span_id:#x}' ) + assert plugin.before_capture.span_id == plugin.request_capture.span_id, ( + 'before_model_callback and on_model_request_callback saw different spans:' + f' before={plugin.before_capture.span_id:#x},' + f' request={plugin.request_capture.span_id:#x}' + ) def test_callbacks_same_trace_id(): diff --git a/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py b/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py index f2f7e35b054..8f0fcea615d 100644 --- a/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py @@ -41,12 +41,14 @@ class MockPlugin(BasePlugin): before_model_text = 'before_model_text from MockPlugin' after_model_text = 'after_model_text from MockPlugin' on_model_error_text = 'on_model_error_text from MockPlugin' + on_model_request_text = 'on_model_request_text from MockPlugin' def __init__(self, name='mock_plugin'): self.name = name self.enable_before_model_callback = False self.enable_after_model_callback = False self.enable_on_model_error_callback = False + self.enable_on_model_request_callback = False self.before_model_response = LlmResponse( content=testing_utils.ModelContent( [types.Part.from_text(text=self.before_model_text)] @@ -62,6 +64,11 @@ def __init__(self, name='mock_plugin'): [types.Part.from_text(text=self.on_model_error_text)] ) ) + self.on_model_request_response = LlmResponse( + content=testing_utils.ModelContent( + [types.Part.from_text(text=self.on_model_request_text)] + ) + ) async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest @@ -88,6 +95,13 @@ async def on_model_error_callback( return None return self.on_model_error_response + async def on_model_request_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + if not self.enable_on_model_request_callback: + return None + return self.on_model_request_response + CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content' @@ -138,6 +152,22 @@ def test_before_model_fallback_canonical_callback(mock_plugin): ] +def test_on_model_request_callback_with_plugin(mock_plugin): + """Tests that the model response is overridden by on_model_request_callback from the plugin.""" + responses = ['model_response'] + mock_model = testing_utils.MockModel.create(responses=responses) + mock_plugin.enable_on_model_request_callback = True + agent = Agent( + name='root_agent', + model=mock_model, + ) + + runner = testing_utils.InMemoryRunner(agent, plugins=[mock_plugin]) + assert testing_utils.simplify_events(runner.run('test')) == [ + ('root_agent', mock_plugin.on_model_request_text), + ] + + def test_before_model_callback_fallback_model(mock_plugin): """Tests that the model response is executed normally when both plugin and canonical agent model callback return empty response.""" responses = ['model_response'] diff --git a/tests/unittests/plugins/test_base_plugin.py b/tests/unittests/plugins/test_base_plugin.py index aa7c17fb017..9f71810f425 100644 --- a/tests/unittests/plugins/test_base_plugin.py +++ b/tests/unittests/plugins/test_base_plugin.py @@ -73,6 +73,9 @@ async def on_tool_error_callback(self, **kwargs) -> str: async def before_model_callback(self, **kwargs) -> str: return "overridden_before_model" + async def on_model_request_callback(self, **kwargs) -> str: + return "overridden_on_model_request" + async def after_model_callback(self, **kwargs) -> str: return "overridden_after_model" @@ -158,6 +161,12 @@ async def test_base_plugin_default_callbacks_return_none(): ) is None ) + assert ( + await plugin.on_model_request_callback( + callback_context=mock_context, llm_request=mock_context + ) + is None + ) assert ( await plugin.after_model_callback( callback_context=mock_context, llm_response=mock_context @@ -240,6 +249,12 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): ) == "overridden_before_model" ) + assert ( + await plugin.on_model_request_callback( + callback_context=mock_callback_context, llm_request=mock_llm_request + ) + == "overridden_on_model_request" + ) assert ( await plugin.after_model_callback( callback_context=mock_callback_context, llm_response=mock_llm_response diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py index 6c72a2a6650..fe2ebc3e82f 100644 --- a/tests/unittests/plugins/test_plugin_manager.py +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -85,6 +85,9 @@ async def on_tool_error_callback(self, **kwargs): async def before_model_callback(self, **kwargs): return await self._handle_callback("before_model_callback") + async def on_model_request_callback(self, **kwargs): + return await self._handle_callback("on_model_request_callback") + async def after_model_callback(self, **kwargs): return await self._handle_callback("after_model_callback") @@ -244,6 +247,9 @@ async def test_all_callbacks_are_supported( await service.run_before_model_callback( callback_context=mock_context, llm_request=mock_context ) + await service.run_on_model_request_callback( + callback_context=mock_context, llm_request=mock_context + ) await service.run_after_model_callback( callback_context=mock_context, llm_response=mock_context ) @@ -265,6 +271,7 @@ async def test_all_callbacks_are_supported( "after_tool_callback", "on_tool_error_callback", "before_model_callback", + "on_model_request_callback", "after_model_callback", "on_model_error_callback", ]