Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vertexai/_genai/_agent_engines_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ def _generate_class_methods_spec_or_raise(
class_method = _to_proto(schema_dict)
class_method[_MODE_KEY_IN_SCHEMA] = mode
if hasattr(agent, "agent_card"):
class_method[_A2A_AGENT_CARD] = getattr(
agent, "agent_card"
).model_dump_json()
class_method[_A2A_AGENT_CARD] = json_format.MessageToJson(
getattr(agent, "agent_card")
)
class_methods_spec.append(class_method)

return class_methods_spec
Expand Down
9 changes: 6 additions & 3 deletions vertexai/_genai/agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,10 +1834,13 @@ def _create_config(
agent_card = getattr(agent, "agent_card")
if agent_card:
try:
agent_engine_spec["agent_card"] = agent_card.model_dump(
exclude_none=True
from google.protobuf import json_format
import json

agent_engine_spec["agent_card"] = json.loads(
json_format.MessageToJson(agent_card)
)
except TypeError as e:
except Exception as e:
raise ValueError(
f"Failed to convert agent card to dict (serialization error): {e}"
) from e
Expand Down
38 changes: 24 additions & 14 deletions vertexai/agent_engines/_agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,28 @@
try:
from a2a.types import (
AgentCard,
TransportProtocol,
AgentInterface,
Message,
TaskIdParams,
TaskQueryParams,
)
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
from a2a.client import ClientConfig, ClientFactory

AgentCard = AgentCard
AgentInterface = AgentInterface
TransportProtocol = TransportProtocol
PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_CURRENT
Message = Message
ClientConfig = ClientConfig
ClientFactory = ClientFactory
TaskIdParams = TaskIdParams
TaskQueryParams = TaskQueryParams
except (ImportError, AttributeError):
AgentCard = None
AgentInterface = None
TransportProtocol = None
PROTOCOL_VERSION_CURRENT = None
Message = None
ClientConfig = None
ClientFactory = None
Expand Down Expand Up @@ -1735,17 +1740,20 @@ async def _method(self, **kwargs) -> Any:
a2a_agent_card = AgentCard(**json.loads(agent_card))

# A2A + AE integration currently only supports Rest API.
if (
a2a_agent_card.preferred_transport
and a2a_agent_card.preferred_transport != TransportProtocol.http_json
):
if a2a_agent_card.supported_interfaces and a2a_agent_card.supported_interfaces[0].protocol_binding != TransportProtocol.HTTP_JSON:
raise ValueError(
"Only HTTP+JSON is supported for preferred transport on agent card "
"Only HTTP+JSON is supported for primary interface on agent card "
)

# Set preferred transport to HTTP+JSON if not set.
if not hasattr(a2a_agent_card, "preferred_transport"):
a2a_agent_card.preferred_transport = TransportProtocol.http_json
# Set primary interface to HTTP+JSON if not set.
if not a2a_agent_card.supported_interfaces:
a2a_agent_card.supported_interfaces = []
a2a_agent_card.supported_interfaces.append(
AgentInterface(
protocol_binding=TransportProtocol.HTTP_JSON,
protocol_version=PROTOCOL_VERSION_CURRENT,
)
)

# AE cannot support streaming yet. Turn off streaming for now.
if a2a_agent_card.capabilities and a2a_agent_card.capabilities.streaming:
Expand All @@ -1759,12 +1767,13 @@ async def _method(self, **kwargs) -> Any:

# agent_card is set on the class_methods before set_up is invoked.
# Ensure that the agent_card url is set correctly before the client is created.
a2a_agent_card.url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a"
url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a"
a2a_agent_card.supported_interfaces[0].url = url

# Using a2a client, inject the auth token from the global config.
config = ClientConfig(
supported_transports=[
TransportProtocol.http_json,
TransportProtocol.HTTP_JSON,
],
use_client_preference=True,
httpx_client=httpx.AsyncClient(
Expand Down Expand Up @@ -1977,9 +1986,10 @@ def _generate_class_methods_spec_or_raise(
class_method[_MODE_KEY_IN_SCHEMA] = mode
# A2A agent card is a special case, when running in A2A mode,
if hasattr(agent_engine, "agent_card"):
class_method[_A2A_AGENT_CARD] = getattr(
agent_engine, "agent_card"
).model_dump_json()
from google.protobuf import json_format
class_method[_A2A_AGENT_CARD] = json_format.MessageToJson(
getattr(agent_engine, "agent_card")
)
class_methods_spec.append(class_method)

return class_methods_spec
Expand Down
173 changes: 139 additions & 34 deletions vertexai/preview/reasoning_engines/templates/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def create_agent_card(
provided.
"""
# pylint: disable=g-import-not-at-top
from a2a.types import AgentCard, AgentCapabilities, TransportProtocol
from a2a.types import AgentCard, AgentCapabilities, AgentInterface
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT

# Check if a dictionary was provided.
if agent_card:
Expand All @@ -98,14 +99,20 @@ def create_agent_card(
return AgentCard(
name=agent_name,
description=description,
url="http://localhost:9999/",
version="1.0.0",
default_input_modes=default_input_modes or ["text/plain"],
default_output_modes=default_output_modes or ["application/json"],
capabilities=AgentCapabilities(streaming=streaming),
capabilities=AgentCapabilities(
streaming=streaming, extended_agent_card=True
),
skills=skills,
preferred_transport=TransportProtocol.http_json, # Http Only.
supports_authenticated_extended_card=True,
supported_interfaces=[
AgentInterface(
url="http://localhost:9999/",
protocol_binding=TransportProtocol.HTTP_JSON,
protocol_version=PROTOCOL_VERSION_CURRENT,
)
],
)

# Raise an error if insufficient data is provided.
Expand Down Expand Up @@ -162,6 +169,21 @@ async def cancel(
)


def _is_version_enabled(agent_card: "AgentCard", version: str) -> bool:
"""Checks if a specific version compatibility should be enabled for the A2aAgent."""
from a2a.utils.constants import TransportProtocol

if not agent_card.supported_interfaces:
return False
for interface in agent_card.supported_interfaces:
if (
interface.protocol_version == version
and interface.protocol_binding == TransportProtocol.HTTP_JSON
):
return True
return False


class A2aAgent:
"""A class to initialize and set up an Agent-to-Agent application."""

Expand All @@ -181,14 +203,15 @@ def __init__(
"""Initializes the A2A agent."""
# pylint: disable=g-import-not-at-top
from google.cloud.aiplatform import initializer
from a2a.types import TransportProtocol
from a2a.utils.constants import TransportProtocol

if (
agent_card.preferred_transport
and agent_card.preferred_transport != TransportProtocol.http_json
agent_card.supported_interfaces
and agent_card.supported_interfaces[0].protocol_binding
!= TransportProtocol.HTTP_JSON
):
raise ValueError(
"Only HTTP+JSON is supported for preferred transport on agent card "
"Only HTTP+JSON is supported for the primary interface on agent card "
)

self._tmpl_attrs: dict[str, Any] = {
Expand Down Expand Up @@ -244,7 +267,21 @@ def set_up(self):
agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "test-agent-engine")
version = "v1beta1"

self.agent_card.url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a"
new_url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a"
if not self.agent_card.supported_interfaces:
from a2a.types import AgentInterface
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT

self.agent_card.supported_interfaces.append(
AgentInterface(
url=new_url,
protocol_binding=TransportProtocol.HTTP_JSON,
protocol_version=PROTOCOL_VERSION_CURRENT,
)
)
else:
# primary interface must be HTTP+JSON
self.agent_card.supported_interfaces[0].url = new_url
self._tmpl_attrs["agent_card"] = self.agent_card

# Create the agent executor if a builder is provided.
Expand Down Expand Up @@ -286,17 +323,30 @@ def set_up(self):

# a2a_rest_adapter is used to register the A2A API routes in the
# Reasoning Engine API router.
self.a2a_rest_adapter = RESTAdapter(
agent_card=self.agent_card,
http_handler=self._tmpl_attrs.get("request_handler"),
extended_agent_card=self._tmpl_attrs.get("extended_agent_card"),
)
if _is_version_enabled(self.agent_card, "1.0"):
self.a2a_rest_adapter = RESTAdapter(
agent_card=self.agent_card,
http_handler=self._tmpl_attrs.get("request_handler"),
extended_agent_card=self._tmpl_attrs.get("extended_agent_card"),
)

# rest_handler is used to handle the A2A API requests.
self.rest_handler = RESTHandler(
agent_card=self.agent_card,
request_handler=self._tmpl_attrs.get("request_handler"),
)
# rest_handler is used to handle the A2A API requests.
self.rest_handler = RESTHandler(
agent_card=self.agent_card,
request_handler=self._tmpl_attrs.get("request_handler"),
)

# v0.3 handlers will be deprecated in the future.
if _is_version_enabled(self.agent_card, "0.3"):
from a2a.compat.v0_3.rest_adapter import REST03Adapter
from a2a.compat.v0_3.rest_handler import REST03Handler
import functools

self.v03_rest_adapter = REST03Adapter(
agent_card=self.agent_card,
http_handler=self._tmpl_attrs.get("request_handler"),
extended_agent_card=self._tmpl_attrs.get("extended_agent_card"),
)

async def on_message_send(
self,
Expand Down Expand Up @@ -330,18 +380,25 @@ async def handle_authenticated_agent_card(

def register_operations(self) -> Dict[str, List[str]]:
"""Registers the operations of the A2A Agent."""
routes = {
"a2a_extension": [
"on_message_send",
"on_get_task",
"on_cancel_task",
]
}
if self.agent_card.capabilities and self.agent_card.capabilities.streaming:
routes["a2a_extension"].append("on_message_send_stream")
routes["a2a_extension"].append("on_resubscribe_to_task")
if self.agent_card.supports_authenticated_extended_card:
routes["a2a_extension"].append("handle_authenticated_agent_card")
routes = {"a2a_extension": []}

if _is_version_enabled(self.agent_card, "1.0"):
routes["a2a_extension"].extend(
[
"on_message_send",
"on_get_task",
"on_cancel_task",
]
)
if self.agent_card.capabilities and self.agent_card.capabilities.streaming:
routes["a2a_extension"].append("on_message_send_stream")
routes["a2a_extension"].append("on_subscribe_to_task")
if (
self.agent_card.capabilities
and self.agent_card.capabilities.extended_agent_card
):
routes["a2a_extension"].append("handle_authenticated_agent_card")

return routes

async def on_message_send_stream(
Expand All @@ -353,11 +410,59 @@ async def on_message_send_stream(
async for chunk in self.rest_handler.on_message_send_stream(request, context):
yield chunk

async def on_resubscribe_to_task(
async def on_subscribe_to_task(
self,
request: "Request",
context: "ServerCallContext",
) -> AsyncIterator[str]:
"""Handles A2A task resubscription requests via SSE."""
async for chunk in self.rest_handler.on_resubscribe_to_task(request, context):
async for chunk in self.rest_handler.on_subscribe_to_task(request, context):
yield chunk

def __getstate__(self):
"""Serializes AgentCard proto to a dictionary."""
from google.protobuf import json_format
import json

state = self.__dict__.copy()

def _to_dict_if_proto(obj):
if hasattr(obj, "DESCRIPTOR"):
return {
"__protobuf_AgentCard__": json.loads(json_format.MessageToJson(obj))
}
return obj

state["agent_card"] = _to_dict_if_proto(state.get("agent_card"))
if "_tmpl_attrs" in state:
tmpl_attrs = state["_tmpl_attrs"].copy()
tmpl_attrs["agent_card"] = _to_dict_if_proto(tmpl_attrs.get("agent_card"))
tmpl_attrs["extended_agent_card"] = _to_dict_if_proto(
tmpl_attrs.get("extended_agent_card")
)
state["_tmpl_attrs"] = tmpl_attrs

return state

def __setstate__(self, state):
"""Deserializes AgentCard proto from a dictionary."""
from google.protobuf import json_format
from a2a.types import AgentCard

def _from_dict_if_proto(obj):
if isinstance(obj, dict) and "__protobuf_AgentCard__" in obj:
agent_card = AgentCard()
json_format.ParseDict(obj["__protobuf_AgentCard__"], agent_card)
return agent_card
return obj

state["agent_card"] = _from_dict_if_proto(state.get("agent_card"))
if "_tmpl_attrs" in state:
state["_tmpl_attrs"]["agent_card"] = _from_dict_if_proto(
state["_tmpl_attrs"].get("agent_card")
)
state["_tmpl_attrs"]["extended_agent_card"] = _from_dict_if_proto(
state["_tmpl_attrs"].get("extended_agent_card")
)

self.__dict__.update(state)
Loading