Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions vertexai/_genai/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,52 @@ def _GetAgentEngineMemoryRequestParameters_to_vertex(
return to_object


def _IngestEventsConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}

if getv(from_object, ["force_flush"]) is not None:
setv(parent_object, ["forceFlush"], getv(from_object, ["force_flush"]))

return to_object


def _IngestEventsRequestParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ["config"]) is not None:
_IngestEventsConfig_to_vertex(getv(from_object, ["config"]), to_object)

if getv(from_object, ["name"]) is not None:
setv(to_object, ["_url", "name"], getv(from_object, ["name"]))

if getv(from_object, ["stream_id"]) is not None:
setv(to_object, ["streamId"], getv(from_object, ["stream_id"]))

if getv(from_object, ["direct_contents_source"]) is not None:
setv(
to_object,
["directContentsSource"],
getv(from_object, ["direct_contents_source"]),
)

if getv(from_object, ["scope"]) is not None:
setv(to_object, ["scope"], getv(from_object, ["scope"]))

if getv(from_object, ["generation_trigger_config"]) is not None:
setv(
to_object,
["generationTriggerConfig"],
getv(from_object, ["generation_trigger_config"]),
)

return to_object


def _ListAgentEngineMemoryConfig_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -713,6 +759,69 @@ def get(
self._api_client._verify_response(return_value)
return return_value

def _ingest_events(
self,
*,
config: Optional[types.IngestEventsConfigOrDict] = None,
name: str,
stream_id: Optional[str] = None,
direct_contents_source: Optional[
types.IngestionDirectContentsSourceOrDict
] = None,
scope: Optional[dict[str, str]] = None,
generation_trigger_config: Optional[types.GenerationTriggerConfigOrDict] = None,
) -> types.MemoryBankIngestEventsOperation:
"""
Ingest events into a Memory Bank.
"""

parameter_model = types._IngestEventsRequestParameters(
config=config,
name=name,
stream_id=stream_id,
direct_contents_source=direct_contents_source,
scope=scope,
generation_trigger_config=generation_trigger_config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
else:
request_dict = _IngestEventsRequestParameters_to_vertex(parameter_model)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "{name}/memories:ingestEvents".format_map(request_url_dict)
else:
path = "{name}/memories:ingestEvents"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = self._api_client.request("post", path, request_dict, http_options)

response_dict = {} if not response.body else json.loads(response.body)

return_value = types.MemoryBankIngestEventsOperation._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)

self._api_client._verify_response(return_value)
return return_value

def _list(
self,
*,
Expand Down Expand Up @@ -1416,6 +1525,55 @@ def purge(
raise RuntimeError(f"Failed to purge memories: {operation.error}")
return operation

def ingest_events(
self,
*,
name: str,
scope: dict[str, str],
stream_id: str = None,
direct_contents_source: Optional[
types.IngestionDirectContentsSourceOrDict
] = None,
generation_trigger_config: Optional[types.GenerationTriggerConfigOrDict] = None,
config: Optional[types.IngestEventsConfigOrDict] = None,
) -> types.MemoryBankIngestEventsOperation:
"""Ingests events into an Agent Engine.

Args:
name (str):
Required. The name of the Agent Engine to ingest events into.
scope (dict[str, str]):
Required. The scope of the events to ingest. For example,
{"user_id": "123"}.
config (IngestEventsConfig):
Optional. The configuration for the ingest events operation.

Returns:
AgentEngineIngestEventsOperation:
The operation for ingesting the events.
"""
if config is None:
config = types.IngestEventsConfig()
elif isinstance(config, dict):
config = types.IngestEventsConfig.model_validate(config)
operation = self._ingest_events(
name=name,
scope=scope,
stream_id=stream_id,
generation_trigger_config=generation_trigger_config,
direct_contents_source=direct_contents_source,
config=config,
)
if config.wait_for_completion and not operation.done:
operation = _agent_engines_utils._await_operation(
operation_name=operation.name,
get_operation_fn=self._get_memory_operation,
poll_interval_seconds=0.5,
)
if operation.error:
raise RuntimeError(f"Failed to ingest events: {operation.error}")
return operation


class AsyncMemories(_api_module.BaseModule):

Expand Down Expand Up @@ -1679,6 +1837,71 @@ async def get(
self._api_client._verify_response(return_value)
return return_value

async def _ingest_events(
self,
*,
config: Optional[types.IngestEventsConfigOrDict] = None,
name: str,
stream_id: Optional[str] = None,
direct_contents_source: Optional[
types.IngestionDirectContentsSourceOrDict
] = None,
scope: Optional[dict[str, str]] = None,
generation_trigger_config: Optional[types.GenerationTriggerConfigOrDict] = None,
) -> types.MemoryBankIngestEventsOperation:
"""
Ingest events into a Memory Bank.
"""

parameter_model = types._IngestEventsRequestParameters(
config=config,
name=name,
stream_id=stream_id,
direct_contents_source=direct_contents_source,
scope=scope,
generation_trigger_config=generation_trigger_config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
else:
request_dict = _IngestEventsRequestParameters_to_vertex(parameter_model)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "{name}/memories:ingestEvents".format_map(request_url_dict)
else:
path = "{name}/memories:ingestEvents"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = await self._api_client.async_request(
"post", path, request_dict, http_options
)

response_dict = {} if not response.body else json.loads(response.body)

return_value = types.MemoryBankIngestEventsOperation._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)

self._api_client._verify_response(return_value)
return return_value

async def _list(
self,
*,
Expand Down
44 changes: 44 additions & 0 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .common import _GetEvaluationSetParameters
from .common import _GetMultimodalDatasetOperationParameters
from .common import _GetMultimodalDatasetParameters
from .common import _IngestEventsRequestParameters
from .common import _ListAgentEngineMemoryRequestParameters
from .common import _ListAgentEngineMemoryRevisionsRequestParameters
from .common import _ListAgentEngineRequestParameters
Expand Down Expand Up @@ -475,6 +476,9 @@
from .common import GenerateUserScenariosResponse
from .common import GenerateUserScenariosResponseDict
from .common import GenerateUserScenariosResponseOrDict
from .common import GenerationTriggerConfig
from .common import GenerationTriggerConfigDict
from .common import GenerationTriggerConfigOrDict
from .common import GetAgentEngineConfig
from .common import GetAgentEngineConfigDict
from .common import GetAgentEngineConfigOrDict
Expand Down Expand Up @@ -519,6 +523,15 @@
from .common import GetPromptConfigOrDict
from .common import IdentityType
from .common import Importance
from .common import IngestEventsConfig
from .common import IngestEventsConfigDict
from .common import IngestEventsConfigOrDict
from .common import IngestionDirectContentsSource
from .common import IngestionDirectContentsSourceDict
from .common import IngestionDirectContentsSourceEvent
from .common import IngestionDirectContentsSourceEventDict
from .common import IngestionDirectContentsSourceEventOrDict
from .common import IngestionDirectContentsSourceOrDict
from .common import IntermediateExtractedMemory
from .common import IntermediateExtractedMemoryDict
from .common import IntermediateExtractedMemoryOrDict
Expand Down Expand Up @@ -646,13 +659,22 @@
from .common import MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicOrDict
from .common import MemoryBankCustomizationConfigMemoryTopicOrDict
from .common import MemoryBankCustomizationConfigOrDict
from .common import MemoryBankIngestEventsOperation
from .common import MemoryBankIngestEventsOperationDict
from .common import MemoryBankIngestEventsOperationOrDict
from .common import MemoryConjunctionFilter
from .common import MemoryConjunctionFilterDict
from .common import MemoryConjunctionFilterOrDict
from .common import MemoryDict
from .common import MemoryFilter
from .common import MemoryFilterDict
from .common import MemoryFilterOrDict
from .common import MemoryGenerationTriggerConfig
from .common import MemoryGenerationTriggerConfigDict
from .common import MemoryGenerationTriggerConfigGenerationTriggerRule
from .common import MemoryGenerationTriggerConfigGenerationTriggerRuleDict
from .common import MemoryGenerationTriggerConfigGenerationTriggerRuleOrDict
from .common import MemoryGenerationTriggerConfigOrDict
from .common import MemoryMetadataMergeStrategy
from .common import MemoryMetadataValue
from .common import MemoryMetadataValueDict
Expand Down Expand Up @@ -1580,6 +1602,12 @@
"MemoryBankCustomizationConfig",
"MemoryBankCustomizationConfigDict",
"MemoryBankCustomizationConfigOrDict",
"MemoryGenerationTriggerConfigGenerationTriggerRule",
"MemoryGenerationTriggerConfigGenerationTriggerRuleDict",
"MemoryGenerationTriggerConfigGenerationTriggerRuleOrDict",
"MemoryGenerationTriggerConfig",
"MemoryGenerationTriggerConfigDict",
"MemoryGenerationTriggerConfigOrDict",
"ReasoningEngineContextSpecMemoryBankConfigGenerationConfig",
"ReasoningEngineContextSpecMemoryBankConfigGenerationConfigDict",
"ReasoningEngineContextSpecMemoryBankConfigGenerationConfigOrDict",
Expand Down Expand Up @@ -1724,6 +1752,21 @@
"GetAgentEngineMemoryConfig",
"GetAgentEngineMemoryConfigDict",
"GetAgentEngineMemoryConfigOrDict",
"IngestionDirectContentsSourceEvent",
"IngestionDirectContentsSourceEventDict",
"IngestionDirectContentsSourceEventOrDict",
"IngestionDirectContentsSource",
"IngestionDirectContentsSourceDict",
"IngestionDirectContentsSourceOrDict",
"GenerationTriggerConfig",
"GenerationTriggerConfigDict",
"GenerationTriggerConfigOrDict",
"IngestEventsConfig",
"IngestEventsConfigDict",
"IngestEventsConfigOrDict",
"MemoryBankIngestEventsOperation",
"MemoryBankIngestEventsOperationDict",
"MemoryBankIngestEventsOperationOrDict",
"ListAgentEngineMemoryConfig",
"ListAgentEngineMemoryConfigDict",
"ListAgentEngineMemoryConfigOrDict",
Expand Down Expand Up @@ -2205,6 +2248,7 @@
"_DeleteAgentEngineMemoryRequestParameters",
"_GenerateAgentEngineMemoriesRequestParameters",
"_GetAgentEngineMemoryRequestParameters",
"_IngestEventsRequestParameters",
"_ListAgentEngineMemoryRequestParameters",
"_GetAgentEngineMemoryOperationParameters",
"_GetAgentEngineGenerateMemoriesOperationParameters",
Expand Down
Loading
Loading