diff --git a/src/ModelContextProtocol.Core/AIContentExtensions.cs b/src/ModelContextProtocol.Core/AIContentExtensions.cs index b1ba32bf4..cffc554b3 100644 --- a/src/ModelContextProtocol.Core/AIContentExtensions.cs +++ b/src/ModelContextProtocol.Core/AIContentExtensions.cs @@ -128,7 +128,10 @@ public static class AIContentExtensions { if (sm.Content?.Select(b => b.ToAIContent()).OfType().ToList() is { Count: > 0 } aiContents) { - messages.Add(new ChatMessage(sm.Role is Role.Assistant ? ChatRole.Assistant : ChatRole.User, aiContents)); + ChatRole role = aiContents.All(c => c is FunctionResultContent) ? ChatRole.Tool : + sm.Role is Role.Assistant ? ChatRole.Assistant : + ChatRole.User; + messages.Add(new ChatMessage(role, aiContents)); } } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index 86cefcf10..33470f989 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -260,6 +260,84 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages() Assert.Equal("endTurn", result.StopReason); } + [Fact] + public async Task CreateSamplingHandler_ShouldUseToolRoleForToolResultMessages() + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = [new TextContentBlock { Text = "What is the weather in Paris?" }] + }, + new SamplingMessage + { + Role = Role.Assistant, + Content = [new ToolUseContentBlock + { + Id = "call_weather_123", + Name = "get_weather", + Input = JsonSerializer.SerializeToElement(new { location = "Paris" }, McpJsonUtilities.DefaultOptions) + }] + }, + new SamplingMessage + { + Role = Role.User, + Content = [new ToolResultContentBlock + { + ToolUseId = "call_weather_123", + Content = [new TextContentBlock { Text = "Weather: 18°C, sunny" }] + }] + } + ], + MaxTokens = 100 + }; + + IEnumerable? capturedMessages = null; + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = [new TextContent("The weather in Paris is 18°C and sunny.")] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Callback, ChatOptions?, CancellationToken>((messages, _, _) => capturedMessages = messages.ToList()) + .Returns(expectedResponse); + + var handler = mockChatClient.Object.CreateSamplingHandler(); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.NotNull(capturedMessages); + var messagesList = capturedMessages.ToList(); + Assert.Equal(3, messagesList.Count); + + // First message should be User role (text message) + Assert.Equal(ChatRole.User, messagesList[0].Role); + Assert.IsType(messagesList[0].Contents.Single()); + + // Second message should be Assistant role (tool use) + Assert.Equal(ChatRole.Assistant, messagesList[1].Role); + Assert.IsType(messagesList[1].Contents.Single()); + + // Third message should be Tool role (tool result) - this is the bug fix + Assert.Equal(ChatRole.Tool, messagesList[2].Role); + Assert.IsType(messagesList[2].Contents.Single()); + } + [Fact] public async Task ListToolsAsync_AllToolsReturned() {