diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 8f78c15f9b..0ff5e9c20b 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -38,6 +38,9 @@ from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager +from ..memory.base_memory_service import BaseMemoryService +from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from ..runners import Runner from .adk_web_server import AdkWebServer from .service_registry import load_services_module @@ -79,6 +82,7 @@ def get_fast_api_app( artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = True, + memory_service: Optional[BaseMemoryService] = None, eval_storage_uri: Optional[str] = None, allow_origins: Optional[list[str]] = None, web: bool, @@ -161,13 +165,13 @@ def get_fast_api_app( load_services_module(agents_dir) # Build the Memory service - try: + if memory_service: + pass + else: memory_service = create_memory_service_from_options( base_dir=agents_dir, memory_service_uri=memory_service_uri, ) - except ValueError as exc: - raise click.ClickException(str(exc)) from exc # Build the Session service session_service = create_session_service_from_options( diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d6ccf6e218..a96fee8457 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1717,5 +1717,75 @@ async def run_async_session_not_found(self, **kwargs): assert "Session not found" in response.json()["detail"] +def test_get_fast_api_app_with_custom_memory_service( + mock_session_service, + mock_artifact_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + """Test that custom memory_service is used directly when provided.""" + custom_memory_service = MagicMock() + + with ( + patch.object(signal, "signal", autospec=True, return_value=None), + patch.object( + fast_api_module, + "create_session_service_from_options", + autospec=True, + return_value=mock_session_service, + ), + patch.object( + fast_api_module, + "create_artifact_service_from_options", + autospec=True, + return_value=mock_artifact_service, + ), + patch.object( + fast_api_module, + "create_memory_service_from_options", + autospec=True, + ) as mock_create_memory_service, + patch.object( + fast_api_module, + "AgentLoader", + autospec=True, + return_value=mock_agent_loader, + ), + patch.object( + fast_api_module, + "LocalEvalSetsManager", + autospec=True, + return_value=mock_eval_sets_manager, + ), + patch.object( + fast_api_module, + "LocalEvalSetResultsManager", + autospec=True, + return_value=mock_eval_set_results_manager, + ), + patch.object( + fast_api_module, + "load_services_module", + autospec=True, + return_value=None, + ), + ): + app = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + memory_service=custom_memory_service, + allow_origins=["*"], + a2a=False, + host="127.0.0.1", + port=8000, + ) + + mock_create_memory_service.assert_not_called() + + if __name__ == "__main__": pytest.main(["-xvs", __file__])