Skip to content

Commit 1b85aec

Browse files
yeesiancopybara-github
authored andcommitted
feat: Allow list of events to be passed to AdkApp when querying
PiperOrigin-RevId: 844834249
1 parent df0976e commit 1b85aec

File tree

1 file changed

+34
-2
lines changed
  • vertexai/agent_engines/templates

1 file changed

+34
-2
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ async def async_stream_query(
932932
message: Union[str, Dict[str, Any]],
933933
user_id: str,
934934
session_id: Optional[str] = None,
935+
session_events: Optional[Dict[str, Any]] = None,
935936
run_config: Optional[Dict[str, Any]] = None,
936937
**kwargs,
937938
) -> AsyncIterable[Dict[str, Any]]:
@@ -944,7 +945,11 @@ async def async_stream_query(
944945
Required. The ID of the user.
945946
session_id (str):
946947
Optional. The ID of the session. If not provided, a new
947-
session will be created for the user.
948+
session will be created for the user. If this is specified, then
949+
`session_events` will be ignored.
950+
session_events (Optional[List[Dict[str, Any]]]):
951+
Optional. The session events to use for the query. This will be
952+
used to initialize the session if `session_id` is not provided.
948953
run_config (Optional[Dict[str, Any]]):
949954
Optional. The run config to use for the query. If you want to
950955
pass in a `run_config` pydantic object, you can pass in a dict
@@ -974,6 +979,18 @@ async def async_stream_query(
974979
if not session_id:
975980
session = await self.async_create_session(user_id=user_id)
976981
session_id = session.id
982+
if session_events is not None:
983+
# We allow for session_events to be an empty list.
984+
from google.adk.events.event import Event
985+
986+
session_service = self._tmpl_attrs.get("session_service")
987+
async for event in session_events:
988+
if not isinstance(event, Event):
989+
event = Event.model_validate(event)
990+
await session_service.append_event(
991+
session=session,
992+
event=event,
993+
)
977994

978995
run_config = _validate_run_config(run_config)
979996
if run_config:
@@ -1009,6 +1026,7 @@ def stream_query(
10091026
message: Union[str, Dict[str, Any]],
10101027
user_id: str,
10111028
session_id: Optional[str] = None,
1029+
session_events: Optional[Dict[str, Any]] = None,
10121030
run_config: Optional[Dict[str, Any]] = None,
10131031
**kwargs,
10141032
):
@@ -1023,7 +1041,11 @@ def stream_query(
10231041
Required. The ID of the user.
10241042
session_id (str):
10251043
Optional. The ID of the session. If not provided, a new
1026-
session will be created for the user.
1044+
session will be created for the user. If this is specified, then
1045+
`session_events` will be ignored.
1046+
session_events (Optional[Dict[str, Any]]):
1047+
Optional. The session events to use for the query. This will be
1048+
used to initialize the session if `session_id` is not provided.
10271049
run_config (Optional[Dict[str, Any]]):
10281050
Optional. The run config to use for the query. If you want to
10291051
pass in a `run_config` pydantic object, you can pass in a dict
@@ -1063,6 +1085,16 @@ def stream_query(
10631085
if not session_id:
10641086
session = self.create_session(user_id=user_id)
10651087
session_id = session.id
1088+
if session_events is not None:
1089+
# We allow for session_events to be an empty list.
1090+
from google.adk.events.event import Event
1091+
1092+
session_service = self._tmpl_attrs.get("session_service")
1093+
for event_dict in session_events:
1094+
await session_service.append_event(
1095+
session=session,
1096+
event=Event.model_validate(event_dict),
1097+
)
10661098
run_config = _validate_run_config(run_config)
10671099
if run_config:
10681100
for event in self._tmpl_attrs.get("runner").run(

0 commit comments

Comments
 (0)