Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,19 @@ internal static string MakeNewSessionId()
// Implementation for reading a JSON-RPC message from the request body
var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted);

if (context.User?.Identity?.IsAuthenticated == true && message is not null)
if (message is null)
{
message.Context = new()
{
User = context.User,
};
return null;
}

message.Context = new()
{
CancellationToken = context.RequestAborted,
};

if (context.User?.Identity?.IsAuthenticated == true)
{
message.Context.User = context.User;
}

return message;
Expand Down
14 changes: 11 additions & 3 deletions src/ModelContextProtocol.Core/McpSessionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,16 @@ async Task ProcessMessageAsync()
// out of order.
if (messageWithId is not null)
{
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var requestCancellationToken = message.Context?.CancellationToken;
if (requestCancellationToken is { CanBeCanceled: true })
{
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, requestCancellationToken.Value);
}
else
{
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
}

_handlingRequests[messageWithId.Id] = combinedCts;
}

Expand Down Expand Up @@ -523,15 +532,14 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
LogSendingRequest(EndpointName, request.Method);
}

await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false);

// Now that the request has been sent, register for cancellation. If we registered before,
// a cancellation request could arrive before the server knew about that request ID, in which
// case the server could ignore it.
LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id);
JsonRpcMessage? response;
using (var registration = RegisterCancellation(cancellationToken, request))
{
await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false);
response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,13 @@ public sealed class JsonRpcMessageContext
/// </para>
/// </remarks>
public IDictionary<string, object?>? Items { get; set; }

/// <summary>
/// Gets or sets the cancellation token associated with the transport request that carried this JSON-RPC message.
/// </summary>
/// <remarks>
/// For HTTP transports, this can be linked to the underlying HTTP request's aborted cancellation token (e.g. HttpContext.RequestAborted)
/// to propagate transport-level cancellations to tool executions.
/// </remarks>
public CancellationToken? CancellationToken { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public override void Dispose()

protected abstract HttpClientTransportOptions ClientTransportOptions { get; }

private Task<McpClient> GetClientAsync(McpClientOptions? options = null)
protected Task<McpClient> GetClientAsync(McpClientOptions? options = null)
{
return _fixture.ConnectMcpClientAsync(options, LoggerFactory);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using System.Text;

namespace ModelContextProtocol.AspNetCore.Tests;
Expand Down Expand Up @@ -60,4 +61,57 @@ public async Task EventSourceStream_Includes_MessageEventType()
var messageEvent = await streamReader.ReadLineAsync(TestContext.Current.CancellationToken);
Assert.Equal("event: message", messageEvent);
}

[Fact]
public async Task CallTool_CancelToken_SendsCancellationNotification_KeepsConnectionOpen()
{
await using var client = await GetClientAsync();

using CancellationTokenSource cts = new();
var toolTask = client.CallToolAsync(
"longRunning",
new Dictionary<string, object?> { ["durationMs"] = 10000 },
cancellationToken: cts.Token
);

// Allow some time for the request to be sent
await Task.Delay(500, TestContext.Current.CancellationToken);

cts.Cancel();

// Client throws OperationCanceledException
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await toolTask);

// Verify the connection is still open by pinging
var pingResult = await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.NotNull(pingResult);
}

[Fact]
public async Task CallTool_ClientDisconnectsAbruptly_CancelsServerToken()
{
var client = await GetClientAsync();

// Send the tool call
var toolTask = client.CallToolAsync(
"longRunning",
new Dictionary<string, object?> { ["durationMs"] = 10000 },
cancellationToken: TestContext.Current.CancellationToken
);

// Allow some time for the request to be sent and processing to start on the server
await Task.Delay(500, TestContext.Current.CancellationToken);

// Disposing the client will tear down the transport and drop the underlying HTTP connection,
// simulating a client crash or network drop without sending notifications/cancelled.
await client.DisposeAsync();

// The local client task will throw because the transport disconnected
await Assert.ThrowsAnyAsync<Exception>(async () => await toolTask);

// Verify the server is still alive and handling requests from a *new* client
await using var newClient = await GetClientAsync();
var pingResult = await newClient.PingAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.NotNull(pingResult);
}
}
54 changes: 54 additions & 0 deletions tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,60 @@ public async Task CallTool_Stdio_EchoSessionId_ReturnsEmpty()
Assert.Empty(textContent.Text);
}

[Fact]
public async Task CallTool_Stdio_CancelToken_ThrowsOperationCanceledException()
{
await using var client = await _fixture.CreateClientAsync("test_server");

using CancellationTokenSource cts = new();
var toolTask = client.CallToolAsync(
"longRunning",
new Dictionary<string, object?> { ["durationMs"] = 10000 },
cancellationToken: cts.Token
);

// Allow some time for the request to be sent
await Task.Delay(500, TestContext.Current.CancellationToken);

cts.Cancel();

// Client throws OperationCanceledException
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await toolTask);

// Verify the connection is still open by pinging
var pingResult = await client.PingAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.NotNull(pingResult);
}

[Fact]
public async Task CallTool_Stdio_ClientDisconnectsAbruptly_CancelsServerToken()
{
var client = await _fixture.CreateClientAsync("test_server");

// Send the tool call
var toolTask = client.CallToolAsync(
"longRunning",
new Dictionary<string, object?> { ["durationMs"] = 10000 },
cancellationToken: TestContext.Current.CancellationToken
);

// Allow some time for the request to be sent and processing to start
await Task.Delay(500, TestContext.Current.CancellationToken);

// Disposing the client will tear down the stdio pipes abruptly
await client.DisposeAsync();

// The local client task will throw because the transport disconnected
await Assert.ThrowsAnyAsync<Exception>(async () => await toolTask);

// Verify the server process was terminated or we can create a new connection depending on server behavior.
// For Stdio, the server process is typically isolated to the client connection instance,
// so we start a new client to ensure the transport factory is healthy.
await using var newClient = await _fixture.CreateClientAsync("test_server");
var pingResult = await newClient.PingAsync(cancellationToken: TestContext.Current.CancellationToken);
Assert.NotNull(pingResult);
}

[Theory]
[MemberData(nameof(GetClients))]
public async Task CallTool_Stdio_ViaAIFunction_EchoServer(string clientId)
Expand Down