diff --git a/src/google/adk/code_executors/built_in_code_executor.py b/src/google/adk/code_executors/built_in_code_executor.py index 50a0b9f4f6..a4e3203461 100644 --- a/src/google/adk/code_executors/built_in_code_executor.py +++ b/src/google/adk/code_executors/built_in_code_executor.py @@ -20,6 +20,7 @@ from ..agents.invocation_context import InvocationContext from ..models import LlmRequest from ..utils.model_name_utils import is_gemini_2_or_above +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_code_executor import BaseCodeExecutor from .code_execution_utils import CodeExecutionInput from .code_execution_utils import CodeExecutionResult @@ -42,7 +43,8 @@ def execute_code( def process_llm_request(self, llm_request: LlmRequest) -> None: """Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool.""" - if is_gemini_2_or_above(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + if is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] llm_request.config.tools.append( diff --git a/src/google/adk/events/event_actions.py b/src/google/adk/events/event_actions.py index fe8556088f..b3fe665455 100644 --- a/src/google/adk/events/event_actions.py +++ b/src/google/adk/events/event_actions.py @@ -47,6 +47,38 @@ class EventCompaction(BaseModel): """The compacted content of the events.""" +class RewindAuditReceipt(BaseModel): # type: ignore[misc] + """Audit receipt metadata emitted for rewind operations.""" + + model_config = ConfigDict( + extra='forbid', + alias_generator=alias_generators.to_camel, + populate_by_name=True, + ) + """The pydantic model config.""" + + rewind_before_invocation_id: str + """The invocation ID that the rewind operation targeted.""" + + boundary_after_invocation_id: Optional[str] = None + """The last invocation ID retained before the rewind boundary, if any.""" + + events_before_rewind: int + """The number of events present before appending the rewind event.""" + + events_after_rewind: int + """The number of pre-existing events retained after rewind filtering.""" + + history_before_hash: str + """Canonical hash of the full pre-rewind event history.""" + + history_after_hash: str + """Canonical hash of the retained pre-rewind event history.""" + + receipt_hash: str + """Tamper-evident hash over the rewind receipt summary.""" + + class EventActions(BaseModel): """Represents the actions attached to an event.""" @@ -108,3 +140,6 @@ class EventActions(BaseModel): rewind_before_invocation_id: Optional[str] = None """The invocation id to rewind to. This is only set for rewind event.""" + + rewind_audit_receipt: Optional[RewindAuditReceipt] = None + """Structured receipt proving rewind boundaries and history digests.""" diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 7bb18efae3..2218c8742b 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -65,6 +65,11 @@ 'wait_for_completion', }) +_ENABLE_CONSOLIDATION_KEY = 'enable_consolidation' +# Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow +# at most 5 direct_memories per request. +_MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5 + def _supports_generate_memories_metadata() -> bool: """Returns whether installed Vertex SDK supports config.metadata.""" @@ -160,6 +165,11 @@ def __init__( not use Google AI Studio API key for this field. For more details, visit https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview """ + if not agent_engine_id: + raise ValueError( + 'agent_engine_id is required for VertexAiMemoryBankService.' + ) + self._project = project self._location = location self._agent_engine_id = agent_engine_id @@ -219,7 +229,22 @@ async def add_memory( memories: Sequence[MemoryEntry], custom_metadata: Mapping[str, object] | None = None, ) -> None: - """Adds explicit memory items via Vertex memories.create.""" + """Adds explicit memory items using Vertex Memory Bank. + + By default, this writes directly via `memories.create`. + If `custom_metadata["enable_consolidation"]` is set to True, this uses + `memories.generate` with `direct_memories_source` so provided memories are + consolidated server-side. + """ + if _is_consolidation_enabled(custom_metadata): + await self._add_memories_via_generate_direct_memories_source( + app_name=app_name, + user_id=user_id, + memories=memories, + custom_metadata=custom_metadata, + ) + return + await self._add_memories_via_create( app_name=app_name, user_id=user_id, @@ -235,9 +260,6 @@ async def _add_events_to_memory_from_events( events_to_process: Sequence[Event], custom_metadata: Mapping[str, object] | None = None, ) -> None: - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - direct_events = [] for event in events_to_process: if _should_filter_out_event(event.content): @@ -272,9 +294,6 @@ async def _add_memories_via_create( custom_metadata: Mapping[str, object] | None = None, ) -> None: """Adds direct memory items without server-side extraction.""" - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - normalized_memories = _normalize_memories_for_create(memories) api_client = self._get_api_client() for index, memory in enumerate(normalized_memories): @@ -300,11 +319,41 @@ async def _add_memories_via_create( logger.info('Create memory response received.') logger.debug('Create memory response: %s', operation) + async def _add_memories_via_generate_direct_memories_source( + self, + *, + app_name: str, + user_id: str, + memories: Sequence[MemoryEntry], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + """Adds memories via generate API with direct_memories_source.""" + normalized_memories = _normalize_memories_for_create(memories) + memory_texts = [ + _memory_entry_to_fact(m, index=i) + for i, m in enumerate(normalized_memories) + ] + api_client = self._get_api_client() + config = _build_generate_memories_config(custom_metadata) + for memory_batch in _iter_memory_batches(memory_texts): + operation = await api_client.agent_engines.memories.generate( + name='reasoningEngines/' + self._agent_engine_id, + direct_memories_source={ + 'direct_memories': [ + {'fact': memory_text} for memory_text in memory_batch + ] + }, + scope={ + 'app_name': app_name, + 'user_id': user_id, + }, + config=config, + ) + logger.info('Generate direct memory response received.') + logger.debug('Generate direct memory response: %s', operation) + @override async def search_memory(self, *, app_name: str, user_id: str, query: str): - if not self._agent_engine_id: - raise ValueError('Agent Engine ID is required for Memory Bank.') - api_client = self._get_api_client() retrieved_memories_iterator = ( await api_client.agent_engines.memories.retrieve( @@ -379,6 +428,8 @@ def _build_generate_memories_config( metadata_by_key: dict[str, object] = {} for key, value in custom_metadata.items(): + if key == _ENABLE_CONSOLIDATION_KEY: + continue if key == 'ttl': if value is None: continue @@ -456,6 +507,8 @@ def _build_create_memory_config( metadata_by_key: dict[str, object] = {} custom_revision_labels: dict[str, str] = {} for key, value in (custom_metadata or {}).items(): + if key == _ENABLE_CONSOLIDATION_KEY: + continue if key == 'metadata': if value is None: continue @@ -641,6 +694,32 @@ def _extract_revision_labels( return revision_labels +def _is_consolidation_enabled( + custom_metadata: Mapping[str, object] | None, +) -> bool: + """Returns whether direct memories should be consolidated via generate API.""" + if not custom_metadata: + return False + enable_consolidation = custom_metadata.get(_ENABLE_CONSOLIDATION_KEY) + if enable_consolidation is None: + return False + if not isinstance(enable_consolidation, bool): + raise TypeError( + f'custom_metadata["{_ENABLE_CONSOLIDATION_KEY}"] must be a bool.' + ) + return enable_consolidation + + +def _iter_memory_batches(memories: Sequence[str]) -> Sequence[Sequence[str]]: + """Returns memory slices that comply with direct_memories limits.""" + memory_batches: list[Sequence[str]] = [] + for index in range(0, len(memories), _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL): + memory_batches.append( + memories[index : index + _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL] + ) + return memory_batches + + def _build_vertex_metadata( metadata_by_key: Mapping[str, object], ) -> dict[str, object]: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bc0251a81e..561405b21e 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -15,7 +15,9 @@ from __future__ import annotations import asyncio +import hashlib import inspect +import json import logging from pathlib import Path import queue @@ -47,6 +49,7 @@ from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event from .events.event import EventActions +from .events.event_actions import RewindAuditReceipt from .flows.llm_flows import contents from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService @@ -591,6 +594,11 @@ async def rewind_async( artifact_delta = await self._compute_artifact_delta_for_rewind( session, rewind_event_index ) + rewind_audit_receipt = self._build_rewind_audit_receipt( + session=session, + rewind_event_index=rewind_event_index, + rewind_before_invocation_id=rewind_before_invocation_id, + ) # Create rewind event rewind_event = Event( @@ -600,6 +608,7 @@ async def rewind_async( rewind_before_invocation_id=rewind_before_invocation_id, state_delta=state_delta, artifact_delta=artifact_delta, + rewind_audit_receipt=rewind_audit_receipt, ), ) @@ -607,6 +616,67 @@ async def rewind_async( await self.session_service.append_event(session=session, event=rewind_event) + def _build_rewind_audit_receipt( + self, + *, + session: Session, + rewind_event_index: int, + rewind_before_invocation_id: str, + ) -> RewindAuditReceipt: + """Builds a deterministic audit receipt for a rewind operation.""" + events_before = session.events + events_after = session.events[:rewind_event_index] + boundary_after_invocation_id = None + if rewind_event_index > 0: + boundary_after_invocation_id = session.events[ + rewind_event_index - 1 + ].invocation_id + + history_before_hash = self._hash_rewind_events(events_before) + history_after_hash = self._hash_rewind_events(events_after) + + receipt_payload = { + 'rewind_before_invocation_id': rewind_before_invocation_id, + 'boundary_after_invocation_id': boundary_after_invocation_id, + 'events_before_rewind': len(events_before), + 'events_after_rewind': len(events_after), + 'history_before_hash': history_before_hash, + 'history_after_hash': history_after_hash, + } + receipt_hash = self._hash_rewind_payload(receipt_payload) + + return RewindAuditReceipt( + **receipt_payload, + receipt_hash=receipt_hash, + ) + + def _hash_rewind_events(self, events: List[Event]) -> str: + """Hashes event summaries for deterministic rewind audit receipts.""" + summarized_events = [ + { + 'event_id': event.id, + 'invocation_id': event.invocation_id, + 'author': event.author, + 'state_delta': event.actions.state_delta, + 'artifact_delta': event.actions.artifact_delta, + 'rewind_before_invocation_id': ( + event.actions.rewind_before_invocation_id + ), + } + for event in events + ] + return self._hash_rewind_payload({'events': summarized_events}) + + def _hash_rewind_payload(self, payload: dict[str, Any]) -> str: + """Returns a canonical SHA-256 digest for rewind audit payloads.""" + canonical_json = json.dumps( + payload, + sort_keys=True, + separators=(',', ':'), + ensure_ascii=True, + ) + return hashlib.sha256(canonical_json.encode('utf-8')).hexdigest() + async def _compute_state_delta_for_rewind( self, session: Session, rewind_event_index: int ) -> dict[str, Any]: diff --git a/src/google/adk/tools/enterprise_search_tool.py b/src/google/adk/tools/enterprise_search_tool.py index 4f7a0d7f35..c114fdb46d 100644 --- a/src/google/adk/tools/enterprise_search_tool.py +++ b/src/google/adk/tools/enterprise_search_tool.py @@ -21,6 +21,7 @@ from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -54,14 +55,16 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if is_gemini_model(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + if is_gemini_model(llm_request.model) or model_check_disabled: if is_gemini_1_model(llm_request.model) and llm_request.config.tools: raise ValueError( 'Enterprise Web Search tool cannot be used with other tools in' ' Gemini 1.x.' ) - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.tools = llm_request.config.tools or [] llm_request.config.tools.append( types.Tool(enterprise_web_search=types.EnterpriseWebSearch()) ) diff --git a/src/google/adk/tools/google_maps_grounding_tool.py b/src/google/adk/tools/google_maps_grounding_tool.py index bade0a3385..d4b105ec1e 100644 --- a/src/google/adk/tools/google_maps_grounding_tool.py +++ b/src/google/adk/tools/google_maps_grounding_tool.py @@ -21,6 +21,7 @@ from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -49,13 +50,14 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): raise ValueError( 'Google Maps grounding tool cannot be used with Gemini 1.x models.' ) - elif is_gemini_model(llm_request.model): + elif is_gemini_model(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(google_maps=types.GoogleMaps()) ) diff --git a/src/google/adk/tools/google_search_tool.py b/src/google/adk/tools/google_search_tool.py index 406ad2189e..1c11e091de 100644 --- a/src/google/adk/tools/google_search_tool.py +++ b/src/google/adk/tools/google_search_tool.py @@ -21,6 +21,7 @@ from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -67,6 +68,7 @@ async def process_llm_request( if self.model is not None: llm_request.model = self.model + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): @@ -77,7 +79,7 @@ async def process_llm_request( llm_request.config.tools.append( types.Tool(google_search_retrieval=types.GoogleSearchRetrieval()) ) - elif is_gemini_model(llm_request.model): + elif is_gemini_model(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(google_search=types.GoogleSearch()) ) diff --git a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py index 206819a9be..4d564ca164 100644 --- a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +++ b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py @@ -24,6 +24,7 @@ from typing_extensions import override from ...utils.model_name_utils import is_gemini_2_or_above +from ...utils.model_name_utils import is_gemini_model_id_check_disabled from ..tool_context import ToolContext from .base_retrieval_tool import BaseRetrievalTool @@ -63,7 +64,8 @@ async def process_llm_request( llm_request: LlmRequest, ) -> None: # Use Gemini built-in Vertex AI RAG tool for Gemini 2 models. - if is_gemini_2_or_above(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + if is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config = ( types.GenerateContentConfig() if not llm_request.config diff --git a/src/google/adk/tools/url_context_tool.py b/src/google/adk/tools/url_context_tool.py index fcdf76dab5..5e923e7447 100644 --- a/src/google/adk/tools/url_context_tool.py +++ b/src/google/adk/tools/url_context_tool.py @@ -21,6 +21,7 @@ from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_2_or_above +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -46,11 +47,12 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + model_check_disabled = is_gemini_model_id_check_disabled() llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): raise ValueError('Url context tool cannot be used in Gemini 1.x.') - elif is_gemini_2_or_above(llm_request.model): + elif is_gemini_2_or_above(llm_request.model) or model_check_disabled: llm_request.config.tools.append( types.Tool(url_context=types.UrlContext()) ) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index 91fe60e553..46104c5ed4 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -24,6 +24,7 @@ from ..agents.readonly_context import ReadonlyContext from ..utils.model_name_utils import is_gemini_1_model from ..utils.model_name_utils import is_gemini_model +from ..utils.model_name_utils import is_gemini_model_id_check_disabled from .base_tool import BaseTool from .tool_context import ToolContext @@ -141,14 +142,16 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - if is_gemini_model(llm_request.model): + model_check_disabled = is_gemini_model_id_check_disabled() + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tools = llm_request.config.tools or [] + + if is_gemini_model(llm_request.model) or model_check_disabled: if is_gemini_1_model(llm_request.model) and llm_request.config.tools: raise ValueError( 'Vertex AI search tool cannot be used with other tools in Gemini' ' 1.x.' ) - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.tools = llm_request.config.tools or [] # Build the search config (can be overridden by subclasses) vertex_ai_search_config = self._build_vertex_ai_search_config( diff --git a/src/google/adk/utils/model_name_utils.py b/src/google/adk/utils/model_name_utils.py index 4960b0b78f..57103fb2c7 100644 --- a/src/google/adk/utils/model_name_utils.py +++ b/src/google/adk/utils/model_name_utils.py @@ -22,6 +22,19 @@ from packaging.version import InvalidVersion from packaging.version import Version +from .env_utils import is_env_enabled + +_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR = 'ADK_DISABLE_GEMINI_MODEL_ID_CHECK' + + +def is_gemini_model_id_check_disabled() -> bool: + """Returns True when Gemini model-id validation should be bypassed. + + This opt-in environment variable is intended for internal usage where model + ids may not follow the public ``gemini-*`` naming convention. + """ + return is_env_enabled(_DISABLE_GEMINI_MODEL_ID_CHECK_ENV_VAR) + def extract_model_name(model_string: str) -> str: """Extract the actual model name from either simple or path-based format. diff --git a/tests/unittests/code_executors/test_built_in_code_executor.py b/tests/unittests/code_executors/test_built_in_code_executor.py index 58f54c7cef..cbf128fba9 100644 --- a/tests/unittests/code_executors/test_built_in_code_executor.py +++ b/tests/unittests/code_executors/test_built_in_code_executor.py @@ -97,6 +97,22 @@ def test_process_llm_request_non_gemini_2_model( ) +def test_process_llm_request_non_gemini_2_model_with_disabled_check( + built_in_executor: BuiltInCodeExecutor, + monkeypatch, +): + """Tests non-Gemini models pass when model-id check is disabled.""" + monkeypatch.setenv("ADK_DISABLE_GEMINI_MODEL_ID_CHECK", "true") + llm_request = LlmRequest(model="internal-model-v1") + + built_in_executor.process_llm_request(llm_request) + + assert llm_request.config is not None + assert llm_request.config.tools == [ + types.Tool(code_execution=types.ToolCodeExecution()) + ] + + def test_process_llm_request_no_model_name( built_in_executor: BuiltInCodeExecutor, ): diff --git a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py index 6f342a08b1..c498b8335b 100644 --- a/tests/unittests/memory/test_vertex_ai_memory_bank_service.py +++ b/tests/unittests/memory/test_vertex_ai_memory_bank_service.py @@ -230,6 +230,14 @@ async def test_initialize_with_project_location_and_api_key_error(): ) +def test_initialize_without_agent_engine_id_error(): + with pytest.raises( + ValueError, + match='agent_engine_id is required for VertexAiMemoryBankService', + ): + mock_vertex_ai_memory_bank_service(agent_engine_id=None) + + @pytest.mark.asyncio async def test_add_session_to_memory(mock_vertexai_client): memory_service = mock_vertex_ai_memory_bank_service() @@ -481,6 +489,7 @@ async def test_add_memory_calls_create( ), ], custom_metadata={ + 'enable_consolidation': False, 'ttl': '6000s', 'source': 'agent', }, @@ -518,6 +527,139 @@ async def test_add_memory_calls_create( vertex_common_types.AgentEngineMemoryConfig(**create_config) +@pytest.mark.asyncio +async def test_add_memory_enable_consolidation_calls_generate_direct_source( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + ], + custom_metadata={ + 'enable_consolidation': True, + 'source': 'agent', + }, + ) + + expected_config = {'wait_for_completion': False} + if _supports_generate_memories_metadata(): + expected_config['metadata'] = {'source': {'string_value': 'agent'}} + + mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact one'}, + {'fact': 'fact two'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config=expected_config, + ) + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + generate_config = ( + mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[ + 'config' + ] + ) + vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config) + + +@pytest.mark.asyncio +async def test_add_memory_enable_consolidation_batches_generate_calls( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact two')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact three')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact four')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact five')]) + ), + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact six')]) + ), + ], + custom_metadata={ + 'enable_consolidation': True, + }, + ) + + mock_vertexai_client.agent_engines.memories.generate.assert_has_awaits([ + mock.call( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact one'}, + {'fact': 'fact two'}, + {'fact': 'fact three'}, + {'fact': 'fact four'}, + {'fact': 'fact five'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ), + mock.call( + name='reasoningEngines/123', + direct_memories_source={ + 'direct_memories': [ + {'fact': 'fact six'}, + ] + }, + scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID}, + config={'wait_for_completion': False}, + ), + ]) + assert mock_vertexai_client.agent_engines.memories.generate.await_count == 2 + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_memory_invalid_enable_consolidation_type_raises( + mock_vertexai_client, +): + memory_service = mock_vertex_ai_memory_bank_service() + with pytest.raises( + TypeError, + match=r'custom_metadata\["enable_consolidation"\] must be a bool', + ): + await memory_service.add_memory( + app_name=MOCK_SESSION.app_name, + user_id=MOCK_SESSION.user_id, + memories=[ + MemoryEntry( + content=types.Content(parts=[types.Part(text='fact one')]) + ) + ], + custom_metadata={'enable_consolidation': 'yes'}, + ) + mock_vertexai_client.agent_engines.memories.generate.assert_not_called() + mock_vertexai_client.agent_engines.memories.create.assert_not_called() + + @pytest.mark.asyncio async def test_add_memory_calls_create_with_memory_entry_metadata( mock_vertexai_client, diff --git a/tests/unittests/runners/test_runner_rewind.py b/tests/unittests/runners/test_runner_rewind.py index 035d28437b..562b53bcb0 100644 --- a/tests/unittests/runners/test_runner_rewind.py +++ b/tests/unittests/runners/test_runner_rewind.py @@ -154,6 +154,15 @@ async def test_rewind_async_with_state_and_artifacts(self): ) is None ) + rewind_receipt = session.events[-1].actions.rewind_audit_receipt + assert rewind_receipt is not None + assert rewind_receipt.rewind_before_invocation_id == "invocation2" + assert rewind_receipt.boundary_after_invocation_id == "invocation1" + assert rewind_receipt.events_before_rewind == 3 + assert rewind_receipt.events_after_rewind == 1 + assert rewind_receipt.history_before_hash + assert rewind_receipt.history_after_hash + assert rewind_receipt.receipt_hash @pytest.mark.asyncio async def test_rewind_async_not_first_invocation(self): @@ -246,3 +255,40 @@ async def test_rewind_async_not_first_invocation(self): session_id=session_id, filename="f2", ) == types.Part.from_text(text="f2v0") + + @pytest.mark.asyncio + async def test_rewind_receipt_hash_is_deterministic(self): + """Tests that rewind receipt hashes are stable for the same history.""" + runner = self.runner + user_id = "test_user" + session_id = "test_session" + session = await runner.session_service.create_session( + app_name=runner.app_name, user_id=user_id, session_id=session_id + ) + + for invocation_id in ("invocation1", "invocation2", "invocation3"): + await runner.session_service.append_event( + session=session, + event=Event( + invocation_id=invocation_id, + author="agent", + actions=EventActions(state_delta={invocation_id: invocation_id}), + ), + ) + + first_receipt = runner._build_rewind_audit_receipt( + session=session, + rewind_event_index=1, + rewind_before_invocation_id="invocation2", + ) + second_receipt = runner._build_rewind_audit_receipt( + session=session, + rewind_event_index=1, + rewind_before_invocation_id="invocation2", + ) + + assert ( + first_receipt.history_before_hash == second_receipt.history_before_hash + ) + assert first_receipt.history_after_hash == second_receipt.history_after_hash + assert first_receipt.receipt_hash == second_receipt.receipt_hash diff --git a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py index 3b5aa26f8a..0a86d07c63 100644 --- a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +++ b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py @@ -145,3 +145,43 @@ def test_vertex_rag_retrieval_for_gemini_2_x(): ) ] assert 'rag_retrieval' not in mockModel.requests[0].tools_dict + + +def test_vertex_rag_retrieval_for_non_gemini_with_disabled_check(monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + responses = [ + 'response1', + ] + mockModel = testing_utils.MockModel.create(responses=responses) + mockModel.model = 'internal-model-v1' + + agent = Agent( + name='root_agent', + model=mockModel, + tools=[ + VertexAiRagRetrieval( + name='rag_retrieval', + description='rag_retrieval', + rag_corpora=[ + 'projects/123456789/locations/us-central1/ragCorpora/1234567890' + ], + ) + ], + ) + runner = testing_utils.InMemoryRunner(agent) + runner.run('test1') + + assert len(mockModel.requests) == 1 + assert len(mockModel.requests[0].config.tools) == 1 + assert mockModel.requests[0].config.tools == [ + types.Tool( + retrieval=types.Retrieval( + vertex_rag_store=types.VertexRagStore( + rag_corpora=[ + 'projects/123456789/locations/us-central1/ragCorpora/1234567890' + ] + ) + ) + ) + ] + assert 'rag_retrieval' not in mockModel.requests[0].tools_dict diff --git a/tests/unittests/tools/test_enterprise_web_search_tool.py b/tests/unittests/tools/test_enterprise_web_search_tool.py index ed4715963e..7b28d858fd 100644 --- a/tests/unittests/tools/test_enterprise_web_search_tool.py +++ b/tests/unittests/tools/test_enterprise_web_search_tool.py @@ -76,6 +76,25 @@ async def test_process_llm_request_failure_with_non_gemini_models(): assert 'is not supported for model' in str(exc_info.value) +@pytest.mark.asyncio +async def test_process_llm_request_non_gemini_with_disabled_check(monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = EnterpriseWebSearchTool() + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + tool_context = await _create_tool_context() + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert ( + llm_request.config.tools[0].enterprise_web_search + == types.EnterpriseWebSearch() + ) + + @pytest.mark.asyncio async def test_process_llm_request_failure_with_multiple_tools_gemini_1_models(): tool = EnterpriseWebSearchTool() diff --git a/tests/unittests/tools/test_google_maps_grounding_tool.py b/tests/unittests/tools/test_google_maps_grounding_tool.py new file mode 100644 index 0000000000..0cd2c4fa6c --- /dev/null +++ b/tests/unittests/tools/test_google_maps_grounding_tool.py @@ -0,0 +1,92 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.llm_request import LlmRequest +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.google_maps_grounding_tool import GoogleMapsGroundingTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + + +async def _create_tool_context() -> ToolContext: + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=agent, + session=session, + session_service=session_service, + ) + return ToolContext(invocation_context=invocation_context) + + +class TestGoogleMapsGroundingTool: + """Tests for GoogleMapsGroundingTool.""" + + @pytest.mark.asyncio + async def test_process_llm_request_with_gemini_2_model(self): + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='gemini-2.5-pro', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_maps is not None + + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_raises_error(self): + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='claude-3-sonnet', config=types.GenerateContentConfig() + ) + + with pytest.raises( + ValueError, + match='Google maps tool is not supported for model claude-3-sonnet', + ): + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_and_disabled_check( + self, monkeypatch + ): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = GoogleMapsGroundingTool() + tool_context = await _create_tool_context() + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_maps is not None diff --git a/tests/unittests/tools/test_google_search_tool.py b/tests/unittests/tools/test_google_search_tool.py index ad5d46b59e..d71061b883 100644 --- a/tests/unittests/tools/test_google_search_tool.py +++ b/tests/unittests/tools/test_google_search_tool.py @@ -268,6 +268,27 @@ async def test_process_llm_request_with_non_gemini_model_raises_error(self): tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = GoogleSearchTool() + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].google_search is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/tools/test_url_context_tool.py b/tests/unittests/tools/test_url_context_tool.py index 53ee7e6277..8fd44b59cb 100644 --- a/tests/unittests/tools/test_url_context_tool.py +++ b/tests/unittests/tools/test_url_context_tool.py @@ -190,6 +190,27 @@ async def test_process_llm_request_with_non_gemini_model_raises_error(self): tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = UrlContextTool() + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + assert llm_request.config.tools[0].url_context is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/tools/test_vertex_ai_search_tool.py b/tests/unittests/tools/test_vertex_ai_search_tool.py index 3ade634da6..b15d3a1f64 100644 --- a/tests/unittests/tools/test_vertex_ai_search_tool.py +++ b/tests/unittests/tools/test_vertex_ai_search_tool.py @@ -376,6 +376,29 @@ async def test_process_llm_request_with_non_gemini_model_raises_error(self): tool_context=tool_context, llm_request=llm_request ) + @pytest.mark.asyncio + async def test_process_llm_request_with_non_gemini_model_and_disabled_check( + self, monkeypatch + ): + """Test non-Gemini model can pass when model-id check is disabled.""" + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + tool = VertexAiSearchTool(data_store_id='test_data_store') + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model='internal-model-v1', config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1 + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + @pytest.mark.asyncio async def test_process_llm_request_with_path_based_non_gemini_model_raises_error( self, diff --git a/tests/unittests/utils/test_model_name_utils.py b/tests/unittests/utils/test_model_name_utils.py index cbac37e3f7..2af1584b05 100644 --- a/tests/unittests/utils/test_model_name_utils.py +++ b/tests/unittests/utils/test_model_name_utils.py @@ -18,6 +18,7 @@ from google.adk.utils.model_name_utils import is_gemini_1_model from google.adk.utils.model_name_utils import is_gemini_2_or_above from google.adk.utils.model_name_utils import is_gemini_model +from google.adk.utils.model_name_utils import is_gemini_model_id_check_disabled class TestExtractModelName: @@ -318,3 +319,15 @@ def test_path_vs_simple_model_consistency(self): f'Inconsistent Gemini 2.0+ classification for {simple_model} vs' f' {path_model}' ) + + +class TestGeminiModelIdCheckFlag: + """Tests for Gemini model-id check override flag.""" + + def test_default_is_disabled(self, monkeypatch): + monkeypatch.delenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', raising=False) + assert is_gemini_model_id_check_disabled() is False + + def test_true_enables_check_bypass(self, monkeypatch): + monkeypatch.setenv('ADK_DISABLE_GEMINI_MODEL_ID_CHECK', 'true') + assert is_gemini_model_id_check_disabled() is True