diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index b3a51957b..4d866cf32 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -479,12 +479,19 @@ internal static string MakeNewSessionId() // Implementation for reading a JSON-RPC message from the request body var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted); - if (context.User?.Identity?.IsAuthenticated == true && message is not null) + if (message is null) { - message.Context = new() - { - User = context.User, - }; + return null; + } + + message.Context = new() + { + CancellationToken = context.RequestAborted, + }; + + if (context.User?.Identity?.IsAuthenticated == true) + { + message.Context.User = context.User; } return message; diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index 1aa444692..29c013e2d 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -193,7 +193,16 @@ async Task ProcessMessageAsync() // out of order. if (messageWithId is not null) { - combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var requestCancellationToken = message.Context?.CancellationToken; + if (requestCancellationToken is { CanBeCanceled: true }) + { + combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, requestCancellationToken.Value); + } + else + { + combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + } + _handlingRequests[messageWithId.Id] = combinedCts; } @@ -523,8 +532,6 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc LogSendingRequest(EndpointName, request.Method); } - await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); - // Now that the request has been sent, register for cancellation. If we registered before, // a cancellation request could arrive before the server knew about that request ID, in which // case the server could ignore it. @@ -532,6 +539,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Canc JsonRpcMessage? response; using (var registration = RegisterCancellation(cancellationToken, request)) { + await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); } diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs index b9c9a2483..08da9b378 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -74,4 +74,13 @@ public sealed class JsonRpcMessageContext /// /// public IDictionary? Items { get; set; } + + /// + /// Gets or sets the cancellation token associated with the transport request that carried this JSON-RPC message. + /// + /// + /// For HTTP transports, this can be linked to the underlying HTTP request's aborted cancellation token (e.g. HttpContext.RequestAborted) + /// to propagate transport-level cancellations to tool executions. + /// + public CancellationToken? CancellationToken { get; set; } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 0af1bdc68..b70d12ec0 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -23,7 +23,7 @@ public override void Dispose() protected abstract HttpClientTransportOptions ClientTransportOptions { get; } - private Task GetClientAsync(McpClientOptions? options = null) + protected Task GetClientAsync(McpClientOptions? options = null) { return _fixture.ConnectMcpClientAsync(options, LoggerFactory); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index b2b0b5499..8c86d5ba2 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; using System.Text; namespace ModelContextProtocol.AspNetCore.Tests; @@ -60,4 +61,57 @@ public async Task EventSourceStream_Includes_MessageEventType() var messageEvent = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken); Assert.Equal("event: message", messageEvent); } + + [Fact] + public async Task CallTool_CancelToken_SendsCancellationNotification_KeepsConnectionOpen() + { + await using var client = await GetClientAsync(); + + using CancellationTokenSource cts = new(); + var toolTask = client.CallToolAsync( + "longRunning", + new Dictionary { ["durationMs"] = 10000 }, + cancellationToken: cts.Token + ); + + // Allow some time for the request to be sent + await Task.Delay(500, TestContext.Current.CancellationToken); + + cts.Cancel(); + + // Client throws OperationCanceledException + await Assert.ThrowsAnyAsync(async () => await toolTask); + + // Verify the connection is still open by pinging + var pingResult = await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(pingResult); + } + + [Fact] + public async Task CallTool_ClientDisconnectsAbruptly_CancelsServerToken() + { + var client = await GetClientAsync(); + + // Send the tool call + var toolTask = client.CallToolAsync( + "longRunning", + new Dictionary { ["durationMs"] = 10000 }, + cancellationToken: TestContext.Current.CancellationToken + ); + + // Allow some time for the request to be sent and processing to start on the server + await Task.Delay(500, TestContext.Current.CancellationToken); + + // Disposing the client will tear down the transport and drop the underlying HTTP connection, + // simulating a client crash or network drop without sending notifications/cancelled. + await client.DisposeAsync(); + + // The local client task will throw because the transport disconnected + await Assert.ThrowsAnyAsync(async () => await toolTask); + + // Verify the server is still alive and handling requests from a *new* client + await using var newClient = await GetClientAsync(); + var pingResult = await newClient.PingAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(pingResult); + } } diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 079be04f7..5d1dad112 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -112,6 +112,60 @@ public async Task CallTool_Stdio_EchoSessionId_ReturnsEmpty() Assert.Empty(textContent.Text); } + [Fact] + public async Task CallTool_Stdio_CancelToken_ThrowsOperationCanceledException() + { + await using var client = await _fixture.CreateClientAsync("test_server"); + + using CancellationTokenSource cts = new(); + var toolTask = client.CallToolAsync( + "longRunning", + new Dictionary { ["durationMs"] = 10000 }, + cancellationToken: cts.Token + ); + + // Allow some time for the request to be sent + await Task.Delay(500, TestContext.Current.CancellationToken); + + cts.Cancel(); + + // Client throws OperationCanceledException + await Assert.ThrowsAnyAsync(async () => await toolTask); + + // Verify the connection is still open by pinging + var pingResult = await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(pingResult); + } + + [Fact] + public async Task CallTool_Stdio_ClientDisconnectsAbruptly_CancelsServerToken() + { + var client = await _fixture.CreateClientAsync("test_server"); + + // Send the tool call + var toolTask = client.CallToolAsync( + "longRunning", + new Dictionary { ["durationMs"] = 10000 }, + cancellationToken: TestContext.Current.CancellationToken + ); + + // Allow some time for the request to be sent and processing to start + await Task.Delay(500, TestContext.Current.CancellationToken); + + // Disposing the client will tear down the stdio pipes abruptly + await client.DisposeAsync(); + + // The local client task will throw because the transport disconnected + await Assert.ThrowsAnyAsync(async () => await toolTask); + + // Verify the server process was terminated or we can create a new connection depending on server behavior. + // For Stdio, the server process is typically isolated to the client connection instance, + // so we start a new client to ensure the transport factory is healthy. + await using var newClient = await _fixture.CreateClientAsync("test_server"); + var pingResult = await newClient.PingAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.NotNull(pingResult); + } + [Theory] [MemberData(nameof(GetClients))] public async Task CallTool_Stdio_ViaAIFunction_EchoServer(string clientId)