Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,16 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
invocation_context.agent.name
)

callback_context = CallbackContext(
invocation_context, event_actions=model_response_event.actions
)
if response := await invocation_context.plugin_manager.run_on_model_request_callback(
callback_context=callback_context,
llm_request=llm_request,
):
yield response
return

# Calls the LLM.
llm = self.__get_llm(invocation_context)

Expand Down
20 changes: 20 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,26 @@ async def before_model_callback(
"""
pass

async def on_model_request_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
"""Callback executed immediately before a request is sent to the model.

This hook is fired after all `before_model_callback`s have completed and
the request has been finalized (e.g. labels injected). It is the correct
place to observe the exact `LlmRequest` that will be sent to the model.

Args:
callback_context: The context for the current agent call.
llm_request: The final request object to be sent to the model.

Returns:
An optional LlmResponse. If an LlmResponse is returned, it will be used
instead of calling the model. Returning `None` allows the model call
to proceed normally.
"""
pass

async def after_model_callback(
self, *, callback_context: CallbackContext, llm_response: LlmResponse
) -> Optional[LlmResponse]:
Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3726,13 +3726,13 @@ async def after_agent_callback(
)

@_safe_callback
async def before_model_callback(
async def on_model_request_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> None:
"""Callback before LLM call.
"""Callback immediately before LLM call.

Logs the LLM request details including:
1. Prompt content
Expand Down
11 changes: 11 additions & 0 deletions src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"before_tool_callback",
"after_tool_callback",
"before_model_callback",
"on_model_request_callback",
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
Expand Down Expand Up @@ -245,6 +246,16 @@ async def run_before_model_callback(
llm_request=llm_request,
)

async def run_on_model_request_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
"""Runs the `on_model_request_callback` for all plugins."""
return await self._run_callbacks(
"on_model_request_callback",
callback_context=callback_context,
llm_request=llm_request,
)

async def run_after_model_callback(
self, *, callback_context: CallbackContext, llm_response: LlmResponse
) -> Optional[LlmResponse]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class SpanCapturingPlugin(BasePlugin):
def __init__(self):
self.name = 'span_capturing_plugin'
self.before_capture = _SpanCapture()
self.request_capture = _SpanCapture()
self.after_capture = _SpanCapture()
self.error_capture = _SpanCapture()

Expand All @@ -80,6 +81,15 @@ async def before_model_callback(
return self._short_circuit_response
return None

async def on_model_request_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
self.request_capture.capture()
return None

async def after_model_callback(
self,
*,
Expand Down Expand Up @@ -149,6 +159,11 @@ def test_before_and_after_callbacks_share_same_span():
f' before={plugin.before_capture.span_id:#x},'
f' after={plugin.after_capture.span_id:#x}'
)
assert plugin.before_capture.span_id == plugin.request_capture.span_id, (
'before_model_callback and on_model_request_callback saw different spans:'
f' before={plugin.before_capture.span_id:#x},'
f' request={plugin.request_capture.span_id:#x}'
)


def test_callbacks_same_trace_id():
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/flows/llm_flows/test_plugin_model_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ class MockPlugin(BasePlugin):
before_model_text = 'before_model_text from MockPlugin'
after_model_text = 'after_model_text from MockPlugin'
on_model_error_text = 'on_model_error_text from MockPlugin'
on_model_request_text = 'on_model_request_text from MockPlugin'

def __init__(self, name='mock_plugin'):
self.name = name
self.enable_before_model_callback = False
self.enable_after_model_callback = False
self.enable_on_model_error_callback = False
self.enable_on_model_request_callback = False
self.before_model_response = LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text=self.before_model_text)]
Expand All @@ -62,6 +64,11 @@ def __init__(self, name='mock_plugin'):
[types.Part.from_text(text=self.on_model_error_text)]
)
)
self.on_model_request_response = LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text=self.on_model_request_text)]
)
)

async def before_model_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
Expand All @@ -88,6 +95,13 @@ async def on_model_error_callback(
return None
return self.on_model_error_response

async def on_model_request_callback(
self, *, callback_context: CallbackContext, llm_request: LlmRequest
) -> Optional[LlmResponse]:
if not self.enable_on_model_request_callback:
return None
return self.on_model_request_response


CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content'

Expand Down Expand Up @@ -138,6 +152,22 @@ def test_before_model_fallback_canonical_callback(mock_plugin):
]


def test_on_model_request_callback_with_plugin(mock_plugin):
"""Tests that the model response is overridden by on_model_request_callback from the plugin."""
responses = ['model_response']
mock_model = testing_utils.MockModel.create(responses=responses)
mock_plugin.enable_on_model_request_callback = True
agent = Agent(
name='root_agent',
model=mock_model,
)

runner = testing_utils.InMemoryRunner(agent, plugins=[mock_plugin])
assert testing_utils.simplify_events(runner.run('test')) == [
('root_agent', mock_plugin.on_model_request_text),
]


def test_before_model_callback_fallback_model(mock_plugin):
"""Tests that the model response is executed normally when both plugin and canonical agent model callback return empty response."""
responses = ['model_response']
Expand Down
15 changes: 15 additions & 0 deletions tests/unittests/plugins/test_base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ async def on_tool_error_callback(self, **kwargs) -> str:
async def before_model_callback(self, **kwargs) -> str:
return "overridden_before_model"

async def on_model_request_callback(self, **kwargs) -> str:
return "overridden_on_model_request"

async def after_model_callback(self, **kwargs) -> str:
return "overridden_after_model"

Expand Down Expand Up @@ -158,6 +161,12 @@ async def test_base_plugin_default_callbacks_return_none():
)
is None
)
assert (
await plugin.on_model_request_callback(
callback_context=mock_context, llm_request=mock_context
)
is None
)
assert (
await plugin.after_model_callback(
callback_context=mock_context, llm_response=mock_context
Expand Down Expand Up @@ -240,6 +249,12 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
)
== "overridden_before_model"
)
assert (
await plugin.on_model_request_callback(
callback_context=mock_callback_context, llm_request=mock_llm_request
)
== "overridden_on_model_request"
)
assert (
await plugin.after_model_callback(
callback_context=mock_callback_context, llm_response=mock_llm_response
Expand Down
7 changes: 7 additions & 0 deletions tests/unittests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ async def on_tool_error_callback(self, **kwargs):
async def before_model_callback(self, **kwargs):
return await self._handle_callback("before_model_callback")

async def on_model_request_callback(self, **kwargs):
return await self._handle_callback("on_model_request_callback")

async def after_model_callback(self, **kwargs):
return await self._handle_callback("after_model_callback")

Expand Down Expand Up @@ -244,6 +247,9 @@ async def test_all_callbacks_are_supported(
await service.run_before_model_callback(
callback_context=mock_context, llm_request=mock_context
)
await service.run_on_model_request_callback(
callback_context=mock_context, llm_request=mock_context
)
await service.run_after_model_callback(
callback_context=mock_context, llm_response=mock_context
)
Expand All @@ -265,6 +271,7 @@ async def test_all_callbacks_are_supported(
"after_tool_callback",
"on_tool_error_callback",
"before_model_callback",
"on_model_request_callback",
"after_model_callback",
"on_model_error_callback",
]
Expand Down