Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from ..memory.base_memory_service import BaseMemoryService
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
from ..runners import Runner
from .adk_web_server import AdkWebServer
from .service_registry import load_services_module
Expand Down Expand Up @@ -79,6 +82,7 @@ def get_fast_api_app(
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
use_local_storage: bool = True,
memory_service: Optional[BaseMemoryService] = None,
eval_storage_uri: Optional[str] = None,
allow_origins: Optional[list[str]] = None,
web: bool,
Expand Down Expand Up @@ -161,13 +165,13 @@ def get_fast_api_app(
load_services_module(agents_dir)

# Build the Memory service
try:
if memory_service:
pass
else:
memory_service = create_memory_service_from_options(
base_dir=agents_dir,
memory_service_uri=memory_service_uri,
)
except ValueError as exc:
raise click.ClickException(str(exc)) from exc

# Build the Session service
session_service = create_session_service_from_options(
Expand Down
70 changes: 70 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,5 +1717,75 @@ async def run_async_session_not_found(self, **kwargs):
assert "Session not found" in response.json()["detail"]


def test_get_fast_api_app_with_custom_memory_service(
mock_session_service,
mock_artifact_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
):
"""Test that custom memory_service is used directly when provided."""
custom_memory_service = MagicMock()

with (
patch.object(signal, "signal", autospec=True, return_value=None),
patch.object(
fast_api_module,
"create_session_service_from_options",
autospec=True,
return_value=mock_session_service,
),
patch.object(
fast_api_module,
"create_artifact_service_from_options",
autospec=True,
return_value=mock_artifact_service,
),
patch.object(
fast_api_module,
"create_memory_service_from_options",
autospec=True,
) as mock_create_memory_service,
patch.object(
fast_api_module,
"AgentLoader",
autospec=True,
return_value=mock_agent_loader,
),
patch.object(
fast_api_module,
"LocalEvalSetsManager",
autospec=True,
return_value=mock_eval_sets_manager,
),
patch.object(
fast_api_module,
"LocalEvalSetResultsManager",
autospec=True,
return_value=mock_eval_set_results_manager,
),
patch.object(
fast_api_module,
"load_services_module",
autospec=True,
return_value=None,
),
):
app = get_fast_api_app(
agents_dir=".",
web=True,
session_service_uri="",
artifact_service_uri="",
memory_service_uri="",
memory_service=custom_memory_service,
allow_origins=["*"],
a2a=False,
host="127.0.0.1",
port=8000,
)

mock_create_memory_service.assert_not_called()


if __name__ == "__main__":
pytest.main(["-xvs", __file__])