Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ModelContextProtocol.Core/AIContentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ public static class AIContentExtensions
{
if (sm.Content?.Select(b => b.ToAIContent()).OfType<AIContent>().ToList() is { Count: > 0 } aiContents)
{
messages.Add(new ChatMessage(sm.Role is Role.Assistant ? ChatRole.Assistant : ChatRole.User, aiContents));
ChatRole role = aiContents.All(c => c is FunctionResultContent) ? ChatRole.Tool :
sm.Role is Role.Assistant ? ChatRole.Assistant :
ChatRole.User;
messages.Add(new ChatMessage(role, aiContents));
}
}

Expand Down
78 changes: 78 additions & 0 deletions tests/ModelContextProtocol.Tests/Client/McpClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,84 @@ public async Task CreateSamplingHandler_ShouldHandleResourceMessages()
Assert.Equal("endTurn", result.StopReason);
}

[Fact]
public async Task CreateSamplingHandler_ShouldUseToolRoleForToolResultMessages()
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
var requestParams = new CreateMessageRequestParams
{
Messages =
[
new SamplingMessage
{
Role = Role.User,
Content = [new TextContentBlock { Text = "What is the weather in Paris?" }]
},
new SamplingMessage
{
Role = Role.Assistant,
Content = [new ToolUseContentBlock
{
Id = "call_weather_123",
Name = "get_weather",
Input = JsonSerializer.SerializeToElement(new { location = "Paris" }, McpJsonUtilities.DefaultOptions)
}]
},
new SamplingMessage
{
Role = Role.User,
Content = [new ToolResultContentBlock
{
ToolUseId = "call_weather_123",
Content = [new TextContentBlock { Text = "Weather: 18°C, sunny" }]
}]
}
],
MaxTokens = 100
};

IEnumerable<ChatMessage>? capturedMessages = null;
var cancellationToken = CancellationToken.None;
var expectedResponse = new[] {
new ChatResponseUpdate
{
ModelId = "test-model",
FinishReason = ChatFinishReason.Stop,
Role = ChatRole.Assistant,
Contents = [new TextContent("The weather in Paris is 18°C and sunny.")]
}
}.ToAsyncEnumerable();

mockChatClient
.Setup(client => client.GetStreamingResponseAsync(It.IsAny<IEnumerable<ChatMessage>>(), It.IsAny<ChatOptions>(), cancellationToken))
.Callback<IEnumerable<ChatMessage>, ChatOptions?, CancellationToken>((messages, _, _) => capturedMessages = messages.ToList())
.Returns(expectedResponse);

var handler = mockChatClient.Object.CreateSamplingHandler();

// Act
var result = await handler(requestParams, Mock.Of<IProgress<ProgressNotificationValue>>(), cancellationToken);

// Assert
Assert.NotNull(result);
Assert.NotNull(capturedMessages);
var messagesList = capturedMessages.ToList();
Assert.Equal(3, messagesList.Count);

// First message should be User role (text message)
Assert.Equal(ChatRole.User, messagesList[0].Role);
Assert.IsType<TextContent>(messagesList[0].Contents.Single());

// Second message should be Assistant role (tool use)
Assert.Equal(ChatRole.Assistant, messagesList[1].Role);
Assert.IsType<FunctionCallContent>(messagesList[1].Contents.Single());

// Third message should be Tool role (tool result) - this is the bug fix
Assert.Equal(ChatRole.Tool, messagesList[2].Role);
Assert.IsType<FunctionResultContent>(messagesList[2].Contents.Single());
}

[Fact]
public async Task ListToolsAsync_AllToolsReturned()
{
Expand Down
Loading