diff --git a/src/ModelContextProtocol.Core/Client/ClientCompletionDetails.cs b/src/ModelContextProtocol.Core/Client/ClientCompletionDetails.cs new file mode 100644 index 000000000..fc366bc0f --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/ClientCompletionDetails.cs @@ -0,0 +1,21 @@ +namespace ModelContextProtocol.Client; + +/// +/// Provides details about why an MCP client session completed. +/// +/// +/// +/// Transport implementations may return derived types with additional strongly-typed +/// information, such as . +/// +/// +public class ClientCompletionDetails +{ + /// + /// Gets the exception that caused the session to close, if any. + /// + /// + /// This is for graceful closure. + /// + public Exception? Exception { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Client/HttpClientCompletionDetails.cs b/src/ModelContextProtocol.Core/Client/HttpClientCompletionDetails.cs new file mode 100644 index 000000000..eee9bafca --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/HttpClientCompletionDetails.cs @@ -0,0 +1,15 @@ +using System.Net; + +namespace ModelContextProtocol.Client; + +/// +/// Provides details about the completion of an HTTP-based MCP client session, +/// including sessions using the legacy SSE transport or the Streamable HTTP transport. +/// +public sealed class HttpClientCompletionDetails : ClientCompletionDetails +{ + /// + /// Gets the HTTP status code that caused the session to close, or if unavailable. + /// + public HttpStatusCode? HttpStatusCode { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index acd0bb12d..406969121 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -53,4 +53,21 @@ protected McpClient() /// /// public abstract string? ServerInstructions { get; } + + /// + /// Gets a that completes when the client session has completed. + /// + /// + /// + /// The task always completes successfully. The result provides details about why the session + /// completed. Transport implementations may return derived types with additional strongly-typed + /// information, such as . + /// + /// + /// For graceful closure (e.g., explicit disposal), + /// will be . For unexpected closure (e.g., process crash, network failure), + /// it may contain an exception that caused or that represents the failure. + /// + /// + public abstract Task Completion { get; } } diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs index 8317656d0..4205c28e1 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -521,6 +521,9 @@ private void RegisterTaskHandlers(RequestHandlers requestHandlers, IMcpTaskStore /// public override string? ServerInstructions => _serverInstructions; + /// + public override Task Completion => _sessionHandler.CompletionTask; + /// /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. /// @@ -655,6 +658,14 @@ public override async ValueTask DisposeAsync() _taskCancellationTokenProvider?.Dispose(); await _sessionHandler.DisposeAsync().ConfigureAwait(false); await _transport.DisposeAsync().ConfigureAwait(false); + + // After disposal, the channel writer is complete but ProcessMessagesCoreAsync + // may have been cancelled with unread items still buffered. ChannelReader.Completion + // only resolves once all items are consumed, so drain remaining items. + while (_transport.MessageReader.TryRead(out var _)); + + // Then ensure all work has quiesced. + await Completion.ConfigureAwait(false); } [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 085f8ccfe..0bcc69417 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.Diagnostics; +using System.Net; using System.Net.Http.Headers; using System.Net.ServerSentEvents; using System.Text.Json; @@ -124,7 +125,7 @@ private async Task CloseAsync() } finally { - SetDisconnected(); + SetDisconnected(new TransportClosedException(new HttpClientCompletionDetails())); } } @@ -143,6 +144,7 @@ public override async ValueTask DisposeAsync() private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { + HttpStatusCode? failureStatusCode = null; try { using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); @@ -151,6 +153,11 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) using var response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); + if (!response.IsSuccessStatusCode) + { + failureStatusCode = response.StatusCode; + } + await response.EnsureSuccessStatusCodeWithResponseBodyAsync(cancellationToken).ConfigureAwait(false); using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); @@ -179,6 +186,12 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) } else { + SetDisconnected(new TransportClosedException(new HttpClientCompletionDetails + { + HttpStatusCode = failureStatusCode, + Exception = ex, + })); + LogTransportReadMessagesFailed(Name, ex); _connectionEstablished.TrySetException(ex); throw; @@ -186,7 +199,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) } finally { - SetDisconnected(); + SetDisconnected(new TransportClosedException(new HttpClientCompletionDetails())); } } diff --git a/src/ModelContextProtocol.Core/Client/StdioClientCompletionDetails.cs b/src/ModelContextProtocol.Core/Client/StdioClientCompletionDetails.cs new file mode 100644 index 000000000..9fc845de6 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/StdioClientCompletionDetails.cs @@ -0,0 +1,22 @@ +namespace ModelContextProtocol.Client; + +/// +/// Provides details about the completion of a stdio-based MCP client session. +/// +public sealed class StdioClientCompletionDetails : ClientCompletionDetails +{ + /// + /// Gets the process ID of the server process, or if unavailable. + /// + public int? ProcessId { get; set; } + + /// + /// Gets the exit code of the server process, or if unavailable. + /// + public int? ExitCode { get; set; } + + /// + /// Gets the last lines of the server process's standard error output, or if unavailable. + /// + public IReadOnlyList? StandardErrorTail { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs index f8b10746f..a92093246 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs @@ -5,14 +5,22 @@ namespace ModelContextProtocol.Client; /// Provides the client side of a stdio-based session transport. -internal sealed class StdioClientSessionTransport( - StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, ILoggerFactory? loggerFactory) : - StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory) +internal sealed class StdioClientSessionTransport : StreamClientSessionTransport { - private readonly StdioClientTransportOptions _options = options; - private readonly Process _process = process; - private readonly Queue _stderrRollingLog = stderrRollingLog; + private readonly StdioClientTransportOptions _options; + private readonly Process _process; + private readonly Queue _stderrRollingLog; private int _cleanedUp = 0; + private readonly int? _processId; + + public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, ILoggerFactory? loggerFactory) : + base(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory) + { + _options = options; + _process = process; + _stderrRollingLog = stderrRollingLog; + try { _processId = process.Id; } catch { } + } /// public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) @@ -47,17 +55,26 @@ protected override async ValueTask CleanupAsync(Exception? error = null, Cancell // so create an exception with details about that. error ??= await GetUnexpectedExitExceptionAsync(cancellationToken).ConfigureAwait(false); - // Now terminate the server process. + // Terminate the server process (or confirm it already exited), then build + // and publish strongly-typed completion details while the process handle + // is still valid so we can read the exit code. try { - StdioClientTransport.DisposeProcess(_process, processRunning: true, shutdownTimeout: _options.ShutdownTimeout); + StdioClientTransport.DisposeProcess( + _process, + processRunning: true, + _options.ShutdownTimeout, + beforeDispose: () => SetDisconnected(new TransportClosedException(BuildCompletionDetails(error)))); } catch (Exception ex) { LogTransportShutdownFailed(Name, ex); + SetDisconnected(new TransportClosedException(BuildCompletionDetails(error))); } - // And handle cleanup in the base type. + // And handle cleanup in the base type. SetDisconnected has already been + // called above, so the base call is a no-op for disconnect state but + // still performs other cleanup (cancelling the read task, etc.). await base.CleanupAsync(error, cancellationToken).ConfigureAwait(false); } @@ -104,4 +121,32 @@ protected override async ValueTask CleanupAsync(Exception? error = null, Cancell return new IOException(errorMessage); } + + private StdioClientCompletionDetails BuildCompletionDetails(Exception? error) + { + StdioClientCompletionDetails details = new() + { + Exception = error, + ProcessId = _processId, + }; + + try + { + if (StdioClientTransport.HasExited(_process)) + { + details.ExitCode = _process.ExitCode; + } + } + catch { } + + lock (_stderrRollingLog) + { + if (_stderrRollingLog.Count > 0) + { + details.StandardErrorTail = _stderrRollingLog.ToArray(); + } + } + + return details; + } } diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs index 191a512b1..24a47613b 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs @@ -213,7 +213,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = } internal static void DisposeProcess( - Process? process, bool processRunning, TimeSpan shutdownTimeout) + Process? process, bool processRunning, TimeSpan shutdownTimeout, Action? beforeDispose = null) { if (process is not null) { @@ -239,6 +239,10 @@ internal static void DisposeProcess( { process.WaitForExit(); } + + // Invoke the callback while the process handle is still valid, + // e.g. to read ExitCode before Dispose() invalidates it. + beforeDispose?.Invoke(); } finally { diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index c19283036..2bb34a761 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -20,11 +20,12 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private readonly McpHttpClient _httpClient; private readonly HttpClientTransportOptions _options; - private readonly CancellationTokenSource _connectionCts; + private readonly CancellationTokenSource _connectionCts = new(); private readonly ILogger _logger; private string? _negotiatedProtocolVersion; private Task? _getReceiveTask; + private volatile TransportClosedException? _disconnectError; private readonly SemaphoreSlim _disposeLock = new(1, 1); private bool _disposed; @@ -42,7 +43,6 @@ public StreamableHttpClientSessionTransport( _options = transportOptions; _httpClient = httpClient; - _connectionCts = new CancellationTokenSource(); _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; // We connect with the initialization request with the MCP transport. This means that any errors won't be observed @@ -96,6 +96,13 @@ internal async Task SendHttpRequestAsync(JsonRpcMessage mes // We'll let the caller decide whether to throw or fall back given an unsuccessful response. if (!response.IsSuccessStatusCode) { + // Per the MCP spec, a 404 response to a request containing an Mcp-Session-Id + // indicates the session has ended. Signal completion so McpClient.Completion resolves. + if (response.StatusCode == HttpStatusCode.NotFound && SessionId is not null) + { + SetSessionExpired(response.StatusCode); + } + return response; } @@ -184,10 +191,6 @@ public override async ValueTask DisposeAsync() { LogTransportShutdownFailed(Name, ex); } - finally - { - _connectionCts.Dispose(); - } } finally { @@ -195,7 +198,7 @@ public override async ValueTask DisposeAsync() // This class isn't directly exposed to public callers, so we don't have to worry about changing the _state in this case. if (_options.TransportMode is not HttpTransportMode.AutoDetect || _getReceiveTask is not null) { - SetDisconnected(); + SetDisconnected(_disconnectError ?? new TransportClosedException(new HttpClientCompletionDetails())); } } } @@ -204,8 +207,8 @@ private async Task ReceiveUnsolicitedMessagesAsync() { var state = new SseStreamState(); - // Continuously receive unsolicited messages until canceled - while (!_connectionCts.Token.IsCancellationRequested) + // Continuously receive unsolicited messages until canceled or disconnected + while (!_connectionCts.Token.IsCancellationRequested && IsConnected) { await SendGetSseRequestWithRetriesAsync( relatedRpcRequest: null, @@ -285,6 +288,13 @@ await SendGetSseRequestWithRetriesAsync( if (!response.IsSuccessStatusCode) { + // Per the MCP spec, a 404 response to a request containing an Mcp-Session-Id + // indicates the session has ended. Signal completion so McpClient.Completion resolves. + if (response.StatusCode == HttpStatusCode.NotFound && SessionId is not null) + { + SetSessionExpired(response.StatusCode); + } + // If the server could be reached but returned a non-success status code, // retrying likely won't change that. return null; @@ -474,4 +484,23 @@ private static TimeSpan ElapsedSince(long stopwatchTimestamp) return TimeSpan.FromSeconds((double)(Stopwatch.GetTimestamp() - stopwatchTimestamp) / Stopwatch.Frequency); #endif } + + private void SetSessionExpired(HttpStatusCode statusCode) + { + // Store the error before canceling so DisposeAsync can use it if it races us, especially + // after the call to Cancel below, to invoke SetDisconnected. + _disconnectError = new TransportClosedException(new HttpClientCompletionDetails + { + HttpStatusCode = statusCode, + Exception = new McpException( + "The server returned HTTP 404 for a request with an Mcp-Session-Id, indicating the session has expired. " + + "To continue, create a new client session or call ResumeSessionAsync with a new connection."), + }); + + // Cancel to unblock any in-flight operations (e.g., SSE stream reads in + // SendGetSseRequestWithRetriesAsync) that are waiting on _connectionCts.Token. + _connectionCts.Cancel(); + + SetDisconnected(_disconnectError); + } } diff --git a/src/ModelContextProtocol.Core/Client/TransportClosedException.cs b/src/ModelContextProtocol.Core/Client/TransportClosedException.cs new file mode 100644 index 000000000..55d711991 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/TransportClosedException.cs @@ -0,0 +1,19 @@ +using ModelContextProtocol.Protocol; +using System.Threading.Channels; + +namespace ModelContextProtocol.Client; + +/// +/// used to smuggle through +/// the mechanism. +/// +/// +/// This could be made public in the future to allow custom +/// implementations to provide their own -derived types +/// by completing their channel with this exception. +/// +internal sealed class TransportClosedException(ClientCompletionDetails details) : + IOException(details.Exception?.Message, details.Exception) +{ + public ClientCompletionDetails Details { get; } = details; +} diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index 4d9cb01ba..cd4170edf 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -132,6 +132,7 @@ public McpSessionHandler( _incomingMessageFilter = incomingMessageFilter ?? (next => next); _outgoingMessageFilter = outgoingMessageFilter ?? (next => next); _logger = logger; + LogSessionCreated(EndpointName, _sessionId, _transportKind); } @@ -145,6 +146,15 @@ public McpSessionHandler( /// public string? NegotiatedProtocolVersion { get; set; } + /// + /// Gets a task that completes when the client session has completed, providing details about the closure. + /// Completion details are resolved from the transport's channel completion exception: if a transport + /// completes its channel with a , the wrapped + /// is unwrapped. Otherwise, a default instance is returned. + /// + internal Task CompletionTask => + field ??= GetCompletionDetailsAsync(_transport.MessageReader.Completion); + /// /// Starts processing messages from the transport. This method will block until the transport is disconnected. /// This is generally started in a background task or thread from the initialization logic of the derived class. @@ -297,6 +307,28 @@ ex is OperationCanceledException && } } + /// + /// Resolves from the transport's channel completion. + /// If the channel was completed with a , the wrapped + /// details are returned. Otherwise a default instance is created from the completion state. + /// + private static async Task GetCompletionDetailsAsync(Task channelCompletion) + { + try + { + await channelCompletion.ConfigureAwait(false); + return new ClientCompletionDetails(); + } + catch (TransportClosedException tce) + { + return tce.Details; + } + catch (Exception ex) + { + return new ClientCompletionDetails { Exception = ex }; + } + } + private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken) { Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; @@ -877,9 +909,11 @@ public async ValueTask DisposeAsync() { await _messageProcessingTask.ConfigureAwait(false); } - catch (OperationCanceledException) + catch { - // Ignore cancellation + // Ignore exceptions from the message processing loop. It may fault with + // OperationCanceledException on normal shutdown or TransportClosedException + // when the transport's channel completes with an error. } } diff --git a/src/ModelContextProtocol.Core/Protocol/TransportBase.cs b/src/ModelContextProtocol.Core/Protocol/TransportBase.cs index e3e8e8c8b..c1c642c30 100644 --- a/src/ModelContextProtocol.Core/Protocol/TransportBase.cs +++ b/src/ModelContextProtocol.Core/Protocol/TransportBase.cs @@ -90,13 +90,15 @@ internal TransportBase(string name, Channel? messageChannel, ILo /// The to monitor for cancellation requests. The default is . protected async Task WriteMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + cancellationToken.ThrowIfCancellationRequested(); + if (!IsConnected) { - throw new InvalidOperationException("Transport is not connected."); + // Transport disconnected concurrently. Silently drop rather than throw, + // to avoid surfacing spurious errors during shutdown races. + return; } - cancellationToken.ThrowIfCancellationRequested(); - if (_logger.IsEnabled(LogLevel.Debug)) { var messageId = (message as JsonRpcMessageWithId)?.Id.ToString() ?? "(no id)"; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 0af1bdc68..5f961fe32 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -303,4 +303,19 @@ public async Task CallTool_Sse_EchoServer_Concurrently() Assert.Equal($"Echo: Hello MCP! {i}", textContent.Text); } } + + [Fact] + public async Task Completion_GracefulDisposal_ReturnsCompletionDetails() + { + var client = await GetClientAsync(); + Assert.False(client.Completion.IsCompleted); + + await client.DisposeAsync(); + Assert.True(client.Completion.IsCompleted); + + var details = await client.Completion; + var httpDetails = Assert.IsType(details); + Assert.Null(httpDetails.Exception); + Assert.Null(httpDetails.HttpStatusCode); + } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 35adccef9..5ed7a48ef 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -307,6 +307,25 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints, b }); } + [Fact] + public async Task Completion_ServerShutdown_ReturnsHttpCompletionDetails() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + await using var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + + var mcpClient = await ConnectMcpClientAsync(); + Assert.False(mcpClient.Completion.IsCompleted); + + // Stop the server while the client is still connected. + await app.StopAsync(TestContext.Current.CancellationToken); + + var details = await mcpClient.Completion.WaitAsync(TestContext.Current.CancellationToken); + var httpDetails = Assert.IsType(details); + Assert.Null(httpDetails.HttpStatusCode); + } + public class Envelope { public required string Message { get; set; } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 296af1c95..c8e3f8d7b 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -1,4 +1,4 @@ -using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Json; using Microsoft.Extensions.DependencyInjection; @@ -7,6 +7,7 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; +using System.Net; using System.Threading; using System.Threading.Tasks; using System.Text.Json; @@ -370,6 +371,159 @@ public async Task DisposeAsync_DoesNotHang_WhenOwnsSessionIsFalse_WithActiveGetS } } + [Fact] + public async Task Completion_SessionExpiredOnPost_ReturnsHttpCompletionDetails() + { + bool expireSession = false; + + Builder.Services.Configure(options => + { + options.SerializerOptions.TypeInfoResolverChain.Add(McpJsonUtilities.DefaultOptions.TypeInfoResolver!); + }); + _app = Builder.Build(); + + _app.MapPost("/mcp", (JsonRpcMessage message, HttpContext context) => + { + if (message is not JsonRpcRequest request) + { + return Results.Accepted(); + } + + context.Response.Headers.Append("mcp-session-id", "expiry-test-session"); + + if (expireSession) + { + return Results.NotFound(); + } + + if (request.Method == "initialize") + { + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new InitializeResult + { + ProtocolVersion = "2024-11-05", + Capabilities = new() { Tools = new() }, + ServerInfo = new Implementation { Name = "expiry-test", Version = "0.0.1" }, + }, McpJsonUtilities.DefaultOptions) + }); + } + + return Results.Accepted(); + }); + + await _app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal("expiry-test-session", client.SessionId); + Assert.False(client.Completion.IsCompleted); + + // Simulate session expiry by having the server return 404 + expireSession = true; + + await Assert.ThrowsAnyAsync(async () => + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)); + + var details = await client.Completion.WaitAsync(TestContext.Current.CancellationToken); + var httpDetails = Assert.IsType(details); + Assert.Equal(HttpStatusCode.NotFound, httpDetails.HttpStatusCode); + Assert.NotNull(httpDetails.Exception); + } + + [Fact] + public async Task Completion_SessionExpiredOnGet_ReturnsHttpCompletionDetails() + { + var expireSession = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + Builder.Services.Configure(options => + { + options.SerializerOptions.TypeInfoResolverChain.Add(McpJsonUtilities.DefaultOptions.TypeInfoResolver!); + }); + _app = Builder.Build(); + + _app.MapPost("/mcp", (JsonRpcMessage message, HttpContext context) => + { + if (message is not JsonRpcRequest request) + { + return Results.Accepted(); + } + + context.Response.Headers.Append("mcp-session-id", "get-expiry-test"); + + if (request.Method == "initialize") + { + return Results.Json(new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new InitializeResult + { + ProtocolVersion = "2024-11-05", + Capabilities = new() { Tools = new() }, + ServerInfo = new Implementation { Name = "get-expiry-test", Version = "0.0.1" }, + }, McpJsonUtilities.DefaultOptions) + }); + } + + return Results.Accepted(); + }); + + // GET handler waits for the signal, then returns 404 to simulate session expiry + _app.MapGet("/mcp", async (HttpContext context) => + { + await expireSession.Task; + context.Response.StatusCode = StatusCodes.Status404NotFound; + }); + + await _app.StartAsync(TestContext.Current.CancellationToken); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal("get-expiry-test", client.SessionId); + + // Trigger session expiry on the GET SSE stream + expireSession.SetResult(); + + var details = await client.Completion.WaitAsync(TestContext.Current.CancellationToken); + var httpDetails = Assert.IsType(details); + Assert.Equal(HttpStatusCode.NotFound, httpDetails.HttpStatusCode); + Assert.NotNull(httpDetails.Exception); + } + + [Fact] + public async Task Completion_GracefulDisposal_ReturnsCompletionDetails() + { + await StartAsync(enableDelete: true); + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new("http://localhost:5000/mcp"), + TransportMode = HttpTransportMode.StreamableHttp, + }, HttpClient, LoggerFactory); + + var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + Assert.False(client.Completion.IsCompleted); + + await client.DisposeAsync(); + Assert.True(client.Completion.IsCompleted); + + var details = await client.Completion; + var httpDetails = Assert.IsType(details); + Assert.Null(httpDetails.Exception); + Assert.Null(httpDetails.HttpStatusCode); + } + private static async Task CallEchoAndValidateAsync(McpClientTool echoTool) { var response = await echoTool.CallAsync(new Dictionary() { ["message"] = "Hello world!" }, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index a62ddf252..89ad9987d 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -181,6 +181,23 @@ private static void ConfigureTools(McpServerOptions options, string? cliArg) { TaskSupport = ToolTaskSupport.Optional } + }, + new Tool + { + Name = "crash", + Description = "Terminates the server process with a specified exit code.", + InputSchema = JsonElement.Parse(""" + { + "type": "object", + "properties": { + "exitCode": { + "type": "number", + "description": "The exit code to terminate with" + } + }, + "required": ["exitCode"] + } + """), } ] }; @@ -241,6 +258,17 @@ private static void ConfigureTools(McpServerOptions options, string? cliArg) Content = [new TextContentBlock { Text = $"Long-running operation completed after {durationMs}ms" }] }; } + else if (request.Params?.Name == "crash") + { + if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("exitCode", out var exitCodeValue)) + { + throw new McpProtocolException("Missing required argument 'exitCode'", McpErrorCode.InvalidParams); + } + int exitCode = Convert.ToInt32(exitCodeValue.GetRawText()); + Console.Error.WriteLine($"Crashing with exit code {exitCode}"); + Environment.Exit(exitCode); + throw new Exception("unreachable"); + } else { throw new McpProtocolException($"Unknown tool: {request.Params?.Name}", McpErrorCode.InvalidParams); diff --git a/tests/ModelContextProtocol.Tests/Client/ClientCompletionDetailsTests.cs b/tests/ModelContextProtocol.Tests/Client/ClientCompletionDetailsTests.cs new file mode 100644 index 000000000..b4d50850b --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/ClientCompletionDetailsTests.cs @@ -0,0 +1,95 @@ +using ModelContextProtocol.Client; + +namespace ModelContextProtocol.Tests.Client; + +public class ClientCompletionDetailsTests +{ + [Fact] + public void ClientCompletionDetails_PropertiesRoundtrip() + { + var exception = new InvalidOperationException("test"); + var details = new ClientCompletionDetails + { + Exception = exception, + }; + + Assert.Same(exception, details.Exception); + } + + [Fact] + public void ClientCompletionDetails_DefaultsToNull() + { + var details = new ClientCompletionDetails(); + Assert.Null(details.Exception); + } + + [Fact] + public void StdioClientCompletionDetails_PropertiesRoundtrip() + { + var exception = new IOException("process exited"); + string[] stderrLines = ["error line 1", "error line 2"]; + + var details = new StdioClientCompletionDetails + { + Exception = exception, + ProcessId = 12345, + ExitCode = 42, + StandardErrorTail = stderrLines, + }; + + Assert.Same(exception, details.Exception); + Assert.Equal(12345, details.ProcessId); + Assert.Equal(42, details.ExitCode); + Assert.Same(stderrLines, details.StandardErrorTail); + } + + [Fact] + public void StdioClientCompletionDetails_DefaultsToNull() + { + var details = new StdioClientCompletionDetails(); + Assert.Null(details.Exception); + Assert.Null(details.ProcessId); + Assert.Null(details.ExitCode); + Assert.Null(details.StandardErrorTail); + } + + [Fact] + public void StdioClientCompletionDetails_IsClientCompletionDetails() + { + ClientCompletionDetails details = new StdioClientCompletionDetails { ExitCode = 1 }; + var stdio = Assert.IsType(details); + Assert.Equal(1, stdio.ExitCode); + } + + [Fact] + public void HttpClientCompletionDetails_PropertiesRoundtrip() + { + var exception = new HttpRequestException("connection refused"); + + var details = new HttpClientCompletionDetails + { + Exception = exception, + HttpStatusCode = System.Net.HttpStatusCode.NotFound, + }; + + Assert.Same(exception, details.Exception); + Assert.Equal(System.Net.HttpStatusCode.NotFound, details.HttpStatusCode); + } + + [Fact] + public void HttpClientCompletionDetails_DefaultsToNull() + { + var details = new HttpClientCompletionDetails(); + Assert.Null(details.Exception); + Assert.Null(details.HttpStatusCode); + } + + [Fact] + public void HttpClientCompletionDetails_IsClientCompletionDetails() + { + ClientCompletionDetails details = new HttpClientCompletionDetails { HttpStatusCode = System.Net.HttpStatusCode.NotFound }; + var http = Assert.IsType(details); + Assert.Equal(System.Net.HttpStatusCode.NotFound, http.HttpStatusCode); + } + +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs index 504b52e21..8fd67f7ac 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs @@ -112,7 +112,11 @@ private class NopTransport : ITransport, IClientTransport public Task ConnectAsync(CancellationToken cancellationToken = default) => Task.FromResult(this); - public ValueTask DisposeAsync() => default; + public ValueTask DisposeAsync() + { + _channel.Writer.TryComplete(); + return default; + } public string Name => "Test Nop Transport"; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs index 32e04da60..0ae491171 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -782,4 +782,17 @@ public async Task SetLoggingLevelAsync_WithRequestParams_NullThrows() await Assert.ThrowsAsync("requestParams", () => client.SetLoggingLevelAsync((SetLevelRequestParams)null!, TestContext.Current.CancellationToken)); } + + [Fact] + public async Task Completion_GracefulDisposal_CompletesWithNoException() + { + var client = await CreateMcpClientForServer(); + Assert.False(client.Completion.IsCompleted); + + await client.DisposeAsync(); + Assert.True(client.Completion.IsCompleted); + + var details = await client.Completion; + Assert.Null(details.Exception); + } } diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 079be04f7..70553ee45 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -777,4 +777,43 @@ await client.UnsubscribeFromResourceAsync( [JsonSerializable(typeof(TestNotification))] partial class JsonContext3 : JsonSerializerContext; + + [Fact] + public async Task Completion_Stdio_GracefulDisposal_ReturnsStdioDetails() + { + var client = await _fixture.CreateClientAsync("test_server"); + Assert.False(client.Completion.IsCompleted); + + await client.DisposeAsync(); + Assert.True(client.Completion.IsCompleted); + + var details = await client.Completion.WaitAsync(TestContext.Current.CancellationToken); + var stdioDetails = Assert.IsType(details); + Assert.Null(stdioDetails.Exception); + Assert.NotNull(stdioDetails.ProcessId); + Assert.True(stdioDetails.ProcessId > 0); + Assert.NotNull(stdioDetails.ExitCode); + } + + [Fact] + public async Task Completion_Stdio_ServerCrash_ReturnsExitCodeAndStderr() + { + var client = await _fixture.CreateClientAsync("test_server"); + + // Tell the server to crash with a specific exit code. + // CallToolAsync will throw because the server exits before responding. + await Assert.ThrowsAnyAsync(async () => await client.CallToolAsync( + "crash", + new Dictionary { ["exitCode"] = 42 }, + cancellationToken: TestContext.Current.CancellationToken)); + + var details = await client.Completion.WaitAsync(TestContext.Current.CancellationToken); + var stdioDetails = Assert.IsType(details); + + Assert.NotNull(stdioDetails.ProcessId); + Assert.True(stdioDetails.ProcessId > 0); + Assert.Equal(42, stdioDetails.ExitCode); + Assert.NotNull(stdioDetails.StandardErrorTail); + Assert.Contains(stdioDetails.StandardErrorTail, line => line.Contains("Crashing with exit code 42")); + } } diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index 16cd352c3..d84ea9377 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -21,10 +21,10 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws() new(new() { Command = "cmd", Arguments = ["/c", $"echo {id} >&2 & exit /b 1"] }, LoggerFactory) : new(new() { Command = "sh", Arguments = ["-c", $"echo {id} >&2; exit 1"] }, LoggerFactory); - await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAnyAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } - [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] + [Fact(Skip= "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() { string id = Guid.NewGuid().ToString("N"); @@ -45,7 +45,7 @@ public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() new(new() { Command = "cmd", Arguments = ["/c", $"echo {id} >&2 & exit /b 1"], StandardErrorLines = stdErrCallback }, LoggerFactory) : new(new() { Command = "sh", Arguments = ["-c", $"echo {id} >&2; exit 1"], StandardErrorLines = stdErrCallback }, LoggerFactory); - await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAnyAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); // The stderr reading thread may not have delivered the callback yet // after the IOException is thrown. Poll briefly for it to arrive.