From a069f6492a25d0abc8a4fec58cc09a2265f15408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Cant=C3=BA?= Date: Wed, 10 Dec 2025 18:49:20 -0600 Subject: [PATCH 1/6] Add UseMcpClient --- .../McpChatClientBuilderExtensions.cs | 230 ++++++++++++ .../ModelContextProtocol.csproj | 1 + .../HttpServerIntegrationTests.cs | 6 +- .../SseServerIntegrationTestFixture.cs | 69 +++- .../SseServerIntegrationTests.cs | 2 +- .../StatelessServerIntegrationTests.cs | 2 +- .../StreamableHttpServerIntegrationTests.cs | 2 +- .../UseMcpClientWithTestSseServerTests.cs | 332 ++++++++++++++++++ .../Program.cs | 13 +- 9 files changed, 634 insertions(+), 23 deletions(-) create mode 100644 src/ModelContextProtocol/McpChatClientBuilderExtensions.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs diff --git a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs new file mode 100644 index 000000000..7d04bfb00 --- /dev/null +++ b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs @@ -0,0 +1,230 @@ +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Client; +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace ModelContextProtocol; + +/// +/// Extension methods for adding MCP client support to chat clients. +/// +public static class McpChatClientBuilderExtensions +{ + /// + /// Adds a chat client to the chat client pipeline that creates an for each + /// in and augments it with the tools from MCP servers as instances. + /// + /// The to configure. + /// The to use, or to create a new instance. + /// The to use, or to resolve from services. + /// The for method chaining. + /// + /// + /// When a HostedMcpServerTool is encountered in the tools collection, the client + /// connects to the MCP server, retrieves available tools, and expands them into callable AI functions. + /// Connections are cached by server address to avoid redundant connections. + /// + /// + /// Use this method as an alternative when working with chat providers that don't have built-in support for hosted MCP servers. + /// + /// + public static ChatClientBuilder UseMcpClient( + this ChatClientBuilder builder, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + { + return builder.Use((innerClient, services) => + { + loggerFactory ??= (ILoggerFactory)services.GetService(typeof(ILoggerFactory))!; + var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory); + return chatClient; + }); + } + + private class McpChatClient : DelegatingChatClient + { + private readonly ILoggerFactory? _loggerFactory; + private readonly ILogger _logger; + private readonly HttpClient _httpClient; + private readonly bool _ownsHttpClient; + private ConcurrentDictionary>? _mcpClientTasks = null; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying , or the next instance in a chain of clients. + /// An optional to use when connecting to MCP servers. If not provided, a new instance will be created. + /// An to use for logging information about function invocation. + public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) + : base(innerClient) + { + _loggerFactory = loggerFactory; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _httpClient = httpClient ?? new HttpClient(); + _ownsHttpClient = httpClient is null; + } + + /// + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + if (options?.Tools is { Count: > 0 }) + { + var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false); + options = options.Clone(); + options.Tools = downstreamTools; + } + + return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + } + + /// + public override async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (options?.Tools is { Count: > 0 }) + { + var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false); + options = options.Clone(); + options.Tools = downstreamTools; + } + + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } + + private async Task?> BuildDownstreamAIToolsAsync(IList? inputTools, CancellationToken cancellationToken) + { + List? downstreamTools = null; + foreach (var tool in inputTools ?? []) + { + if (tool is not HostedMcpServerTool mcpTool) + { + // For other tools, we want to keep them in the list of tools. + downstreamTools ??= new List(); + downstreamTools.Add(tool); + continue; + } + + if (!Uri.TryCreate(mcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) || + (parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps)) + { + throw new InvalidOperationException( + $"MCP server address must be an absolute HTTP or HTTPS URI. Invalid address: '{mcpTool.ServerAddress}'"); + } + + // List all MCP functions from the specified MCP server. + // This will need some caching in a real-world scenario to avoid repeated calls. + var mcpClient = await CreateMcpClientAsync(parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false); + var mcpFunctions = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + + // Add the listed functions to our list of tools we'll pass to the inner client. + foreach (var mcpFunction in mcpFunctions) + { + if (mcpTool.AllowedTools is not null && !mcpTool.AllowedTools.Contains(mcpFunction.Name)) + { + _logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name); + continue; + } + + downstreamTools ??= new List(); + switch (mcpTool.ApprovalMode) + { + case HostedMcpServerToolAlwaysRequireApprovalMode alwaysRequireApproval: + downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); + break; + case HostedMcpServerToolNeverRequireApprovalMode neverRequireApproval: + downstreamTools.Add(mcpFunction); + break; + case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.AlwaysRequireApprovalToolNames?.Contains(mcpFunction.Name) is true: + downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); + break; + case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpFunction.Name) is true: + downstreamTools.Add(mcpFunction); + break; + default: + // Default to always require approval if no specific mode is set. + downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); + break; + } + } + } + + return downstreamTools; + } + + /// + protected override void Dispose(bool disposing) + { + if (disposing) + { + // Dispose of the HTTP client if it was created by this client. + if (_ownsHttpClient) + { + _httpClient?.Dispose(); + } + + if (_mcpClientTasks is not null) + { + // Dispose of all cached MCP clients. + foreach (var clientTask in _mcpClientTasks.Values) + { +#if NETSTANDARD2_0 + if (clientTask.Status == TaskStatus.RanToCompletion) +#else + if (clientTask.IsCompletedSuccessfully) +#endif + { + _ = clientTask.Result.DisposeAsync(); + } + } + + _mcpClientTasks.Clear(); + } + } + + base.Dispose(disposing); + } + + private Task CreateMcpClientAsync(Uri serverAddress, string serverName, string? authorizationToken) + { + if (_mcpClientTasks is null) + { + _mcpClientTasks = new ConcurrentDictionary>(StringComparer.OrdinalIgnoreCase); + } + + // Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token. + // Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently. + return _mcpClientTasks.GetOrAdd(serverAddress.ToString(), _ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None)); + } + + private async Task CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken) + { + var serverAddressKey = serverAddress.ToString(); + try + { + var transport = new HttpClientTransport(new HttpClientTransportOptions + { + Endpoint = serverAddress, + Name = serverName, + AdditionalHeaders = authorizationToken is not null + // Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available. + ? new Dictionary() { { "Authorization", $"Bearer {authorizationToken}" } } + : null, + }, _httpClient, _loggerFactory); + + return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch + { + // Remove the failed task from cache so subsequent requests can retry + _mcpClientTasks?.TryRemove(serverAddressKey, out _); + throw; + } + } + } +} diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index b69108ab2..fe394a056 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -23,6 +23,7 @@ + diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index ce4f3b56a..562bedae8 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -4,11 +4,11 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture +public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture { - protected readonly SseServerIntegrationTestFixture _fixture; + protected readonly SseServerWithXunitLoggerFixture _fixture; - public HttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) + public HttpServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper) : base(testOutputHelper) { _fixture = fixture; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index c382c4385..7044acd30 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -7,23 +7,18 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class SseServerIntegrationTestFixture : IAsyncDisposable +public abstract class SseServerIntegrationTestFixture : IAsyncDisposable { private readonly KestrelInMemoryTransport _inMemoryTransport = new(); - private readonly Task _serverTask; private readonly CancellationTokenSource _stopCts = new(); - // XUnit's ITestOutputHelper is created per test, while this fixture is used for - // multiple tests, so this dispatches the output to the current test. - private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); - private HttpClientTransportOptions DefaultTransportOptions { get; set; } = new() { Endpoint = new("http://localhost:5000/"), }; - public SseServerIntegrationTestFixture() + protected SseServerIntegrationTestFixture() { var socketsHttpHandler = new SocketsHttpHandler { @@ -39,8 +34,10 @@ public SseServerIntegrationTestFixture() BaseAddress = new("http://localhost:5000/"), }; - _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token); + _serverTask = Program.MainAsync([], CreateLoggerProvider(), _inMemoryTransport, _stopCts.Token); } + + protected abstract ILoggerProvider CreateLoggerProvider(); public HttpClient HttpClient { get; } @@ -53,21 +50,17 @@ public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerF TestContext.Current.CancellationToken); } - public void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions) + public virtual void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions) { - _delegatingTestOutputHelper.CurrentTestOutputHelper = output; DefaultTransportOptions = clientTransportOptions; } - public void TestCompleted() + public virtual void TestCompleted() { - _delegatingTestOutputHelper.CurrentTestOutputHelper = null; } - public async ValueTask DisposeAsync() + public virtual async ValueTask DisposeAsync() { - _delegatingTestOutputHelper.CurrentTestOutputHelper = null; - HttpClient.Dispose(); _stopCts.Cancel(); @@ -82,3 +75,49 @@ public async ValueTask DisposeAsync() _stopCts.Dispose(); } } + +/// +/// SSE server fixture that routes logs to xUnit test output. +/// +public class SseServerWithXunitLoggerFixture : SseServerIntegrationTestFixture +{ + // XUnit's ITestOutputHelper is created per test, while this fixture is used for + // multiple tests, so this dispatches the output to the current test. + private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); + + protected override ILoggerProvider CreateLoggerProvider() + => new XunitLoggerProvider(_delegatingTestOutputHelper); + + public override void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions) + { + _delegatingTestOutputHelper.CurrentTestOutputHelper = output; + base.Initialize(output, clientTransportOptions); + } + + public override void TestCompleted() + { + _delegatingTestOutputHelper.CurrentTestOutputHelper = null; + base.TestCompleted(); + } + + public override async ValueTask DisposeAsync() + { + _delegatingTestOutputHelper.CurrentTestOutputHelper = null; + await base.DisposeAsync(); + } +} + +/// +/// Fixture for tests that need to inspect server logs using MockLoggerProvider. +/// Use for tests that just need xUnit output. +/// +public class SseServerWithMockLoggerFixture : SseServerIntegrationTestFixture +{ + private readonly MockLoggerProvider _mockLoggerProvider = new(); + + protected override ILoggerProvider CreateLoggerProvider() + => _mockLoggerProvider; + + public IEnumerable<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> ServerLogs + => _mockLoggerProvider.LogMessages; +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index eb7db0110..5339235af 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -4,7 +4,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) +public class SseServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper) : HttpServerIntegrationTests(fixture, testOutputHelper) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index 2ce63a1bc..6937e4be6 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -2,7 +2,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) +public class StatelessServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper) : StreamableHttpServerIntegrationTests(fixture, testOutputHelper) { protected override HttpClientTransportOptions ClientTransportOptions => new() diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index b2b0b5499..63c6dc77b 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -3,7 +3,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) +public class StreamableHttpServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper) : HttpServerIntegrationTests(fixture, testOutputHelper) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs new file mode 100644 index 000000000..f4a3ae023 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs @@ -0,0 +1,332 @@ +using System.Runtime.CompilerServices; +using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; +using ModelContextProtocol.Tests.Utils; +using Moq; +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class UseMcpClientWithTestSseServerTests : LoggedTest, IClassFixture +{ + private readonly HttpClientTransportOptions _transportOptions; + private readonly SseServerWithMockLoggerFixture _fixture; + + public UseMcpClientWithTestSseServerTests(SseServerWithMockLoggerFixture fixture, ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + _transportOptions = new HttpClientTransportOptions() + { + Endpoint = new("http://localhost:5000/sse"), + Name = "TestSseServer", + }; + + _fixture = fixture; + _fixture.Initialize(testOutputHelper, _transportOptions); + } + + public override void Dispose() + { + _fixture.TestCompleted(); + base.Dispose(); + } + + private sealed class CallbackState + { + public ChatOptions? CapturedOptions { get; set; } + } + + private IChatClient CreateTestChatClient(out CallbackState callbackState) + { + var state = new CallbackState(); + + var mockInnerClient = new Mock(); + mockInnerClient + .Setup(c => c.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .Callback, ChatOptions?, CancellationToken>( + (msgs, opts, ct) => state.CapturedOptions = opts) + .ReturnsAsync(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Dummy response")])); + + mockInnerClient + .Setup(c => c.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())) + .Callback, ChatOptions?, CancellationToken>( + (msgs, opts, ct) => state.CapturedOptions = opts) + .Returns(GetStreamingResponseAsync()); + + callbackState = state; + return mockInnerClient.Object.AsBuilder() + .UseMcpClient(_fixture.HttpClient, LoggerFactory) + .Build(); + + static async IAsyncEnumerable GetStreamingResponseAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + yield return new ChatResponseUpdate(ChatRole.Assistant, "Dummy response"); + } + } + + private async Task GetResponseAsync(IChatClient client, ChatOptions options, bool streaming) + { + if (streaming) + { + await foreach (var _ in client.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken)) + { } + } + else + { + _ = await client.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + } + } + + [Theory] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task UseMcpClient_ShouldProduceTools(bool streaming, bool useUrl) + { + // Arrange + IChatClient sut = CreateTestChatClient(out var callbackState); + var mcpTool = useUrl ? + new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint) : + new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint.ToString()); + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + await GetResponseAsync(sut, options, streaming); + + // Assert + Assert.NotNull(callbackState.CapturedOptions); + Assert.NotNull(callbackState.CapturedOptions.Tools); + var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(3, toolNames.Count); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_DoesNotConflictWithRegularTools(bool streaming) + { + // Arrange + IChatClient sut = CreateTestChatClient(out var callbackState); + var regularTool = AIFunctionFactory.Create(() => "regular tool result", "RegularTool"); + var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint); + var options = new ChatOptions + { + Tools = + [ + regularTool, + mcpTool + ] + }; + + // Act + await GetResponseAsync(sut, options, streaming); + + // Assert + Assert.NotNull(callbackState.CapturedOptions); + Assert.NotNull(callbackState.CapturedOptions.Tools); + var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(4, toolNames.Count); + Assert.Contains("RegularTool", toolNames); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_AuthorizationTokenHeaderFlowsCorrectly(bool streaming) + { + // Arrange + const string testToken = "test-bearer-token-12345"; + IChatClient sut = CreateTestChatClient(out var callbackState); + var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint) + { + AuthorizationToken = testToken + }; + var options = new ChatOptions + { + Tools = [mcpTool] + }; + + // Act + await GetResponseAsync(sut, options, streaming); + + // Assert + Assert.NotNull(callbackState.CapturedOptions); + Assert.NotNull(callbackState.CapturedOptions.Tools); + var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(3, toolNames.Count); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + // We set TestSseServer to log IHeaderDictionary as json. + Assert.Contains(_fixture.ServerLogs, log => log.Message.Contains(@"""Authorization"":[""Bearer test-bearer-token-12345""]")); + } + + public static IEnumerable UseMcpClient_ApprovalsWorkCorrectly_TestData() + { + string[] allToolNames = ["echo", "echoSessionId", "sampleLLM"]; + foreach (var streaming in new[] { false, true }) + { + yield return new object?[] { streaming, new HostedMcpServerToolNeverRequireApprovalMode(), (string[])[], allToolNames }; + yield return new object?[] { streaming, new HostedMcpServerToolAlwaysRequireApprovalMode(), allToolNames, (string[])[] }; + yield return new object?[] { streaming, null, allToolNames, (string[])[] }; + // Specific mode with empty lists - all tools should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode([], []), allToolNames, (string[])[] }; + // Specific mode with one tool always requiring approval - the other two should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode(["echo"], []), allToolNames, (string[])[] }; + // Specific mode with one tool never requiring approval - the other two should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode([], ["echo"]), (string[])["echoSessionId", "sampleLLM"], (string[])["echo"] }; + } + } + + [Theory] + [MemberData(nameof(UseMcpClient_ApprovalsWorkCorrectly_TestData))] + public async Task UseMcpClient_ApprovalsWorkCorrectly( + bool streaming, + HostedMcpServerToolApprovalMode? approvalMode, + string[] expectedApprovalRequiredAIFunctions, + string[] expectedNormalAIFunctions) + { + // Arrange + IChatClient sut = CreateTestChatClient(out var callbackState); + var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint) + { + ApprovalMode = approvalMode + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + await GetResponseAsync(sut, options, streaming); + + // Assert + Assert.NotNull(callbackState.CapturedOptions); + Assert.NotNull(callbackState.CapturedOptions.Tools); + Assert.Equal(3, callbackState.CapturedOptions.Tools.Count); + + var toolsRequiringApproval = callbackState.CapturedOptions.Tools + .Where(t => t is ApprovalRequiredAIFunction).Select(t => t.Name); + + var toolsNotRequiringApproval = callbackState.CapturedOptions.Tools + .Where(t => t is not ApprovalRequiredAIFunction).Select(t => t.Name); + + Assert.Equivalent(expectedApprovalRequiredAIFunctions, toolsRequiringApproval); + Assert.Equivalent(expectedNormalAIFunctions, toolsNotRequiringApproval); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_SupportsConnectorIdAsServer(bool streaming) + { + // Arrange + IChatClient sut = CreateTestChatClient(out var callbackState); + const string connectorId = "test-connector-123"; + var mcpTool = new HostedMcpServerTool(connectorId, _transportOptions.Endpoint); + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + await GetResponseAsync(sut, options, streaming); + + // Assert + Assert.NotNull(callbackState.CapturedOptions); + Assert.NotNull(callbackState.CapturedOptions.Tools); + var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(3, toolNames.Count); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_ThrowsInvalidOperationException_WhenServerAddressIsInvalid(bool streaming) + { + // Arrange + IChatClient sut = CreateTestChatClient(out _); + var mcpTool = new HostedMcpServerTool("test-server", "test-connector-123"); + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => GetResponseAsync(sut, options, streaming)); + Assert.Contains("test-connector-123", exception.Message); + } + + [Theory] + [InlineData(false, null, (string[])["echo", "echoSessionId", "sampleLLM"])] + [InlineData(true, null, (string[])["echo", "echoSessionId", "sampleLLM"])] + [InlineData(false, (string[])["echo"], (string[])["echo"])] + [InlineData(true, (string[])["echo"], (string[])["echo"])] + [InlineData(false, (string[])[], (string[])[])] + [InlineData(true, (string[])[], (string[])[])] + public async Task UseMcpClient_AllowedTools_FiltersCorrectly(bool streaming, string[]? allowedTools, string[] expectedTools) + { + // Arrange + IChatClient sut = CreateTestChatClient(out var callbackState); + var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint) + { + AllowedTools = allowedTools + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + await GetResponseAsync(sut, options, streaming); + + // Assert + Assert.NotNull(callbackState.CapturedOptions); + if (expectedTools.Length == 0) + { + // When all MCP tools are filtered out and no other tools exist, the Tools collection should be null + Assert.Null(callbackState.CapturedOptions.Tools); + } + else + { + Assert.NotNull(callbackState.CapturedOptions.Tools); + var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(expectedTools.Length, toolNames.Count); + Assert.Equivalent(expectedTools, toolNames); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_CachesClientForSameServerAddress(bool streaming) + { + // Arrange + IChatClient sut = CreateTestChatClient(out var callbackState); + var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint); + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act - First call + await GetResponseAsync(sut, options, streaming); + + // Assert - First call should succeed and produce tools + Assert.NotNull(callbackState.CapturedOptions); + Assert.NotNull(callbackState.CapturedOptions.Tools); + var firstCallToolCount = callbackState.CapturedOptions.Tools.Count; + Assert.Equal(3, firstCallToolCount); + + // Act - Second call with same server address (should use cached client) + await GetResponseAsync(sut, options, streaming); + + // Assert - Second call should also succeed with same tools + Assert.NotNull(callbackState.CapturedOptions); + Assert.NotNull(callbackState.CapturedOptions.Tools); + var secondCallToolCount = callbackState.CapturedOptions.Tools.Count; + Assert.Equal(3, secondCallToolCount); + Assert.Equal(firstCallToolCount, secondCallToolCount); + + // Verify the tools are the same + var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } +} diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index a29c30587..a6f55b345 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -38,7 +38,7 @@ private static void ConfigureOptions(McpServerOptions options) Console.WriteLine("Registering handlers."); - #region Helped method + #region Helper method static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) { return new CreateMessageRequestParams @@ -421,7 +421,16 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide } builder.Services.AddMcpServer(ConfigureOptions) - .WithHttpTransport(); + .WithHttpTransport(httpOptions => + { + // Log headers for testing purposes + httpOptions.ConfigureSessionOptions = (httpContext, serverOptions, ct) => + { + var logger = httpContext.RequestServices.GetRequiredService>(); + logger.LogInformation(JsonSerializer.Serialize(httpContext.Request.Headers)); + return Task.CompletedTask; + }; + }); var app = builder.Build(); From 1b3de48b7341e235a2f71c312871aa60fde3d66d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Cant=C3=BA?= Date: Mon, 15 Dec 2025 18:29:07 -0600 Subject: [PATCH 2/6] Address feedback --- .../McpChatClientBuilderExtensions.cs | 97 +++++++++---------- .../UseMcpClientWithTestSseServerTests.cs | 40 +------- 2 files changed, 50 insertions(+), 87 deletions(-) diff --git a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs index 7d04bfb00..0171752fa 100644 --- a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs @@ -1,10 +1,10 @@ using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Client; -#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. namespace ModelContextProtocol; @@ -31,6 +31,7 @@ public static class McpChatClientBuilderExtensions /// Use this method as an alternative when working with chat providers that don't have built-in support for hosted MCP servers. /// /// + [Experimental("MEAI001")] public static ChatClientBuilder UseMcpClient( this ChatClientBuilder builder, HttpClient? httpClient = null, @@ -44,13 +45,14 @@ public static ChatClientBuilder UseMcpClient( }); } - private class McpChatClient : DelegatingChatClient + [Experimental("MEAI001")] + private sealed class McpChatClient : DelegatingChatClient { private readonly ILoggerFactory? _loggerFactory; private readonly ILogger _logger; private readonly HttpClient _httpClient; private readonly bool _ownsHttpClient; - private ConcurrentDictionary>? _mcpClientTasks = null; + private readonly ConcurrentDictionary> _mcpClientTasks = []; /// /// Initializes a new instance of the class. @@ -97,29 +99,27 @@ public override async IAsyncEnumerable GetStreamingResponseA } } - private async Task?> BuildDownstreamAIToolsAsync(IList? inputTools, CancellationToken cancellationToken) + private async Task> BuildDownstreamAIToolsAsync(IList inputTools, CancellationToken cancellationToken) { - List? downstreamTools = null; - foreach (var tool in inputTools ?? []) + List downstreamTools = []; + foreach (var tool in inputTools) { if (tool is not HostedMcpServerTool mcpTool) { // For other tools, we want to keep them in the list of tools. - downstreamTools ??= new List(); downstreamTools.Add(tool); continue; } if (!Uri.TryCreate(mcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) || - (parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps)) + (parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps)) { - throw new InvalidOperationException( - $"MCP server address must be an absolute HTTP or HTTPS URI. Invalid address: '{mcpTool.ServerAddress}'"); + throw new InvalidOperationException( + $"Invalid http(s) address: '{mcpTool.ServerAddress}'. MCP server address must be an absolute https(s) URL."); } // List all MCP functions from the specified MCP server. - // This will need some caching in a real-world scenario to avoid repeated calls. - var mcpClient = await CreateMcpClientAsync(parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false); + var mcpClient = await CreateMcpClientAsync(mcpTool.ServerAddress, parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false); var mcpFunctions = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false); // Add the listed functions to our list of tools we'll pass to the inner client. @@ -127,25 +127,20 @@ public override async IAsyncEnumerable GetStreamingResponseA { if (mcpTool.AllowedTools is not null && !mcpTool.AllowedTools.Contains(mcpFunction.Name)) { - _logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name); + if (_logger.IsEnabled(LogLevel.Information)) + { + _logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name); + } continue; } - downstreamTools ??= new List(); switch (mcpTool.ApprovalMode) { - case HostedMcpServerToolAlwaysRequireApprovalMode alwaysRequireApproval: - downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); - break; - case HostedMcpServerToolNeverRequireApprovalMode neverRequireApproval: - downstreamTools.Add(mcpFunction); - break; - case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.AlwaysRequireApprovalToolNames?.Contains(mcpFunction.Name) is true: - downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); - break; + case HostedMcpServerToolNeverRequireApprovalMode: case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpFunction.Name) is true: downstreamTools.Add(mcpFunction); break; + default: // Default to always require approval if no specific mode is set. downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); @@ -173,11 +168,7 @@ protected override void Dispose(bool disposing) // Dispose of all cached MCP clients. foreach (var clientTask in _mcpClientTasks.Values) { -#if NETSTANDARD2_0 if (clientTask.Status == TaskStatus.RanToCompletion) -#else - if (clientTask.IsCompletedSuccessfully) -#endif { _ = clientTask.Result.DisposeAsync(); } @@ -190,41 +181,45 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } - private Task CreateMcpClientAsync(Uri serverAddress, string serverName, string? authorizationToken) + private async Task CreateMcpClientAsync(string key, Uri serverAddress, string serverName, string? authorizationToken) { - if (_mcpClientTasks is null) - { - _mcpClientTasks = new ConcurrentDictionary>(StringComparer.OrdinalIgnoreCase); - } - // Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token. // Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently. - return _mcpClientTasks.GetOrAdd(serverAddress.ToString(), _ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None)); - } +#if NET + // Avoid closure allocation. + Task task = _mcpClientTasks.GetOrAdd(key, + static (_, state) => state.self.CreateMcpClientCoreAsync(state.serverAddress, state.serverName, state.authorizationToken, CancellationToken.None), + (self: this, serverAddress, serverName, authorizationToken)); +#else + Task task = _mcpClientTasks.GetOrAdd(key, + _ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None)); +#endif - private async Task CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken) - { - var serverAddressKey = serverAddress.ToString(); try { - var transport = new HttpClientTransport(new HttpClientTransportOptions - { - Endpoint = serverAddress, - Name = serverName, - AdditionalHeaders = authorizationToken is not null - // Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available. - ? new Dictionary() { { "Authorization", $"Bearer {authorizationToken}" } } - : null, - }, _httpClient, _loggerFactory); - - return await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false); + return await task.ConfigureAwait(false); } catch { - // Remove the failed task from cache so subsequent requests can retry - _mcpClientTasks?.TryRemove(serverAddressKey, out _); + // Remove the failed task from cache so subsequent requests can retry. + _mcpClientTasks.TryRemove(key, out _); throw; } } + + private Task CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken) + { + var transport = new HttpClientTransport(new HttpClientTransportOptions + { + Endpoint = serverAddress, + Name = serverName, + AdditionalHeaders = authorizationToken is not null + // Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available. + ? new Dictionary() { { "Authorization", $"Bearer {authorizationToken}" } } + : null, + }, _httpClient, _loggerFactory); + + return McpClient.CreateAsync(transport, cancellationToken: cancellationToken); + } } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs index f4a3ae023..671db18ba 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs @@ -219,30 +219,6 @@ public async Task UseMcpClient_ApprovalsWorkCorrectly( Assert.Equivalent(expectedNormalAIFunctions, toolsNotRequiringApproval); } - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task UseMcpClient_SupportsConnectorIdAsServer(bool streaming) - { - // Arrange - IChatClient sut = CreateTestChatClient(out var callbackState); - const string connectorId = "test-connector-123"; - var mcpTool = new HostedMcpServerTool(connectorId, _transportOptions.Endpoint); - var options = new ChatOptions { Tools = [mcpTool] }; - - // Act - await GetResponseAsync(sut, options, streaming); - - // Assert - Assert.NotNull(callbackState.CapturedOptions); - Assert.NotNull(callbackState.CapturedOptions.Tools); - var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); - Assert.Equal(3, toolNames.Count); - Assert.Contains("echo", toolNames); - Assert.Contains("echoSessionId", toolNames); - Assert.Contains("sampleLLM", toolNames); - } - [Theory] [InlineData(false)] [InlineData(true)] @@ -280,18 +256,10 @@ public async Task UseMcpClient_AllowedTools_FiltersCorrectly(bool streaming, str // Assert Assert.NotNull(callbackState.CapturedOptions); - if (expectedTools.Length == 0) - { - // When all MCP tools are filtered out and no other tools exist, the Tools collection should be null - Assert.Null(callbackState.CapturedOptions.Tools); - } - else - { - Assert.NotNull(callbackState.CapturedOptions.Tools); - var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); - Assert.Equal(expectedTools.Length, toolNames.Count); - Assert.Equivalent(expectedTools, toolNames); - } + Assert.NotNull(callbackState.CapturedOptions.Tools); + var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(expectedTools.Length, toolNames.Count); + Assert.Equivalent(expectedTools, toolNames); } [Theory] From 298e76eb17b64674a8a1deefd0404d1a533df4b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Cant=C3=BA?= Date: Wed, 7 Jan 2026 17:00:40 -0600 Subject: [PATCH 3/6] Add LRU cache Add more tests Add Retry logic Add configureTransportOptions --- .../McpChatClientBuilderExtensions.cs | 185 ++-- .../McpClientTasksLruCache.cs | 88 ++ .../UseMcpClientTests.cs | 801 ++++++++++++++++++ 3 files changed, 1012 insertions(+), 62 deletions(-) create mode 100644 src/ModelContextProtocol/McpClientTasksLruCache.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs diff --git a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs index 0171752fa..058a2fdc1 100644 --- a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs @@ -1,12 +1,11 @@ -using System.Collections.Concurrent; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Client; -namespace ModelContextProtocol; +namespace ModelContextProtocol.Client; /// /// Extension methods for adding MCP client support to chat clients. @@ -20,6 +19,7 @@ public static class McpChatClientBuilderExtensions /// The to configure. /// The to use, or to create a new instance. /// The to use, or to resolve from services. + /// An optional callback to configure the for each . /// The for method chaining. /// /// @@ -35,12 +35,13 @@ public static class McpChatClientBuilderExtensions public static ChatClientBuilder UseMcpClient( this ChatClientBuilder builder, HttpClient? httpClient = null, - ILoggerFactory? loggerFactory = null) + ILoggerFactory? loggerFactory = null, + Action? configureTransportOptions = null) { return builder.Use((innerClient, services) => { loggerFactory ??= (ILoggerFactory)services.GetService(typeof(ILoggerFactory))!; - var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory); + var chatClient = new McpChatClient(innerClient, httpClient, loggerFactory, configureTransportOptions); return chatClient; }); } @@ -52,7 +53,8 @@ private sealed class McpChatClient : DelegatingChatClient private readonly ILogger _logger; private readonly HttpClient _httpClient; private readonly bool _ownsHttpClient; - private readonly ConcurrentDictionary> _mcpClientTasks = []; + private readonly McpClientTasksLruCache _lruCache; + private readonly Action? _configureTransportOptions; /// /// Initializes a new instance of the class. @@ -60,22 +62,24 @@ private sealed class McpChatClient : DelegatingChatClient /// The underlying , or the next instance in a chain of clients. /// An optional to use when connecting to MCP servers. If not provided, a new instance will be created. /// An to use for logging information about function invocation. - public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null) + /// An optional callback to configure the for each . + public McpChatClient(IChatClient innerClient, HttpClient? httpClient = null, ILoggerFactory? loggerFactory = null, Action? configureTransportOptions = null) : base(innerClient) { _loggerFactory = loggerFactory; _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _httpClient = httpClient ?? new HttpClient(); _ownsHttpClient = httpClient is null; + _lruCache = new McpClientTasksLruCache(capacity: 20); + _configureTransportOptions = configureTransportOptions; } - /// public override async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { if (options?.Tools is { Count: > 0 }) { - var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false); + var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools).ConfigureAwait(false); options = options.Clone(); options.Tools = downstreamTools; } @@ -83,12 +87,11 @@ public override async Task GetResponseAsync( return await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); } - /// public override async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (options?.Tools is { Count: > 0 }) { - var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools, cancellationToken).ConfigureAwait(false); + var downstreamTools = await BuildDownstreamAIToolsAsync(options.Tools).ConfigureAwait(false); options = options.Clone(); options.Tools = downstreamTools; } @@ -99,51 +102,52 @@ public override async IAsyncEnumerable GetStreamingResponseA } } - private async Task> BuildDownstreamAIToolsAsync(IList inputTools, CancellationToken cancellationToken) + private async Task> BuildDownstreamAIToolsAsync(IList chatOptionsTools) { List downstreamTools = []; - foreach (var tool in inputTools) + foreach (var chatOptionsTool in chatOptionsTools) { - if (tool is not HostedMcpServerTool mcpTool) + if (chatOptionsTool is not HostedMcpServerTool hostedMcpTool) { // For other tools, we want to keep them in the list of tools. - downstreamTools.Add(tool); + downstreamTools.Add(chatOptionsTool); continue; } - if (!Uri.TryCreate(mcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) || + if (!Uri.TryCreate(hostedMcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) || (parsedAddress.Scheme != Uri.UriSchemeHttp && parsedAddress.Scheme != Uri.UriSchemeHttps)) { throw new InvalidOperationException( - $"Invalid http(s) address: '{mcpTool.ServerAddress}'. MCP server address must be an absolute https(s) URL."); + $"Invalid http(s) address: '{hostedMcpTool.ServerAddress}'. MCP server address must be an absolute http(s) URL."); } - // List all MCP functions from the specified MCP server. - var mcpClient = await CreateMcpClientAsync(mcpTool.ServerAddress, parsedAddress, mcpTool.ServerName, mcpTool.AuthorizationToken).ConfigureAwait(false); - var mcpFunctions = await mcpClient.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + // Get MCP client and its tools from cache (both are fetched together on first access). + var (_, mcpTools) = await GetClientAndToolsAsync(hostedMcpTool, parsedAddress).ConfigureAwait(false); // Add the listed functions to our list of tools we'll pass to the inner client. - foreach (var mcpFunction in mcpFunctions) + foreach (var mcpTool in mcpTools) { - if (mcpTool.AllowedTools is not null && !mcpTool.AllowedTools.Contains(mcpFunction.Name)) + if (hostedMcpTool.AllowedTools is not null && !hostedMcpTool.AllowedTools.Contains(mcpTool.Name)) { if (_logger.IsEnabled(LogLevel.Information)) { - _logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpFunction.Name); + _logger.LogInformation("MCP function '{FunctionName}' is not allowed by the tool configuration.", mcpTool.Name); } continue; } - switch (mcpTool.ApprovalMode) + var wrappedFunction = new McpRetriableAIFunction(mcpTool, hostedMcpTool, parsedAddress, this); + + switch (hostedMcpTool.ApprovalMode) { case HostedMcpServerToolNeverRequireApprovalMode: - case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpFunction.Name) is true: - downstreamTools.Add(mcpFunction); + case HostedMcpServerToolRequireSpecificApprovalMode specificApprovalMode when specificApprovalMode.NeverRequireApprovalToolNames?.Contains(mcpTool.Name) is true: + downstreamTools.Add(wrappedFunction); break; default: // Default to always require approval if no specific mode is set. - downstreamTools.Add(new ApprovalRequiredAIFunction(mcpFunction)); + downstreamTools.Add(new ApprovalRequiredAIFunction(wrappedFunction)); break; } } @@ -152,48 +156,29 @@ private async Task> BuildDownstreamAIToolsAsync(IList input return downstreamTools; } - /// protected override void Dispose(bool disposing) { if (disposing) { - // Dispose of the HTTP client if it was created by this client. if (_ownsHttpClient) { _httpClient?.Dispose(); } - if (_mcpClientTasks is not null) - { - // Dispose of all cached MCP clients. - foreach (var clientTask in _mcpClientTasks.Values) - { - if (clientTask.Status == TaskStatus.RanToCompletion) - { - _ = clientTask.Result.DisposeAsync(); - } - } - - _mcpClientTasks.Clear(); - } + _lruCache.Dispose(); } base.Dispose(disposing); } - private async Task CreateMcpClientAsync(string key, Uri serverAddress, string serverName, string? authorizationToken) + internal async Task<(McpClient Client, IList Tools)> GetClientAndToolsAsync(HostedMcpServerTool hostedMcpTool, Uri serverAddressUri) { // Note: We don't pass cancellationToken to the factory because the cached task should not be tied to any single caller's cancellation token. // Instead, callers can cancel waiting for the task, but the connection attempt itself will complete independently. -#if NET - // Avoid closure allocation. - Task task = _mcpClientTasks.GetOrAdd(key, - static (_, state) => state.self.CreateMcpClientCoreAsync(state.serverAddress, state.serverName, state.authorizationToken, CancellationToken.None), - (self: this, serverAddress, serverName, authorizationToken)); -#else - Task task = _mcpClientTasks.GetOrAdd(key, - _ => CreateMcpClientCoreAsync(serverAddress, serverName, authorizationToken, CancellationToken.None)); -#endif + Task<(McpClient, IList Tools)> task = _lruCache.GetOrAdd( + hostedMcpTool.ServerAddress, + static (_, state) => state.self.CreateMcpClientAndToolsAsync(state.hostedMcpTool, state.serverAddressUri, CancellationToken.None), + (self: this, hostedMcpTool, serverAddressUri)); try { @@ -201,25 +186,101 @@ private async Task CreateMcpClientAsync(string key, Uri serverAddress } catch { - // Remove the failed task from cache so subsequent requests can retry. - _mcpClientTasks.TryRemove(key, out _); + bool result = RemoveMcpClientFromCache(hostedMcpTool.ServerAddress, out var removedTask); + Debug.Assert(result && removedTask!.Status != TaskStatus.RanToCompletion); throw; } } - private Task CreateMcpClientCoreAsync(Uri serverAddress, string serverName, string? authorizationToken, CancellationToken cancellationToken) + private async Task<(McpClient Client, IList Tools)> CreateMcpClientAndToolsAsync(HostedMcpServerTool hostedMcpTool, Uri serverAddressUri, CancellationToken cancellationToken) { - var transport = new HttpClientTransport(new HttpClientTransportOptions + var transportOptions = new HttpClientTransportOptions { - Endpoint = serverAddress, - Name = serverName, - AdditionalHeaders = authorizationToken is not null + Endpoint = serverAddressUri, + Name = hostedMcpTool.ServerName, + AdditionalHeaders = hostedMcpTool.AuthorizationToken is not null // Update to pass all headers once https://github.com/dotnet/extensions/pull/7053 is available. - ? new Dictionary() { { "Authorization", $"Bearer {authorizationToken}" } } + ? new Dictionary() { { "Authorization", $"Bearer {hostedMcpTool.AuthorizationToken}" } } : null, - }, _httpClient, _loggerFactory); + }; + + _configureTransportOptions?.Invoke(new DummyHostedMcpServerTool(hostedMcpTool.ServerName, serverAddressUri), transportOptions); + + var transport = new HttpClientTransport(transportOptions, _httpClient, _loggerFactory); + var client = await McpClient.CreateAsync(transport, cancellationToken: cancellationToken).ConfigureAwait(false); + try + { + var tools = await client.ListToolsAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + return (client, tools); + } + catch + { + try + { + await client.DisposeAsync().ConfigureAwait(false); + } + catch { } // allow the original exception to propagate + + throw; + } + } + + internal bool RemoveMcpClientFromCache(string key, out Task<(McpClient Client, IList Tools)>? removedTask) + => _lruCache.TryRemove(key, out removedTask); + + /// + /// A temporary instance passed to the configureTransportOptions callback. + /// This prevents the callback from modifying the original tool instance. + /// + private sealed class DummyHostedMcpServerTool(string serverName, Uri serverAddress) + : HostedMcpServerTool(serverName, serverAddress); + } + + /// + /// An AI function wrapper that retries the invocation by recreating an MCP client when an occurs. + /// For example, this can happen if a session is revoked or a server error occurs. The retry evicts the cached MCP client. + /// + [Experimental("MEAI001")] + private sealed class McpRetriableAIFunction : DelegatingAIFunction + { + private readonly HostedMcpServerTool _hostedMcpTool; + private readonly Uri _serverAddressUri; + private readonly McpChatClient _chatClient; + + public McpRetriableAIFunction(AIFunction innerFunction, HostedMcpServerTool hostedMcpTool, Uri serverAddressUri, McpChatClient chatClient) + : base(innerFunction) + { + _hostedMcpTool = hostedMcpTool; + _serverAddressUri = serverAddressUri; + _chatClient = chatClient; + } + + protected override async ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) + { + try + { + return await base.InvokeCoreAsync(arguments, cancellationToken).ConfigureAwait(false); + } + catch (HttpRequestException) { } + + bool result = _chatClient.RemoveMcpClientFromCache(_hostedMcpTool.ServerAddress, out var removedTask); + Debug.Assert(result && removedTask!.Status == TaskStatus.RanToCompletion); + _ = removedTask!.Result.Client.DisposeAsync().AsTask(); + + var freshTool = await GetCurrentToolAsync().ConfigureAwait(false); + return await freshTool.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); + } + + private async Task GetCurrentToolAsync() + { + Debug.Assert(Uri.TryCreate(_hostedMcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) && + (parsedAddress.Scheme == Uri.UriSchemeHttp || parsedAddress.Scheme == Uri.UriSchemeHttps), + "Server address should have been validated before construction"); - return McpClient.CreateAsync(transport, cancellationToken: cancellationToken); + var (client, tools) = await _chatClient.GetClientAndToolsAsync(_hostedMcpTool, _serverAddressUri!).ConfigureAwait(false); + + return tools.FirstOrDefault(t => t.Name == Name) ?? + throw new McpProtocolException($"Tool '{Name}' no longer exists on the MCP server.", McpErrorCode.InvalidParams); } } } diff --git a/src/ModelContextProtocol/McpClientTasksLruCache.cs b/src/ModelContextProtocol/McpClientTasksLruCache.cs new file mode 100644 index 000000000..9f9255a0c --- /dev/null +++ b/src/ModelContextProtocol/McpClientTasksLruCache.cs @@ -0,0 +1,88 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace ModelContextProtocol.Client; + +/// +/// A thread-safe Least Recently Used (LRU) cache for MCP client and tools. +/// +internal sealed class McpClientTasksLruCache : IDisposable +{ + private readonly Dictionary Node, Task<(McpClient Client, IList Tools)> Task)> _cache; + private readonly LinkedList _lruList; + private readonly object _lock = new(); + private readonly int _capacity; + + public McpClientTasksLruCache(int capacity) + { + Debug.Assert(capacity > 0); + _capacity = capacity; + _cache = new Dictionary, Task<(McpClient, IList)>)>(capacity); + _lruList = []; + } + + public Task<(McpClient Client, IList Tools)> GetOrAdd(string key, Func)>> valueFactory, TState state) + { + lock (_lock) + { + if (_cache.TryGetValue(key, out var existing)) + { + _lruList.Remove(existing.Node); + _lruList.AddLast(existing.Node); + return existing.Task; + } + + var value = valueFactory(key, state); + var newNode = _lruList.AddLast(key); + _cache[key] = (newNode, value); + + // Evict oldest if over capacity + if (_cache.Count > _capacity) + { + string oldestKey = _lruList.First!.Value; + _lruList.RemoveFirst(); + (_, Task<(McpClient Client, IList Tools)> task) = _cache[oldestKey]; + _cache.Remove(oldestKey); + + // Dispose evicted MCP client + if (task.Status == TaskStatus.RanToCompletion) + { + _ = task.Result.Client.DisposeAsync().AsTask(); + } + } + + return value; + } + } + + public bool TryRemove(string key, [MaybeNullWhen(false)] out Task<(McpClient Client, IList Tools)>? task) + { + lock (_lock) + { + if (_cache.TryGetValue(key, out var entry)) + { + _cache.Remove(key); + _lruList.Remove(entry.Node); + task = entry.Task; + return true; + } + + task = null; + return false; + } + } + + public void Dispose() + { + lock (_lock) + { + foreach ((_, Task<(McpClient Client, IList Tools)> task) in _cache.Values) + { + if (task.Status == TaskStatus.RanToCompletion) + { + _ = task.Result.Client.DisposeAsync().AsTask(); + } + } + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs new file mode 100644 index 000000000..8fa884464 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs @@ -0,0 +1,801 @@ +using System.Runtime.CompilerServices; +using System.Text.Json; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using Moq; +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class UseMcpClientTests : KestrelInMemoryTest +{ + public UseMcpClientTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + private async Task StartServerAsync(Action? configureApp = null) + { + IMcpServerBuilder builder = Builder.Services.AddMcpServer(options => + { + options.Capabilities = new ServerCapabilities + { + Tools = new(), + Resources = new(), + Prompts = new(), + }; + options.ServerInstructions = "This is a test server with only stub functionality"; + options.Handlers = new() + { + ListToolsHandler = async (request, cancellationToken) => + { + return new ListToolsResult + { + Tools = + [ + new Tool + { + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = JsonElement.Parse(""" + { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The input to echo back." + } + }, + "required": ["message"] + } + """), + }, + new Tool + { + Name = "echoSessionId", + Description = "Echoes the session id back to the client.", + InputSchema = JsonElement.Parse(""" + { + "type": "object" + } + """), + }, + new Tool + { + Name = "sampleLLM", + Description = "Samples from an LLM using MCP's sampling feature.", + InputSchema = JsonElement.Parse(""" + { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt to send to the LLM" + }, + "maxTokens": { + "type": "number", + "description": "Maximum number of tokens to generate" + } + }, + "required": ["prompt", "maxTokens"] + } + """), + } + ] + }; + }, + CallToolHandler = async (request, cancellationToken) => + { + if (request.Params is null) + { + throw new McpProtocolException("Missing required parameter 'name'", McpErrorCode.InvalidParams); + } + if (request.Params.Name == "echo") + { + if (request.Params.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) + { + throw new McpProtocolException("Missing required argument 'message'", McpErrorCode.InvalidParams); + } + return new CallToolResult + { + Content = [new TextContentBlock { Text = $"Echo: {message}" }] + }; + } + else if (request.Params.Name == "echoSessionId") + { + return new CallToolResult + { + Content = [new TextContentBlock { Text = request.Server.SessionId ?? string.Empty }] + }; + } + else if (request.Params.Name == "sampleLLM") + { + if (request.Params.Arguments is null || + !request.Params.Arguments.TryGetValue("prompt", out var prompt) || + !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) + { + throw new McpProtocolException("Missing required arguments 'prompt' and 'maxTokens'", McpErrorCode.InvalidParams); + } + // Simple mock response for sampleLLM + return new CallToolResult + { + Content = [new TextContentBlock { Text = "LLM sampling result: Test response" }] + }; + } + else + { + throw new McpProtocolException($"Unknown tool: '{request.Params.Name}'", McpErrorCode.InvalidParams); + } + } + }; + }) + .WithHttpTransport(); + + var app = Builder.Build(); + configureApp?.Invoke(app); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + /// + /// Captures the arguments received by the leaf mock IChatClient. + /// + private sealed class LeafChatClientState + { + public ChatOptions? CapturedOptions { get; set; } + public List> CapturedMessages { get; set; } = []; + public int CallCount { get; set; } + public void Clear() + { + CapturedOptions = null; + CapturedMessages.Clear(); + CallCount = 0; + } + } + + private IChatClient CreateTestChatClient(out LeafChatClientState leafClientState, Action? configureTransportOptions = null) + { + var state = new LeafChatClientState(); + + var mockInnerClient = new Mock(); + mockInnerClient + .Setup(c => c.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns((IEnumerable messages, ChatOptions? options, CancellationToken ct) => + GetStreamingResponseAsync(messages, options, ct).ToChatResponseAsync(ct)); + + mockInnerClient + .Setup(c => c.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Returns(GetStreamingResponseAsync); + + leafClientState = state; + return mockInnerClient.Object.AsBuilder() + .UseMcpClient(HttpClient, LoggerFactory, configureTransportOptions) + // Placement is important, must be after UseMcpClient, otherwise, UseFunctionInvocation won't see the MCP tools. + .UseFunctionInvocation() + .Build(); + + async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, + ChatOptions? options, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + state.CapturedOptions = options; + state.CapturedMessages.Add(messages); + + // First call: request to invoke the echo tool + if (state.CallCount++ == 0 && options?.Tools is { Count: > 0 } tools) + { + Assert.Contains(tools, t => t.Name == "echo"); + yield return new ChatResponseUpdate(ChatRole.Assistant, + [ + new FunctionCallContent("call_123", "echo", new Dictionary { ["message"] = "test message" }) + ]); + } + else + { + // Subsequent calls: return final response + yield return new ChatResponseUpdate(ChatRole.Assistant, "Final response"); + } + } + } + + private static void AssertLeafClientMessagesWithInvocation(List> capturedMessages) + { + Assert.Equal(2, capturedMessages.Count); + var firstCall = capturedMessages[0]; + var msg = Assert.Single(firstCall); + Assert.Equal(ChatRole.User, msg.Role); + Assert.Equal("Test message", msg.Text); + + var secondCall = capturedMessages[1].ToList(); + Assert.Equal(3, secondCall.Count); + Assert.Equal(ChatRole.User, secondCall[0].Role); + Assert.Equal("Test message", secondCall[0].Text); + + Assert.Equal(ChatRole.Assistant, secondCall[1].Role); + var functionCall = Assert.IsType(Assert.Single(secondCall[1].Contents)); + Assert.Equal("call_123", functionCall.CallId); + Assert.Equal("echo", functionCall.Name); + + Assert.Equal(ChatRole.Tool, secondCall[2].Role); + var functionResult = Assert.IsType(Assert.Single(secondCall[2].Contents)); + Assert.Equal("call_123", functionResult.CallId); + Assert.Contains("Echo: test message", functionResult.Result?.ToString()); + } + + private static void AssertResponseWithInvocation(ChatResponse response) + { + Assert.NotNull(response); + Assert.Equal(3, response.Messages.Count); + + Assert.Equal(ChatRole.Assistant, response.Messages[0].Role); + Assert.Single(response.Messages[0].Contents); + Assert.IsType(response.Messages[0].Contents[0]); + + Assert.Equal(ChatRole.Tool, response.Messages[1].Role); + Assert.Single(response.Messages[1].Contents); + Assert.IsType(response.Messages[1].Contents[0]); + + Assert.Equal(ChatRole.Assistant, response.Messages[2].Role); + Assert.Equal("Final response", response.Messages[2].Text); + } + + [Theory] + [InlineData(false, false)] + [InlineData(false, true)] + [InlineData(true, false)] + [InlineData(true, true)] + public async Task UseMcpClient_ShouldProduceTools(bool streaming, bool useUrl) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = useUrl ? + new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) : + new HostedMcpServerTool("serverName", HttpClient.BaseAddress!.ToString()); + mcpTool.ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(3, toolNames.Count); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_DoesNotConflictWithRegularTools(bool streaming) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var regularTool = AIFunctionFactory.Create(() => "regular tool result", "regularTool"); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!.ToString()) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions + { + Tools = [regularTool, mcpTool] + }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(4, toolNames.Count); + Assert.Contains("regularTool", toolNames); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + public static IEnumerable UseMcpClient_ApprovalMode_TestData() + { + string[] allToolNames = ["echo", "echoSessionId", "sampleLLM"]; + foreach (var streaming in new[] { false, true }) + { + yield return new object?[] { streaming, new HostedMcpServerToolNeverRequireApprovalMode(), (string[])[], allToolNames }; + yield return new object?[] { streaming, new HostedMcpServerToolAlwaysRequireApprovalMode(), allToolNames, (string[])[] }; + yield return new object?[] { streaming, null, allToolNames, (string[])[] }; + // Specific mode with empty lists - all tools should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode([], []), allToolNames, (string[])[] }; + // Specific mode with one tool always requiring approval - the other two should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode(["echo"], []), allToolNames, (string[])[] }; + // Specific mode with one tool never requiring approval - the other two should default to requiring approval. + yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode([], ["echo"]), (string[])["echoSessionId", "sampleLLM"], (string[])["echo"] }; + } + } + + [Theory] + [MemberData(nameof(UseMcpClient_ApprovalMode_TestData))] + public async Task UseMcpClient_ApprovalMode(bool streaming, HostedMcpServerToolApprovalMode? approvalMode, string[] expectedApprovalRequiredAIFunctions, string[] expectedNormalAIFunctions) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = approvalMode + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + Assert.Equal(3, leafClientState.CapturedOptions.Tools.Count); + + var toolsRequiringApproval = leafClientState.CapturedOptions.Tools + .Where(t => t is ApprovalRequiredAIFunction).Select(t => t.Name); + var toolsNotRequiringApproval = leafClientState.CapturedOptions.Tools + .Where(t => t is not ApprovalRequiredAIFunction).Select(t => t.Name); + + Assert.Equivalent(expectedApprovalRequiredAIFunctions, toolsRequiringApproval); + Assert.Equivalent(expectedNormalAIFunctions, toolsNotRequiringApproval); + } + + public static IEnumerable UseMcpClient_HandleFunctionApprovalRequest_TestData() + { + foreach (var streaming in new[] { false, true }) + { + // Approval modes that will cause function approval requests + yield return new object?[] { streaming, null }; + yield return new object?[] { streaming, HostedMcpServerToolApprovalMode.AlwaysRequire }; + yield return new object?[] { streaming, HostedMcpServerToolApprovalMode.RequireSpecific(["echo"], null) }; + } + } + + [Theory] + [MemberData(nameof(UseMcpClient_HandleFunctionApprovalRequest_TestData))] + public async Task UseMcpClient_HandleFunctionApprovalRequest(bool streaming, HostedMcpServerToolApprovalMode? approvalMode) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = approvalMode + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + List chatHistory = []; + chatHistory.Add(new ChatMessage(ChatRole.User, "Test message")); + var response = streaming ? + await sut.GetStreamingResponseAsync(chatHistory, options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync(chatHistory, options, TestContext.Current.CancellationToken); + + chatHistory.AddRange(response.Messages); + var approvalRequest = Assert.Single(response.Messages.SelectMany(m => m.Contents).OfType()); + chatHistory.Add(new ChatMessage(ChatRole.User, [approvalRequest.CreateResponse(true)])); + + response = streaming ? + await sut.GetStreamingResponseAsync(chatHistory, options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync(chatHistory, options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + } + + [Theory] + [InlineData(false, null, (string[])["echo", "echoSessionId", "sampleLLM"])] + [InlineData(true, null, (string[])["echo", "echoSessionId", "sampleLLM"])] + [InlineData(false, (string[])["echo"], (string[])["echo"])] + [InlineData(true, (string[])["echo"], (string[])["echo"])] + [InlineData(false, (string[])[], (string[])[])] + [InlineData(true, (string[])[], (string[])[])] + public async Task UseMcpClient_AllowedTools_FiltersCorrectly(bool streaming, string[]? allowedTools, string[] expectedTools) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + AllowedTools = allowedTools, + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(expectedTools.Length, toolNames.Count); + Assert.Equivalent(expectedTools, toolNames); + + if (expectedTools.Contains("echo")) + { + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + } + else + { + var responseMsg = Assert.Single(response.Messages); + Assert.Equal(ChatRole.Assistant, responseMsg.Role); + Assert.Equal("Final response", responseMsg.Text); + + Assert.Single(leafClientState.CapturedMessages); + var firstCall = leafClientState.CapturedMessages[0]; + var leafClientMessage = Assert.Single(firstCall); + Assert.Equal(ChatRole.User, leafClientMessage.Role); + Assert.Equal("Test message", leafClientMessage.Text); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_AuthorizationTokenHeaderFlowsCorrectly(bool streaming) + { + // Arrange + const string testToken = "test-bearer-token-12345"; + bool authReceivedForInitialize = false; + bool authReceivedForNotificationsInitialized = false; + bool authReceivedForToolsList = false; + bool authReceivedForToolsCall = false; + + await using var _ = await StartServerAsync( + configureApp: app => + { + app.Use(async (context, next) => + { + if (context.Request.Method == "POST" && + context.Request.Headers.TryGetValue("Authorization", out var authHeader)) + { + Assert.Equal($"Bearer {testToken}", authHeader.ToString()); + + context.Request.EnableBuffering(); + JsonRpcRequest? rpcRequest = await JsonSerializer.DeserializeAsync( + context.Request.Body, + McpJsonUtilities.DefaultOptions, + context.RequestAborted); + context.Request.Body.Position = 0; + Assert.NotNull(rpcRequest); + + switch (rpcRequest.Method) + { + case "initialize": + authReceivedForInitialize = true; + break; + case "notifications/initialized": + authReceivedForNotificationsInitialized = true; + break; + case "tools/list": + authReceivedForToolsList = true; + break; + case "tools/call": + authReceivedForToolsCall = true; + break; + } + } + await next(); + }); + }); + + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + AuthorizationToken = testToken, + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions + { + Tools = [mcpTool] + }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.True(authReceivedForInitialize, "Authorization header was not captured in initial request"); + Assert.True(authReceivedForNotificationsInitialized, "Authorization header was not captured in notifications/initialized request"); + Assert.True(authReceivedForToolsList, "Authorization header was not captured in tools/list request"); + Assert.True(authReceivedForToolsCall, "Authorization header was not captured in tools/call request"); + + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Equal(3, toolNames.Count); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_CachesClientForSameServerAddress(bool streaming) + { + // Arrange + int initializeCallCount = 0; + await using var _ = await StartServerAsync(configureApp: app => + { + app.Use(async (context, next) => + { + if (context.Request.Method == "POST") + { + context.Request.EnableBuffering(); + var rpcRequest = await JsonSerializer.DeserializeAsync( + context.Request.Body, + McpJsonUtilities.DefaultOptions, + context.RequestAborted); + context.Request.Body.Position = 0; + + if (rpcRequest?.Method == "initialize") + { + initializeCallCount++; + } + } + await next(); + }); + }); + + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var firstCallToolCount = leafClientState.CapturedOptions.Tools.Count; + Assert.Equal(3, firstCallToolCount); + var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + Assert.Equal(1, initializeCallCount); + + // Arrange + leafClientState.Clear(); + + // Act + var secondResponse = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(secondResponse); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + Assert.NotNull(leafClientState.CapturedOptions); + Assert.NotNull(leafClientState.CapturedOptions.Tools); + var secondCallToolCount = leafClientState.CapturedOptions.Tools.Count; + Assert.Equal(3, secondCallToolCount); + Assert.Equal(firstCallToolCount, secondCallToolCount); + toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); + Assert.Contains("echo", toolNames); + Assert.Contains("echoSessionId", toolNames); + Assert.Contains("sampleLLM", toolNames); + Assert.True(initializeCallCount == 1, "Initialize should not be called more than once because the MCP client is cached."); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_RetriesWhenSessionRevokedByServer(bool streaming) + { + // Arrange + string? firstSessionId = null; + string? secondSessionId = null; + + await using var app = await StartServerAsync( + configureApp: app => + { + app.Use(async (context, next) => + { + if (context.Request.Method == "POST") + { + context.Request.EnableBuffering(); + var rpcRequest = await JsonSerializer.DeserializeAsync( + context.Request.Body, + McpJsonUtilities.DefaultOptions); + context.Request.Body.Position = 0; + + if (rpcRequest?.Method == "tools/call" && context.Request.Headers.TryGetValue("Mcp-Session-Id", out var sessionIdHeader)) + { + var sessionId = sessionIdHeader.ToString(); + + if (firstSessionId == null) + { + // First tool call - capture session and return 404 to revoke it + firstSessionId = sessionId; + context.Response.StatusCode = StatusCodes.Status404NotFound; + return; + } + else + { + // Second tool call - capture session and let it succeed + secondSessionId = sessionId; + } + } + } + await next(); + }); + }); + + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + Assert.NotNull(firstSessionId); + Assert.NotNull(secondSessionId); + Assert.NotEqual(firstSessionId, secondSessionId); + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_RetriesOnServerError(bool streaming) + { + int toolCallCount = 0; + await using var app = await StartServerAsync(configureApp: app => + { + app.Use(async (context, next) => + { + if (context.Request.Method == "POST") + { + context.Request.EnableBuffering(); + var rpcRequest = await JsonSerializer.DeserializeAsync( + context.Request.Body, + McpJsonUtilities.DefaultOptions, + context.RequestAborted); + context.Request.Body.Position = 0; + + if (rpcRequest?.Method == "tools/call" && ++toolCallCount == 1) + { + throw new Exception("Simulated server error."); + } + } + await next(); + }); + }); + + using IChatClient sut = CreateTestChatClient(out var leafClientState); + + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + Assert.Equal(2, toolCallCount); + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_ConfigureTransportOptions_CallbackIsInvoked(bool streaming) + { + // Arrange + HostedMcpServerTool? capturedTool = null; + HttpClientTransportOptions? capturedTransportOptions = null; + await using var _ = await StartServerAsync(); + + using IChatClient sut = CreateTestChatClient(out var leafClientState, (tool, transportOptions) => + { + capturedTool = tool; + capturedTransportOptions = transportOptions; + }); + + var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) + { + ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire, + AuthorizationToken = "test-auth-token-123" + }; + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); + + // Assert + AssertResponseWithInvocation(response); + AssertLeafClientMessagesWithInvocation(leafClientState.CapturedMessages); + + Assert.NotNull(capturedTool); + Assert.Equal("serverName", capturedTool.ServerName); + Assert.Equal(HttpClient.BaseAddress!.ToString(), capturedTool.ServerAddress); + Assert.Null(capturedTool.ServerDescription); + Assert.Null(capturedTool.AuthorizationToken); + Assert.Null(capturedTool.AllowedTools); + Assert.Null(capturedTool.ApprovalMode); + + Assert.NotNull(capturedTransportOptions); + Assert.Equal(HttpClient.BaseAddress, capturedTransportOptions.Endpoint); + Assert.Equal("serverName", capturedTransportOptions.Name); + Assert.Equal("Bearer test-auth-token-123", capturedTransportOptions.AdditionalHeaders!["Authorization"]); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task UseMcpClient_ThrowsInvalidOperationException_WhenServerAddressIsInvalid(bool streaming) + { + // Arrange + await using var _ = await StartServerAsync(); + using IChatClient sut = CreateTestChatClient(out var leafClientState); + var mcpTool = new HostedMcpServerTool("serverNameConnector", "test-connector-123"); + var options = new ChatOptions { Tools = [mcpTool] }; + + // Act & Assert + var exception = await Assert.ThrowsAsync(() => streaming ? + sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken)); + Assert.Contains("test-connector-123", exception.Message); + } +} From 9ecbbd4a2f426612f3d34c3f7a5840ce01d0e48e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Cant=C3=BA?= Date: Wed, 7 Jan 2026 20:56:48 -0600 Subject: [PATCH 4/6] Revert remnants of previous testing approach --- .../HttpServerIntegrationTests.cs | 6 +- .../SseServerIntegrationTestFixture.cs | 69 +--- .../SseServerIntegrationTests.cs | 2 +- .../StatelessServerIntegrationTests.cs | 2 +- .../StreamableHttpServerIntegrationTests.cs | 2 +- .../UseMcpClientWithTestSseServerTests.cs | 300 ------------------ .../Program.cs | 13 +- 7 files changed, 23 insertions(+), 371 deletions(-) delete mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 562bedae8..ce4f3b56a 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -4,11 +4,11 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture +public abstract class HttpServerIntegrationTests : LoggedTest, IClassFixture { - protected readonly SseServerWithXunitLoggerFixture _fixture; + protected readonly SseServerIntegrationTestFixture _fixture; - public HttpServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper) + public HttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) : base(testOutputHelper) { _fixture = fixture; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index 7044acd30..c382c4385 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -7,18 +7,23 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public abstract class SseServerIntegrationTestFixture : IAsyncDisposable +public class SseServerIntegrationTestFixture : IAsyncDisposable { private readonly KestrelInMemoryTransport _inMemoryTransport = new(); + private readonly Task _serverTask; private readonly CancellationTokenSource _stopCts = new(); + // XUnit's ITestOutputHelper is created per test, while this fixture is used for + // multiple tests, so this dispatches the output to the current test. + private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); + private HttpClientTransportOptions DefaultTransportOptions { get; set; } = new() { Endpoint = new("http://localhost:5000/"), }; - protected SseServerIntegrationTestFixture() + public SseServerIntegrationTestFixture() { var socketsHttpHandler = new SocketsHttpHandler { @@ -34,10 +39,8 @@ protected SseServerIntegrationTestFixture() BaseAddress = new("http://localhost:5000/"), }; - _serverTask = Program.MainAsync([], CreateLoggerProvider(), _inMemoryTransport, _stopCts.Token); + _serverTask = Program.MainAsync([], new XunitLoggerProvider(_delegatingTestOutputHelper), _inMemoryTransport, _stopCts.Token); } - - protected abstract ILoggerProvider CreateLoggerProvider(); public HttpClient HttpClient { get; } @@ -50,17 +53,21 @@ public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerF TestContext.Current.CancellationToken); } - public virtual void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions) + public void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions) { + _delegatingTestOutputHelper.CurrentTestOutputHelper = output; DefaultTransportOptions = clientTransportOptions; } - public virtual void TestCompleted() + public void TestCompleted() { + _delegatingTestOutputHelper.CurrentTestOutputHelper = null; } - public virtual async ValueTask DisposeAsync() + public async ValueTask DisposeAsync() { + _delegatingTestOutputHelper.CurrentTestOutputHelper = null; + HttpClient.Dispose(); _stopCts.Cancel(); @@ -75,49 +82,3 @@ public virtual async ValueTask DisposeAsync() _stopCts.Dispose(); } } - -/// -/// SSE server fixture that routes logs to xUnit test output. -/// -public class SseServerWithXunitLoggerFixture : SseServerIntegrationTestFixture -{ - // XUnit's ITestOutputHelper is created per test, while this fixture is used for - // multiple tests, so this dispatches the output to the current test. - private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); - - protected override ILoggerProvider CreateLoggerProvider() - => new XunitLoggerProvider(_delegatingTestOutputHelper); - - public override void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions) - { - _delegatingTestOutputHelper.CurrentTestOutputHelper = output; - base.Initialize(output, clientTransportOptions); - } - - public override void TestCompleted() - { - _delegatingTestOutputHelper.CurrentTestOutputHelper = null; - base.TestCompleted(); - } - - public override async ValueTask DisposeAsync() - { - _delegatingTestOutputHelper.CurrentTestOutputHelper = null; - await base.DisposeAsync(); - } -} - -/// -/// Fixture for tests that need to inspect server logs using MockLoggerProvider. -/// Use for tests that just need xUnit output. -/// -public class SseServerWithMockLoggerFixture : SseServerIntegrationTestFixture -{ - private readonly MockLoggerProvider _mockLoggerProvider = new(); - - protected override ILoggerProvider CreateLoggerProvider() - => _mockLoggerProvider; - - public IEnumerable<(string Category, LogLevel LogLevel, EventId EventId, string Message, Exception? Exception)> ServerLogs - => _mockLoggerProvider.LogMessages; -} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index 5339235af..eb7db0110 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -4,7 +4,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class SseServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper) +public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) : HttpServerIntegrationTests(fixture, testOutputHelper) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index 6937e4be6..2ce63a1bc 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -2,7 +2,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class StatelessServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper) +public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) : StreamableHttpServerIntegrationTests(fixture, testOutputHelper) { protected override HttpClientTransportOptions ClientTransportOptions => new() diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 63c6dc77b..b2b0b5499 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -3,7 +3,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; -public class StreamableHttpServerIntegrationTests(SseServerWithXunitLoggerFixture fixture, ITestOutputHelper testOutputHelper) +public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) : HttpServerIntegrationTests(fixture, testOutputHelper) { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs deleted file mode 100644 index 671db18ba..000000000 --- a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientWithTestSseServerTests.cs +++ /dev/null @@ -1,300 +0,0 @@ -using System.Runtime.CompilerServices; -using Microsoft.Extensions.AI; -using ModelContextProtocol.Client; -using ModelContextProtocol.Tests.Utils; -using Moq; -#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - -namespace ModelContextProtocol.AspNetCore.Tests; - -public class UseMcpClientWithTestSseServerTests : LoggedTest, IClassFixture -{ - private readonly HttpClientTransportOptions _transportOptions; - private readonly SseServerWithMockLoggerFixture _fixture; - - public UseMcpClientWithTestSseServerTests(SseServerWithMockLoggerFixture fixture, ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - _transportOptions = new HttpClientTransportOptions() - { - Endpoint = new("http://localhost:5000/sse"), - Name = "TestSseServer", - }; - - _fixture = fixture; - _fixture.Initialize(testOutputHelper, _transportOptions); - } - - public override void Dispose() - { - _fixture.TestCompleted(); - base.Dispose(); - } - - private sealed class CallbackState - { - public ChatOptions? CapturedOptions { get; set; } - } - - private IChatClient CreateTestChatClient(out CallbackState callbackState) - { - var state = new CallbackState(); - - var mockInnerClient = new Mock(); - mockInnerClient - .Setup(c => c.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())) - .Callback, ChatOptions?, CancellationToken>( - (msgs, opts, ct) => state.CapturedOptions = opts) - .ReturnsAsync(new ChatResponse([new ChatMessage(ChatRole.Assistant, "Dummy response")])); - - mockInnerClient - .Setup(c => c.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())) - .Callback, ChatOptions?, CancellationToken>( - (msgs, opts, ct) => state.CapturedOptions = opts) - .Returns(GetStreamingResponseAsync()); - - callbackState = state; - return mockInnerClient.Object.AsBuilder() - .UseMcpClient(_fixture.HttpClient, LoggerFactory) - .Build(); - - static async IAsyncEnumerable GetStreamingResponseAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) - { - yield return new ChatResponseUpdate(ChatRole.Assistant, "Dummy response"); - } - } - - private async Task GetResponseAsync(IChatClient client, ChatOptions options, bool streaming) - { - if (streaming) - { - await foreach (var _ in client.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken)) - { } - } - else - { - _ = await client.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); - } - } - - [Theory] - [InlineData(false, false)] - [InlineData(false, true)] - [InlineData(true, false)] - [InlineData(true, true)] - public async Task UseMcpClient_ShouldProduceTools(bool streaming, bool useUrl) - { - // Arrange - IChatClient sut = CreateTestChatClient(out var callbackState); - var mcpTool = useUrl ? - new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint) : - new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint.ToString()); - var options = new ChatOptions { Tools = [mcpTool] }; - - // Act - await GetResponseAsync(sut, options, streaming); - - // Assert - Assert.NotNull(callbackState.CapturedOptions); - Assert.NotNull(callbackState.CapturedOptions.Tools); - var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); - Assert.Equal(3, toolNames.Count); - Assert.Contains("echo", toolNames); - Assert.Contains("echoSessionId", toolNames); - Assert.Contains("sampleLLM", toolNames); - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task UseMcpClient_DoesNotConflictWithRegularTools(bool streaming) - { - // Arrange - IChatClient sut = CreateTestChatClient(out var callbackState); - var regularTool = AIFunctionFactory.Create(() => "regular tool result", "RegularTool"); - var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint); - var options = new ChatOptions - { - Tools = - [ - regularTool, - mcpTool - ] - }; - - // Act - await GetResponseAsync(sut, options, streaming); - - // Assert - Assert.NotNull(callbackState.CapturedOptions); - Assert.NotNull(callbackState.CapturedOptions.Tools); - var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); - Assert.Equal(4, toolNames.Count); - Assert.Contains("RegularTool", toolNames); - Assert.Contains("echo", toolNames); - Assert.Contains("echoSessionId", toolNames); - Assert.Contains("sampleLLM", toolNames); - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task UseMcpClient_AuthorizationTokenHeaderFlowsCorrectly(bool streaming) - { - // Arrange - const string testToken = "test-bearer-token-12345"; - IChatClient sut = CreateTestChatClient(out var callbackState); - var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint) - { - AuthorizationToken = testToken - }; - var options = new ChatOptions - { - Tools = [mcpTool] - }; - - // Act - await GetResponseAsync(sut, options, streaming); - - // Assert - Assert.NotNull(callbackState.CapturedOptions); - Assert.NotNull(callbackState.CapturedOptions.Tools); - var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); - Assert.Equal(3, toolNames.Count); - Assert.Contains("echo", toolNames); - Assert.Contains("echoSessionId", toolNames); - Assert.Contains("sampleLLM", toolNames); - // We set TestSseServer to log IHeaderDictionary as json. - Assert.Contains(_fixture.ServerLogs, log => log.Message.Contains(@"""Authorization"":[""Bearer test-bearer-token-12345""]")); - } - - public static IEnumerable UseMcpClient_ApprovalsWorkCorrectly_TestData() - { - string[] allToolNames = ["echo", "echoSessionId", "sampleLLM"]; - foreach (var streaming in new[] { false, true }) - { - yield return new object?[] { streaming, new HostedMcpServerToolNeverRequireApprovalMode(), (string[])[], allToolNames }; - yield return new object?[] { streaming, new HostedMcpServerToolAlwaysRequireApprovalMode(), allToolNames, (string[])[] }; - yield return new object?[] { streaming, null, allToolNames, (string[])[] }; - // Specific mode with empty lists - all tools should default to requiring approval. - yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode([], []), allToolNames, (string[])[] }; - // Specific mode with one tool always requiring approval - the other two should default to requiring approval. - yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode(["echo"], []), allToolNames, (string[])[] }; - // Specific mode with one tool never requiring approval - the other two should default to requiring approval. - yield return new object?[] { streaming, new HostedMcpServerToolRequireSpecificApprovalMode([], ["echo"]), (string[])["echoSessionId", "sampleLLM"], (string[])["echo"] }; - } - } - - [Theory] - [MemberData(nameof(UseMcpClient_ApprovalsWorkCorrectly_TestData))] - public async Task UseMcpClient_ApprovalsWorkCorrectly( - bool streaming, - HostedMcpServerToolApprovalMode? approvalMode, - string[] expectedApprovalRequiredAIFunctions, - string[] expectedNormalAIFunctions) - { - // Arrange - IChatClient sut = CreateTestChatClient(out var callbackState); - var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint) - { - ApprovalMode = approvalMode - }; - var options = new ChatOptions { Tools = [mcpTool] }; - - // Act - await GetResponseAsync(sut, options, streaming); - - // Assert - Assert.NotNull(callbackState.CapturedOptions); - Assert.NotNull(callbackState.CapturedOptions.Tools); - Assert.Equal(3, callbackState.CapturedOptions.Tools.Count); - - var toolsRequiringApproval = callbackState.CapturedOptions.Tools - .Where(t => t is ApprovalRequiredAIFunction).Select(t => t.Name); - - var toolsNotRequiringApproval = callbackState.CapturedOptions.Tools - .Where(t => t is not ApprovalRequiredAIFunction).Select(t => t.Name); - - Assert.Equivalent(expectedApprovalRequiredAIFunctions, toolsRequiringApproval); - Assert.Equivalent(expectedNormalAIFunctions, toolsNotRequiringApproval); - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task UseMcpClient_ThrowsInvalidOperationException_WhenServerAddressIsInvalid(bool streaming) - { - // Arrange - IChatClient sut = CreateTestChatClient(out _); - var mcpTool = new HostedMcpServerTool("test-server", "test-connector-123"); - var options = new ChatOptions { Tools = [mcpTool] }; - - // Act & Assert - var exception = await Assert.ThrowsAsync(() => GetResponseAsync(sut, options, streaming)); - Assert.Contains("test-connector-123", exception.Message); - } - - [Theory] - [InlineData(false, null, (string[])["echo", "echoSessionId", "sampleLLM"])] - [InlineData(true, null, (string[])["echo", "echoSessionId", "sampleLLM"])] - [InlineData(false, (string[])["echo"], (string[])["echo"])] - [InlineData(true, (string[])["echo"], (string[])["echo"])] - [InlineData(false, (string[])[], (string[])[])] - [InlineData(true, (string[])[], (string[])[])] - public async Task UseMcpClient_AllowedTools_FiltersCorrectly(bool streaming, string[]? allowedTools, string[] expectedTools) - { - // Arrange - IChatClient sut = CreateTestChatClient(out var callbackState); - var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint) - { - AllowedTools = allowedTools - }; - var options = new ChatOptions { Tools = [mcpTool] }; - - // Act - await GetResponseAsync(sut, options, streaming); - - // Assert - Assert.NotNull(callbackState.CapturedOptions); - Assert.NotNull(callbackState.CapturedOptions.Tools); - var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); - Assert.Equal(expectedTools.Length, toolNames.Count); - Assert.Equivalent(expectedTools, toolNames); - } - - [Theory] - [InlineData(false)] - [InlineData(true)] - public async Task UseMcpClient_CachesClientForSameServerAddress(bool streaming) - { - // Arrange - IChatClient sut = CreateTestChatClient(out var callbackState); - var mcpTool = new HostedMcpServerTool(_transportOptions.Name!, _transportOptions.Endpoint); - var options = new ChatOptions { Tools = [mcpTool] }; - - // Act - First call - await GetResponseAsync(sut, options, streaming); - - // Assert - First call should succeed and produce tools - Assert.NotNull(callbackState.CapturedOptions); - Assert.NotNull(callbackState.CapturedOptions.Tools); - var firstCallToolCount = callbackState.CapturedOptions.Tools.Count; - Assert.Equal(3, firstCallToolCount); - - // Act - Second call with same server address (should use cached client) - await GetResponseAsync(sut, options, streaming); - - // Assert - Second call should also succeed with same tools - Assert.NotNull(callbackState.CapturedOptions); - Assert.NotNull(callbackState.CapturedOptions.Tools); - var secondCallToolCount = callbackState.CapturedOptions.Tools.Count; - Assert.Equal(3, secondCallToolCount); - Assert.Equal(firstCallToolCount, secondCallToolCount); - - // Verify the tools are the same - var toolNames = callbackState.CapturedOptions.Tools.Select(t => t.Name).ToList(); - Assert.Contains("echo", toolNames); - Assert.Contains("echoSessionId", toolNames); - Assert.Contains("sampleLLM", toolNames); - } -} diff --git a/tests/ModelContextProtocol.TestSseServer/Program.cs b/tests/ModelContextProtocol.TestSseServer/Program.cs index a6f55b345..a29c30587 100644 --- a/tests/ModelContextProtocol.TestSseServer/Program.cs +++ b/tests/ModelContextProtocol.TestSseServer/Program.cs @@ -38,7 +38,7 @@ private static void ConfigureOptions(McpServerOptions options) Console.WriteLine("Registering handlers."); - #region Helper method + #region Helped method static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) { return new CreateMessageRequestParams @@ -421,16 +421,7 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide } builder.Services.AddMcpServer(ConfigureOptions) - .WithHttpTransport(httpOptions => - { - // Log headers for testing purposes - httpOptions.ConfigureSessionOptions = (httpContext, serverOptions, ct) => - { - var logger = httpContext.RequestServices.GetRequiredService>(); - logger.LogInformation(JsonSerializer.Serialize(httpContext.Request.Headers)); - return Task.CompletedTask; - }; - }); + .WithHttpTransport(); var app = builder.Build(); From 104e721cc874435f45cefd08c8a1c9b9d666c6ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Cant=C3=BA?= Date: Wed, 7 Jan 2026 21:34:51 -0600 Subject: [PATCH 5/6] Use Experimental ID MCP5002 instead of MEAI001 --- src/Common/Experimentals.cs | 3 +++ src/ModelContextProtocol/McpChatClientBuilderExtensions.cs | 5 ++--- src/ModelContextProtocol/ModelContextProtocol.csproj | 1 + .../UseMcpClientTests.cs | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Common/Experimentals.cs b/src/Common/Experimentals.cs index c81ef981e..b12f11a4b 100644 --- a/src/Common/Experimentals.cs +++ b/src/Common/Experimentals.cs @@ -24,4 +24,7 @@ internal static class Experimentals // public const string Tasks_DiagnosticId = "MCP5001"; // public const string Tasks_Message = "The Tasks feature is experimental within specification version 2025-11-25 and is subject to change. See SEP-1686 for more information."; // public const string Tasks_Url = "https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1686"; + + public const string UseMcpClient_DiagnosticId = "MCP5002"; + public const string UseMcpClient_Message = "The UseMcpClient middleware for integrating hosted MCP servers with IChatClient is experimental and subject to change."; } diff --git a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs index 058a2fdc1..c27a90ed1 100644 --- a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs @@ -4,6 +4,7 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. namespace ModelContextProtocol.Client; @@ -31,7 +32,7 @@ public static class McpChatClientBuilderExtensions /// Use this method as an alternative when working with chat providers that don't have built-in support for hosted MCP servers. /// /// - [Experimental("MEAI001")] + [Experimental(Experimentals.UseMcpClient_DiagnosticId)] public static ChatClientBuilder UseMcpClient( this ChatClientBuilder builder, HttpClient? httpClient = null, @@ -46,7 +47,6 @@ public static ChatClientBuilder UseMcpClient( }); } - [Experimental("MEAI001")] private sealed class McpChatClient : DelegatingChatClient { private readonly ILoggerFactory? _loggerFactory; @@ -240,7 +240,6 @@ private sealed class DummyHostedMcpServerTool(string serverName, Uri serverAddre /// An AI function wrapper that retries the invocation by recreating an MCP client when an occurs. /// For example, this can happen if a session is revoked or a server error occurs. The retry evicts the cached MCP client. /// - [Experimental("MEAI001")] private sealed class McpRetriableAIFunction : DelegatingAIFunction { private readonly HostedMcpServerTool _hostedMcpTool; diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj index fe394a056..a4f9fc9c4 100644 --- a/src/ModelContextProtocol/ModelContextProtocol.csproj +++ b/src/ModelContextProtocol/ModelContextProtocol.csproj @@ -15,6 +15,7 @@ + diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs index 8fa884464..5320ebc49 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs @@ -8,7 +8,8 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using Moq; -#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +#pragma warning disable MCP5002 +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. namespace ModelContextProtocol.AspNetCore.Tests; From 9920d160c4028ff41c06d3694c62d2ed39145ab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Cant=C3=BA?= Date: Wed, 7 Jan 2026 22:02:27 -0600 Subject: [PATCH 6/6] Trailing whitespace --- .../McpChatClientBuilderExtensions.cs | 14 ++-- .../McpClientTasksLruCache.cs | 2 +- .../UseMcpClientTests.cs | 80 +++++++++---------- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs index c27a90ed1..82f2cd45d 100644 --- a/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpChatClientBuilderExtensions.cs @@ -14,7 +14,7 @@ namespace ModelContextProtocol.Client; public static class McpChatClientBuilderExtensions { /// - /// Adds a chat client to the chat client pipeline that creates an for each + /// Adds a chat client to the chat client pipeline that creates an for each /// in and augments it with the tools from MCP servers as instances. /// /// The to configure. @@ -24,7 +24,7 @@ public static class McpChatClientBuilderExtensions /// The for method chaining. /// /// - /// When a HostedMcpServerTool is encountered in the tools collection, the client + /// When a HostedMcpServerTool is encountered in the tools collection, the client /// connects to the MCP server, retrieves available tools, and expands them into callable AI functions. /// Connections are cached by server address to avoid redundant connections. /// @@ -220,7 +220,7 @@ protected override void Dispose(bool disposing) await client.DisposeAsync().ConfigureAwait(false); } catch { } // allow the original exception to propagate - + throw; } } @@ -269,16 +269,16 @@ public McpRetriableAIFunction(AIFunction innerFunction, HostedMcpServerTool host var freshTool = await GetCurrentToolAsync().ConfigureAwait(false); return await freshTool.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); } - + private async Task GetCurrentToolAsync() { Debug.Assert(Uri.TryCreate(_hostedMcpTool.ServerAddress, UriKind.Absolute, out var parsedAddress) && (parsedAddress.Scheme == Uri.UriSchemeHttp || parsedAddress.Scheme == Uri.UriSchemeHttps), "Server address should have been validated before construction"); - var (client, tools) = await _chatClient.GetClientAndToolsAsync(_hostedMcpTool, _serverAddressUri!).ConfigureAwait(false); - - return tools.FirstOrDefault(t => t.Name == Name) ?? + var (_, tools) = await _chatClient.GetClientAndToolsAsync(_hostedMcpTool, _serverAddressUri!).ConfigureAwait(false); + + return tools.FirstOrDefault(t => t.Name == Name) ?? throw new McpProtocolException($"Tool '{Name}' no longer exists on the MCP server.", McpErrorCode.InvalidParams); } } diff --git a/src/ModelContextProtocol/McpClientTasksLruCache.cs b/src/ModelContextProtocol/McpClientTasksLruCache.cs index 9f9255a0c..5646aad37 100644 --- a/src/ModelContextProtocol/McpClientTasksLruCache.cs +++ b/src/ModelContextProtocol/McpClientTasksLruCache.cs @@ -71,7 +71,7 @@ public bool TryRemove(string key, [MaybeNullWhen(false)] out Task<(McpClient Cli return false; } } - + public void Dispose() { lock (_lock) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs index 5320ebc49..6d9a6bf3a 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/UseMcpClientTests.cs @@ -167,16 +167,16 @@ private IChatClient CreateTestChatClient(out LeafChatClientState leafClientState var mockInnerClient = new Mock(); mockInnerClient .Setup(c => c.GetResponseAsync( - It.IsAny>(), - It.IsAny(), + It.IsAny>(), + It.IsAny(), It.IsAny())) - .Returns((IEnumerable messages, ChatOptions? options, CancellationToken ct) => + .Returns((IEnumerable messages, ChatOptions? options, CancellationToken ct) => GetStreamingResponseAsync(messages, options, ct).ToChatResponseAsync(ct)); mockInnerClient .Setup(c => c.GetStreamingResponseAsync( - It.IsAny>(), - It.IsAny(), + It.IsAny>(), + It.IsAny(), It.IsAny())) .Returns(GetStreamingResponseAsync); @@ -184,12 +184,12 @@ private IChatClient CreateTestChatClient(out LeafChatClientState leafClientState return mockInnerClient.Object.AsBuilder() .UseMcpClient(HttpClient, LoggerFactory, configureTransportOptions) // Placement is important, must be after UseMcpClient, otherwise, UseFunctionInvocation won't see the MCP tools. - .UseFunctionInvocation() + .UseFunctionInvocation() .Build(); async IAsyncEnumerable GetStreamingResponseAsync( - IEnumerable messages, - ChatOptions? options, + IEnumerable messages, + ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken = default) { state.CapturedOptions = options; @@ -240,15 +240,15 @@ private static void AssertResponseWithInvocation(ChatResponse response) { Assert.NotNull(response); Assert.Equal(3, response.Messages.Count); - + Assert.Equal(ChatRole.Assistant, response.Messages[0].Role); Assert.Single(response.Messages[0].Contents); Assert.IsType(response.Messages[0].Contents[0]); - + Assert.Equal(ChatRole.Tool, response.Messages[1].Role); Assert.Single(response.Messages[1].Contents); Assert.IsType(response.Messages[1].Contents[0]); - + Assert.Equal(ChatRole.Assistant, response.Messages[2].Role); Assert.Equal("Final response", response.Messages[2].Text); } @@ -263,15 +263,15 @@ public async Task UseMcpClient_ShouldProduceTools(bool streaming, bool useUrl) // Arrange await using var _ = await StartServerAsync(); using IChatClient sut = CreateTestChatClient(out var leafClientState); - var mcpTool = useUrl ? - new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) : + var mcpTool = useUrl ? + new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) : new HostedMcpServerTool("serverName", HttpClient.BaseAddress!.ToString()); mcpTool.ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire; var options = new ChatOptions { Tools = [mcpTool] }; // Act - var response = streaming ? - await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); // Assert @@ -305,8 +305,8 @@ public async Task UseMcpClient_DoesNotConflictWithRegularTools(bool streaming) }; // Act - var response = streaming ? - await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); // Assert @@ -353,7 +353,7 @@ public async Task UseMcpClient_ApprovalMode(bool streaming, HostedMcpServerToolA var options = new ChatOptions { Tools = [mcpTool] }; // Act - var response = streaming ? + var response = streaming ? await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); @@ -398,7 +398,7 @@ public async Task UseMcpClient_HandleFunctionApprovalRequest(bool streaming, Hos // Act List chatHistory = []; chatHistory.Add(new ChatMessage(ChatRole.User, "Test message")); - var response = streaming ? + var response = streaming ? await sut.GetStreamingResponseAsync(chatHistory, options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync(chatHistory, options, TestContext.Current.CancellationToken); @@ -435,7 +435,7 @@ public async Task UseMcpClient_AllowedTools_FiltersCorrectly(bool streaming, str var options = new ChatOptions { Tools = [mcpTool] }; // Act - var response = streaming ? + var response = streaming ? await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); @@ -445,7 +445,7 @@ await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current var toolNames = leafClientState.CapturedOptions.Tools.Select(t => t.Name).ToList(); Assert.Equal(expectedTools.Length, toolNames.Count); Assert.Equivalent(expectedTools, toolNames); - + if (expectedTools.Contains("echo")) { AssertResponseWithInvocation(response); @@ -489,12 +489,12 @@ public async Task UseMcpClient_AuthorizationTokenHeaderFlowsCorrectly(bool strea context.Request.EnableBuffering(); JsonRpcRequest? rpcRequest = await JsonSerializer.DeserializeAsync( - context.Request.Body, - McpJsonUtilities.DefaultOptions, + context.Request.Body, + McpJsonUtilities.DefaultOptions, context.RequestAborted); context.Request.Body.Position = 0; Assert.NotNull(rpcRequest); - + switch (rpcRequest.Method) { case "initialize": @@ -514,7 +514,7 @@ public async Task UseMcpClient_AuthorizationTokenHeaderFlowsCorrectly(bool strea await next(); }); }); - + using IChatClient sut = CreateTestChatClient(out var leafClientState); var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) { @@ -527,8 +527,8 @@ public async Task UseMcpClient_AuthorizationTokenHeaderFlowsCorrectly(bool strea }; // Act - var response = streaming ? - await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : + var response = streaming ? + await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); // Assert @@ -576,7 +576,7 @@ public async Task UseMcpClient_CachesClientForSameServerAddress(bool streaming) await next(); }); }); - + using IChatClient sut = CreateTestChatClient(out var leafClientState); var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) { @@ -585,7 +585,7 @@ public async Task UseMcpClient_CachesClientForSameServerAddress(bool streaming) var options = new ChatOptions { Tools = [mcpTool] }; // Act - var response = streaming ? + var response = streaming ? await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); @@ -606,7 +606,7 @@ await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current leafClientState.Clear(); // Act - var secondResponse = streaming ? + var secondResponse = streaming ? await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); @@ -633,7 +633,7 @@ public async Task UseMcpClient_RetriesWhenSessionRevokedByServer(bool streaming) // Arrange string? firstSessionId = null; string? secondSessionId = null; - + await using var app = await StartServerAsync( configureApp: app => { @@ -643,14 +643,14 @@ public async Task UseMcpClient_RetriesWhenSessionRevokedByServer(bool streaming) { context.Request.EnableBuffering(); var rpcRequest = await JsonSerializer.DeserializeAsync( - context.Request.Body, + context.Request.Body, McpJsonUtilities.DefaultOptions); context.Request.Body.Position = 0; - + if (rpcRequest?.Method == "tools/call" && context.Request.Headers.TryGetValue("Mcp-Session-Id", out var sessionIdHeader)) { var sessionId = sessionIdHeader.ToString(); - + if (firstSessionId == null) { // First tool call - capture session and return 404 to revoke it @@ -668,19 +668,19 @@ public async Task UseMcpClient_RetriesWhenSessionRevokedByServer(bool streaming) await next(); }); }); - + using IChatClient sut = CreateTestChatClient(out var leafClientState); var mcpTool = new HostedMcpServerTool("serverName", HttpClient.BaseAddress!) { ApprovalMode = HostedMcpServerToolApprovalMode.NeverRequire }; var options = new ChatOptions { Tools = [mcpTool] }; - + // Act - var response = streaming ? + var response = streaming ? await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken); - + // Assert Assert.NotNull(firstSessionId); Assert.NotNull(secondSessionId); @@ -745,7 +745,7 @@ public async Task UseMcpClient_ConfigureTransportOptions_CallbackIsInvoked(bool HostedMcpServerTool? capturedTool = null; HttpClientTransportOptions? capturedTransportOptions = null; await using var _ = await StartServerAsync(); - + using IChatClient sut = CreateTestChatClient(out var leafClientState, (tool, transportOptions) => { capturedTool = tool; @@ -760,7 +760,7 @@ public async Task UseMcpClient_ConfigureTransportOptions_CallbackIsInvoked(bool var options = new ChatOptions { Tools = [mcpTool] }; // Act - var response = streaming ? + var response = streaming ? await sut.GetStreamingResponseAsync("Test message", options, TestContext.Current.CancellationToken).ToChatResponseAsync(TestContext.Current.CancellationToken) : await sut.GetResponseAsync("Test message", options, TestContext.Current.CancellationToken);