diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ac3cfeae5c..353d411d93 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -490,12 +490,23 @@ async def direct_call_tool( ): # The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function. # See https://github.com/modelcontextprotocol/python-sdk#structured-output - if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured: - return structured['result'] - return structured - - mapped = [await self._map_tool_result_part(part) for part in result.content] - return mapped[0] if len(mapped) == 1 else mapped + return_value = ( + structured['result'] + if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured + else structured + ) + else: + mapped = [await self._map_tool_result_part(part) for part in result.content] + return_value = mapped[0] if len(mapped) == 1 else mapped + if result.meta: + # The following branching cannot be tested until FastMCP is updated to version 2.13.1 + # such that the MCP server can generate ToolResult and result.meta can be specified. + # TODO: Add tests for the following branching once FastMCP is updated. + return ( # pragma: no cover + messages.ToolReturn(return_value=return_value, metadata=result.meta) + ) + else: + return return_value async def call_tool( self, @@ -574,16 +585,24 @@ async def list_resource_templates(self) -> list[ResourceTemplate]: return [ResourceTemplate.from_mcp_sdk(t) for t in result.resourceTemplates] @overload - async def read_resource(self, uri: str) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: ... + async def read_resource( + self, uri: str + ) -> ( + str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent] + ): ... @overload async def read_resource( self, uri: Resource - ) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: ... + ) -> ( + str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent] + ): ... async def read_resource( self, uri: str | Resource - ) -> str | messages.BinaryContent | list[str | messages.BinaryContent]: + ) -> ( + str | messages.TextContent | messages.BinaryContent | list[str | messages.TextContent | messages.BinaryContent] + ): """Read the contents of a specific resource by URI. Args: @@ -682,24 +701,29 @@ async def _sampling_callback( async def _map_tool_result_part( self, part: mcp_types.ContentBlock - ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]: + ) -> str | messages.TextContent | messages.BinaryContent | dict[str, Any] | list[Any]: # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values if isinstance(part, mcp_types.TextContent): text = part.text - if text.startswith(('[', '{')): - try: - return pydantic_core.from_json(text) - except ValueError: - pass - return text + if part.meta: + return messages.TextContent(content=text, metadata=part.meta) + else: + if text.startswith(('[', '{')): + try: + return pydantic_core.from_json(text) + except ValueError: + pass + return text elif isinstance(part, mcp_types.ImageContent): - return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) + return messages.BinaryContent( + data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta + ) elif isinstance(part, mcp_types.AudioContent): # NOTE: The FastMCP server doesn't support audio content. # See for more details. return messages.BinaryContent( - data=base64.b64decode(part.data), media_type=part.mimeType + data=base64.b64decode(part.data), media_type=part.mimeType, metadata=part.meta ) # pragma: no cover elif isinstance(part, mcp_types.EmbeddedResource): resource = part.resource @@ -711,12 +735,18 @@ async def _map_tool_result_part( def _get_content( self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents - ) -> str | messages.BinaryContent: + ) -> str | messages.TextContent | messages.BinaryContent: if isinstance(resource, mcp_types.TextResourceContents): - return resource.text + return ( + resource.text + if not resource.meta + else messages.TextContent(content=resource.text, metadata=resource.meta) + ) elif isinstance(resource, mcp_types.BlobResourceContents): return messages.BinaryContent( - data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream' + data=base64.b64decode(resource.blob), + media_type=resource.mimeType or 'application/octet-stream', + metadata=resource.meta, ) else: assert_never(resource) @@ -1178,10 +1208,12 @@ def __eq__(self, value: object, /) -> bool: ToolResult = ( str + | messages.TextContent | messages.BinaryContent + | messages.ToolReturn | dict[str, Any] | list[Any] - | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] + | Sequence[str | messages.TextContent | messages.BinaryContent | dict[str, Any] | list[Any]] ) """The result type of an MCP tool call.""" diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index ac0fb0da6d..6fcb28ff2f 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -466,6 +466,33 @@ def format(self) -> DocumentFormat: raise ValueError(f'Unknown document media type: {media_type}') from e +@dataclass(repr=False) +class TextContent: + """A plain text response from a model with optional metadata.""" + + content: str + """The text content of the response.""" + + _: KW_ONLY + + provider_details: dict[str, Any] | None = None + """Additional data returned by the provider that can't be mapped to standard fields. + + This is used for data that is required to be sent back to APIs, as well as data users may want to access programmatically.""" + + metadata: Any = None + """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" + + kind: Literal['text'] = 'text' + """Type identifier, this is available as a discriminator.""" + + def has_content(self) -> bool: + """Return `True` if the text content is non-empty.""" + return bool(self.content) + + __repr__ = _utils.dataclasses_no_defaults_repr + + @dataclass(init=False, repr=False) class BinaryContent: """Binary content, e.g. an audio or image file.""" @@ -486,6 +513,9 @@ class BinaryContent: - `OpenAIChatModel`, `OpenAIResponsesModel`: `BinaryContent.vendor_metadata['detail']` is used as `detail` setting for images """ + metadata: Any = None + """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" + _identifier: Annotated[str | None, pydantic.Field(alias='identifier', default=None, exclude=True)] = field( compare=False, default=None ) @@ -500,6 +530,7 @@ def __init__( media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str, identifier: str | None = None, vendor_metadata: dict[str, Any] | None = None, + metadata: Any = None, kind: Literal['binary'] = 'binary', # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs. _identifier: str | None = None, @@ -508,6 +539,7 @@ def __init__( self.media_type = media_type self._identifier = identifier or _identifier self.vendor_metadata = vendor_metadata + self.metadata = metadata self.kind = kind @staticmethod @@ -519,6 +551,7 @@ def narrow_type(bc: BinaryContent) -> BinaryContent | BinaryImage: media_type=bc.media_type, identifier=bc.identifier, vendor_metadata=bc.vendor_metadata, + metadata=bc.metadata, ) else: return bc @@ -622,11 +655,17 @@ def __init__( identifier: str | None = None, vendor_metadata: dict[str, Any] | None = None, # Required for inline-snapshot which expects all dataclass `__init__` methods to take all field names as kwargs. + metadata: Any = None, kind: Literal['binary'] = 'binary', _identifier: str | None = None, ): super().__init__( - data=data, media_type=media_type, identifier=identifier or _identifier, vendor_metadata=vendor_metadata + data=data, + media_type=media_type, + identifier=identifier or _identifier, + vendor_metadata=vendor_metadata, + metadata=metadata, + kind=kind, ) if not self.is_image: @@ -657,7 +696,7 @@ class CachePoint: MultiModalContent = ImageUrl | AudioUrl | DocumentUrl | VideoUrl | BinaryContent -UserContent: TypeAlias = str | MultiModalContent | CachePoint +UserContent: TypeAlias = str | TextContent | MultiModalContent | CachePoint @dataclass(repr=False) diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index ff03460904..25f9446800 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -37,6 +37,7 @@ UserPromptPart, VideoUrl, _utils, + messages, usage, ) from pydantic_ai._run_context import RunContext @@ -628,6 +629,8 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) for item in part.content: if isinstance(item, str): content.append({'text': item}) + elif isinstance(item, messages.TextContent): + content.append({'text': item.content}) elif isinstance(item, BinaryContent): format = item.format if item.is_document: diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4da92018fd..d4474756de 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -31,6 +31,7 @@ ModelResponseStreamEvent, RetryPromptPart, SystemPromptPart, + TextContent, TextPart, ThinkingPart, ToolCallPart, @@ -374,6 +375,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion] for item in part.content: if isinstance(item, str): content.append({'text': item}) + elif isinstance(item, TextContent): + content.append({'text': item.content}) elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') content.append( diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 89290ea3ce..c37938464a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -30,6 +30,7 @@ ModelResponseStreamEvent, RetryPromptPart, SystemPromptPart, + TextContent, TextPart, ThinkingPart, ToolCallPart, @@ -601,6 +602,8 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]: for item in part.content: if isinstance(item, str): content.append({'text': item}) + elif isinstance(item, TextContent): + content.append({'text': item.content}) elif isinstance(item, BinaryContent): inline_data_dict: BlobDict = {'data': item.data, 'mime_type': item.media_type} part_dict: PartDict = {'inline_data': inline_data_dict} diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 790b30bec3..2ac49116de 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -30,6 +30,7 @@ ModelResponseStreamEvent, RetryPromptPart, SystemPromptPart, + TextContent, TextPart, ThinkingPart, ToolCallPart, @@ -433,6 +434,8 @@ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage: for item in part.content: if isinstance(item, str): content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore + elif isinstance(item, TextContent): + content.append(ChatCompletionInputMessageChunk(type='text', text=item.content)) # type: ignore elif isinstance(item, ImageUrl): url = ChatCompletionInputURL(url=item.url) # type: ignore content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 10af284ee8..c3a9ecdba4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -40,6 +40,7 @@ PartStartEvent, RetryPromptPart, SystemPromptPart, + TextContent, TextPart, ThinkingPart, ToolCallPart, @@ -914,6 +915,13 @@ async def _map_user_prompt(self, part: UserPromptPart) -> chat.ChatCompletionUse for item in part.content: if isinstance(item, str): content.append(ChatCompletionContentPartTextParam(text=item, type='text')) + elif isinstance(item, TextContent): + content.append( + ChatCompletionContentPartTextParam( + text=item.content, + type='text', + ) + ) elif isinstance(item, ImageUrl): image_url: ImageURL = {'url': item.url} if metadata := item.vendor_metadata: @@ -1754,6 +1762,8 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa for item in part.content: if isinstance(item, str): content.append(responses.ResponseInputTextParam(text=item, type='input_text')) + elif isinstance(item, TextContent): + content.append(responses.ResponseInputTextParam(text=item.content, type='input_text')) elif isinstance(item, BinaryContent): if item.is_image: detail: Literal['auto', 'low', 'high'] = 'auto' diff --git a/tests/mcp_server.py b/tests/mcp_server.py index 8ba9b9997f..b4fe3dea04 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -47,6 +47,36 @@ async def get_weather_forecast(location: str) -> str: return f'The weather in {location} is sunny and 26 degrees Celsius.' +@mcp.tool(structured_output=False, annotations=ToolAnnotations(title='Collatz Conjecture sequence generator')) +async def get_collatz_conjecture(n: int) -> TextContent: + """Generate the Collatz conjecture sequence for a given number. + This tool attaches response metadata. + + Args: + n: The starting number for the Collatz sequence. + Returns: + A list representing the Collatz sequence with attached metadata. + """ + if n <= 0: + raise ValueError('Starting number for the Collatz conjecture must be a positive integer.') + + input_param_n = n # store the original input value + + sequence = [n] + while n != 1: + if n % 2 == 0: + n = n // 2 + else: + n = 3 * n + 1 + sequence.append(n) + + return TextContent( + type='text', + text=str(sequence), + _meta={'pydantic_ai': {'tool': 'collatz_conjecture', 'n': input_param_n, 'length': len(sequence)}}, + ) + + @mcp.tool() async def get_image_resource() -> EmbeddedResource: data = Path(__file__).parent.joinpath('assets/kiwi.png').read_bytes() diff --git a/tests/test_agent.py b/tests/test_agent.py index c912334434..0226ed172d 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3785,6 +3785,7 @@ def test_binary_content_serializable(): 'data': 'SGVsbG8=', 'media_type': 'text/plain', 'vendor_metadata': None, + 'metadata': None, 'kind': 'binary', 'identifier': 'f7ff9e', }, @@ -3800,7 +3801,13 @@ def test_binary_content_serializable(): }, { 'parts': [ - {'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None} + { + 'content': 'success (no tool calls)', + 'id': None, + 'part_kind': 'text', + 'metadata': None, + 'provider_details': None, + } ], 'usage': { 'input_tokens': 56, @@ -3862,7 +3869,13 @@ def test_image_url_serializable_missing_media_type(): }, { 'parts': [ - {'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None} + { + 'content': 'success (no tool calls)', + 'id': None, + 'part_kind': 'text', + 'metadata': None, + 'provider_details': None, + } ], 'usage': { 'input_tokens': 51, @@ -3931,7 +3944,13 @@ def test_image_url_serializable(): }, { 'parts': [ - {'content': 'success (no tool calls)', 'id': None, 'part_kind': 'text', 'provider_details': None} + { + 'content': 'success (no tool calls)', + 'id': None, + 'part_kind': 'text', + 'metadata': None, + 'provider_details': None, + } ], 'usage': { 'input_tokens': 51, @@ -3978,6 +3997,7 @@ def test_tool_return_part_binary_content_serialization(): 'data': 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzgAAAAASUVORK5CYII=', 'media_type': 'image/png', 'vendor_metadata': None, + 'metadata': None, '_identifier': None, 'kind': 'binary', } diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 221ad37548..0cb3acfc17 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -22,6 +22,7 @@ ToolCallPart, ToolReturnPart, UserPromptPart, + messages, ) from pydantic_ai.agent import Agent from pydantic_ai.exceptions import ( @@ -95,7 +96,7 @@ async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(19) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -105,6 +106,21 @@ async def test_stdio_server(run_context: RunContext[int]): assert result == snapshot(32.0) +async def test_tool_response_single_text_part_metadata(run_context: RunContext[int]): + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + async with server: + tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] + assert len(tools) == snapshot(19) + assert tools[2].name == 'get_collatz_conjecture' + assert isinstance(tools[2].description, str) + assert tools[2].description.startswith('Generate the Collatz conjecture sequence for a given number.') + + result = await server.direct_call_tool('get_collatz_conjecture', {'n': 7}) + assert isinstance(result, messages.TextContent) + assert result.content == snapshot('[7, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]') + assert result.metadata == snapshot({'pydantic_ai': {'tool': 'collatz_conjecture', 'n': 7, 'length': 17}}) + + async def test_reentrant_context_manager(): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: @@ -156,7 +172,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: tools = await server.get_tools(run_context) - assert len(tools) == snapshot(18) + assert len(tools) == snapshot(19) async def test_process_tool_call(run_context: RunContext[int]) -> int: diff --git a/tests/test_messages.py b/tests/test_messages.py index d6d9617247..ae806fe962 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -457,6 +457,7 @@ def test_file_part_serialization_roundtrip(): 'data': 'ZmFrZQ==', 'media_type': 'image/jpeg', 'identifier': 'c053ec', + 'metadata': None, 'vendor_metadata': None, 'kind': 'binary', }, @@ -605,6 +606,7 @@ def test_binary_content_validation_with_optional_identifier(): 'data': b'fake', 'vendor_metadata': None, 'kind': 'binary', + 'metadata': None, 'media_type': 'image/jpeg', 'identifier': 'c053ec', } @@ -621,6 +623,7 @@ def test_binary_content_validation_with_optional_identifier(): 'data': b'fake', 'vendor_metadata': None, 'kind': 'binary', + 'metadata': None, 'media_type': 'image/png', 'identifier': 'foo', }