Skip to content

Commit f283027

Browse files
GWealecopybara-github
authored andcommitted
feat: expose service URI flags
Adds the shared adk_services_options decorator to adk run and other commands so developers can pass session/artifact URIs from the CLI Has new warning for the unsupported memory service on adk run, and removes the legacy --session_db_url/--artifact_storage_uri flags with tests Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 836743358
1 parent 06e6fc9 commit f283027

File tree

3 files changed

+143
-74
lines changed

3 files changed

+143
-74
lines changed

src/google/adk/cli/cli_tools_click.py

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import os
2525
from pathlib import Path
2626
import tempfile
27+
import textwrap
2728
from typing import Optional
2829

2930
import click
@@ -354,7 +355,62 @@ def validate_exclusive(ctx, param, value):
354355
return value
355356

356357

358+
def adk_services_options():
359+
"""Decorator to add ADK services options to click commands."""
360+
361+
def decorator(func):
362+
@click.option(
363+
"--session_service_uri",
364+
help=textwrap.dedent(
365+
"""\
366+
Optional. The URI of the session service.
367+
- Leave unset to use the in-memory session service (default).
368+
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
369+
sessions. <agent_engine> can either be the full qualified resource
370+
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
371+
the resource id '123'.
372+
- Use 'memory://' to run with the in-memory session service.
373+
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
374+
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported database URIs."""
375+
),
376+
)
377+
@click.option(
378+
"--artifact_service_uri",
379+
type=str,
380+
help=textwrap.dedent(
381+
"""\
382+
Optional. The URI of the artifact service.
383+
- Leave unset to store artifacts under '.adk/artifacts' locally.
384+
- Use 'gs://<bucket_name>' to connect to the GCS artifact service.
385+
- Use 'memory://' to force the in-memory artifact service.
386+
- Use 'file://<path>' to store artifacts in a custom local directory."""
387+
),
388+
default=None,
389+
)
390+
@click.option(
391+
"--memory_service_uri",
392+
type=str,
393+
help=textwrap.dedent("""\
394+
Optional. The URI of the memory service.
395+
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service.
396+
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
397+
sessions. <agent_engine> can either be the full qualified resource
398+
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
399+
the resource id '123'.
400+
- Use 'memory://' to force the in-memory memory service."""),
401+
default=None,
402+
)
403+
@functools.wraps(func)
404+
def wrapper(*args, **kwargs):
405+
return func(*args, **kwargs)
406+
407+
return wrapper
408+
409+
return decorator
410+
411+
357412
@main.command("run", cls=HelpfulCommand)
413+
@adk_services_options()
358414
@click.option(
359415
"--save_session",
360416
type=bool,
@@ -409,6 +465,9 @@ def cli_run(
409465
session_id: Optional[str],
410466
replay: Optional[str],
411467
resume: Optional[str],
468+
session_service_uri: Optional[str] = None,
469+
artifact_service_uri: Optional[str] = None,
470+
memory_service_uri: Optional[str] = None,
412471
):
413472
"""Runs an interactive CLI for a certain agent.
414473
@@ -420,6 +479,14 @@ def cli_run(
420479
"""
421480
logs.log_to_tmp_folder()
422481

482+
# Validation warning for memory_service_uri (not supported for adk run)
483+
if memory_service_uri:
484+
click.secho(
485+
"WARNING: --memory_service_uri is not supported for adk run.",
486+
fg="yellow",
487+
err=True,
488+
)
489+
423490
agent_parent_folder = os.path.dirname(agent)
424491
agent_folder_name = os.path.basename(agent)
425492

@@ -431,6 +498,8 @@ def cli_run(
431498
saved_session_file=resume,
432499
save_session=save_session,
433500
session_id=session_id,
501+
session_service_uri=session_service_uri,
502+
artifact_service_uri=artifact_service_uri,
434503
)
435504
)
436505

@@ -865,63 +934,14 @@ def wrapper(*args, **kwargs):
865934
return decorator
866935

867936

