From 42a1a0cf063a2a0bf7485e9c8eee31af0fdd585b Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Fri, 27 Mar 2026 05:06:25 -0700 Subject: [PATCH] feat: update sdk to support a2a 1.0 PiperOrigin-RevId: 890388363 --- vertexai/_genai/_agent_engines_utils.py | 6 +- vertexai/_genai/agent_engines.py | 9 +- vertexai/agent_engines/_agent_engines.py | 38 ++-- .../reasoning_engines/templates/a2a.py | 173 ++++++++++++++---- 4 files changed, 172 insertions(+), 54 deletions(-) diff --git a/vertexai/_genai/_agent_engines_utils.py b/vertexai/_genai/_agent_engines_utils.py index ba388f4d38..f55a91d472 100644 --- a/vertexai/_genai/_agent_engines_utils.py +++ b/vertexai/_genai/_agent_engines_utils.py @@ -632,9 +632,9 @@ def _generate_class_methods_spec_or_raise( class_method = _to_proto(schema_dict) class_method[_MODE_KEY_IN_SCHEMA] = mode if hasattr(agent, "agent_card"): - class_method[_A2A_AGENT_CARD] = getattr( - agent, "agent_card" - ).model_dump_json() + class_method[_A2A_AGENT_CARD] = json_format.MessageToJson( + getattr(agent, "agent_card") + ) class_methods_spec.append(class_method) return class_methods_spec diff --git a/vertexai/_genai/agent_engines.py b/vertexai/_genai/agent_engines.py index cf6e96c943..fa5d1a25e4 100644 --- a/vertexai/_genai/agent_engines.py +++ b/vertexai/_genai/agent_engines.py @@ -1834,10 +1834,13 @@ def _create_config( agent_card = getattr(agent, "agent_card") if agent_card: try: - agent_engine_spec["agent_card"] = agent_card.model_dump( - exclude_none=True + from google.protobuf import json_format + import json + + agent_engine_spec["agent_card"] = json.loads( + json_format.MessageToJson(agent_card) ) - except TypeError as e: + except Exception as e: raise ValueError( f"Failed to convert agent card to dict (serialization error): {e}" ) from e diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py index dd4e35269d..b7d3a80ca9 100644 --- a/vertexai/agent_engines/_agent_engines.py +++ b/vertexai/agent_engines/_agent_engines.py @@ -119,15 +119,18 @@ try: from a2a.types import ( AgentCard, - TransportProtocol, + AgentInterface, Message, TaskIdParams, TaskQueryParams, ) + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT from a2a.client import ClientConfig, ClientFactory AgentCard = AgentCard + AgentInterface = AgentInterface TransportProtocol = TransportProtocol + PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_CURRENT Message = Message ClientConfig = ClientConfig ClientFactory = ClientFactory @@ -135,7 +138,9 @@ TaskQueryParams = TaskQueryParams except (ImportError, AttributeError): AgentCard = None + AgentInterface = None TransportProtocol = None + PROTOCOL_VERSION_CURRENT = None Message = None ClientConfig = None ClientFactory = None @@ -1735,17 +1740,20 @@ async def _method(self, **kwargs) -> Any: a2a_agent_card = AgentCard(**json.loads(agent_card)) # A2A + AE integration currently only supports Rest API. - if ( - a2a_agent_card.preferred_transport - and a2a_agent_card.preferred_transport != TransportProtocol.http_json - ): + if a2a_agent_card.supported_interfaces and a2a_agent_card.supported_interfaces[0].protocol_binding != TransportProtocol.HTTP_JSON: raise ValueError( - "Only HTTP+JSON is supported for preferred transport on agent card " + "Only HTTP+JSON is supported for primary interface on agent card " ) - # Set preferred transport to HTTP+JSON if not set. - if not hasattr(a2a_agent_card, "preferred_transport"): - a2a_agent_card.preferred_transport = TransportProtocol.http_json + # Set primary interface to HTTP+JSON if not set. + if not a2a_agent_card.supported_interfaces: + a2a_agent_card.supported_interfaces = [] + a2a_agent_card.supported_interfaces.append( + AgentInterface( + protocol_binding=TransportProtocol.HTTP_JSON, + protocol_version=PROTOCOL_VERSION_CURRENT, + ) + ) # AE cannot support streaming yet. Turn off streaming for now. if a2a_agent_card.capabilities and a2a_agent_card.capabilities.streaming: @@ -1759,12 +1767,13 @@ async def _method(self, **kwargs) -> Any: # agent_card is set on the class_methods before set_up is invoked. # Ensure that the agent_card url is set correctly before the client is created. - a2a_agent_card.url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a" + url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a" + a2a_agent_card.supported_interfaces[0].url = url # Using a2a client, inject the auth token from the global config. config = ClientConfig( supported_transports=[ - TransportProtocol.http_json, + TransportProtocol.HTTP_JSON, ], use_client_preference=True, httpx_client=httpx.AsyncClient( @@ -1977,9 +1986,10 @@ def _generate_class_methods_spec_or_raise( class_method[_MODE_KEY_IN_SCHEMA] = mode # A2A agent card is a special case, when running in A2A mode, if hasattr(agent_engine, "agent_card"): - class_method[_A2A_AGENT_CARD] = getattr( - agent_engine, "agent_card" - ).model_dump_json() + from google.protobuf import json_format + class_method[_A2A_AGENT_CARD] = json_format.MessageToJson( + getattr(agent_engine, "agent_card") + ) class_methods_spec.append(class_method) return class_methods_spec diff --git a/vertexai/preview/reasoning_engines/templates/a2a.py b/vertexai/preview/reasoning_engines/templates/a2a.py index 724e2af41e..3d0a35c260 100644 --- a/vertexai/preview/reasoning_engines/templates/a2a.py +++ b/vertexai/preview/reasoning_engines/templates/a2a.py @@ -87,7 +87,8 @@ def create_agent_card( provided. """ # pylint: disable=g-import-not-at-top - from a2a.types import AgentCard, AgentCapabilities, TransportProtocol + from a2a.types import AgentCard, AgentCapabilities, AgentInterface + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT # Check if a dictionary was provided. if agent_card: @@ -98,14 +99,20 @@ def create_agent_card( return AgentCard( name=agent_name, description=description, - url="http://localhost:9999/", version="1.0.0", default_input_modes=default_input_modes or ["text/plain"], default_output_modes=default_output_modes or ["application/json"], - capabilities=AgentCapabilities(streaming=streaming), + capabilities=AgentCapabilities( + streaming=streaming, extended_agent_card=True + ), skills=skills, - preferred_transport=TransportProtocol.http_json, # Http Only. - supports_authenticated_extended_card=True, + supported_interfaces=[ + AgentInterface( + url="http://localhost:9999/", + protocol_binding=TransportProtocol.HTTP_JSON, + protocol_version=PROTOCOL_VERSION_CURRENT, + ) + ], ) # Raise an error if insufficient data is provided. @@ -162,6 +169,21 @@ async def cancel( ) +def _is_version_enabled(agent_card: "AgentCard", version: str) -> bool: + """Checks if a specific version compatibility should be enabled for the A2aAgent.""" + from a2a.utils.constants import TransportProtocol + + if not agent_card.supported_interfaces: + return False + for interface in agent_card.supported_interfaces: + if ( + interface.protocol_version == version + and interface.protocol_binding == TransportProtocol.HTTP_JSON + ): + return True + return False + + class A2aAgent: """A class to initialize and set up an Agent-to-Agent application.""" @@ -181,14 +203,15 @@ def __init__( """Initializes the A2A agent.""" # pylint: disable=g-import-not-at-top from google.cloud.aiplatform import initializer - from a2a.types import TransportProtocol + from a2a.utils.constants import TransportProtocol if ( - agent_card.preferred_transport - and agent_card.preferred_transport != TransportProtocol.http_json + agent_card.supported_interfaces + and agent_card.supported_interfaces[0].protocol_binding + != TransportProtocol.HTTP_JSON ): raise ValueError( - "Only HTTP+JSON is supported for preferred transport on agent card " + "Only HTTP+JSON is supported for the primary interface on agent card " ) self._tmpl_attrs: dict[str, Any] = { @@ -244,7 +267,21 @@ def set_up(self): agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "test-agent-engine") version = "v1beta1" - self.agent_card.url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a" + new_url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a" + if not self.agent_card.supported_interfaces: + from a2a.types import AgentInterface + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT + + self.agent_card.supported_interfaces.append( + AgentInterface( + url=new_url, + protocol_binding=TransportProtocol.HTTP_JSON, + protocol_version=PROTOCOL_VERSION_CURRENT, + ) + ) + else: + # primary interface must be HTTP+JSON + self.agent_card.supported_interfaces[0].url = new_url self._tmpl_attrs["agent_card"] = self.agent_card # Create the agent executor if a builder is provided. @@ -286,17 +323,30 @@ def set_up(self): # a2a_rest_adapter is used to register the A2A API routes in the # Reasoning Engine API router. - self.a2a_rest_adapter = RESTAdapter( - agent_card=self.agent_card, - http_handler=self._tmpl_attrs.get("request_handler"), - extended_agent_card=self._tmpl_attrs.get("extended_agent_card"), - ) + if _is_version_enabled(self.agent_card, "1.0"): + self.a2a_rest_adapter = RESTAdapter( + agent_card=self.agent_card, + http_handler=self._tmpl_attrs.get("request_handler"), + extended_agent_card=self._tmpl_attrs.get("extended_agent_card"), + ) - # rest_handler is used to handle the A2A API requests. - self.rest_handler = RESTHandler( - agent_card=self.agent_card, - request_handler=self._tmpl_attrs.get("request_handler"), - ) + # rest_handler is used to handle the A2A API requests. + self.rest_handler = RESTHandler( + agent_card=self.agent_card, + request_handler=self._tmpl_attrs.get("request_handler"), + ) + + # v0.3 handlers will be deprecated in the future. + if _is_version_enabled(self.agent_card, "0.3"): + from a2a.compat.v0_3.rest_adapter import REST03Adapter + from a2a.compat.v0_3.rest_handler import REST03Handler + import functools + + self.v03_rest_adapter = REST03Adapter( + agent_card=self.agent_card, + http_handler=self._tmpl_attrs.get("request_handler"), + extended_agent_card=self._tmpl_attrs.get("extended_agent_card"), + ) async def on_message_send( self, @@ -330,18 +380,25 @@ async def handle_authenticated_agent_card( def register_operations(self) -> Dict[str, List[str]]: """Registers the operations of the A2A Agent.""" - routes = { - "a2a_extension": [ - "on_message_send", - "on_get_task", - "on_cancel_task", - ] - } - if self.agent_card.capabilities and self.agent_card.capabilities.streaming: - routes["a2a_extension"].append("on_message_send_stream") - routes["a2a_extension"].append("on_resubscribe_to_task") - if self.agent_card.supports_authenticated_extended_card: - routes["a2a_extension"].append("handle_authenticated_agent_card") + routes = {"a2a_extension": []} + + if _is_version_enabled(self.agent_card, "1.0"): + routes["a2a_extension"].extend( + [ + "on_message_send", + "on_get_task", + "on_cancel_task", + ] + ) + if self.agent_card.capabilities and self.agent_card.capabilities.streaming: + routes["a2a_extension"].append("on_message_send_stream") + routes["a2a_extension"].append("on_subscribe_to_task") + if ( + self.agent_card.capabilities + and self.agent_card.capabilities.extended_agent_card + ): + routes["a2a_extension"].append("handle_authenticated_agent_card") + return routes async def on_message_send_stream( @@ -353,11 +410,59 @@ async def on_message_send_stream( async for chunk in self.rest_handler.on_message_send_stream(request, context): yield chunk - async def on_resubscribe_to_task( + async def on_subscribe_to_task( self, request: "Request", context: "ServerCallContext", ) -> AsyncIterator[str]: """Handles A2A task resubscription requests via SSE.""" - async for chunk in self.rest_handler.on_resubscribe_to_task(request, context): + async for chunk in self.rest_handler.on_subscribe_to_task(request, context): yield chunk + + def __getstate__(self): + """Serializes AgentCard proto to a dictionary.""" + from google.protobuf import json_format + import json + + state = self.__dict__.copy() + + def _to_dict_if_proto(obj): + if hasattr(obj, "DESCRIPTOR"): + return { + "__protobuf_AgentCard__": json.loads(json_format.MessageToJson(obj)) + } + return obj + + state["agent_card"] = _to_dict_if_proto(state.get("agent_card")) + if "_tmpl_attrs" in state: + tmpl_attrs = state["_tmpl_attrs"].copy() + tmpl_attrs["agent_card"] = _to_dict_if_proto(tmpl_attrs.get("agent_card")) + tmpl_attrs["extended_agent_card"] = _to_dict_if_proto( + tmpl_attrs.get("extended_agent_card") + ) + state["_tmpl_attrs"] = tmpl_attrs + + return state + + def __setstate__(self, state): + """Deserializes AgentCard proto from a dictionary.""" + from google.protobuf import json_format + from a2a.types import AgentCard + + def _from_dict_if_proto(obj): + if isinstance(obj, dict) and "__protobuf_AgentCard__" in obj: + agent_card = AgentCard() + json_format.ParseDict(obj["__protobuf_AgentCard__"], agent_card) + return agent_card + return obj + + state["agent_card"] = _from_dict_if_proto(state.get("agent_card")) + if "_tmpl_attrs" in state: + state["_tmpl_attrs"]["agent_card"] = _from_dict_if_proto( + state["_tmpl_attrs"].get("agent_card") + ) + state["_tmpl_attrs"]["extended_agent_card"] = _from_dict_if_proto( + state["_tmpl_attrs"].get("extended_agent_card") + ) + + self.__dict__.update(state)