868-
def adk_services_options():
869-
"""Decorator to add ADK services options to click commands."""
870-
871-
def decorator(func):
872-
@click.option(
873-
"--session_service_uri",
874-
help=(
875-
"""Optional. The URI of the session service.
876-
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
877-
sessions. <agent_engine> can either be the full qualified resource
878-
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
879-
the resource id '123'.
880-
- Use 'sqlite://<path_to_sqlite_file>' to connect to an aio-sqlite
881-
based session service, which is good for local development.
882-
- Use 'postgresql://<user>:<password>@<host>:<port>/<database_name>'
883-
to connect to a PostgreSQL DB.
884-
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls
885-
for more details on other database URIs supported by SQLAlchemy."""
886-
),
887-
)
888-
@click.option(
889-
"--artifact_service_uri",
890-
type=str,
891-
help=(
892-
"Optional. The URI of the artifact service,"
893-
" supported URIs: gs://<bucket name> for GCS artifact service."
894-
),
895-
default=None,
896-
)
897-
@click.option(
898-
"--memory_service_uri",
899-
type=str,
900-
help=("""Optional. The URI of the memory service.
901-
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service.
902-
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
903-
sessions. <agent_engine> can either be the full qualified resource
904-
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
905-
the resource id '123'."""),
906-
default=None,
907-
)
908-
@functools.wraps(func)
909-
def wrapper(*args, **kwargs):
910-
return func(*args, **kwargs)
911-
912-
return wrapper
913-
914-
return decorator
915-
916-
917937
def deprecated_adk_services_options():
918938
"""Deprecated ADK services options."""
919939

920940
def warn(alternative_param, ctx, param, value):
921941
if value:
922942
click.echo(
923943
click.style(
924-
f"WARNING: Deprecated option {param.name} is used. Please use"
944+
f"WARNING: Deprecated option --{param.name} is used. Please use"
925945
f" {alternative_param} instead.",
926946
fg="yellow",
927947
),
@@ -1116,6 +1136,8 @@ def cli_web(
11161136
11171137
adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir
11181138
"""
1139+
session_service_uri = session_service_uri or session_db_url
1140+
artifact_service_uri = artifact_service_uri or artifact_storage_uri
11191141
logs.setup_adk_logger(getattr(logging, log_level.upper()))
11201142

11211143
@asynccontextmanager
@@ -1140,8 +1162,6 @@ async def _lifespan(app: FastAPI):
11401162
fg="green",
11411163
)
11421164

1143-
session_service_uri = session_service_uri or session_db_url
1144-
artifact_service_uri = artifact_service_uri or artifact_storage_uri
11451165
app = get_fast_api_app(
11461166
agents_dir=agents_dir,
11471167
session_service_uri=session_service_uri,
@@ -1215,10 +1235,10 @@ def cli_api_server(
12151235
12161236
adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir
12171237
"""
1218-
logs.setup_adk_logger(getattr(logging, log_level.upper()))
1219-
12201238
session_service_uri = session_service_uri or session_db_url
12211239
artifact_service_uri = artifact_service_uri or artifact_storage_uri
1240+
logs.setup_adk_logger(getattr(logging, log_level.upper()))
1241+
12221242
config = uvicorn.Config(
12231243
get_fast_api_app(
12241244
agents_dir=agents_dir,

src/google/adk/cli/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
from ...agents.base_agent import BaseAgent
2020
from ...agents.llm_agent import LlmAgent
21+
from .dot_adk_folder import DotAdkFolder
2122
from .state import create_empty_state
2223

2324
__all__ = [
2425
'create_empty_state',
26+
'DotAdkFolder',
2527
]

tests/unittests/cli/utils/test_cli_tools_click.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401
7676

7777
# Fixtures
7878
@pytest.fixture(autouse=True)
79-
def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None:
79+
def _mute_click(request, monkeypatch: pytest.MonkeyPatch) -> None:
8080
"""Suppress click output during tests."""
81+
# Allow tests to opt-out of muting by using the 'unmute_click' marker
82+
if "unmute_click" in request.keywords:
83+
return
8184
monkeypatch.setattr(click, "echo", lambda *a, **k: None)
8285
# Keep secho for error messages
8386
# monkeypatch.setattr(click, "secho", lambda *a, **k: None)
@@ -121,32 +124,70 @@ def test_cli_create_cmd_invokes_run_cmd(
121124
cli_tools_click.main,
122125
["create", "--model", "gemini", "--api_key", "key123", str(app_dir)],
123126
)
124-
assert result.exit_code == 0
127+
assert result.exit_code == 0, (result.output, repr(result.exception))
125128
assert rec.calls, "cli_create.run_cmd must be called"
126129

127130

128131
# cli run
129-
@pytest.mark.asyncio
130-
async def test_cli_run_invokes_run_cli(
131-
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
132+
@pytest.mark.parametrize(
133+
"cli_args,expected_session_uri,expected_artifact_uri",
134+
[
135+
pytest.param(
136+
[
137+
"--session_service_uri",
138+
"memory://",
139+
"--artifact_service_uri",
140+
"memory://",
141+
],
142+
"memory://",
143+
"memory://",
144+
id="memory_scheme_uris",
145+
),
146+
pytest.param(
147+
[],
148+
None,
149+
None,
150+
id="default_uris_none",
151+
),
152+
],
153+
)
154+
def test_cli_run_service_uris(
155+
tmp_path: Path,
156+
monkeypatch: pytest.MonkeyPatch,
157+
cli_args: list,
158+
expected_session_uri: str,
159+
expected_artifact_uri: str,
132160
) -> None:
133-
"""`adk run` should call run_cli via asyncio.run with correct parameters."""
134-
rec = _Recorder()
135-
monkeypatch.setattr(cli_tools_click, "run_cli", lambda **kwargs: rec(kwargs))
136-
monkeypatch.setattr(
137-
cli_tools_click.asyncio, "run", lambda coro: coro
138-
) # pass-through
139-
140-
# create dummy agent directory
161+
"""`adk run` should forward service URIs correctly to run_cli."""
141162
agent_dir = tmp_path / "agent"
142163
agent_dir.mkdir()
143164
(agent_dir / "__init__.py").touch()
144165
(agent_dir / "agent.py").touch()
145166

167+
# Capture the coroutine's locals before closing it
168+
captured_locals = []
169+
170+
def capture_asyncio_run(coro):
171+
# Extract the locals before closing the coroutine
172+
if coro.cr_frame is not None:
173+
captured_locals.append(dict(coro.cr_frame.f_locals))
174+
coro.close() # Properly close the coroutine to avoid warnings
175+
176+
monkeypatch.setattr(cli_tools_click.asyncio, "run", capture_asyncio_run)
177+
146178
runner = CliRunner()
147-
result = runner.invoke(cli_tools_click.main, ["run", str(agent_dir)])
148-
assert result.exit_code == 0
149-
assert rec.calls and rec.calls[0][0][0]["agent_folder_name"] == "agent"
179+
result = runner.invoke(
180+
cli_tools_click.main,
181+
["run", *cli_args, str(agent_dir)],
182+
)
183+
assert result.exit_code == 0, (result.output, repr(result.exception))
184+
assert len(captured_locals) == 1, "Expected asyncio.run to be called once"
185+
186+
# Verify the kwargs passed to run_cli
187+
coro_locals = captured_locals[0]
188+
assert coro_locals.get("session_service_uri") == expected_session_uri
189+
assert coro_locals.get("artifact_service_uri") == expected_artifact_uri
190+
assert coro_locals["agent_folder_name"] == "agent"
150191

151192

152193
# cli deploy cloud_run
@@ -520,10 +561,13 @@ def test_cli_web_passes_service_uris(
520561
assert called_kwargs.get("memory_service_uri") == "rag://mycorpus"
521562

522563

523-
def test_cli_web_passes_deprecated_uris(
524-
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, _patch_uvicorn: _Recorder
564+
@pytest.mark.unmute_click
565+
def test_cli_web_warns_and_maps_deprecated_uris(
566+
tmp_path: Path,
567+
_patch_uvicorn: _Recorder,
568+
monkeypatch: pytest.MonkeyPatch,
525569
) -> None:
526-
"""`adk web` should use deprecated URIs if new ones are not provided."""
570+
"""`adk web` should accept deprecated URI flags with warnings."""
527571
agents_dir = tmp_path / "agents"
528572
agents_dir.mkdir()
529573

@@ -542,11 +586,14 @@ def test_cli_web_passes_deprecated_uris(
542586
"gs://deprecated",
543587
],
544588
)
589+
545590
assert result.exit_code == 0
546-
assert mock_get_app.calls
547591
called_kwargs = mock_get_app.calls[0][1]
548592
assert called_kwargs.get("session_service_uri") == "sqlite:///deprecated.db"
549593
assert called_kwargs.get("artifact_service_uri") == "gs://deprecated"
594+
# Check output for deprecation warnings (CliRunner captures both stdout and stderr)
595+
assert "--session_db_url" in result.output
596+
assert "--artifact_storage_uri" in result.output
550597

551598

552599
def test_cli_eval_with_eval_set_file_path(

0 commit comments

Comments
 (0)