From f78c3875cce3502134909e7442a3503ef8681827 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Fri, 26 Jun 2026 16:10:04 +0200 Subject: [PATCH] Add MCP OAuth lifecycle SDK support Expose host-delegated MCP OAuth handling across SDK languages, sync generated RPC and event models to the lifecycle contract, and add cross-language E2E coverage for initial auth, refresh, upscope, reauth, and cancellation. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/Client.cs | 10 + dotnet/src/Session.cs | 124 +++++ dotnet/src/Types.cs | 75 +++ dotnet/test/E2E/McpOAuthE2ETests.cs | 302 ++++++++++++ dotnet/test/Harness/E2ETestContext.cs | 2 + .../test/Unit/ClientSessionLifetimeTests.cs | 207 +++++++++ dotnet/test/Unit/PublicDtoTests.cs | 19 + .../Unit/SessionEventSerializationTests.cs | 70 +++ go/client.go | 21 + go/client_test.go | 287 ++++++++++++ go/internal/e2e/mcp_oauth_e2e_test.go | 340 ++++++++++++++ go/internal/e2e/testharness/context.go | 2 + go/session.go | 96 ++++ go/session_test.go | 199 ++++++++ go/types.go | 71 +++ .../com/github/copilot/CopilotClient.java | 46 +- .../com/github/copilot/CopilotSession.java | 77 ++++ .../github/copilot/SessionRequestBuilder.java | 6 + .../github/copilot/rpc/McpAuthHandler.java | 26 ++ .../github/copilot/rpc/McpAuthInvocation.java | 36 ++ .../github/copilot/rpc/McpAuthRequest.java | 19 + .../com/github/copilot/rpc/McpAuthResult.java | 32 ++ .../com/github/copilot/rpc/McpAuthToken.java | 13 + .../copilot/rpc/ResumeSessionConfig.java | 24 + .../com/github/copilot/rpc/SessionConfig.java | 27 ++ .../com/github/copilot/E2ETestContext.java | 3 +- .../McpAuthInterestRegistrationTest.java | 299 ++++++++++++ .../com/github/copilot/McpOAuthE2ETest.java | 301 ++++++++++++ nodejs/src/client.ts | 18 +- nodejs/src/session.ts | 55 ++- nodejs/src/types.ts | 77 ++++ nodejs/test/client.test.ts | 222 +++++++++ nodejs/test/e2e/harness/sdkTestContext.ts | 26 +- nodejs/test/e2e/mcp_oauth.e2e.test.ts | 311 +++++++++++++ nodejs/test/e2e/provider_endpoint.e2e.test.ts | 12 +- python/copilot/__init__.py | 14 + python/copilot/client.py | 15 + python/copilot/session.py | 182 ++++++++ python/e2e/test_mcp_oauth_e2e.py | 258 +++++++++++ python/e2e/testharness/context.py | 2 + python/test_client.py | 310 +++++++++++++ rust/src/handler.rs | 97 +++- rust/src/session.rs | 116 ++++- rust/src/types.rs | 30 +- rust/tests/e2e.rs | 2 + rust/tests/e2e/mcp_oauth.rs | 433 ++++++++++++++++++ rust/tests/e2e/support.rs | 2 + rust/tests/session_test.rs | 308 ++++++++++++- test/harness/test-mcp-oauth-server.mjs | 325 +++++++++++++ 49 files changed, 5508 insertions(+), 41 deletions(-) create mode 100644 dotnet/test/E2E/McpOAuthE2ETests.cs create mode 100644 go/internal/e2e/mcp_oauth_e2e_test.go create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthInvocation.java create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthResult.java create mode 100644 java/src/main/java/com/github/copilot/rpc/McpAuthToken.java create mode 100644 java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java create mode 100644 java/src/test/java/com/github/copilot/McpOAuthE2ETest.java create mode 100644 nodejs/test/e2e/mcp_oauth.e2e.test.ts create mode 100644 python/e2e/test_mcp_oauth_e2e.py create mode 100644 rust/tests/e2e/mcp_oauth.rs create mode 100644 test/harness/test-mcp-oauth-server.mjs diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index a67eb96817..9fbe8c5a72 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -630,6 +630,7 @@ private CopilotSession InitializeSession( this); session.RegisterTools(config.Tools ?? []); session.RegisterPermissionHandler(config.OnPermissionRequest); + session.RegisterMcpAuthHandler(config.OnMcpAuthRequest); session.RegisterCommands(config.Commands); session.RegisterElicitationHandler(config.OnElicitationRequest); session.RegisterExitPlanModeHandler(config.OnExitPlanModeRequest); @@ -1080,6 +1081,11 @@ public async Task CreateSessionAsync(SessionConfig config, Cance $"session.create returned sessionId {response.SessionId} but the caller requested {localSessionId}."); } + if (config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } + session.WorkspacePath = response.WorkspacePath; session.SetCapabilities(response.Capabilities); session.SetOpenCanvases(response.OpenCanvases); @@ -1166,6 +1172,10 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes transformCallbacks, hasHooks, "CopilotClient.ResumeSessionAsync"); + if (config.OnMcpAuthRequest is not null) + { + await session.Rpc.EventLog.RegisterInterestAsync("mcp.oauth_required", cancellationToken); + } try { diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 0985848e26..f009faa0bd 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -63,6 +63,7 @@ public sealed partial class CopilotSession : IAsyncDisposable private readonly CopilotClient _parentClient; private volatile Func>? _permissionHandler; + private volatile Func>? _mcpAuthHandler; private volatile Func>? _userInputHandler; private volatile Func>? _elicitationHandler; private volatile Func>? _exitPlanModeHandler; @@ -558,6 +559,11 @@ internal void RegisterPermissionHandler(Func>? handler) + { + _mcpAuthHandler = handler; + } + /// /// Handles a permission request from the Copilot CLI. /// @@ -633,6 +639,39 @@ private async Task HandleBroadcastEventAsync(SessionEvent sessionEvent) break; } + case McpOauthRequiredEvent authEvent: + { + var data = authEvent.Data; + if (string.IsNullOrEmpty(data.RequestId)) + return; + + var handler = _mcpAuthHandler; + if (handler is null) + { + if (_logger.IsEnabled(LogLevel.Warning)) + { + _logger.LogWarning( + "Received MCP OAuth request without a registered MCP auth handler. SessionId={SessionId}, RequestId={RequestId}", + SessionId, + data.RequestId); + } + return; + } + + await ExecuteMcpAuthAndRespondAsync(data.RequestId, new McpAuthContext + { + SessionId = SessionId, + RequestId = data.RequestId, + ServerName = data.ServerName, + ServerUrl = data.ServerUrl, + Reason = data.Reason, + WwwAuthenticateParams = data.WwwAuthenticateParams, + ResourceMetadata = data.ResourceMetadata, + StaticClientConfig = data.StaticClientConfig + }, handler); + break; + } + case CommandExecuteEvent cmdEvent: { var data = cmdEvent.Data; @@ -702,6 +741,91 @@ await HandleElicitationRequestAsync( } } + private async Task ExecuteMcpAuthAndRespondAsync( + string requestId, + McpAuthContext context, + Func> handler) + { + try + { + var result = await handler(context); + McpOauthPendingRequestResponse response = + result is { Cancelled: false, Token: { } token } + ? new McpOauthPendingRequestResponseToken + { + AccessToken = token.AccessToken, + TokenType = token.TokenType, + ExpiresIn = token.ExpiresIn + } + : new McpOauthPendingRequestResponseCancelled(); + + await Rpc.Mcp.Oauth.HandlePendingRequestAsync(requestId, response); + } + catch (OperationCanceledException) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + catch (ObjectDisposedException) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + catch (InvalidOperationException) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + catch (ArgumentException) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + catch (NotSupportedException) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + catch (JsonException) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + catch (RemoteRpcException) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + catch (IOException) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + catch (Exception ex) when (IsRecoverableMcpAuthFailure(ex)) + { + await TryCancelMcpAuthRequestAsync(requestId); + } + } + + private static bool IsRecoverableMcpAuthFailure(Exception exception) + => exception is not OperationCanceledException + and not OutOfMemoryException + and not StackOverflowException + and not AccessViolationException + and not AppDomainUnloadedException; + + private async Task TryCancelMcpAuthRequestAsync(string requestId) + { + try + { + await Rpc.Mcp.Oauth.HandlePendingRequestAsync(requestId, new McpOauthPendingRequestResponseCancelled()); + } + catch (IOException) + { + // Connection lost — nothing we can do. + } + catch (ObjectDisposedException) + { + // Connection already disposed — nothing we can do. + } + catch (RemoteRpcException) + { + // The pending request may already be gone — nothing we can do. + } + } + /// /// Executes a tool handler and sends the result back via the HandlePendingToolCall RPC. /// diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 5ae9657813..ecb2774398 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -1128,6 +1128,72 @@ public sealed class ElicitationContext public string? Url { get; set; } } +/// +/// Context for an MCP OAuth request callback. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthContext +{ + /// Identifier of the session that triggered the MCP OAuth request. + public string SessionId { get; set; } = string.Empty; + + /// Identifier of the pending MCP OAuth request. + public string RequestId { get; set; } = string.Empty; + + /// Display name of the MCP server that requires OAuth. + public string ServerName { get; set; } = string.Empty; + + /// URL of the MCP server that requires OAuth. + public string ServerUrl { get; set; } = string.Empty; + + /// Why the runtime is requesting host-provided OAuth credentials. + public McpOauthRequestReason Reason { get; set; } + + /// Parsed WWW-Authenticate parameters from the MCP server, if available. + public McpOauthWWWAuthenticateParams? WwwAuthenticateParams { get; set; } + + /// Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. + public string? ResourceMetadata { get; set; } + + /// Static OAuth client configuration, if the server specifies one. + public McpOauthRequiredStaticClientConfig? StaticClientConfig { get; set; } +} + +/// +/// Host-provided OAuth token data for a pending MCP OAuth request. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthToken +{ + /// Access token acquired by the SDK host. + public required string AccessToken { get; set; } + + /// OAuth token type. Defaults to Bearer when omitted. + public string? TokenType { get; set; } + + /// Token lifetime in seconds, if known. + public long? ExpiresIn { get; set; } +} + +/// +/// Result returned by an MCP auth request handler. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class McpAuthResult +{ + /// Whether the request should be cancelled instead of resolved with a token. + public bool Cancelled { get; set; } + + /// Host-provided token data. Ignored when is true. + public McpAuthToken? Token { get; set; } + + /// Create a token result. + public static McpAuthResult FromToken(McpAuthToken token) => new() { Token = token }; + + /// Create a cancellation result. + public static McpAuthResult Cancel() => new() { Cancelled = true }; +} + // ============================================================================ // Session Capabilities // ============================================================================ @@ -2719,6 +2785,7 @@ protected SessionConfigBase(SessionConfigBase? other) OnElicitationRequest = other.OnElicitationRequest; OnEvent = other.OnEvent; OnExitPlanModeRequest = other.OnExitPlanModeRequest; + OnMcpAuthRequest = other.OnMcpAuthRequest; OnPermissionRequest = other.OnPermissionRequest; OnUserInputRequest = other.OnUserInputRequest; Provider = other.Provider; @@ -3180,6 +3247,14 @@ protected SessionConfigBase(SessionConfigBase? other) [JsonIgnore] public ICanvasHandler? CanvasHandler { get; set; } #pragma warning restore GHCP001 + + /// + /// Optional handler for MCP OAuth requests from MCP servers. + /// When provided, the SDK can satisfy MCP server OAuth requests with host-provided token data or cancellation. + /// + [Experimental(Diagnostics.Experimental)] + [JsonIgnore] + public Func>? OnMcpAuthRequest { get; set; } } /// diff --git a/dotnet/test/E2E/McpOAuthE2ETests.cs b/dotnet/test/E2E/McpOAuthE2ETests.cs new file mode 100644 index 0000000000..417b7ad1bd --- /dev/null +++ b/dotnet/test/E2E/McpOAuthE2ETests.cs @@ -0,0 +1,302 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Rpc; +using GitHub.Copilot.Test.Harness; +using System.Diagnostics; +using System.Net.Http; +using System.Text.Json; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +public class McpOAuthE2ETests(E2ETestFixture fixture, ITestOutputHelper output) : E2ETestBase(fixture, "mcp_oauth", output) +{ + private const string ExpectedToken = "sdk-host-token"; + private const string RefreshToken = ExpectedToken + "-refresh"; + private const string UpscopeToken = ExpectedToken + "-upscope"; + private const string ReauthToken = ExpectedToken + "-reauth"; + + [Fact] + public async Task Should_Satisfy_MCP_OAuth_Using_Host_Provided_Token() + { + await using var oauthServer = await OAuthMcpServer.StartAsync(ExpectedToken); + var serverName = "oauth-protected-mcp"; + McpAuthContext? observedRequest = null; + + await using var session = await CreateSessionAsync(new SessionConfig + { + OnMcpAuthRequest = request => + { + observedRequest = request; + return Task.FromResult(McpAuthResult.FromToken(new McpAuthToken + { + AccessToken = ExpectedToken, + TokenType = "Bearer", + ExpiresIn = 3600 + })); + }, + McpServers = new Dictionary + { + [serverName] = new McpHttpServerConfig + { + Url = $"{oauthServer.Url}/mcp", + Tools = ["*"] + } + } + }); + + await WaitForMcpServerStatusAsync(session, serverName, McpServerStatus.Connected); + var tools = await session.Rpc.Mcp.ListToolsAsync(serverName); + Assert.Contains(tools.Tools, tool => tool.Name == "whoami"); + + Assert.NotNull(observedRequest); + Assert.NotEmpty(observedRequest!.RequestId); + Assert.Equal(serverName, observedRequest!.ServerName); + Assert.Equal($"{oauthServer.Url}/mcp", observedRequest.ServerUrl); + Assert.Equal(McpOauthRequestReason.Initial, observedRequest.Reason); + Assert.NotNull(observedRequest.WwwAuthenticateParams); + Assert.Equal($"{oauthServer.Url}/.well-known/oauth-protected-resource", observedRequest.WwwAuthenticateParams!.ResourceMetadataUrl); + Assert.Equal("mcp.read", observedRequest.WwwAuthenticateParams.Scope); + Assert.Equal("invalid_token", observedRequest.WwwAuthenticateParams.Error); + + using var metadata = JsonDocument.Parse(observedRequest.ResourceMetadata!); + Assert.Equal($"{oauthServer.Url}/mcp", metadata.RootElement.GetProperty("resource").GetString()); + + var requests = await oauthServer.GetRequestsAsync(); + Assert.Contains(requests, request => request.Authorization is null); + Assert.Contains(requests, request => request.Authorization == $"Bearer {ExpectedToken}"); + } + + [Fact] + public async Task Should_Request_Replacement_Tokens_Across_MCP_OAuth_Lifecycle() + { + await using var oauthServer = await OAuthMcpServer.StartAsync(ExpectedToken); + var serverName = "oauth-lifecycle-mcp"; + List observedReasons = []; + var refreshCount = 0; + + await using var session = await CreateSessionAsync(new SessionConfig + { + EnableMcpApps = true, + OnMcpAuthRequest = request => + { + observedReasons.Add(request.Reason); + if (request.Reason == McpOauthRequestReason.Refresh) + { + refreshCount++; + Assert.NotNull(request.WwwAuthenticateParams); + Assert.Null(request.WwwAuthenticateParams!.ResourceMetadataUrl); + Assert.Equal("invalid_token", request.WwwAuthenticateParams.Error); + if (refreshCount > 1) + { + return Task.FromResult(McpAuthResult.Cancel()); + } + } + + if (request.Reason == McpOauthRequestReason.Upscope) + { + Assert.NotNull(request.WwwAuthenticateParams); + Assert.Equal($"{oauthServer.Url}/.well-known/oauth-protected-resource", request.WwwAuthenticateParams!.ResourceMetadataUrl); + Assert.Equal("mcp.write", request.WwwAuthenticateParams.Scope); + Assert.Equal("insufficient_scope", request.WwwAuthenticateParams.Error); + } + + var token = request.Reason == McpOauthRequestReason.Refresh + ? RefreshToken + : request.Reason == McpOauthRequestReason.Upscope + ? UpscopeToken + : request.Reason == McpOauthRequestReason.Reauth + ? ReauthToken + : ExpectedToken; + + return Task.FromResult(McpAuthResult.FromToken(new McpAuthToken + { + AccessToken = token + })); + }, + McpServers = new Dictionary + { + [serverName] = new McpHttpServerConfig + { + Url = $"{oauthServer.Url}/mcp", + Tools = ["*"] + } + } + }); + + await WaitForMcpServerStatusAsync(session, serverName, McpServerStatus.Connected); + await CallWhoamiAsync(session, serverName, "refresh"); + await CallWhoamiAsync(session, serverName, "upscope"); + await CallWhoamiAsync(session, serverName, "reauth"); + + Assert.Equal( + [ + McpOauthRequestReason.Initial, + McpOauthRequestReason.Refresh, + McpOauthRequestReason.Upscope, + McpOauthRequestReason.Refresh, + McpOauthRequestReason.Reauth + ], + observedReasons); + + var requests = await oauthServer.GetRequestsAsync(); + Assert.Contains(requests, request => request.Authorization == $"Bearer {RefreshToken}"); + Assert.Contains(requests, request => request.Authorization == $"Bearer {UpscopeToken}"); + Assert.Contains(requests, request => request.Authorization == $"Bearer {ReauthToken}"); + } + + [Fact] + public async Task Should_Cancel_Pending_MCP_OAuth_Request() + { + await using var oauthServer = await OAuthMcpServer.StartAsync(ExpectedToken); + var serverName = "oauth-cancelled-mcp"; + McpAuthContext? observedRequest = null; + + await using var session = await CreateSessionAsync(new SessionConfig + { + OnMcpAuthRequest = request => + { + observedRequest = request; + return Task.FromResult(McpAuthResult.Cancel()); + }, + McpServers = new Dictionary + { + [serverName] = new McpHttpServerConfig + { + Url = $"{oauthServer.Url}/mcp", + Tools = ["*"] + } + } + }); + + await WaitForMcpServerStatusAsync(session, serverName, McpServerStatus.Failed); + + Assert.NotNull(observedRequest); + Assert.NotEmpty(observedRequest!.RequestId); + Assert.Equal(serverName, observedRequest!.ServerName); + Assert.Equal(McpOauthRequestReason.Initial, observedRequest.Reason); + } + + private static async Task CallWhoamiAsync(CopilotSession session, string serverName, string scenario) + { + using var argumentDocument = JsonDocument.Parse($"{{\"scenario\":\"{scenario}\"}}"); + var result = await session.Rpc.Mcp.Apps.CallToolAsync( + serverName, + "whoami", + serverName, + new Dictionary + { + ["scenario"] = argumentDocument.RootElement.GetProperty("scenario").Clone() + }); + + var content = result["content"].EnumerateArray().ToList(); + Assert.Single(content); + Assert.Equal("oauth-test-user", content[0].GetProperty("text").GetString()); + } + + private sealed class OAuthMcpServer : IAsyncDisposable + { + private readonly Process _process; + private readonly HttpClient _http = new(); + + private OAuthMcpServer(Process process, string url) + { + _process = process; + Url = url; + } + + public string Url { get; } + + public static async Task StartAsync(string expectedToken) + { + var repoRoot = FindRepoRoot(); + var script = GetRepoRelativePath(repoRoot, "test", "harness", "test-mcp-oauth-server.mjs"); + var startInfo = new ProcessStartInfo + { + FileName = "node", + Arguments = QuoteProcessArgument(script), + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false + }; + startInfo.Environment["EXPECTED_TOKEN"] = expectedToken; + + var process = Process.Start(startInfo) + ?? throw new InvalidOperationException("Failed to start OAuth MCP server."); + var stderrTask = process.StandardError.ReadToEndAsync(); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + while (!cts.IsCancellationRequested) + { + var line = await process.StandardOutput.ReadLineAsync(cts.Token); + if (line is null) + { + throw new InvalidOperationException($"OAuth MCP server exited before listening: {await stderrTask}"); + } + if (line.StartsWith("Listening: ", StringComparison.Ordinal)) + { + return new OAuthMcpServer(process, line["Listening: ".Length..]); + } + } + + throw new TimeoutException($"Timed out waiting for OAuth MCP server: {await stderrTask}"); + } + + public async Task> GetRequestsAsync() + { + var json = await _http.GetStringAsync($"{Url}/__requests"); + using var document = JsonDocument.Parse(json); + return document.RootElement.EnumerateArray() + .Select(element => new OAuthMcpRequest( + element.TryGetProperty("authorization", out var authorization) + && authorization.ValueKind is JsonValueKind.String + ? authorization.GetString() + : null)) + .ToList(); + } + + public async ValueTask DisposeAsync() + { + _http.Dispose(); + if (!_process.HasExited) + { + _process.Kill(entireProcessTree: true); + await _process.WaitForExitAsync(); + } + _process.Dispose(); + } + + private static string FindRepoRoot() + { + var dir = new DirectoryInfo(AppContext.BaseDirectory); + while (dir != null) + { + var candidate = GetRepoRelativePath(dir.FullName, "test", "harness", "test-mcp-oauth-server.mjs"); + if (File.Exists(candidate)) + return dir.FullName; + dir = dir.Parent; + } + throw new InvalidOperationException("Could not find repository root."); + } + + private static string GetRepoRelativePath(string repoRoot, params string[] relativeSegments) + { + var path = repoRoot; + foreach (var segment in relativeSegments) + { + if (Path.IsPathRooted(segment)) + throw new ArgumentException("Repository-relative path segments must not be rooted.", nameof(relativeSegments)); + path = Path.Join(path, segment); + } + return Path.GetFullPath(path); + } + + private static string QuoteProcessArgument(string argument) + => "\"" + argument.Replace("\"", "\\\"") + "\""; + } + + private sealed record OAuthMcpRequest(string? Authorization); +} diff --git a/dotnet/test/Harness/E2ETestContext.cs b/dotnet/test/Harness/E2ETestContext.cs index 2e2043183a..6e26299a49 100644 --- a/dotnet/test/Harness/E2ETestContext.cs +++ b/dotnet/test/Harness/E2ETestContext.cs @@ -192,6 +192,8 @@ public Dictionary GetEnvironment() env["GH_CONFIG_DIR"] = HomeDir; env["XDG_CONFIG_HOME"] = HomeDir; env["XDG_STATE_HOME"] = HomeDir; + env["COPILOT_MCP_APPS"] = "true"; + env["MCP_APPS"] = "true"; if (!string.IsNullOrEmpty(_proxy.ConnectProxyUrl) && !string.IsNullOrEmpty(_proxy.CaFilePath)) { const string noProxy = "127.0.0.1,localhost,::1"; diff --git a/dotnet/test/Unit/ClientSessionLifetimeTests.cs b/dotnet/test/Unit/ClientSessionLifetimeTests.cs index 2c11c7d6b5..e51a4b9119 100644 --- a/dotnet/test/Unit/ClientSessionLifetimeTests.cs +++ b/dotnet/test/Unit/ClientSessionLifetimeTests.cs @@ -16,6 +16,8 @@ namespace GitHub.Copilot.Test.Unit; public sealed class ClientSessionLifetimeTests { + private sealed record RpcRequestRecord(string Method, JsonElement Params); + [Fact] public async Task StopAsync_Requests_Runtime_Shutdown_For_Owned_Process() { @@ -188,6 +190,151 @@ public async Task ResumeSessionAsync_Throws_When_Same_Client_Already_Tracks_Sess AssertSessionCount(client, sessions: 1); } + [Fact] + public async Task CreateSessionAsync_Registers_McpAuth_Interest_Only_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var withoutAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnEvent = _ => { } + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + Assert.Contains(server.Requests, request => + request.Method == "session.create" + && request.Params.GetProperty("requestPermission").GetBoolean()); + + server.ClearRequests(); + + await using var withAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()) + }); + + Assert.Collection( + server.Requests.Take(2), + request => Assert.Equal("session.create", request.Method), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }); + } + + [Fact] + public async Task CreateSessionAsync_Registers_McpAuth_Interest_After_Cloud_Create_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + var cloud = new CloudSessionOptions + { + Repository = new CloudSessionRepository + { + Owner = "github", + Name = "copilot-sdk", + Branch = "main" + } + }; + + await using var withoutAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + Cloud = cloud + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + + server.ClearRequests(); + + await using var withAuth = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()), + Cloud = cloud + }); + + Assert.Collection( + server.Requests.Take(2), + request => Assert.Equal("session.create", request.Method), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }); + } + + [Fact] + public async Task ResumeSessionAsync_Registers_McpAuth_Interest_Only_When_Handler_Configured() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + + await using var withoutAuth = await client.ResumeSessionAsync("session-without-auth", new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnEvent = _ => { } + }); + + Assert.DoesNotContain(server.Requests, request => + request.Method == "session.eventLog.registerInterest" + && request.Params.GetProperty("eventType").GetString() == "mcp.oauth_required"); + Assert.Contains(server.Requests, request => + request.Method == "session.resume" + && request.Params.GetProperty("requestPermission").GetBoolean()); + + server.ClearRequests(); + + await using var withAuth = await client.ResumeSessionAsync("session-with-auth", new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => Task.FromResult(McpAuthResult.Cancel()) + }); + + Assert.Collection( + server.Requests.Take(2), + request => + { + Assert.Equal("session.eventLog.registerInterest", request.Method); + Assert.Equal("mcp.oauth_required", request.Params.GetProperty("eventType").GetString()); + }, + request => Assert.Equal("session.resume", request.Method)); + } + + [Fact] + public async Task McpAuth_Handler_Exception_Cancels_Pending_Request() + { + await using var server = await FakeCopilotServer.StartAsync(); + await using var client = new CopilotClient(new CopilotClientOptions { Connection = RuntimeConnection.ForUri(server.Url) }); + await using var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + OnMcpAuthRequest = _ => throw new ApplicationException("boom") + }); + + DispatchEvent(session, new McpOauthRequiredEvent + { + Data = new McpOauthRequiredData + { + RequestId = "mcp-auth-request-1", + ServerName = "oauth-mcp", + ServerUrl = "http://localhost/mcp", + Reason = McpOauthRequestReason.Initial + } + }); + + var request = await WaitForRequestAsync(server, "session.mcp.oauth.handlePendingRequest"); + Assert.Equal("mcp-auth-request-1", request.Params.GetProperty("requestId").GetString()); + Assert.Equal("cancelled", request.Params.GetProperty("result").GetProperty("kind").GetString()); + } + [Fact] public async Task Generated_Session_Rpc_Throws_When_Session_Disposed() { @@ -238,6 +385,30 @@ private static int GetPrivateDictionaryCount(CopilotClient client, string fieldN return (int)count.GetValue(dictionary)!; } + private static void DispatchEvent(CopilotSession session, SessionEvent evt) + { + var method = typeof(CopilotSession).GetMethod("DispatchEvent", BindingFlags.Instance | BindingFlags.NonPublic) + ?? throw new InvalidOperationException("DispatchEvent method was not found."); + method.Invoke(session, [evt]); + } + + private static async Task WaitForRequestAsync(FakeCopilotServer server, string method) + { + using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + while (!timeout.IsCancellationRequested) + { + var request = server.Requests.FirstOrDefault(request => request.Method == method); + if (request is not null) + { + return request; + } + + await Task.Delay(20, CancellationToken.None); + } + + throw new TimeoutException($"Timed out waiting for RPC method '{method}'."); + } + private static async Task ReplaceConnectionCliProcessAsync(CopilotClient client, Process process) { var field = typeof(CopilotClient).GetField("_connectionTask", BindingFlags.Instance | BindingFlags.NonPublic) @@ -277,6 +448,8 @@ private sealed class FakeCopilotServer : IAsyncDisposable private readonly TaskCompletionSource _destroyStarted = new(TaskCreationOptions.RunContinuationsAsynchronously); private readonly TaskCompletionSource _allowDestroy = new(TaskCreationOptions.RunContinuationsAsynchronously); private readonly Task _serverTask; + private readonly List _requests = []; + private readonly object _requestsLock = new(); private string? _lastSessionId; private bool _delayDestroy; private bool _failRuntimeShutdown; @@ -307,6 +480,25 @@ public static Task StartAsync() public int RuntimeShutdownCount { get; private set; } + public IReadOnlyList Requests + { + get + { + lock (_requestsLock) + { + return _requests.ToArray(); + } + } + } + + public void ClearRequests() + { + lock (_requestsLock) + { + _requests.Clear(); + } + } + public void DelayDestroy() { _delayDestroy = true; @@ -382,6 +574,13 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel return; } + var paramsElement = request.TryGetProperty("params", out var rawParams) + ? rawParams.Clone() + : JsonDocument.Parse("{}").RootElement.Clone(); + lock (_requestsLock) + { + _requests.Add(new RpcRequestRecord(method!, paramsElement)); + } object? result = method switch { "connect" => new Dictionary @@ -392,10 +591,18 @@ private async Task HandleRequestAsync(Stream stream, JsonElement request, Cancel }, "session.create" => CreateSessionResult(request), "session.resume" => CreateSessionResult(request), + "session.eventLog.registerInterest" => new Dictionary + { + ["id"] = "interest-1" + }, "session.send" => new Dictionary { ["messageId"] = "message-1" }, + "session.mcp.oauth.handlePendingRequest" => new Dictionary + { + ["success"] = true + }, "session.delete" => new Dictionary { ["success"] = true diff --git a/dotnet/test/Unit/PublicDtoTests.cs b/dotnet/test/Unit/PublicDtoTests.cs index c81a8a7a64..d1918d2b9a 100644 --- a/dotnet/test/Unit/PublicDtoTests.cs +++ b/dotnet/test/Unit/PublicDtoTests.cs @@ -20,6 +20,25 @@ namespace GitHub.Copilot.Test.Unit; /// public class PublicDtoTests { + [Fact] + public void McpAuth_Result_Factories_Represent_Token_And_Cancellation() + { + var token = new McpAuthToken + { + AccessToken = "host-token", + TokenType = "Bearer", + ExpiresIn = 3600, + }; + + var tokenResult = McpAuthResult.FromToken(token); + Assert.Same(token, tokenResult.Token); + Assert.False(tokenResult.Cancelled); + + var cancelled = McpAuthResult.Cancel(); + Assert.True(cancelled.Cancelled); + Assert.Null(cancelled.Token); + } + [Fact] public void Public_Dto_Properties_Can_Be_Set_And_Read() { diff --git a/dotnet/test/Unit/SessionEventSerializationTests.cs b/dotnet/test/Unit/SessionEventSerializationTests.cs index f537f500f3..64e28a5aee 100644 --- a/dotnet/test/Unit/SessionEventSerializationTests.cs +++ b/dotnet/test/Unit/SessionEventSerializationTests.cs @@ -156,9 +156,15 @@ public class SessionEventSerializationTests StaticClientConfig = new McpOauthRequiredStaticClientConfig { ClientId = "client-id", + ClientSecret = "static-secret", GrantType = "client_credentials", PublicClient = false, }, + WwwAuthenticateParams = new McpOauthWWWAuthenticateParams + { + ResourceMetadataUrl = "https://example.com/.well-known/oauth-protected-resource", + }, + ResourceMetadata = """{"resource":"https://example.com/mcp"}""", }, }, "mcp.oauth_required" @@ -282,6 +288,17 @@ public void SessionEvent_ToJson_RoundTrips_JsonElementBackedPayloads(SessionEven .GetProperty("staticClientConfig") .GetProperty("grantType") .GetString()); + Assert.Equal( + "static-secret", + root.GetProperty("data") + .GetProperty("staticClientConfig") + .GetProperty("clientSecret") + .GetString()); + Assert.Equal( + """{"resource":"https://example.com/mcp"}""", + root.GetProperty("data") + .GetProperty("resourceMetadata") + .GetString()); break; case "assistant.message_start": @@ -298,4 +315,57 @@ public void SessionEvent_ToJson_RoundTrips_JsonElementBackedPayloads(SessionEven break; } } + + [Fact] + public void McpOauthRequiredData_Allows_Missing_Optional_Metadata() + { + const string json = """ + { + "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "timestamp": "2026-03-15T21:26:54.987Z", + "parentId": null, + "type": "mcp.oauth_required", + "data": { + "requestId": "oauth-request", + "reason": "initial", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + } + } + """; + + var authEvent = Assert.IsType(SessionEvent.FromJson(json)); + Assert.Null(authEvent.Data.WwwAuthenticateParams); + Assert.Null(authEvent.Data.ResourceMetadata); + } + + [Fact] + public void McpOauthRequiredData_Preserves_Static_Client_Secret() + { + const string json = """ + { + "id": "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + "timestamp": "2026-03-15T21:26:54.987Z", + "parentId": null, + "type": "mcp.oauth_required", + "data": { + "requestId": "oauth-request", + "reason": "initial", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "staticClientConfig": { + "clientId": "static-client", + "clientSecret": "static-secret", + "grantType": "client_credentials", + "publicClient": false + } + } + } + """; + + var authEvent = Assert.IsType(SessionEvent.FromJson(json)); + + Assert.NotNull(authEvent.Data.StaticClientConfig); + Assert.Equal("static-secret", authEvent.Data.StaticClientConfig.ClientSecret); + } } diff --git a/go/client.go b/go/client.go index 970f046425..9e2819047e 100644 --- a/go/client.go +++ b/go/client.go @@ -806,6 +806,7 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses s.registerTools(config.Tools) s.registerPermissionHandler(config.OnPermissionRequest) + s.registerMCPAuthHandler(config.OnMCPAuthRequest) if config.OnUserInputRequest != nil { s.registerUserInputHandler(config.OnUserInputRequest) } @@ -937,6 +938,14 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses c.sessionsMux.Unlock() return nil, fmt.Errorf("session.create returned sessionId %s but the caller requested %s", response.SessionID, localSessionID) } + if config.OnMCPAuthRequest != nil { + if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{ + "sessionId": session.SessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + return nil, err + } + } session.workspacePath = response.WorkspacePath session.setCapabilities(response.Capabilities) @@ -1106,6 +1115,7 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, session.registerTools(config.Tools) session.registerPermissionHandler(config.OnPermissionRequest) + session.registerMCPAuthHandler(config.OnMCPAuthRequest) if config.OnUserInputRequest != nil { session.registerUserInputHandler(config.OnUserInputRequest) } @@ -1140,6 +1150,17 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, c.sessionsMux.Lock() c.sessions[sessionID] = session c.sessionsMux.Unlock() + if config.OnMCPAuthRequest != nil { + if _, err := c.client.Request(ctx, "session.eventLog.registerInterest", map[string]any{ + "sessionId": sessionID, + "eventType": "mcp.oauth_required", + }); err != nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + return nil, err + } + } if c.options.SessionFS != nil { if config.CreateSessionFSProvider == nil { diff --git a/go/client_test.go b/go/client_test.go index d59c71c6f9..c889ced8d5 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -3,6 +3,8 @@ package copilot import ( "context" "encoding/json" + "fmt" + "io" "net" "os" "os/exec" @@ -1315,6 +1317,291 @@ func TestClient_StartStopRace(t *testing.T) { } } +func TestClient_MCPAuthInterestRegistration(t *testing.T) { + t.Run("create skips MCP OAuth interest without auth handler", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + session, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnEvent: func(SessionEvent) {}, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + defer session.Disconnect() + + assertNoMCPAuthInterest(t, requests.snapshot()) + assertRequestMethod(t, requests.snapshot(), "session.create") + assertCreateRequestPermission(t, requests.snapshot()) + }) + + t.Run("create registers MCP OAuth interest after local session create when auth handler is configured", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + session, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return MCPAuthResultCancelled(), nil + }, + }) + if err != nil { + t.Fatalf("CreateSession failed: %v", err) + } + defer session.Disconnect() + + snapshot := requests.snapshot() + assertRequestMethod(t, snapshot, "session.eventLog.registerInterest") + if snapshot[0].Method != "session.create" { + t.Fatalf("expected session.create before MCP auth interest, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest after session.create, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[1]) + assertCreateRequestPermission(t, snapshot) + }) + + t.Run("cloud create registers MCP OAuth interest after server assigns id only when auth handler is configured", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + withoutAuth, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk", Branch: "main"}, + }, + }) + if err != nil { + t.Fatalf("CreateSession without auth failed: %v", err) + } + defer withoutAuth.Disconnect() + + assertNoMCPAuthInterest(t, requests.snapshot()) + requests.clear() + + withAuth, err := client.CreateSession(t.Context(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return MCPAuthResultCancelled(), nil + }, + Cloud: &CloudSessionOptions{ + Repository: &CloudSessionRepository{Owner: "github", Name: "copilot-sdk", Branch: "main"}, + }, + }) + if err != nil { + t.Fatalf("CreateSession with auth failed: %v", err) + } + defer withAuth.Disconnect() + + snapshot := requests.snapshot() + if snapshot[0].Method != "session.create" { + t.Fatalf("expected cloud session.create before MCP auth interest, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest after cloud session.create, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[1]) + }) + + t.Run("resume conditionally registers MCP OAuth interest before session resume", func(t *testing.T) { + client, requests, cleanup := newInMemoryClient(t) + defer cleanup() + + withoutAuth, err := client.ResumeSession(t.Context(), "session-without-auth", &ResumeSessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnEvent: func(SessionEvent) {}, + }) + if err != nil { + t.Fatalf("ResumeSession without auth failed: %v", err) + } + defer withoutAuth.Disconnect() + + assertNoMCPAuthInterest(t, requests.snapshot()) + assertRequestMethod(t, requests.snapshot(), "session.resume") + requests.clear() + + withAuth, err := client.ResumeSession(t.Context(), "session-with-auth", &ResumeSessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(MCPAuthRequest, MCPAuthInvocation) (*MCPAuthResult, error) { + return MCPAuthResultCancelled(), nil + }, + }) + if err != nil { + t.Fatalf("ResumeSession with auth failed: %v", err) + } + defer withAuth.Disconnect() + + snapshot := requests.snapshot() + if snapshot[0].Method != "session.eventLog.registerInterest" { + t.Fatalf("expected MCP auth interest before session.resume, got %s", snapshot[0].Method) + } + if snapshot[1].Method != "session.resume" { + t.Fatalf("expected session.resume after MCP auth interest, got %s", snapshot[1].Method) + } + assertMCPAuthInterest(t, snapshot[0]) + }) +} + +type recordedRequest struct { + Method string + Params map[string]any +} + +type requestRecorder struct { + mu sync.Mutex + requests []recordedRequest +} + +func (r *requestRecorder) append(request recordedRequest) { + r.mu.Lock() + defer r.mu.Unlock() + r.requests = append(r.requests, request) +} + +func (r *requestRecorder) snapshot() []recordedRequest { + r.mu.Lock() + defer r.mu.Unlock() + out := make([]recordedRequest, len(r.requests)) + copy(out, r.requests) + return out +} + +func (r *requestRecorder) clear() { + r.mu.Lock() + defer r.mu.Unlock() + r.requests = nil +} + +func newInMemoryClient(t *testing.T) (*Client, *requestRecorder, func()) { + t.Helper() + + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + rpcClient := jsonrpc2.NewClient(stdinW, stdoutR) + rpcClient.Start() + + client := NewClient(&ClientOptions{}) + client.client = rpcClient + client.RPC = rpc.NewServerRPC(rpcClient) + client.state = stateConnected + + requests := &requestRecorder{} + done := make(chan struct{}) + go serveInMemoryRuntime(t, stdinR, stdoutW, requests, done) + + cleanup := func() { + rpcClient.Stop() + stdinR.Close() + stdinW.Close() + stdoutR.Close() + stdoutW.Close() + <-done + } + return client, requests, cleanup +} + +func serveInMemoryRuntime(t *testing.T, stdinR *io.PipeReader, stdoutW *io.PipeWriter, requests *requestRecorder, done chan<- struct{}) { + t.Helper() + defer close(done) + + serverAssignedSessions := 0 + for { + frame, err := readTestJSONRPCFrame(stdinR) + if err != nil { + return + } + + var request struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + if err := json.Unmarshal(frame, &request); err != nil { + t.Errorf("failed to unmarshal JSON-RPC request: %v", err) + return + } + requests.append(recordedRequest{Method: request.Method, Params: request.Params}) + + var result map[string]any + switch request.Method { + case "session.create", "session.resume": + sessionID, _ := request.Params["sessionId"].(string) + if sessionID == "" { + serverAssignedSessions++ + sessionID = fmt.Sprintf("server-assigned-session-%d", serverAssignedSessions) + } + result = map[string]any{"sessionId": sessionID, "workspacePath": nil} + case "session.eventLog.registerInterest": + result = map[string]any{"id": "interest-1"} + case "session.options.update": + result = map[string]any{"success": true} + case "session.skills.reload", "session.destroy": + result = map[string]any{} + default: + t.Errorf("unexpected JSON-RPC method %s", request.Method) + return + } + + response := map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(request.ID), + "result": result, + } + data, err := json.Marshal(response) + if err != nil { + t.Errorf("failed to marshal JSON-RPC response: %v", err) + return + } + if _, err := fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(data), data); err != nil { + return + } + } +} + +func assertRequestMethod(t *testing.T, requests []recordedRequest, method string) { + t.Helper() + for _, request := range requests { + if request.Method == method { + return + } + } + t.Fatalf("expected %s request in %+v", method, requests) +} + +func assertNoMCPAuthInterest(t *testing.T, requests []recordedRequest) { + t.Helper() + for _, request := range requests { + if request.Method == "session.eventLog.registerInterest" && request.Params["eventType"] == "mcp.oauth_required" { + t.Fatalf("did not expect MCP auth interest registration in %+v", requests) + } + } +} + +func assertMCPAuthInterest(t *testing.T, request recordedRequest) { + t.Helper() + if request.Method != "session.eventLog.registerInterest" { + t.Fatalf("expected registerInterest request, got %s", request.Method) + } + if request.Params["eventType"] != "mcp.oauth_required" { + t.Fatalf("expected mcp.oauth_required interest, got %v", request.Params["eventType"]) + } +} + +func assertCreateRequestPermission(t *testing.T, requests []recordedRequest) { + t.Helper() + for _, request := range requests { + if request.Method == "session.create" { + if request.Params["requestPermission"] != true { + t.Fatalf("expected create requestPermission=true, got %v", request.Params["requestPermission"]) + } + return + } + } + t.Fatalf("session.create request not found in %+v", requests) +} + func TestCreateSessionRequest_Commands(t *testing.T) { t.Run("forwards commands in session.create RPC", func(t *testing.T) { req := createSessionRequest{ diff --git a/go/internal/e2e/mcp_oauth_e2e_test.go b/go/internal/e2e/mcp_oauth_e2e_test.go new file mode 100644 index 0000000000..e423f12d12 --- /dev/null +++ b/go/internal/e2e/mcp_oauth_e2e_test.go @@ -0,0 +1,340 @@ +package e2e + +import ( + "bufio" + "encoding/json" + "net/http" + "os" + "os/exec" + "path/filepath" + "slices" + "strings" + "sync" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" + "github.com/github/copilot-sdk/go/rpc" +) + +const expectedMCPOAuthToken = "sdk-host-token" +const refreshMCPOAuthToken = expectedMCPOAuthToken + "-refresh" +const upscopeMCPOAuthToken = expectedMCPOAuthToken + "-upscope" +const reauthMCPOAuthToken = expectedMCPOAuthToken + "-reauth" + +func TestMCPOAuthE2E(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient() + t.Cleanup(func() { client.ForceStop() }) + + t.Run("satisfy MCP OAuth using host-provided token", func(t *testing.T) { + baseURL := startOAuthMCPServer(t) + serverName := "oauth-protected-mcp" + tokenType := "Bearer" + expiresIn := int64(3600) + var observedRequest copilot.MCPAuthRequest + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(request copilot.MCPAuthRequest, _ copilot.MCPAuthInvocation) (*copilot.MCPAuthResult, error) { + observedRequest = request + return copilot.MCPAuthResultToken(&copilot.MCPAuthToken{ + AccessToken: expectedMCPOAuthToken, + TokenType: &tokenType, + ExpiresIn: &expiresIn, + }), nil + }, + MCPServers: map[string]copilot.MCPServerConfig{ + serverName: copilot.MCPHTTPServerConfig{ + URL: baseURL + "/mcp", + Tools: []string{"*"}, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + t.Cleanup(func() { session.Disconnect() }) + + waitForMCPServerStatus(t, session, serverName, rpc.MCPServerStatusConnected) + tools, err := session.RPC.MCP.ListTools(t.Context(), &rpc.MCPListToolsRequest{ServerName: serverName}) + if err != nil { + t.Fatalf("Failed to list MCP tools: %v", err) + } + if len(tools.Tools) != 1 || tools.Tools[0].Name != "whoami" { + t.Fatalf("Expected whoami tool, got %#v", tools.Tools) + } + + if observedRequest.ServerName != serverName { + t.Fatalf("Expected serverName %q, got %q", serverName, observedRequest.ServerName) + } + if observedRequest.ServerURL != baseURL+"/mcp" { + t.Fatalf("Expected serverUrl %q, got %q", baseURL+"/mcp", observedRequest.ServerURL) + } + if observedRequest.WwwAuthenticateParams == nil { + t.Fatal("Expected WWW-Authenticate params") + } + if observedRequest.Reason != "initial" { + t.Fatalf("Unexpected auth request reason: %q", observedRequest.Reason) + } + if observedRequest.WwwAuthenticateParams.ResourceMetadataURL == nil || + *observedRequest.WwwAuthenticateParams.ResourceMetadataURL != baseURL+"/.well-known/oauth-protected-resource" { + t.Fatalf("Unexpected resource metadata URL: %v", observedRequest.WwwAuthenticateParams.ResourceMetadataURL) + } + if stringValue(observedRequest.WwwAuthenticateParams.Scope) != "mcp.read" || stringValue(observedRequest.WwwAuthenticateParams.Error) != "invalid_token" { + t.Fatalf("Unexpected WWW-Authenticate params: %#v", observedRequest.WwwAuthenticateParams) + } + + var metadata map[string]any + if observedRequest.ResourceMetadata == nil { + t.Fatal("Expected resource metadata to be propagated") + } + if err := json.Unmarshal([]byte(*observedRequest.ResourceMetadata), &metadata); err != nil { + t.Fatalf("Failed to parse resource metadata: %v", err) + } + if metadata["resource"] != baseURL+"/mcp" { + t.Fatalf("Expected resource %q, got %#v", baseURL+"/mcp", metadata["resource"]) + } + + requests := fetchOAuthMCPRequests(t, baseURL) + if !hasAuthorization(requests, "") { + t.Fatal("Expected at least one unauthenticated MCP request") + } + if !hasAuthorization(requests, "Bearer "+expectedMCPOAuthToken) { + t.Fatal("Expected at least one MCP request with host-provided token") + } + }) + + t.Run("request replacement tokens across MCP OAuth lifecycle", func(t *testing.T) { + baseURL := startOAuthMCPServer(t) + serverName := "oauth-lifecycle-mcp" + var mu sync.Mutex + var observedReasons []copilot.MCPOauthRequestReason + refreshCount := 0 + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + EnableMCPApps: true, + OnMCPAuthRequest: func(request copilot.MCPAuthRequest, _ copilot.MCPAuthInvocation) (*copilot.MCPAuthResult, error) { + mu.Lock() + observedReasons = append(observedReasons, request.Reason) + refreshOrdinal := 0 + if request.Reason == copilot.MCPOauthRequestReasonRefresh { + refreshCount++ + refreshOrdinal = refreshCount + } + mu.Unlock() + + token := expectedMCPOAuthToken + switch request.Reason { + case copilot.MCPOauthRequestReasonRefresh: + if request.WwwAuthenticateParams == nil || + request.WwwAuthenticateParams.ResourceMetadataURL != nil || + stringValue(request.WwwAuthenticateParams.Error) != "invalid_token" { + t.Fatalf("Unexpected refresh WWW-Authenticate params: %#v", request.WwwAuthenticateParams) + } + if refreshOrdinal > 1 { + return copilot.MCPAuthResultCancelled(), nil + } + token = refreshMCPOAuthToken + case copilot.MCPOauthRequestReasonUpscope: + token = upscopeMCPOAuthToken + if request.WwwAuthenticateParams == nil || + request.WwwAuthenticateParams.ResourceMetadataURL == nil || + *request.WwwAuthenticateParams.ResourceMetadataURL != baseURL+"/.well-known/oauth-protected-resource" || + stringValue(request.WwwAuthenticateParams.Scope) != "mcp.write" || + stringValue(request.WwwAuthenticateParams.Error) != "insufficient_scope" { + t.Fatalf("Unexpected upscope WWW-Authenticate params: %#v", request.WwwAuthenticateParams) + } + case copilot.MCPOauthRequestReasonReauth: + token = reauthMCPOAuthToken + } + return copilot.MCPAuthResultToken(&copilot.MCPAuthToken{AccessToken: token}), nil + }, + MCPServers: map[string]copilot.MCPServerConfig{ + serverName: copilot.MCPHTTPServerConfig{ + URL: baseURL + "/mcp", + Tools: []string{"*"}, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + t.Cleanup(func() { session.Disconnect() }) + + waitForMCPServerStatus(t, session, serverName, rpc.MCPServerStatusConnected) + callWhoami(t, session, serverName, "refresh") + callWhoami(t, session, serverName, "upscope") + callWhoami(t, session, serverName, "reauth") + + mu.Lock() + reasons := append([]copilot.MCPOauthRequestReason(nil), observedReasons...) + mu.Unlock() + expectedReasons := []copilot.MCPOauthRequestReason{ + copilot.MCPOauthRequestReasonInitial, + copilot.MCPOauthRequestReasonRefresh, + copilot.MCPOauthRequestReasonUpscope, + copilot.MCPOauthRequestReasonRefresh, + copilot.MCPOauthRequestReasonReauth, + } + if !slices.Equal(reasons, expectedReasons) { + t.Fatalf("Unexpected auth request reasons: %#v", reasons) + } + + requests := fetchOAuthMCPRequests(t, baseURL) + if !hasAuthorization(requests, "Bearer "+refreshMCPOAuthToken) { + t.Fatal("Expected at least one MCP request with refresh token") + } + if !hasAuthorization(requests, "Bearer "+upscopeMCPOAuthToken) { + t.Fatal("Expected at least one MCP request with upscope token") + } + if !hasAuthorization(requests, "Bearer "+reauthMCPOAuthToken) { + t.Fatal("Expected at least one MCP request with reauth token") + } + }) + + t.Run("cancel pending MCP OAuth request", func(t *testing.T) { + baseURL := startOAuthMCPServer(t) + serverName := "oauth-cancelled-mcp" + var observedRequest copilot.MCPAuthRequest + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + OnMCPAuthRequest: func(request copilot.MCPAuthRequest, _ copilot.MCPAuthInvocation) (*copilot.MCPAuthResult, error) { + observedRequest = request + return copilot.MCPAuthResultCancelled(), nil + }, + MCPServers: map[string]copilot.MCPServerConfig{ + serverName: copilot.MCPHTTPServerConfig{ + URL: baseURL + "/mcp", + Tools: []string{"*"}, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + t.Cleanup(func() { session.Disconnect() }) + + waitForMCPServerStatus(t, session, serverName, rpc.MCPServerStatusFailed) + if observedRequest.ServerName != serverName { + t.Fatalf("Expected serverName %q, got %q", serverName, observedRequest.ServerName) + } + if observedRequest.Reason != copilot.MCPOauthRequestReasonInitial { + t.Fatalf("Unexpected auth request reason: %q", observedRequest.Reason) + } + }) +} + +type oauthMCPRequest struct { + Authorization *string `json:"authorization"` +} + +func startOAuthMCPServer(t *testing.T) string { + t.Helper() + + serverPath, err := filepath.Abs("../../../test/harness/test-mcp-oauth-server.mjs") + if err != nil { + t.Fatalf("Failed to resolve OAuth MCP server path: %v", err) + } + cmd := exec.Command("node", serverPath) + cmd.Env = append(os.Environ(), "EXPECTED_TOKEN="+expectedMCPOAuthToken) + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Fatalf("Failed to pipe OAuth MCP server stdout: %v", err) + } + var stderr strings.Builder + cmd.Stderr = &stderr + if err := cmd.Start(); err != nil { + t.Fatalf("Failed to start OAuth MCP server: %v", err) + } + t.Cleanup(func() { + if cmd.ProcessState != nil && cmd.ProcessState.Exited() { + return + } + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + }) + + lines := make(chan string, 1) + go func() { + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + lines <- scanner.Text() + return + } + close(lines) + }() + + select { + case line, ok := <-lines: + if !ok { + t.Fatalf("OAuth MCP server exited before listening: %s", stderr.String()) + } + const prefix = "Listening: " + if !strings.HasPrefix(line, prefix) { + t.Fatalf("Unexpected OAuth MCP server startup line %q. stderr=%s", line, stderr.String()) + } + return strings.TrimPrefix(line, prefix) + case <-time.After(10 * time.Second): + t.Fatalf("Timed out waiting for OAuth MCP server: %s", stderr.String()) + } + return "" +} + +func stringValue(value *string) string { + if value == nil { + return "" + } + return *value +} + +func fetchOAuthMCPRequests(t *testing.T, baseURL string) []oauthMCPRequest { + t.Helper() + + response, err := http.Get(baseURL + "/__requests") + if err != nil { + t.Fatalf("Failed to fetch OAuth MCP requests: %v", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + t.Fatalf("Failed to fetch OAuth MCP requests: %s", response.Status) + } + var requests []oauthMCPRequest + if err := json.NewDecoder(response.Body).Decode(&requests); err != nil { + t.Fatalf("Failed to decode OAuth MCP requests: %v", err) + } + return requests +} + +func hasAuthorization(requests []oauthMCPRequest, expected string) bool { + for _, request := range requests { + if request.Authorization == nil && expected == "" { + return true + } + if request.Authorization != nil && *request.Authorization == expected { + return true + } + } + return false +} + +func callWhoami(t *testing.T, session *copilot.Session, serverName string, scenario string) { + t.Helper() + + result, err := session.RPC.MCP.Apps().CallTool(t.Context(), &rpc.MCPAppsCallToolRequest{ + OriginServerName: serverName, + ServerName: serverName, + ToolName: "whoami", + Arguments: map[string]any{"scenario": scenario}, + }) + if err != nil { + t.Fatalf("Failed to call whoami for %s: %v", scenario, err) + } + content, ok := (*result)["content"].([]any) + if !ok || len(content) != 1 { + t.Fatalf("Unexpected whoami result: %#v", result) + } +} diff --git a/go/internal/e2e/testharness/context.go b/go/internal/e2e/testharness/context.go index adceb9a746..2643980b75 100644 --- a/go/internal/e2e/testharness/context.go +++ b/go/internal/e2e/testharness/context.go @@ -223,6 +223,8 @@ func (c *TestContext) Env() []string { "GH_CONFIG_DIR="+c.HomeDir, "GH_TOKEN="+defaultGitHubToken, "GITHUB_TOKEN="+defaultGitHubToken, + "COPILOT_MCP_APPS=true", + "MCP_APPS=true", "XDG_CONFIG_HOME="+c.HomeDir, "XDG_STATE_HOME="+c.HomeDir, ) diff --git a/go/session.go b/go/session.go index 851157ba87..808876761e 100644 --- a/go/session.go +++ b/go/session.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "log" "sync" "time" @@ -61,6 +62,8 @@ type Session struct { toolHandlersM sync.RWMutex permissionHandler PermissionHandlerFunc permissionMux sync.RWMutex + mcpAuthHandler MCPAuthHandler + mcpAuthMu sync.RWMutex userInputHandler UserInputHandler userInputMux sync.RWMutex exitPlanModeHandler ExitPlanModeRequestHandler @@ -924,6 +927,53 @@ func (s *Session) getElicitationHandler() ElicitationHandler { return s.elicitationHandler } +func (s *Session) registerMCPAuthHandler(handler MCPAuthHandler) { + s.mcpAuthMu.Lock() + defer s.mcpAuthMu.Unlock() + s.mcpAuthHandler = handler +} + +func (s *Session) getMCPAuthHandler() MCPAuthHandler { + s.mcpAuthMu.RLock() + defer s.mcpAuthMu.RUnlock() + return s.mcpAuthHandler +} + +func (s *Session) handleMCPAuthRequest(request MCPAuthRequest) { + handler := s.getMCPAuthHandler() + if handler == nil { + return + } + + ctx := context.Background() + cancel := &rpc.MCPOauthPendingRequestResponseCancelled{} + result, err := handler(request, MCPAuthInvocation{SessionID: s.SessionID}) + if err != nil { + log.Printf( + "MCP OAuth handler failed. SessionId=%s, RequestId=%s, Error=%v", + s.SessionID, + request.RequestID, + err, + ) + } + if err != nil || result == nil || result.Kind == MCPAuthResultKindCancelled || result.Token == nil { + s.RPC.MCP.Oauth().HandlePendingRequest(ctx, &rpc.MCPOauthHandlePendingRequest{ + RequestID: request.RequestID, + Result: cancel, + }) + return + } + + s.RPC.MCP.Oauth().HandlePendingRequest(ctx, &rpc.MCPOauthHandlePendingRequest{ + RequestID: request.RequestID, + Result: &rpc.MCPOauthPendingRequestResponseToken{ + AccessToken: result.Token.AccessToken, + TokenType: result.Token.TokenType, + ExpiresIn: result.Token.ExpiresIn, + }, + }) +} + // handleElicitationRequest dispatches an elicitation.requested event to the registered handler // and sends the result back via the RPC layer. Auto-cancels on error. func (s *Session) handleElicitationRequest(elicitCtx ElicitationContext, requestID string) { @@ -1370,6 +1420,52 @@ func (s *Session) handleBroadcastEvent(event SessionEvent) { } s.executePermissionAndRespond(d.RequestID, d.PermissionRequest, handler) + case *MCPOauthRequiredData: + handler := s.getMCPAuthHandler() + if d.RequestID == "" { + return + } + if handler == nil { + log.Printf( + "Received MCP OAuth request without a registered MCP auth handler. SessionId=%s, RequestId=%s", + s.SessionID, + d.RequestID, + ) + return + } + var staticClientConfig *MCPAuthStaticClientConfig + if d.StaticClientConfig != nil { + var grantType *string + if d.StaticClientConfig.GrantType != nil { + value := string(*d.StaticClientConfig.GrantType) + grantType = &value + } + staticClientConfig = &MCPAuthStaticClientConfig{ + ClientID: d.StaticClientConfig.ClientID, + ClientSecret: d.StaticClientConfig.ClientSecret, + GrantType: grantType, + PublicClient: d.StaticClientConfig.PublicClient, + } + } + request := MCPAuthRequest{ + RequestID: d.RequestID, + ServerName: d.ServerName, + ServerURL: d.ServerURL, + Reason: d.Reason, + StaticClientConfig: staticClientConfig, + } + if d.ResourceMetadata != nil { + request.ResourceMetadata = d.ResourceMetadata + } + if d.WwwAuthenticateParams != nil { + request.WwwAuthenticateParams = &MCPAuthWwwAuthenticateParams{ + ResourceMetadataURL: d.WwwAuthenticateParams.ResourceMetadataURL, + Scope: d.WwwAuthenticateParams.Scope, + Error: d.WwwAuthenticateParams.Error, + } + } + s.handleMCPAuthRequest(request) + case *CommandExecuteData: s.executeCommandAndRespond(d.RequestID, d.CommandName, d.Command, d.Args) diff --git a/go/session_test.go b/go/session_test.go index 654be6ce46..277ea29e3e 100644 --- a/go/session_test.go +++ b/go/session_test.go @@ -60,6 +60,205 @@ func TestSession_SetModelOmitsContextTierWhenUnset(t *testing.T) { } } +func TestSession_MCPAuthRequestSendsHostToken(t *testing.T) { + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + defer stdinR.Close() + defer stdinW.Close() + defer stdoutR.Close() + defer stdoutW.Close() + + client := jsonrpc2.NewClient(stdinW, stdoutR) + client.Start() + defer client.Stop() + + paramsCh := make(chan map[string]any, 1) + errCh := make(chan error, 1) + + go func() { + frame, err := readTestJSONRPCFrame(stdinR) + if err != nil { + errCh <- err + return + } + + var request struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params map[string]any `json:"params"` + } + if err := json.Unmarshal(frame, &request); err != nil { + errCh <- err + return + } + if request.Method != "session.mcp.oauth.handlePendingRequest" { + errCh <- fmt.Errorf("expected session.mcp.oauth.handlePendingRequest, got %s", request.Method) + return + } + + paramsCh <- request.Params + + response := map[string]any{ + "jsonrpc": "2.0", + "id": json.RawMessage(request.ID), + "result": map[string]any{"success": true}, + } + data, err := json.Marshal(response) + if err != nil { + errCh <- err + return + } + if _, err := fmt.Fprintf(stdoutW, "Content-Length: %d\r\n\r\n%s", len(data), data); err != nil { + errCh <- err + } + }() + + session := &Session{ + SessionID: "session-1", + client: client, + RPC: rpc.NewSessionRPC(client, "session-1"), + } + var observedRequest MCPAuthRequest + session.registerMCPAuthHandler(func(request MCPAuthRequest, invocation MCPAuthInvocation) (*MCPAuthResult, error) { + observedRequest = request + if invocation.SessionID != "session-1" { + t.Fatalf("expected invocation session-1, got %s", invocation.SessionID) + } + if request.RequestID != "oauth-request" { + t.Fatalf("expected oauth-request, got %s", request.RequestID) + } + tokenType := "Bearer" + return MCPAuthResultToken(&MCPAuthToken{ + AccessToken: "host-token", + TokenType: &tokenType, + }), nil + }) + resourceMetadataURL := "https://example.com/.well-known/oauth-protected-resource" + resourceMetadata := `{"resource":"https://example.com/mcp"}` + clientSecret := "static-secret" + grantType := rpc.MCPOauthRequiredStaticClientConfigGrantTypeClientCredentials + publicClient := false + session.handleBroadcastEvent(SessionEvent{ + Data: &MCPOauthRequiredData{ + RequestID: "oauth-request", + Reason: rpc.MCPOauthRequestReasonInitial, + ServerName: "oauth-server", + ServerURL: "https://example.com/mcp", + ResourceMetadata: &resourceMetadata, + StaticClientConfig: &MCPOauthRequiredStaticClientConfig{ + ClientID: "static-client", + ClientSecret: &clientSecret, + GrantType: &grantType, + PublicClient: &publicClient, + }, + WwwAuthenticateParams: &MCPOauthWwwAuthenticateParams{ + ResourceMetadataURL: &resourceMetadataURL, + }, + }, + }) + if observedRequest.ResourceMetadata == nil || *observedRequest.ResourceMetadata != `{"resource":"https://example.com/mcp"}` { + t.Fatalf("expected resource metadata to be propagated, got %#v", observedRequest.ResourceMetadata) + } + if observedRequest.Reason != MCPOauthRequestReasonInitial { + t.Fatalf("expected initial reason, got %q", observedRequest.Reason) + } + if observedRequest.WwwAuthenticateParams == nil { + t.Fatal("expected WWW-Authenticate params to be propagated") + } + if observedRequest.StaticClientConfig == nil { + t.Fatal("expected static client config to be propagated") + } + if observedRequest.StaticClientConfig.ClientSecret == nil || *observedRequest.StaticClientConfig.ClientSecret != "static-secret" { + t.Fatalf("expected static client secret to be propagated, got %#v", observedRequest.StaticClientConfig.ClientSecret) + } + if observedRequest.StaticClientConfig.GrantType == nil || *observedRequest.StaticClientConfig.GrantType != "client_credentials" { + t.Fatalf("expected static client grant type to be propagated, got %#v", observedRequest.StaticClientConfig.GrantType) + } + + select { + case params := <-paramsCh: + if params["sessionId"] != "session-1" { + t.Fatalf("expected sessionId session-1, got %v", params["sessionId"]) + } + if params["requestId"] != "oauth-request" { + t.Fatalf("expected requestId oauth-request, got %v", params["requestId"]) + } + result, ok := params["result"].(map[string]any) + if !ok { + t.Fatalf("expected result object, got %T", params["result"]) + } + if result["kind"] != "token" { + t.Fatalf("expected token kind, got %v", result["kind"]) + } + if result["accessToken"] != "host-token" { + t.Fatalf("expected accessToken host-token, got %v", result["accessToken"]) + } + if result["tokenType"] != "Bearer" { + t.Fatalf("expected tokenType Bearer, got %v", result["tokenType"]) + } + case err := <-errCh: + t.Fatal(err) + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for MCP OAuth request") + } +} + +func TestMCPAuthRequestAllowsMissingOptionalMetadata(t *testing.T) { + request := MCPAuthRequest{RequestID: "oauth-request"} + if request.ResourceMetadata != nil { + t.Fatalf("expected no resource metadata, got %#v", request.ResourceMetadata) + } + if request.WwwAuthenticateParams != nil { + t.Fatalf("expected no WWW-Authenticate params, got %#v", request.WwwAuthenticateParams) + } +} + +func TestMCPOauthRequiredDataAllowsOptionalMetadata(t *testing.T) { + var withMetadata rpc.MCPOauthRequiredData + if err := json.Unmarshal([]byte(`{ + "requestId": "oauth-request", + "reason": "initial", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\"resource\":\"https://example.com/mcp\"}", + "staticClientConfig": { + "clientId": "static-client", + "clientSecret": "static-secret", + "publicClient": false + } + }`), &withMetadata); err != nil { + t.Fatal(err) + } + if withMetadata.ResourceMetadata == nil || *withMetadata.ResourceMetadata != `{"resource":"https://example.com/mcp"}` { + t.Fatalf("expected resource metadata, got %#v", withMetadata.ResourceMetadata) + } + if withMetadata.WwwAuthenticateParams == nil { + t.Fatal("expected WWW-Authenticate params") + } + if withMetadata.StaticClientConfig == nil || withMetadata.StaticClientConfig.ClientSecret == nil || *withMetadata.StaticClientConfig.ClientSecret != "static-secret" { + t.Fatalf("expected static client secret, got %#v", withMetadata.StaticClientConfig) + } + + var withoutMetadata rpc.MCPOauthRequiredData + if err := json.Unmarshal([]byte(`{ + "requestId": "oauth-request", + "reason": "initial", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + }`), &withoutMetadata); err != nil { + t.Fatal(err) + } + if withoutMetadata.ResourceMetadata != nil { + t.Fatalf("expected no resource metadata, got %#v", withoutMetadata.ResourceMetadata) + } + if withoutMetadata.WwwAuthenticateParams != nil { + t.Fatalf("expected no WWW-Authenticate params, got %#v", withoutMetadata.WwwAuthenticateParams) + } +} + func captureSetModelRequest(t *testing.T, opts *SetModelOptions) map[string]any { t.Helper() diff --git a/go/types.go b/go/types.go index 8a7df3c46a..ffb8af12a4 100644 --- a/go/types.go +++ b/go/types.go @@ -326,6 +326,70 @@ type PermissionInvocation struct { SessionID string } +// MCPAuthWwwAuthenticateParams contains parsed parameters from an MCP server's WWW-Authenticate response. +type MCPAuthWwwAuthenticateParams struct { + ResourceMetadataURL *string `json:"resourceMetadataUrl,omitempty"` + Scope *string `json:"scope,omitempty"` + Error *string `json:"error,omitempty"` +} + +// MCPAuthStaticClientConfig is static OAuth client configuration supplied by an MCP server. +type MCPAuthStaticClientConfig struct { + ClientID string `json:"clientId"` + ClientSecret *string `json:"clientSecret,omitempty"` + GrantType *string `json:"grantType,omitempty"` + PublicClient *bool `json:"publicClient,omitempty"` +} + +// MCPAuthRequest describes an MCP OAuth request that the SDK host can satisfy with a token. +type MCPAuthRequest struct { + RequestID string `json:"requestId"` + ServerName string `json:"serverName"` + ServerURL string `json:"serverUrl"` + Reason MCPOauthRequestReason `json:"reason"` + WwwAuthenticateParams *MCPAuthWwwAuthenticateParams `json:"wwwAuthenticateParams,omitempty"` + ResourceMetadata *string `json:"resourceMetadata,omitempty"` + StaticClientConfig *MCPAuthStaticClientConfig `json:"staticClientConfig,omitempty"` +} + +// MCPAuthToken is host-provided OAuth token data for a pending MCP OAuth request. +type MCPAuthToken struct { + AccessToken string `json:"accessToken"` + TokenType *string `json:"tokenType,omitempty"` + ExpiresIn *int64 `json:"expiresIn,omitempty"` +} + +// MCPAuthResult is the result returned by an MCP auth request handler. +type MCPAuthResult struct { + Kind string + Token *MCPAuthToken +} + +const ( + // MCPAuthResultKindToken indicates that the host provided token data. + MCPAuthResultKindToken = "token" + // MCPAuthResultKindCancelled indicates that the host declined the request. + MCPAuthResultKindCancelled = "cancelled" +) + +// MCPAuthResultToken returns a token result for an MCP OAuth request. +func MCPAuthResultToken(token *MCPAuthToken) *MCPAuthResult { + return &MCPAuthResult{Kind: MCPAuthResultKindToken, Token: token} +} + +// MCPAuthResultCancelled returns a cancelled result for an MCP OAuth request. +func MCPAuthResultCancelled() *MCPAuthResult { + return &MCPAuthResult{Kind: MCPAuthResultKindCancelled} +} + +// MCPAuthInvocation provides context about an MCP auth handler invocation. +type MCPAuthInvocation struct { + SessionID string +} + +// MCPAuthHandler handles MCP OAuth requests from the runtime. +type MCPAuthHandler func(request MCPAuthRequest, invocation MCPAuthInvocation) (*MCPAuthResult, error) + // UserInputRequest represents a request for user input from the agent type UserInputRequest struct { Question string @@ -975,6 +1039,10 @@ type SessionConfig struct { // When nil, permission requests are surfaced as events and left pending for the // consumer to resolve via pending permission RPCs. OnPermissionRequest PermissionHandlerFunc + // OnMCPAuthRequest is an optional handler for MCP OAuth requests from MCP servers. + // When provided, the SDK can satisfy MCP server OAuth requests with host-provided + // token data or cancellation. + OnMCPAuthRequest MCPAuthHandler // OnUserInputRequest is a handler for user input requests from the agent (enables ask_user tool) OnUserInputRequest UserInputHandler // Hooks configures hook handlers for session lifecycle events @@ -1405,6 +1473,9 @@ type ResumeSessionConfig struct { // When nil, permission requests are surfaced as events and left pending for the // consumer to resolve via pending permission RPCs. OnPermissionRequest PermissionHandlerFunc + // OnMCPAuthRequest is an optional handler for MCP OAuth requests from MCP servers. + // See SessionConfig.OnMCPAuthRequest. + OnMCPAuthRequest MCPAuthHandler // OnUserInputRequest is a handler for user input requests from the agent (enables ask_user tool) OnUserInputRequest UserInputHandler // Hooks configures hook handlers for session lifecycle events diff --git a/java/src/main/java/com/github/copilot/CopilotClient.java b/java/src/main/java/com/github/copilot/CopilotClient.java index 8473b3bb45..9384bc708c 100644 --- a/java/src/main/java/com/github/copilot/CopilotClient.java +++ b/java/src/main/java/com/github/copilot/CopilotClient.java @@ -27,6 +27,7 @@ import com.github.copilot.generated.rpc.SessionInstalledPlugin; import com.github.copilot.generated.rpc.ConnectParams; import com.github.copilot.generated.rpc.ServerRpc; +import com.github.copilot.generated.rpc.SessionEventLogRegisterInterestParams; import com.github.copilot.rpc.DeleteSessionResponse; import com.github.copilot.rpc.GetAuthStatusResponse; import com.github.copilot.rpc.GetLastSessionIdResponse; @@ -638,20 +639,27 @@ public CompletableFuture createSession(SessionConfig config) { ? preRegisteredSessionHolder[0] : initializeSession.apply(returnedId); registeredIdHolder[0] = returnedId; + CompletableFuture interest = config.getOnMcpAuthRequest() != null + ? session.getRpc().eventLog.registerInterest( + new SessionEventLogRegisterInterestParams(returnedId, "mcp.oauth_required")) + : CompletableFuture.completedFuture(null); session.setWorkspacePath(response.workspacePath()); session.setCapabilities(response.capabilities()); session.setOpenCanvases(response.openCanvases()); - return updateSessionOptionsForMode(session, config.getSkipCustomInstructions().orElse(null), - config.getCustomAgentsLocalOnly().orElse(null), - config.getCoauthorEnabled().orElse(null), - config.getManageScheduleEnabled().orElse(null)).thenApply(v -> { - LoggingHelpers.logTiming(LOG, Level.FINE, - "CopilotClient.createSession complete. Elapsed={Elapsed}, SessionId=" - + session.getSessionId(), - totalNanos); - return session; - }); + return interest.thenCompose(interestResult -> { + logMcpAuthInterestRegistration(interestResult); + return updateSessionOptionsForMode(session, config.getSkipCustomInstructions().orElse(null), + config.getCustomAgentsLocalOnly().orElse(null), + config.getCoauthorEnabled().orElse(null), + config.getManageScheduleEnabled().orElse(null)); + }).thenApply(v -> { + LoggingHelpers.logTiming(LOG, Level.FINE, + "CopilotClient.createSession complete. Elapsed={Elapsed}, SessionId=" + + session.getSessionId(), + totalNanos); + return session; + }); }).exceptionally(ex -> { if (registeredIdHolder[0] != null) { sessions.remove(registeredIdHolder[0]); @@ -665,6 +673,12 @@ public CompletableFuture createSession(SessionConfig config) { }); } + private static void logMcpAuthInterestRegistration(Object interestResult) { + if (interestResult != null && LOG.isLoggable(Level.FINEST)) { + LOG.finest("MCP OAuth event interest registered"); + } + } + /** * Resumes an existing Copilot session. *

@@ -714,7 +728,6 @@ public CompletableFuture resumeSession(String sessionId, ResumeS if (extracted.transformCallbacks() != null) { session.registerTransformCallbacks(extracted.transformCallbacks()); } - var request = SessionRequestBuilder.buildResumeRequest(sessionId, config); if (extracted.wireSystemMessage() != config.getSystemMessage()) { request.setSystemMessage(extracted.wireSystemMessage()); @@ -766,6 +779,17 @@ public CompletableFuture resumeSession(String sessionId, ResumeS "CopilotClient.resumeSession session resume request completed. Elapsed={Elapsed}, SessionId=" + sessionId, rpcNanos); + String returnedId = response.sessionId(); + String interestSessionId = returnedId != null ? returnedId : sessionId; + CompletableFuture interest = config.getOnMcpAuthRequest() != null + ? session.getRpc().eventLog.registerInterest(new SessionEventLogRegisterInterestParams( + interestSessionId, "mcp.oauth_required")) + : CompletableFuture.completedFuture(null); + return interest.thenApply(interestResult -> { + logMcpAuthInterestRegistration(interestResult); + return response; + }); + }).thenCompose(response -> { session.setWorkspacePath(response.workspacePath()); session.setCapabilities(response.capabilities()); session.setOpenCanvases(response.openCanvases()); diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index 90f76b6df5..194ce12773 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -33,6 +33,7 @@ import com.github.copilot.generated.rpc.SessionCommandsHandlePendingCommandParams; import com.github.copilot.generated.rpc.SessionLogParams; import com.github.copilot.generated.rpc.SessionLogLevel; +import com.github.copilot.generated.rpc.SessionMcpOauthHandlePendingRequestParams; import com.github.copilot.generated.rpc.ModelCapabilitiesOverride; import com.github.copilot.generated.rpc.ModelCapabilitiesOverrideLimits; import com.github.copilot.generated.rpc.ModelCapabilitiesOverrideSupports; @@ -49,6 +50,7 @@ import com.github.copilot.generated.CommandExecuteEvent; import com.github.copilot.generated.ElicitationRequestedEvent; import com.github.copilot.generated.ExternalToolRequestedEvent; +import com.github.copilot.generated.McpOauthRequiredEvent; import com.github.copilot.generated.PermissionRequestedEvent; import com.github.copilot.generated.SessionCanvasClosedEvent; import com.github.copilot.generated.SessionCanvasOpenedEvent; @@ -79,6 +81,10 @@ import com.github.copilot.rpc.HookInvocation; import com.github.copilot.rpc.InputOptions; import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.McpAuthHandler; +import com.github.copilot.rpc.McpAuthInvocation; +import com.github.copilot.rpc.McpAuthRequest; +import com.github.copilot.rpc.McpAuthResult; import com.github.copilot.rpc.PermissionHandler; import com.github.copilot.rpc.PermissionInvocation; import com.github.copilot.rpc.PermissionRequest; @@ -171,6 +177,7 @@ public final class CopilotSession implements AutoCloseable { private final Map commandHandlers = new ConcurrentHashMap<>(); private final Map bearerTokenProviders = new ConcurrentHashMap<>(); private final AtomicReference permissionHandler = new AtomicReference<>(); + private final AtomicReference mcpAuthHandler = new AtomicReference<>(); private final AtomicReference userInputHandler = new AtomicReference<>(); private final AtomicReference elicitationHandler = new AtomicReference<>(); private final AtomicReference exitPlanModeHandler = new AtomicReference<>(); @@ -839,6 +846,20 @@ private void handleBroadcastEventAsync(SessionEvent event) { } executePermissionAndRespondAsync(data.requestId(), MAPPER.convertValue(data.permissionRequest(), PermissionRequest.class), handler); + } else if (event instanceof McpOauthRequiredEvent authEvent) { + var data = authEvent.getData(); + if (data == null || data.requestId() == null) { + return; + } + McpAuthHandler handler = mcpAuthHandler.get(); + if (handler == null) { + LOG.warning(() -> "Received MCP OAuth request without a registered MCP auth handler. SessionId=" + + sessionId + ", RequestId=" + data.requestId()); + return; + } + executeMcpAuthAndRespondAsync(new McpAuthRequest(data.requestId(), data.serverName(), data.serverUrl(), + data.reason(), data.wwwAuthenticateParams(), data.resourceMetadata(), data.staticClientConfig()), + handler); } else if (event instanceof CommandExecuteEvent cmdEvent) { var data = cmdEvent.getData(); if (data == null || data.requestId() == null || data.commandName() == null) { @@ -1006,6 +1027,58 @@ private void executePermissionAndRespondAsync(String requestId, PermissionReques } } + private void executeMcpAuthAndRespondAsync(McpAuthRequest request, McpAuthHandler handler) { + Runnable task = () -> { + try { + var invocation = new McpAuthInvocation().setSessionId(sessionId); + handler.handle(request, invocation) + .thenAccept(result -> sendMcpAuthResponse(request.requestId(), result)).exceptionally(ex -> { + sendMcpAuthResponse(request.requestId(), McpAuthResult.cancelled()); + return null; + }); + } catch (Exception e) { + LOG.log(Level.WARNING, "Error executing MCP auth handler for requestId=" + request.requestId(), e); + sendMcpAuthResponse(request.requestId(), McpAuthResult.cancelled()); + } + }; + try { + if (executor != null) { + CompletableFuture.runAsync(task, executor); + } else { + CompletableFuture.runAsync(task); + } + } catch (RejectedExecutionException e) { + LOG.log(Level.WARNING, + "Executor rejected MCP auth task for requestId=" + request.requestId() + "; running inline", e); + task.run(); + } + } + + private void sendMcpAuthResponse(String requestId, McpAuthResult result) { + try { + Object response; + if (result == null || result.isCancelled() || result.token() == null) { + response = Map.of("kind", "cancelled"); + } else { + var token = result.token(); + var tokenResponse = new java.util.HashMap(); + tokenResponse.put("kind", "token"); + tokenResponse.put("accessToken", token.accessToken()); + if (token.tokenType() != null) { + tokenResponse.put("tokenType", token.tokenType()); + } + if (token.expiresIn() != null) { + tokenResponse.put("expiresIn", token.expiresIn()); + } + response = tokenResponse; + } + getRpc().mcp.oauth.handlePendingRequest( + new SessionMcpOauthHandlePendingRequestParams(sessionId, requestId, response)); + } catch (Exception e) { + LOG.log(Level.WARNING, "Error sending MCP auth response for requestId=" + requestId, e); + } + } + /** * Registers custom tool handlers for this session. *

@@ -1269,6 +1342,10 @@ void registerPermissionHandler(PermissionHandler handler) { permissionHandler.set(handler); } + void registerMcpAuthHandler(McpAuthHandler handler) { + mcpAuthHandler.set(handler); + } + /** * Handles a permission request from the Copilot CLI. *

diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java index 6000bdef82..8a4b016e1b 100644 --- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java @@ -323,6 +323,9 @@ static void configureSession(CopilotSession session, SessionConfig config) { if (config.getOnPermissionRequest() != null) { session.registerPermissionHandler(config.getOnPermissionRequest()); } + if (config.getOnMcpAuthRequest() != null) { + session.registerMcpAuthHandler(config.getOnMcpAuthRequest()); + } if (config.getOnUserInputRequest() != null) { session.registerUserInputHandler(config.getOnUserInputRequest()); } @@ -370,6 +373,9 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) if (config.getOnPermissionRequest() != null) { session.registerPermissionHandler(config.getOnPermissionRequest()); } + if (config.getOnMcpAuthRequest() != null) { + session.registerMcpAuthHandler(config.getOnMcpAuthRequest()); + } if (config.getOnUserInputRequest() != null) { session.registerUserInputHandler(config.getOnUserInputRequest()); } diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java b/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java new file mode 100644 index 0000000000..55c6a6f180 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthHandler.java @@ -0,0 +1,26 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import java.util.concurrent.CompletableFuture; + +/** + * Handles MCP OAuth requests from the runtime. + * + * @since 1.0.0 + */ +@FunctionalInterface +public interface McpAuthHandler { + /** + * Handles an MCP OAuth request. + * + * @param request + * the MCP OAuth request details + * @param invocation + * the invocation context with session information + * @return a future resolving to token data or cancellation + */ + CompletableFuture handle(McpAuthRequest request, McpAuthInvocation invocation); +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthInvocation.java b/java/src/main/java/com/github/copilot/rpc/McpAuthInvocation.java new file mode 100644 index 0000000000..c7a80a96d3 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthInvocation.java @@ -0,0 +1,36 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +/** + * Context for an MCP OAuth request invocation. + * + * @since 1.0.0 + */ +public class McpAuthInvocation { + + private String sessionId; + + /** + * Gets the session ID. + * + * @return the session ID + */ + public String getSessionId() { + return sessionId; + } + + /** + * Sets the session ID. + * + * @param sessionId + * the session ID + * @return this instance for method chaining + */ + public McpAuthInvocation setSessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java new file mode 100644 index 0000000000..a672685557 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthRequest.java @@ -0,0 +1,19 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +import com.github.copilot.generated.McpOauthRequiredStaticClientConfig; +import com.github.copilot.generated.McpOauthRequestReason; +import com.github.copilot.generated.McpOauthWWWAuthenticateParams; + +/** + * MCP OAuth request that the SDK host can satisfy with a host-acquired token. + * + * @since 1.0.0 + */ +public record McpAuthRequest(String requestId, String serverName, String serverUrl, McpOauthRequestReason reason, + McpOauthWWWAuthenticateParams wwwAuthenticateParams, String resourceMetadata, + McpOauthRequiredStaticClientConfig staticClientConfig) { +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java b/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java new file mode 100644 index 0000000000..6b7fda34f9 --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthResult.java @@ -0,0 +1,32 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +/** + * Result returned by an MCP auth request handler. + * + * @since 1.0.0 + */ +public record McpAuthResult(boolean isCancelled, McpAuthToken token) { + /** + * Creates a token result. + * + * @param token + * the host-provided OAuth token data + * @return token result + */ + public static McpAuthResult token(McpAuthToken token) { + return new McpAuthResult(false, token); + } + + /** + * Creates a cancellation result. + * + * @return cancellation result + */ + public static McpAuthResult cancelled() { + return new McpAuthResult(true, null); + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java b/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java new file mode 100644 index 0000000000..3cf6748fbf --- /dev/null +++ b/java/src/main/java/com/github/copilot/rpc/McpAuthToken.java @@ -0,0 +1,13 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot.rpc; + +/** + * Host-provided OAuth token data for a pending MCP OAuth request. + * + * @since 1.0.0 + */ +public record McpAuthToken(String accessToken, String tokenType, Long expiresIn) { +} diff --git a/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java b/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java index e3e79eab01..48e333f05b 100644 --- a/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/ResumeSessionConfig.java @@ -60,6 +60,7 @@ public class ResumeSessionConfig { private String contextTier; private ModelCapabilitiesOverride modelCapabilities; private PermissionHandler onPermissionRequest; + private McpAuthHandler onMcpAuthRequest; private UserInputHandler onUserInputRequest; private SessionHooks hooks; private String workingDirectory; @@ -635,6 +636,28 @@ public ResumeSessionConfig setOnPermissionRequest(PermissionHandler onPermission return this; } + /** + * Gets the MCP OAuth request handler. + * + * @return the handler, or {@code null} if not set + */ + @JsonIgnore + public McpAuthHandler getOnMcpAuthRequest() { + return onMcpAuthRequest; + } + + /** + * Sets the MCP OAuth request handler. + * + * @param onMcpAuthRequest + * the handler + * @return this config instance for method chaining + */ + public ResumeSessionConfig setOnMcpAuthRequest(McpAuthHandler onMcpAuthRequest) { + this.onMcpAuthRequest = onMcpAuthRequest; + return this; + } + /** * Gets the user input request handler. * @@ -1697,6 +1720,7 @@ public ResumeSessionConfig clone() { copy.onEvent = this.onEvent; copy.commands = this.commands != null ? new ArrayList<>(this.commands) : null; copy.onElicitationRequest = this.onElicitationRequest; + copy.onMcpAuthRequest = this.onMcpAuthRequest; copy.onExitPlanMode = this.onExitPlanMode; copy.onAutoModeSwitch = this.onAutoModeSwitch; copy.enableMcpApps = this.enableMcpApps; diff --git a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java index 38b357e7e3..e5e0e629e1 100644 --- a/java/src/main/java/com/github/copilot/rpc/SessionConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/SessionConfig.java @@ -60,6 +60,7 @@ public class SessionConfig { private Boolean coauthorEnabled; private Boolean manageScheduleEnabled; private PermissionHandler onPermissionRequest; + private McpAuthHandler onMcpAuthRequest; private UserInputHandler onUserInputRequest; private SessionHooks hooks; private String workingDirectory; @@ -678,6 +679,31 @@ public SessionConfig setOnPermissionRequest(PermissionHandler onPermissionReques return this; } + /** + * Gets the MCP OAuth request handler. + * + * @return the handler, or {@code null} if not set + */ + @JsonIgnore + public McpAuthHandler getOnMcpAuthRequest() { + return onMcpAuthRequest; + } + + /** + * Sets the MCP OAuth request handler. + *

+ * When provided, the SDK can satisfy MCP server OAuth requests with + * host-provided token data or cancellation. + * + * @param onMcpAuthRequest + * the handler + * @return this config instance for method chaining + */ + public SessionConfig setOnMcpAuthRequest(McpAuthHandler onMcpAuthRequest) { + this.onMcpAuthRequest = onMcpAuthRequest; + return this; + } + /** * Gets the user input request handler. * @@ -1829,6 +1855,7 @@ public SessionConfig clone() { copy.onEvent = this.onEvent; copy.commands = this.commands != null ? new ArrayList<>(this.commands) : null; copy.onElicitationRequest = this.onElicitationRequest; + copy.onMcpAuthRequest = this.onMcpAuthRequest; copy.onExitPlanMode = this.onExitPlanMode; copy.onAutoModeSwitch = this.onAutoModeSwitch; copy.enableMcpApps = this.enableMcpApps; diff --git a/java/src/test/java/com/github/copilot/E2ETestContext.java b/java/src/test/java/com/github/copilot/E2ETestContext.java index 4089e10ff7..4a4da04229 100644 --- a/java/src/test/java/com/github/copilot/E2ETestContext.java +++ b/java/src/test/java/com/github/copilot/E2ETestContext.java @@ -288,6 +288,8 @@ public Map getEnvironment() { env.put("GH_CONFIG_DIR", homeDir.toString()); env.put("XDG_CONFIG_HOME", homeDir.toString()); env.put("XDG_STATE_HOME", homeDir.toString()); + env.put("COPILOT_MCP_APPS", "true"); + env.put("MCP_APPS", "true"); // Configure CONNECT proxy for HTTPS interception if available String connectUrl = proxy.getConnectProxyUrl(); @@ -438,7 +440,6 @@ private static Path findRepoRoot() throws IOException { } private static String getCliPath(Path repoRoot) throws IOException { - // Try environment variable first (explicit override) String envPath = System.getenv("COPILOT_CLI_PATH"); if (envPath != null && !envPath.isEmpty()) { return envPath; diff --git a/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java new file mode 100644 index 0000000000..06ac08a2a4 --- /dev/null +++ b/java/src/test/java/com/github/copilot/McpAuthInterestRegistrationTest.java @@ -0,0 +1,299 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.github.copilot.generated.McpOauthRequiredEvent; +import com.github.copilot.rpc.CloudSessionOptions; +import com.github.copilot.rpc.CloudSessionRepository; +import com.github.copilot.rpc.CopilotClientOptions; +import com.github.copilot.rpc.McpAuthResult; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.ResumeSessionConfig; +import com.github.copilot.rpc.SessionConfig; + +class McpAuthInterestRegistrationTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + @Test + void mcpOauthRequiredEventExposesOptionalResourceMetadata() throws Exception { + var data = MAPPER.readValue(""" + { + "requestId": "oauth-request", + "reason": "initial", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\\"resource\\":\\"https://example.com/mcp\\"}", + "staticClientConfig": { + "clientId": "static-client", + "clientSecret": "static-secret", + "grantType": "client_credentials", + "publicClient": false + } + } + """, McpOauthRequiredEvent.McpOauthRequiredEventData.class); + + assertEquals("{\"resource\":\"https://example.com/mcp\"}", data.resourceMetadata()); + assertNotNull(data.wwwAuthenticateParams()); + assertNotNull(data.staticClientConfig()); + assertEquals("static-secret", data.staticClientConfig().clientSecret()); + + var withoutMetadata = MAPPER.readValue(""" + { + "requestId": "oauth-request", + "reason": "initial", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + } + """, McpOauthRequiredEvent.McpOauthRequiredEventData.class); + + assertNull(withoutMetadata.resourceMetadata()); + assertNull(withoutMetadata.wwwAuthenticateParams()); + } + + @Test + void createSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + try (var session = client.createSession( + new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setOnEvent(event -> { + })).get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + assertTrue(server.requests().stream().anyMatch(request -> "session.create".equals(request.method()) + && request.params().path("requestPermission").asBoolean())); + + server.clearRequests(); + + try (var session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest((request, invocation) -> { + assertNotNull(request); + assertNotNull(invocation); + return java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()); + })) + .get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.create", requests.get(0).method()); + assertEquals("session.eventLog.registerInterest", requests.get(1).method()); + assertEquals("mcp.oauth_required", requests.get(1).params().path("eventType").asText()); + } + } + + @Test + void cloudCreateSessionRegistersMcpAuthInterestAfterCreateOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + var cloud = new CloudSessionOptions().setRepository( + new CloudSessionRepository().setOwner("github").setName("copilot-sdk").setBranch("main")); + + try (var session = client + .createSession( + new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setCloud(cloud)) + .get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + server.clearRequests(); + + try (var session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setCloud(cloud).setOnMcpAuthRequest((request, invocation) -> { + assertNotNull(request); + assertNotNull(invocation); + return java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()); + })) + .get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.create", requests.get(0).method()); + assertEquals("session.eventLog.registerInterest", requests.get(1).method()); + assertEquals("mcp.oauth_required", requests.get(1).params().path("eventType").asText()); + } + } + + @Test + void resumeSessionRegistersMcpAuthInterestOnlyWhenHandlerConfigured() throws Exception { + try (var server = new RecordingRuntime(); + var client = new CopilotClient(new CopilotClientOptions().setCliUrl(server.url()))) { + try (var session = client.resumeSession("session-without-auth", new ResumeSessionConfig() + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL).setOnEvent(event -> { + })).get()) { + assertNotNull(session); + } + + assertNoMcpAuthInterest(server.requests()); + assertTrue(server.requests().stream().anyMatch(request -> "session.resume".equals(request.method()) + && request.params().path("requestPermission").asBoolean())); + + server.clearRequests(); + + try (var session = client.resumeSession("session-with-auth", + new ResumeSessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest((request, invocation) -> { + assertNotNull(request); + assertNotNull(invocation); + return java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()); + })) + .get()) { + assertNotNull(session); + } + + List requests = server.requests(); + assertEquals("session.resume", requests.get(0).method()); + assertEquals("session.eventLog.registerInterest", requests.get(1).method()); + assertEquals("mcp.oauth_required", requests.get(1).params().path("eventType").asText()); + } + } + + private static void assertNoMcpAuthInterest(List requests) { + assertFalse(requests.stream().anyMatch(request -> "session.eventLog.registerInterest".equals(request.method()) + && "mcp.oauth_required".equals(request.params().path("eventType").asText()))); + } + + private record RpcRequest(String method, JsonNode params) { + } + + private static final class RecordingRuntime implements AutoCloseable { + private final ServerSocket listener; + private final Thread thread; + private final List requests = new CopyOnWriteArrayList<>(); + private volatile boolean running = true; + + RecordingRuntime() throws Exception { + listener = new ServerSocket(0); + thread = new Thread(this::run, "mcp-auth-interest-test-runtime"); + thread.setDaemon(true); + thread.start(); + } + + String url() { + return "127.0.0.1:" + listener.getLocalPort(); + } + + List requests() { + return List.copyOf(requests); + } + + void clearRequests() { + requests.clear(); + } + + @Override + public void close() throws Exception { + running = false; + listener.close(); + thread.join(2000); + } + + private void run() { + try (Socket socket = listener.accept()) { + var in = socket.getInputStream(); + var out = socket.getOutputStream(); + while (running) { + JsonNode message = readMessage(in); + if (message == null) { + return; + } + String method = message.path("method").asText(); + requests.add(new RpcRequest(method, message.path("params").deepCopy())); + sendResponse(out, message.path("id").asLong(), resultFor(method, message.path("params"))); + } + } catch (Exception ex) { + if (running) { + throw new RuntimeException(ex); + } + } + } + + private static JsonNode resultFor(String method, JsonNode params) { + ObjectNode result = MAPPER.createObjectNode(); + switch (method) { + case "connect" -> { + result.put("ok", true); + result.put("protocolVersion", 3); + result.put("version", "test"); + } + case "session.create", "session.resume" -> { + String sessionId = params.path("sessionId").asText("server-assigned-session"); + if (sessionId.isEmpty()) { + sessionId = "server-assigned-session"; + } + result.put("sessionId", sessionId); + result.putNull("workspacePath"); + result.putNull("capabilities"); + } + case "session.eventLog.registerInterest" -> result.put("id", "interest-1"); + case "session.options.update" -> result.put("success", true); + case "session.skills.reload", "session.destroy" -> { + } + default -> throw new IllegalStateException("Unexpected RPC method " + method); + } + return result; + } + + private static JsonNode readMessage(java.io.InputStream in) throws Exception { + StringBuilder header = new StringBuilder(); + int b; + while ((b = in.read()) != -1) { + header.append((char) b); + if (header.toString().endsWith("\r\n\r\n")) { + break; + } + } + if (b == -1) { + return null; + } + int contentLength = 0; + for (String line : header.toString().split("\r\n")) { + int colon = line.indexOf(':'); + if (colon > 0 && "Content-Length".equals(line.substring(0, colon))) { + contentLength = Integer.parseInt(line.substring(colon + 1).trim()); + } + } + byte[] body = in.readNBytes(contentLength); + return MAPPER.readTree(body); + } + + private static void sendResponse(OutputStream out, long id, JsonNode result) throws Exception { + ObjectNode response = MAPPER.createObjectNode(); + response.put("jsonrpc", "2.0"); + response.put("id", id); + response.set("result", result); + byte[] body = MAPPER.writeValueAsBytes(response); + out.write(("Content-Length: " + body.length + "\r\n\r\n").getBytes(StandardCharsets.UTF_8)); + out.write(body); + out.flush(); + } + } +} diff --git a/java/src/test/java/com/github/copilot/McpOAuthE2ETest.java b/java/src/test/java/com/github/copilot/McpOAuthE2ETest.java new file mode 100644 index 0000000000..8da3b08b47 --- /dev/null +++ b/java/src/test/java/com/github/copilot/McpOAuthE2ETest.java @@ -0,0 +1,301 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static org.junit.jupiter.api.Assertions.*; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.generated.McpOauthRequestReason; +import com.github.copilot.generated.rpc.SessionMcpAppsCallToolParams; +import com.github.copilot.generated.rpc.McpServerStatus; +import com.github.copilot.generated.rpc.SessionMcpListToolsParams; +import com.github.copilot.rpc.McpAuthInvocation; +import com.github.copilot.rpc.McpAuthResult; +import com.github.copilot.rpc.McpAuthToken; +import com.github.copilot.rpc.McpHttpServerConfig; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.SessionConfig; + +public class McpOAuthE2ETest { + private static final String EXPECTED_TOKEN = "sdk-host-token"; + private static final String REFRESH_TOKEN = EXPECTED_TOKEN + "-refresh"; + private static final String UPSCOPE_TOKEN = EXPECTED_TOKEN + "-upscope"; + private static final String REAUTH_TOKEN = EXPECTED_TOKEN + "-reauth"; + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private static E2ETestContext ctx; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + @Test + void testShouldSatisfyMcpOauthUsingHostProvidedToken() throws Exception { + try (var oauthServer = OAuthMcpServer.start(ctx.getRepoRoot())) { + var serverName = "oauth-protected-mcp"; + var observedRequest = new java.util.concurrent.atomic.AtomicReference(); + var observedInvocation = new java.util.concurrent.atomic.AtomicReference(); + + try (var client = ctx.createClient(); + var session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest((request, invocation) -> { + observedRequest.set(request); + observedInvocation.set(invocation); + return java.util.concurrent.CompletableFuture.completedFuture( + McpAuthResult.token(new McpAuthToken(EXPECTED_TOKEN, "Bearer", 3600L))); + }).setMcpServers(Map.of(serverName, new McpHttpServerConfig() + .setUrl(oauthServer.url() + "/mcp").setTools(List.of("*"))))) + .get()) { + waitForMcpServerStatus(session, serverName, McpServerStatus.CONNECTED, observedRequest); + assertNotNull(observedInvocation.get(), "MCP auth invocation should be provided"); + assertEquals(session.getSessionId(), observedInvocation.get().getSessionId()); + var tools = session.getRpc().mcp.listTools(new SessionMcpListToolsParams(null, serverName)).get(30, + TimeUnit.SECONDS); + assertTrue(tools.tools().stream().anyMatch(tool -> "whoami".equals(tool.name()))); + } + + var request = observedRequest.get(); + assertNotNull(request, "MCP auth handler should be invoked"); + assertEquals(serverName, request.serverName()); + assertEquals(oauthServer.url() + "/mcp", request.serverUrl()); + assertEquals(McpOauthRequestReason.INITIAL, request.reason()); + assertNotNull(request.wwwAuthenticateParams()); + assertEquals(oauthServer.url() + "/.well-known/oauth-protected-resource", + request.wwwAuthenticateParams().resourceMetadataUrl()); + assertEquals("mcp.read", request.wwwAuthenticateParams().scope()); + assertEquals("invalid_token", request.wwwAuthenticateParams().error()); + assertEquals(oauthServer.url() + "/mcp", + MAPPER.readTree(request.resourceMetadata()).path("resource").asText()); + + var requests = oauthServer.requests(); + assertTrue(requests.stream().anyMatch(record -> record.authorization() == null)); + assertTrue( + requests.stream().anyMatch(record -> ("Bearer " + EXPECTED_TOKEN).equals(record.authorization()))); + } + } + + @Test + void testShouldRequestReplacementTokensAcrossMcpOauthLifecycle() throws Exception { + try (var oauthServer = OAuthMcpServer.start(ctx.getRepoRoot())) { + var serverName = "oauth-lifecycle-mcp"; + var observedReasons = new CopyOnWriteArrayList(); + var refreshCount = new java.util.concurrent.atomic.AtomicInteger(); + + try (var client = ctx.createClient(); + var session = client.createSession(new SessionConfig().setEnableMcpApps(true) + .setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest((request, invocation) -> { + assertNotNull(invocation); + observedReasons.add(request.reason()); + var result = switch (request.reason()) { + case REFRESH -> { + assertNotNull(request.wwwAuthenticateParams()); + assertNull(request.wwwAuthenticateParams().resourceMetadataUrl()); + assertEquals("invalid_token", request.wwwAuthenticateParams().error()); + if (refreshCount.incrementAndGet() > 1) { + yield McpAuthResult.cancelled(); + } + yield McpAuthResult.token(new McpAuthToken(REFRESH_TOKEN, null, null)); + } + case UPSCOPE -> { + assertNotNull(request.wwwAuthenticateParams()); + assertEquals(oauthServer.url() + "/.well-known/oauth-protected-resource", + request.wwwAuthenticateParams().resourceMetadataUrl()); + assertEquals("mcp.write", request.wwwAuthenticateParams().scope()); + assertEquals("insufficient_scope", request.wwwAuthenticateParams().error()); + yield McpAuthResult.token(new McpAuthToken(UPSCOPE_TOKEN, null, null)); + } + case REAUTH -> McpAuthResult.token(new McpAuthToken(REAUTH_TOKEN, null, null)); + default -> McpAuthResult.token(new McpAuthToken(EXPECTED_TOKEN, null, null)); + }; + return java.util.concurrent.CompletableFuture.completedFuture(result); + }).setMcpServers(Map.of(serverName, new McpHttpServerConfig() + .setUrl(oauthServer.url() + "/mcp").setTools(List.of("*"))))) + .get()) { + waitForMcpServerStatus(session, serverName, McpServerStatus.CONNECTED, + new java.util.concurrent.atomic.AtomicReference<>()); + callWhoami(session, serverName, "refresh"); + callWhoami(session, serverName, "upscope"); + callWhoami(session, serverName, "reauth"); + } + + assertEquals(List.of(McpOauthRequestReason.INITIAL, McpOauthRequestReason.REFRESH, + McpOauthRequestReason.UPSCOPE, McpOauthRequestReason.REFRESH, McpOauthRequestReason.REAUTH), + observedReasons); + + var requests = oauthServer.requests(); + assertTrue( + requests.stream().anyMatch(record -> ("Bearer " + REFRESH_TOKEN).equals(record.authorization()))); + assertTrue( + requests.stream().anyMatch(record -> ("Bearer " + UPSCOPE_TOKEN).equals(record.authorization()))); + assertTrue(requests.stream().anyMatch(record -> ("Bearer " + REAUTH_TOKEN).equals(record.authorization()))); + } + } + + @Test + void testShouldCancelPendingMcpOauthRequest() throws Exception { + try (var oauthServer = OAuthMcpServer.start(ctx.getRepoRoot())) { + var serverName = "oauth-cancelled-mcp"; + var observedRequest = new java.util.concurrent.atomic.AtomicReference(); + + try (var client = ctx.createClient(); + var session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setOnMcpAuthRequest((request, invocation) -> { + assertNotNull(invocation); + observedRequest.set(request); + return java.util.concurrent.CompletableFuture + .completedFuture(McpAuthResult.cancelled()); + }).setMcpServers(Map.of(serverName, new McpHttpServerConfig() + .setUrl(oauthServer.url() + "/mcp").setTools(List.of("*"))))) + .get()) { + waitForMcpServerStatus(session, serverName, McpServerStatus.FAILED, observedRequest); + } + + var request = observedRequest.get(); + assertNotNull(request, "MCP auth handler should be invoked"); + assertEquals(serverName, request.serverName()); + assertEquals(McpOauthRequestReason.INITIAL, request.reason()); + } + } + + private static void callWhoami(CopilotSession session, String serverName, String scenario) throws Exception { + var result = session.getRpc().mcp.apps.callTool( + new SessionMcpAppsCallToolParams(null, serverName, "whoami", Map.of("scenario", scenario), serverName)) + .get(30, TimeUnit.SECONDS); + var content = result.path("content"); + assertEquals(1, content.size()); + assertEquals("oauth-test-user", content.get(0).path("text").asText()); + } + + private static void waitForMcpServerStatus(CopilotSession session, String serverName, McpServerStatus status, + java.util.concurrent.atomic.AtomicReference observedRequest) + throws Exception { + var deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(60); + var lastStatus = ""; + while (System.nanoTime() < deadline) { + var result = session.getRpc().mcp.list().get(5, TimeUnit.SECONDS); + var server = result.servers().stream().filter(candidate -> serverName.equals(candidate.name())).findFirst(); + if (server.isPresent()) { + lastStatus = String.valueOf(server.get().status()); + } + if (server.isPresent() && status.equals(server.get().status())) { + return; + } + Thread.sleep(200); + } + fail(serverName + " did not reach " + status + "; last status was " + lastStatus + "; auth handler invoked=" + + (observedRequest.get() != null)); + } + + @JsonIgnoreProperties(ignoreUnknown = true) + private record OAuthMcpRequest(String authorization) { + } + + private record OAuthMcpServer(Process process, String url) implements AutoCloseable { + static OAuthMcpServer start(Path repoRoot) throws Exception { + var script = repoRoot.resolve("test").resolve("harness").resolve("test-mcp-oauth-server.mjs"); + var processBuilder = new ProcessBuilder(resolveExecutable("node"), script.toString()); + processBuilder.environment().put("EXPECTED_TOKEN", EXPECTED_TOKEN); + var process = processBuilder.start(); + var stderr = new StringBuilder(); + Thread stderrThread = new Thread(() -> { + try (var reader = new BufferedReader(new InputStreamReader(process.getErrorStream()))) { + reader.lines().forEach(stderr::append); + } catch (IOException ex) { + stderr.append(ex.getMessage()); + } + }); + stderrThread.setDaemon(true); + stderrThread.start(); + try (var reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) { + var deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10); + while (System.nanoTime() < deadline) { + if (reader.ready()) { + var line = reader.readLine(); + if (line != null && line.startsWith("Listening: ")) { + return new OAuthMcpServer(process, line.substring("Listening: ".length())); + } + } + Thread.sleep(50); + } + } + process.destroyForcibly(); + throw new AssertionError("Timed out waiting for OAuth MCP server: " + stderr); + } + + List requests() throws Exception { + var client = HttpClient.newHttpClient(); + var response = client.send(HttpRequest.newBuilder(URI.create(url + "/__requests")) + .timeout(Duration.ofSeconds(10)).GET().build(), HttpResponse.BodyHandlers.ofString()); + assertEquals(200, response.statusCode()); + return MAPPER.readValue(response.body(), new TypeReference>() { + }); + } + + private static String resolveExecutable(String executable) { + var path = System.getenv("PATH"); + if (path == null || path.isBlank()) { + throw new IllegalStateException("PATH is not configured; cannot find " + executable); + } + + var extensions = isWindows() + ? System.getenv().getOrDefault("PATHEXT", ".COM;.EXE;.BAT;.CMD").split(";") + : new String[]{""}; + for (var directory : path.split(java.util.regex.Pattern.quote(File.pathSeparator))) { + if (directory.isBlank()) { + continue; + } + for (var extension : extensions) { + var candidate = Path.of(directory).resolve(executable + extension).toAbsolutePath().normalize(); + if (Files.isRegularFile(candidate) && Files.isExecutable(candidate)) { + return candidate.toString(); + } + } + } + throw new IllegalStateException("Could not find " + executable + " on PATH."); + } + + private static boolean isWindows() { + return System.getProperty("os.name", "").toLowerCase(java.util.Locale.ROOT).contains("win"); + } + + @Override + public void close() { + process.destroyForcibly(); + } + } +} diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 53686a6ca3..613985103f 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -1326,7 +1326,8 @@ export class CopilotClient { sessionId, this.connection!, undefined, - this.onGetTraceContext + this.onGetTraceContext, + { mcpAuthHandler: config.onMcpAuthRequest } ); s.registerTools(config.tools); s.registerCanvases(config.canvases); @@ -1473,6 +1474,12 @@ export class CopilotClient { session = initializeSession(returnedSessionId); registeredId = returnedSessionId; } + if (config.onMcpAuthRequest) { + await this.connection!.sendRequest("session.eventLog.registerInterest", { + sessionId: returnedSessionId, + eventType: "mcp.oauth_required", + }); + } session["_workspacePath"] = workspacePath; session.setCapabilities(capabilities); @@ -1522,7 +1529,8 @@ export class CopilotClient { sessionId, this.connection!, undefined, - this.onGetTraceContext + this.onGetTraceContext, + { mcpAuthHandler: config.onMcpAuthRequest } ); session.registerTools(config.tools); session.registerCanvases(config.canvases); @@ -1567,6 +1575,12 @@ export class CopilotClient { } this.sessions.set(sessionId, session); this.setupSessionFs(session, config); + if (config.onMcpAuthRequest) { + await this.connection!.sendRequest("session.eventLog.registerInterest", { + sessionId, + eventType: "mcp.oauth_required", + }); + } const toolFilterOptions = this.resolveToolFilterOptions(config); diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 8bf9589c39..8d8fc6714f 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -10,7 +10,11 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; import { ConnectionError, ErrorCodes, ResponseError } from "vscode-jsonrpc/node.js"; import { createSessionRpc } from "./generated/rpc.js"; -import type { ClientSessionApiHandlers, CanvasActionInvokeResult } from "./generated/rpc.js"; +import type { + ClientSessionApiHandlers, + CanvasActionInvokeResult, + McpOauthPendingRequestResponse, +} from "./generated/rpc.js"; import { type Canvas, CanvasError } from "./canvas.js"; import type { OpenCanvasInstance } from "./generated/rpc.js"; import { getTraceContext } from "./telemetry.js"; @@ -29,6 +33,8 @@ import type { BearerTokenProvider, UiInputOptions, MessageOptions, + McpAuthHandler, + McpAuthRequest, PermissionHandler, PermissionRequest, ContextTier, @@ -124,6 +130,7 @@ export class CopilotSession { private bearerTokenProviders: Map = new Map(); private commandHandlers: Map = new Map(); private permissionHandler?: PermissionHandler; + private mcpAuthHandler?: McpAuthHandler; private userInputHandler?: UserInputHandler; private elicitationHandler?: ElicitationHandler; private exitPlanModeHandler?: ExitPlanModeHandler; @@ -152,9 +159,11 @@ export class CopilotSession { public readonly sessionId: string, private connection: MessageConnection, private _workspacePath?: string, - traceContextProvider?: TraceContextProvider + traceContextProvider?: TraceContextProvider, + options?: { mcpAuthHandler?: McpAuthHandler } ) { this.traceContextProvider = traceContextProvider; + this.mcpAuthHandler = options?.mcpAuthHandler; } /** @@ -499,6 +508,19 @@ export class CopilotSession { if (this.permissionHandler) { void this._executePermissionAndRespond(requestId, permissionRequest); } + } else if (event.type === "mcp.oauth_required") { + const data = event.data as McpAuthRequest | undefined; + if (!data?.requestId) { + return; + } + if (!this.mcpAuthHandler) { + console.warn( + "Received MCP OAuth request without a registered MCP auth handler. " + + `SessionId=${this.sessionId}, RequestId=${data.requestId}` + ); + return; + } + void this._executeMcpAuthAndRespond(data); } else if (event.type === "command.execute") { const { requestId, commandName, command, args } = event.data as { requestId: string; @@ -661,6 +683,35 @@ export class CopilotSession { } } + /** + * Executes an MCP auth handler and sends the result back via RPC. + * @internal + */ + private async _executeMcpAuthAndRespond(request: McpAuthRequest): Promise { + try { + const result = await this.mcpAuthHandler!(request, { sessionId: this.sessionId }); + const response: McpOauthPendingRequestResponse = + result && "accessToken" in result + ? { kind: "token", ...result } + : { kind: "cancelled" }; + await this.rpc.mcp.oauth.handlePendingRequest({ + requestId: request.requestId, + result: response, + }); + } catch (_error) { + try { + await this.rpc.mcp.oauth.handlePendingRequest({ + requestId: request.requestId, + result: { kind: "cancelled" }, + }); + } catch (rpcError) { + if (!(rpcError instanceof ConnectionError || rpcError instanceof ResponseError)) { + throw rpcError; + } + } + } + } + /** * Executes a command handler and sends the result back via RPC. * @internal diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index e354bd8218..4adb35b25a 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -1615,6 +1615,76 @@ export type ReasoningEffort = "low" | "medium" | "high" | "xhigh"; */ export type ContextTier = "default" | "long_context"; +/** Parsed parameters from an MCP server's WWW-Authenticate response. */ +export interface McpAuthWwwAuthenticateParams { + /** Parsed resource_metadata URL used for protected-resource metadata discovery, if present. */ + resourceMetadataUrl?: string; + /** Parsed OAuth scope, if present. */ + scope?: string; + /** Parsed OAuth error, if present. */ + error?: string; +} + +/** Static OAuth client configuration supplied by the MCP server, if available. */ +export interface McpAuthStaticClientConfig { + /** OAuth client ID for the server. */ + clientId: string; + /** Optional OAuth client secret for confidential static clients. */ + clientSecret?: string; + /** Optional non-default OAuth grant type. */ + grantType?: "client_credentials"; + /** Whether this is a public OAuth client. */ + publicClient?: boolean; +} + +/** MCP OAuth request that the SDK host can satisfy with a host-acquired token. */ +export interface McpAuthRequest { + /** Unique request identifier used by the SDK when responding. */ + requestId: string; + /** Display name of the MCP server that requires OAuth. */ + serverName: string; + /** URL of the MCP server that requires OAuth. */ + serverUrl: string; + /** Why the runtime is requesting host-provided OAuth credentials. */ + reason: "initial" | "refresh" | "reauth" | "upscope"; + /** Parsed WWW-Authenticate parameters from the MCP server. */ + wwwAuthenticateParams?: McpAuthWwwAuthenticateParams; + /** Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. */ + resourceMetadata?: string; + /** Static OAuth client configuration, if the server specifies one. */ + staticClientConfig?: McpAuthStaticClientConfig; +} + +/** Host-provided OAuth token data for a pending MCP OAuth request. */ +export interface McpAuthToken { + /** Access token acquired by the SDK host. */ + accessToken: string; + /** OAuth token type. Defaults to Bearer when omitted. */ + tokenType?: string; + /** Token lifetime in seconds, if known. */ + expiresIn?: number; +} + +/** + * Result returned by an MCP auth request handler. + * + * Return `null`/`undefined` or `{ kind: "cancelled" }` to cancel the pending + * OAuth request. Return `{ kind: "token", ... }` to provide host-acquired + * OAuth token data. + */ +export type McpAuthResult = ({ kind: "token" } & McpAuthToken) | { kind: "cancelled" }; + +/** Callback invoked when an MCP server requires OAuth and the SDK host opted in. */ +export type McpAuthHandler = ( + request: McpAuthRequest, + context: { sessionId: string } +) => + | McpAuthResult + | McpAuthToken + | null + | undefined + | Promise; + /** * Stable extension identity for session participants that provide canvases. */ @@ -1898,6 +1968,13 @@ export interface SessionConfigBase { */ onPermissionRequest?: PermissionHandler; + /** + * Optional handler for MCP OAuth requests from MCP servers. + * When provided, the SDK can satisfy MCP server OAuth requests with + * host-provided token data or cancellation. + */ + onMcpAuthRequest?: McpAuthHandler; + /** * Handler for user input requests from the agent. * When provided, enables the ask_user tool allowing the agent to ask questions. diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index 96d7da30cf..07cd079df6 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -41,6 +41,228 @@ describe("CopilotClient", () => { expect(spy).not.toHaveBeenCalled(); }); + it("responds to MCP OAuth requests with host token data", async () => { + const sendRequest = vi.fn(async () => ({ success: true })); + let observedRequest: any; + const session = new CopilotSession( + "session-1", + { sendRequest } as any, + undefined, + undefined, + { + mcpAuthHandler: async (request) => { + observedRequest = request; + return { + accessToken: "host-token", + tokenType: "Bearer", + expiresIn: 3600, + }; + }, + } + ); + + await (session as any)._executeMcpAuthAndRespond({ + requestId: "oauth-request", + serverName: "oauth-server", + serverUrl: "https://example.com/mcp", + reason: "initial", + wwwAuthenticateParams: { + resourceMetadataUrl: "https://example.com/.well-known/oauth-protected-resource", + }, + resourceMetadata: '{"resource":"https://example.com/mcp"}', + staticClientConfig: { + clientId: "static-client", + clientSecret: "static-secret", + grantType: "client_credentials", + publicClient: false, + }, + }); + + expect(observedRequest.resourceMetadata).toBe('{"resource":"https://example.com/mcp"}'); + expect(observedRequest.staticClientConfig).toEqual({ + clientId: "static-client", + clientSecret: "static-secret", + grantType: "client_credentials", + publicClient: false, + }); + expect(sendRequest).toHaveBeenCalledWith("session.mcp.oauth.handlePendingRequest", { + sessionId: "session-1", + requestId: "oauth-request", + result: { + kind: "token", + accessToken: "host-token", + tokenType: "Bearer", + expiresIn: 3600, + }, + }); + }); + + it("passes MCP OAuth requests through when optional metadata is absent", async () => { + let observedRequest: any; + const session = new CopilotSession( + "session-1", + { sendRequest: vi.fn(async () => ({ success: true })) } as any, + undefined, + undefined, + { + mcpAuthHandler: async (request) => { + observedRequest = request; + return { kind: "cancelled" }; + }, + } + ); + + await (session as any)._executeMcpAuthAndRespond({ + requestId: "oauth-request", + serverName: "oauth-server", + serverUrl: "https://example.com/mcp", + reason: "initial", + }); + + expect(observedRequest.reason).toBe("initial"); + expect(observedRequest.resourceMetadata).toBeUndefined(); + expect(observedRequest.wwwAuthenticateParams).toBeUndefined(); + }); + + it("registers interest in MCP OAuth required events after create when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.create") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + }); + + expect(spy.mock.calls[0][0]).toBe("session.create"); + expect(spy.mock.calls[1]).toEqual([ + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }), + ]); + expect(spy.mock.calls[1][1].sessionId).toBe(spy.mock.calls[0][1].sessionId); + }); + + it("does not register MCP OAuth interest without an auth handler", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.create") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + onEvent: () => {}, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + expect(spy).toHaveBeenCalledWith( + "session.create", + expect.objectContaining({ requestPermission: true }) + ); + }); + + it("registers MCP OAuth interest after cloud create only when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + let cloudCreateCount = 0; + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, _params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.create") + return { sessionId: `server-assigned-session-${++cloudCreateCount}` }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.createSession({ + onPermissionRequest: approveAll, + cloud: { repository: { owner: "github", name: "copilot-sdk", branch: "main" } }, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + + spy.mockClear(); + await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + cloud: { repository: { owner: "github", name: "copilot-sdk", branch: "main" } }, + }); + + expect(spy.mock.calls[0][0]).toBe("session.create"); + expect(spy.mock.calls[1]).toEqual([ + "session.eventLog.registerInterest", + { sessionId: "server-assigned-session-2", eventType: "mcp.oauth_required" }, + ]); + }); + + it("registers MCP OAuth interest before resuming only when an auth handler is configured", async () => { + const client = new CopilotClient(); + await client.start(); + onTestFinished(() => client.forceStop()); + + const spy = vi + .spyOn((client as any).connection!, "sendRequest") + .mockImplementation(async (method: string, params: any) => { + if (method === "session.eventLog.registerInterest") { + return { id: "interest-1" }; + } + if (method === "session.resume") return { sessionId: params.sessionId }; + throw new Error(`Unexpected method: ${method}`); + }); + + await client.resumeSession("session-with-auth", { + onPermissionRequest: approveAll, + onMcpAuthRequest: () => ({ kind: "cancelled" }), + }); + + expect(spy.mock.calls[0]).toEqual([ + "session.eventLog.registerInterest", + { sessionId: "session-with-auth", eventType: "mcp.oauth_required" }, + ]); + expect(spy.mock.calls[1][0]).toBe("session.resume"); + expect(spy.mock.calls[1][1]).toEqual(expect.objectContaining({ requestPermission: true })); + + spy.mockClear(); + await client.resumeSession("session-without-auth", { + onPermissionRequest: approveAll, + onEvent: () => {}, + }); + + expect(spy).not.toHaveBeenCalledWith( + "session.eventLog.registerInterest", + expect.objectContaining({ eventType: "mcp.oauth_required" }) + ); + expect(spy).toHaveBeenCalledWith( + "session.resume", + expect.objectContaining({ sessionId: "session-without-auth", requestPermission: true }) + ); + }); + it("forwards canvas declarations and request flags in session.create", async () => { const client = new CopilotClient(); await client.start(); diff --git a/nodejs/test/e2e/harness/sdkTestContext.ts b/nodejs/test/e2e/harness/sdkTestContext.ts index cd6494cad3..a59f62126d 100644 --- a/nodejs/test/e2e/harness/sdkTestContext.ts +++ b/nodejs/test/e2e/harness/sdkTestContext.ts @@ -20,6 +20,13 @@ const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); const SNAPSHOTS_DIR = resolve(__dirname, "../../../../test/snapshots"); +function getCliPathForTests(): string | undefined { + if (process.env.COPILOT_CLI_PATH) { + return process.env.COPILOT_CLI_PATH; + } + return undefined; +} + export async function createSdkTestContext({ logLevel, useStdio, @@ -39,6 +46,7 @@ export async function createSdkTestContext({ await openAiEndpoint.setCopilotUserByToken(DEFAULT_GITHUB_TOKEN, { login: "e2e-test-user", copilot_plan: "individual_pro", + is_mcp_enabled: true, endpoints: { api: proxyUrl, telemetry: "https://localhost:1/telemetry", @@ -72,6 +80,7 @@ export async function createSdkTestContext({ }; const userConn = copilotClientOptions?.connection; + const cliPath = getCliPathForTests(); let connection: RuntimeConnection; if (userConn) { // Caller supplied a RuntimeConnection — merge in the harness-managed @@ -82,13 +91,13 @@ export async function createSdkTestContext({ const { kind: _k, ...tcp } = userConn; connection = RuntimeConnection.forTcp({ ...tcp, - path: tcp.path ?? process.env.COPILOT_CLI_PATH, + path: tcp.path ?? cliPath, }); } else if (userConn.kind === "stdio") { const { kind: _k, ...stdio } = userConn; connection = RuntimeConnection.forStdio({ ...stdio, - path: stdio.path ?? process.env.COPILOT_CLI_PATH, + path: stdio.path ?? cliPath, }); } else { connection = userConn; @@ -96,15 +105,18 @@ export async function createSdkTestContext({ } else { connection = useStdio === false - ? RuntimeConnection.forTcp({ path: process.env.COPILOT_CLI_PATH }) - : RuntimeConnection.forStdio({ path: process.env.COPILOT_CLI_PATH }); + ? RuntimeConnection.forTcp({ path: cliPath }) + : RuntimeConnection.forStdio({ path: cliPath }); } - const { connection: _ignoredConnection, ...remainingClientOptions } = - copilotClientOptions ?? {}; + const { + connection: _ignoredConnection, + env: userEnv, + ...remainingClientOptions + } = copilotClientOptions ?? {}; const copilotClient = new CopilotClient({ workingDirectory: workDir, - env, + env: { ...env, ...userEnv }, logLevel: logLevel || "error", connection, gitHubToken: authTokenToUse, diff --git a/nodejs/test/e2e/mcp_oauth.e2e.test.ts b/nodejs/test/e2e/mcp_oauth.e2e.test.ts new file mode 100644 index 0000000000..29ed089edb --- /dev/null +++ b/nodejs/test/e2e/mcp_oauth.e2e.test.ts @@ -0,0 +1,311 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { spawn, type ChildProcessWithoutNullStreams } from "node:child_process"; +import { dirname, resolve } from "node:path"; +import { createInterface } from "node:readline"; +import { fileURLToPath } from "node:url"; +import { describe, expect, it, onTestFinished } from "vitest"; +import type { CopilotSession, MCPServerConfig, McpAuthRequest } from "../../src/index.js"; +import { approveAll } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; +import { waitForCondition } from "./harness/sdkTestHelper.js"; + +const __filename = fileURLToPath(import.meta.url); +const __dirname = dirname(__filename); +const TEST_MCP_OAUTH_SERVER = resolve(__dirname, "../../../test/harness/test-mcp-oauth-server.mjs"); +const EXPECTED_TOKEN = "sdk-host-token"; +const REFRESH_TOKEN = `${EXPECTED_TOKEN}-refresh`; +const UPSCOPE_TOKEN = `${EXPECTED_TOKEN}-upscope`; +const REAUTH_TOKEN = `${EXPECTED_TOKEN}-reauth`; + +describe("MCP OAuth host auth", async () => { + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + env: { + COPILOT_MCP_APPS: "true", + MCP_APPS: "true", + }, + }, + }); + + it("should satisfy MCP OAuth using host-provided token", { timeout: 120_000 }, async () => { + const oauthServer = await startOAuthMcpServer(); + const serverName = "oauth-protected-mcp"; + let authRequest: McpAuthRequest | undefined; + + const session = await client.createSession({ + onPermissionRequest: approveAll, + enableMcpApps: true, + onMcpAuthRequest: async (request) => { + authRequest = request; + return { + kind: "token", + accessToken: EXPECTED_TOKEN, + tokenType: "Bearer", + expiresIn: 3600, + }; + }, + mcpServers: { + [serverName]: { + type: "http", + url: `${oauthServer.url}/mcp`, + tools: ["*"], + oauthClientId: "sdk-e2e-client", + oauthPublicClient: true, + } as unknown as MCPServerConfig, + }, + }); + onTestFinished(() => disconnectSession(session)); + + await waitForMcpServerStatus(session, serverName); + + const tools = await session.rpc.mcp.listTools({ serverName }); + expect(tools.tools.map((tool) => tool.name)).toContain("whoami"); + + expect(authRequest).toMatchObject({ + requestId: expect.any(String), + serverName, + serverUrl: `${oauthServer.url}/mcp`, + reason: "initial", + wwwAuthenticateParams: { + resourceMetadataUrl: `${oauthServer.url}/.well-known/oauth-protected-resource`, + scope: "mcp.read", + error: "invalid_token", + }, + resourceMetadata: JSON.stringify({ + resource: `${oauthServer.url}/mcp`, + authorization_servers: [oauthServer.url], + scopes_supported: ["mcp.read"], + bearer_methods_supported: ["header"], + }), + }); + + const requests = await oauthServer.requests(); + expect(requests.some((request) => request.authorization === null)).toBe(true); + expect( + requests.some((request) => request.authorization === `Bearer ${EXPECTED_TOKEN}`) + ).toBe(true); + }); + + it( + "should request host-owned replacement tokens across the MCP OAuth lifecycle", + { timeout: 120_000 }, + async () => { + const oauthServer = await startOAuthMcpServer(); + const serverName = "oauth-lifecycle-mcp"; + const authRequests: McpAuthRequest[] = []; + let refreshCount = 0; + + const session = await client.createSession({ + onPermissionRequest: approveAll, + enableMcpApps: true, + onMcpAuthRequest: async (request) => { + authRequests.push(request); + switch (request.reason) { + case "initial": + return { kind: "token", accessToken: EXPECTED_TOKEN }; + case "refresh": + refreshCount++; + if (refreshCount === 1) { + return { kind: "token", accessToken: REFRESH_TOKEN }; + } + return { kind: "cancelled" }; + case "upscope": + return { kind: "token", accessToken: UPSCOPE_TOKEN }; + case "reauth": + return { kind: "token", accessToken: REAUTH_TOKEN }; + } + }, + mcpServers: { + [serverName]: { + type: "http", + url: `${oauthServer.url}/mcp`, + tools: ["*"], + oauthClientId: "sdk-e2e-client", + oauthPublicClient: true, + } as unknown as MCPServerConfig, + }, + }); + onTestFinished(() => disconnectSession(session)); + + await waitForMcpServerStatus(session, serverName); + await callWhoami(session, serverName, "refresh"); + await callWhoami(session, serverName, "upscope"); + await callWhoami(session, serverName, "reauth"); + + expect(authRequests.map((request) => request.reason)).toEqual([ + "initial", + "refresh", + "upscope", + "refresh", + "reauth", + ]); + + const upscopeRequest = authRequests.find((request) => request.reason === "upscope"); + expect(upscopeRequest?.wwwAuthenticateParams).toEqual({ + resourceMetadataUrl: `${oauthServer.url}/.well-known/oauth-protected-resource`, + scope: "mcp.write", + error: "insufficient_scope", + }); + expect(upscopeRequest?.resourceMetadata).toBe( + JSON.stringify({ + resource: `${oauthServer.url}/mcp`, + authorization_servers: [oauthServer.url], + scopes_supported: ["mcp.read"], + bearer_methods_supported: ["header"], + }) + ); + + const requests = await oauthServer.requests(); + for (const token of [EXPECTED_TOKEN, REFRESH_TOKEN, UPSCOPE_TOKEN, REAUTH_TOKEN]) { + expect( + requests.some((request) => request.authorization === `Bearer ${token}`) + ).toBe(true); + } + } + ); + + it( + "should cancel pending MCP OAuth requests when the host declines", + { timeout: 120_000 }, + async () => { + const oauthServer = await startOAuthMcpServer(); + const serverName = "oauth-cancelled-mcp"; + let authRequest: McpAuthRequest | undefined; + + const session = await client.createSession({ + onPermissionRequest: approveAll, + onMcpAuthRequest: async (request) => { + authRequest = request; + return { kind: "cancelled" }; + }, + mcpServers: { + [serverName]: { + type: "http", + url: `${oauthServer.url}/mcp`, + tools: ["*"], + oauthClientId: "sdk-e2e-client", + oauthPublicClient: true, + } as unknown as MCPServerConfig, + }, + }); + onTestFinished(() => disconnectSession(session)); + + await waitForMcpServerStatus(session, serverName, "failed"); + + expect(authRequest).toMatchObject({ + serverName, + reason: "initial", + }); + } + ); +}); + +async function waitForMcpServerStatus( + session: CopilotSession, + serverName: string, + expectedStatus = "connected" +): Promise { + let lastStatus = ""; + await waitForCondition( + async () => { + const result = await session.rpc.mcp.list(); + const server = result.servers.find((entry) => entry.name === serverName); + lastStatus = server?.status ?? ""; + return server?.status === expectedStatus; + }, + { + timeoutMs: 60_000, + intervalMs: 200, + timeoutMessage: `${serverName} did not reach ${expectedStatus}; last status was ${lastStatus}`, + } + ); +} + +async function callWhoami( + session: CopilotSession, + serverName: string, + scenario: "refresh" | "upscope" | "reauth" +): Promise { + const result = await session.rpc.mcp.apps.callTool({ + serverName, + originServerName: serverName, + toolName: "whoami", + arguments: { scenario }, + }); + expect(result.content).toEqual([{ type: "text", text: "oauth-test-user" }]); +} + +async function startOAuthMcpServer(): Promise<{ + url: string; + requests: () => Promise>; +}> { + const child = spawn(process.execPath, [TEST_MCP_OAUTH_SERVER], { + env: { ...process.env, EXPECTED_TOKEN }, + stdio: ["ignore", "pipe", "pipe"], + }); + onTestFinished(() => stopChild(child)); + + const stderr: string[] = []; + child.stderr.on("data", (chunk) => stderr.push(String(chunk))); + + const url = await new Promise((resolvePromise, reject) => { + const rl = createInterface({ input: child.stdout }); + const timeout = setTimeout(() => { + rl.close(); + reject(new Error(`Timed out waiting for OAuth MCP server. ${stderr.join("")}`)); + }, 10_000); + + child.once("exit", (code, signal) => { + clearTimeout(timeout); + rl.close(); + reject( + new Error( + `OAuth MCP server exited before listening. code=${code} signal=${signal} ${stderr.join("")}` + ) + ); + }); + + rl.on("line", (line) => { + const match = /^Listening: (.+)$/.exec(line); + if (!match) { + return; + } + clearTimeout(timeout); + rl.close(); + resolvePromise(match[1]); + }); + }); + + return { + url, + requests: async () => { + const response = await fetch(`${url}/__requests`); + if (!response.ok) { + throw new Error(`Failed to fetch OAuth MCP requests: ${response.status}`); + } + return response.json(); + }, + }; +} + +async function disconnectSession(session: CopilotSession): Promise { + try { + await session.disconnect(); + } catch { + // Best-effort cleanup. + } +} + +function stopChild(child: ChildProcessWithoutNullStreams): Promise { + if (child.exitCode !== null || child.killed) { + return Promise.resolve(); + } + const exitPromise = new Promise((resolvePromise) => { + child.once("exit", () => resolvePromise()); + }); + child.kill("SIGTERM"); + return exitPromise; +} diff --git a/nodejs/test/e2e/provider_endpoint.e2e.test.ts b/nodejs/test/e2e/provider_endpoint.e2e.test.ts index 1bac76253b..8acf6a2469 100644 --- a/nodejs/test/e2e/provider_endpoint.e2e.test.ts +++ b/nodejs/test/e2e/provider_endpoint.e2e.test.ts @@ -7,12 +7,12 @@ import { approveAll } from "../../src/index.js"; import { createSdkTestContext } from "./harness/sdkTestContext.js"; describe("session.provider.getEndpoint RPC", async () => { - const { copilotClient: client, env } = await createSdkTestContext(); - - // The provider endpoint API is gated behind an opt-in env var; the harness - // env object is the same one passed to the CLI subprocess, so mutating it - // here enables the API for this test file's client. - env.COPILOT_ALLOW_GET_PROVIDER_ENDPOINT = "true"; + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + // The provider endpoint API is gated behind an opt-in env var. + env: { COPILOT_ALLOW_GET_PROVIDER_ENDPOINT: "true" }, + }, + }); it("returns the BYOK provider endpoint when a custom provider is configured", async () => { const session = await client.createSession({ diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index ff13d47de3..51be3727ac 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -104,6 +104,13 @@ InfiniteSessionConfig, InputOptions, LargeToolOutputConfig, + McpAuthContext, + McpAuthHandler, + McpAuthRequest, + McpAuthResult, + McpAuthStaticClientConfig, + McpAuthToken, + McpAuthWwwAuthenticateParams, MCPHTTPServerConfig, MCPServerConfig, MCPStdioServerConfig, @@ -226,6 +233,13 @@ "MCPHTTPServerConfig", "MCPServerConfig", "MCPStdioServerConfig", + "McpAuthContext", + "McpAuthHandler", + "McpAuthRequest", + "McpAuthResult", + "McpAuthStaticClientConfig", + "McpAuthToken", + "McpAuthWwwAuthenticateParams", "ModelBilling", "ModelBillingTokenPrices", "ModelBillingTokenPricesLongContext", diff --git a/python/copilot/client.py b/python/copilot/client.py index c7d11d12b1..7dade44403 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -92,6 +92,7 @@ ExitPlanModeHandler, InfiniteSessionConfig, LargeToolOutputConfig, + McpAuthHandler, MCPServerConfig, MemoryConfiguration, ModelCapabilitiesOverride, @@ -1697,6 +1698,7 @@ async def create_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + on_mcp_auth_request: McpAuthHandler | None = None, enable_mcp_apps: bool = False, on_exit_plan_mode_request: ExitPlanModeHandler | None = None, on_auto_mode_switch_request: AutoModeSwitchHandler | None = None, @@ -2149,6 +2151,7 @@ def _initialize_session(sid: str) -> CopilotSession: s._register_tools(tools) s._register_commands(commands) s._register_permission_handler(on_permission_request) + s._register_mcp_auth_handler(on_mcp_auth_request) if on_user_input_request: s._register_user_input_handler(on_user_input_request) if on_elicitation_request: @@ -2229,6 +2232,11 @@ def _register_inline(raw_response: Any) -> None: f"session.create returned sessionId {response.get('sessionId')} " f"but the caller requested {local_session_id}" ) + if on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": session.session_id, "eventType": "mcp.oauth_required"}, + ) session._workspace_path = response.get("workspacePath") capabilities = response.get("capabilities") session._set_capabilities(capabilities) @@ -2319,6 +2327,7 @@ async def resume_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + on_mcp_auth_request: McpAuthHandler | None = None, enable_mcp_apps: bool = False, on_exit_plan_mode_request: ExitPlanModeHandler | None = None, on_auto_mode_switch_request: AutoModeSwitchHandler | None = None, @@ -2723,6 +2732,7 @@ async def resume_session( session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) + session._register_mcp_auth_handler(on_mcp_auth_request) if on_user_input_request: session._register_user_input_handler(on_user_input_request) if on_elicitation_request: @@ -2744,6 +2754,11 @@ async def resume_session( session.on(on_event) with self._sessions_lock: self._sessions[session_id] = session + if on_mcp_auth_request is not None: + await self._client.request( + "session.eventLog.registerInterest", + {"sessionId": session_id, "eventType": "mcp.oauth_required"}, + ) log_timing( logger, logging.DEBUG, diff --git a/python/copilot/session.py b/python/copilot/session.py index 0dc569f258..bf34a73340 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -39,6 +39,9 @@ ExternalToolTextResultForLlm, HandlePendingToolCallRequest, LogRequest, + MCPOauthHandlePendingRequest, + MCPOauthPendingRequestResponse, + MCPOauthPendingRequestResponseKind, ModelSwitchToRequest, PermissionDecision, PermissionDecisionApproveOnce, @@ -67,6 +70,7 @@ CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, + McpOauthRequiredData, PermissionRequest, PermissionRequestedData, SessionCanvasClosedData, @@ -367,6 +371,72 @@ def approve_all( return PermissionDecisionApproveOnce() +# ============================================================================ +# MCP Auth Types +# ============================================================================ + + +class McpAuthWwwAuthenticateParams(TypedDict, total=False): + """Parsed parameters from an MCP server's WWW-Authenticate response.""" + + resourceMetadataUrl: str + scope: str + error: str + + +class McpAuthStaticClientConfig(TypedDict, total=False): + """Static OAuth client configuration supplied by the MCP server, if available.""" + + clientId: Required[str] + clientSecret: str + grantType: Literal["client_credentials"] + publicClient: bool + + +class McpAuthRequest(TypedDict, total=False): + """MCP OAuth request that the SDK host can satisfy with a host-acquired token.""" + + requestId: Required[str] + serverName: Required[str] + serverUrl: Required[str] + reason: Required[Literal["initial", "refresh", "reauth", "upscope"]] + wwwAuthenticateParams: McpAuthWwwAuthenticateParams + resourceMetadata: str + staticClientConfig: McpAuthStaticClientConfig + + +class McpAuthToken(TypedDict, total=False): + """Host-provided OAuth token data for a pending MCP OAuth request.""" + + accessToken: Required[str] + tokenType: str + expiresIn: int + + +class McpAuthResult(TypedDict, total=False): + """Result returned by an MCP auth request handler.""" + + kind: Required[Literal["token", "cancelled"]] + accessToken: str + tokenType: str + expiresIn: int + + +class McpAuthContext(TypedDict): + """Context for an MCP auth request handler invocation.""" + + sessionId: str + + +McpAuthHandlerResult = McpAuthResult | McpAuthToken | None + + +McpAuthHandler = Callable[ + [McpAuthRequest, McpAuthContext], + McpAuthHandlerResult | Awaitable[McpAuthHandlerResult], +] + + # ============================================================================ # User Input Request Types # ============================================================================ @@ -1340,6 +1410,8 @@ def __init__( self._tool_handlers_lock = threading.Lock() self._permission_handler: _PermissionHandlerFn | None = None self._permission_handler_lock = threading.Lock() + self._mcp_auth_handler: McpAuthHandler | None = None + self._mcp_auth_handler_lock = threading.Lock() self._user_input_handler: UserInputHandler | None = None self._user_input_handler_lock = threading.Lock() self._exit_plan_mode_handler: ExitPlanModeHandler | None = None @@ -1729,6 +1801,58 @@ def _handle_broadcast_event(self, event: SessionEvent) -> None: ) ) + case McpOauthRequiredData() as data: + with self._mcp_auth_handler_lock: + handler = self._mcp_auth_handler + if not data.request_id: + return + if not handler: + logger.warning( + "Received MCP OAuth request without a registered MCP auth handler. " + "SessionId=%s, RequestId=%s", + self.session_id, + data.request_id, + ) + return + request: McpAuthRequest = { + "requestId": data.request_id, + "serverName": data.server_name, + "serverUrl": data.server_url, + "reason": data.reason.value, + } + if data.www_authenticate_params is not None: + request["wwwAuthenticateParams"] = {} + if data.www_authenticate_params.resource_metadata_url is not None: + request["wwwAuthenticateParams"]["resourceMetadataUrl"] = ( + data.www_authenticate_params.resource_metadata_url + ) + if data.www_authenticate_params.scope is not None: + request["wwwAuthenticateParams"]["scope"] = ( + data.www_authenticate_params.scope + ) + if data.www_authenticate_params.error is not None: + request["wwwAuthenticateParams"]["error"] = ( + data.www_authenticate_params.error + ) + if data.resource_metadata is not None: + request["resourceMetadata"] = data.resource_metadata + if data.static_client_config is not None: + static_client_config: McpAuthStaticClientConfig = { + "clientId": data.static_client_config.client_id, + } + if data.static_client_config.client_secret is not None: + static_client_config["clientSecret"] = ( + data.static_client_config.client_secret + ) + if data.static_client_config.grant_type is not None: + static_client_config["grantType"] = data.static_client_config.grant_type + if data.static_client_config.public_client is not None: + static_client_config["publicClient"] = ( + data.static_client_config.public_client + ) + request["staticClientConfig"] = static_client_config + asyncio.ensure_future(self._execute_mcp_auth_and_respond(request, handler)) + case CommandExecuteData() as data: request_id = data.request_id command_name = data.command_name @@ -1942,6 +2066,59 @@ async def _execute_permission_and_respond( except (JsonRpcError, ProcessExitedError, OSError): pass # Connection lost or RPC error — nothing we can do + async def _execute_mcp_auth_and_respond( + self, + request: McpAuthRequest, + handler: McpAuthHandler, + ) -> None: + """Execute an MCP auth handler and respond via RPC.""" + request_id = request["requestId"] + try: + handler_start = time.perf_counter() + maybe_result = handler(request, {"sessionId": self.session_id}) + if inspect.isawaitable(maybe_result): + result = cast(McpAuthHandlerResult, await maybe_result) + else: + result = maybe_result + log_timing( + logger, + logging.DEBUG, + "CopilotSession._execute_mcp_auth_and_respond dispatch", + handler_start, + session_id=self.session_id, + request_id=request_id, + ) + + if result and result.get("kind", "token") == "token": + rpc_result = MCPOauthPendingRequestResponse( + kind=MCPOauthPendingRequestResponseKind.TOKEN, + access_token=result["accessToken"], + expires_in=result.get("expiresIn"), + token_type=result.get("tokenType"), + ) + else: + rpc_result = MCPOauthPendingRequestResponse( + kind=MCPOauthPendingRequestResponseKind.CANCELLED + ) + await self.rpc.mcp.oauth.handle_pending_request( + MCPOauthHandlePendingRequest( + request_id=request_id, + result=rpc_result, + ) + ) + except Exception: + try: + await self.rpc.mcp.oauth.handle_pending_request( + MCPOauthHandlePendingRequest( + request_id=request_id, + result=MCPOauthPendingRequestResponse( + kind=MCPOauthPendingRequestResponseKind.CANCELLED + ), + ) + ) + except (JsonRpcError, ProcessExitedError, OSError): + pass # Connection lost or RPC error — nothing we can do + async def _execute_command_and_respond( self, request_id: str, @@ -2126,6 +2303,11 @@ def _register_elicitation_handler(self, handler: ElicitationHandler | None) -> N with self._elicitation_handler_lock: self._elicitation_handler = handler + def _register_mcp_auth_handler(self, handler: McpAuthHandler | None) -> None: + """Register the MCP auth handler for this session.""" + with self._mcp_auth_handler_lock: + self._mcp_auth_handler = handler + def _register_exit_plan_mode_handler(self, handler: ExitPlanModeHandler | None) -> None: """Register the exit-plan-mode handler for this session.""" with self._exit_plan_mode_handler_lock: diff --git a/python/e2e/test_mcp_oauth_e2e.py b/python/e2e/test_mcp_oauth_e2e.py new file mode 100644 index 0000000000..47897f69c0 --- /dev/null +++ b/python/e2e/test_mcp_oauth_e2e.py @@ -0,0 +1,258 @@ +import asyncio +import json +import os +from pathlib import Path +from typing import Any + +import httpx +import pytest + +from copilot.generated.rpc import MCPAppsCallToolRequest, MCPListToolsRequest +from copilot.session import MCPServerConfig, PermissionHandler +from copilot.session_events import McpServerStatus + +from .testharness import E2ETestContext, wait_for_condition + +TEST_MCP_OAUTH_SERVER = str( + (Path(__file__).parents[2] / "test" / "harness" / "test-mcp-oauth-server.mjs").resolve() +) +EXPECTED_TOKEN = "sdk-host-token" +REFRESH_TOKEN = f"{EXPECTED_TOKEN}-refresh" +UPSCOPE_TOKEN = f"{EXPECTED_TOKEN}-upscope" +REAUTH_TOKEN = f"{EXPECTED_TOKEN}-reauth" + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +async def _start_oauth_mcp_server() -> tuple[str, asyncio.subprocess.Process]: + process = await asyncio.create_subprocess_exec( + "node", + TEST_MCP_OAUTH_SERVER, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env={**os.environ, "EXPECTED_TOKEN": EXPECTED_TOKEN}, + ) + assert process.stdout is not None + + try: + line = await asyncio.wait_for(process.stdout.readline(), timeout=10) + except TimeoutError as exc: + await _stop_process(process) + assert process.stderr is not None + stderr = (await process.stderr.read()).decode(errors="replace") + raise TimeoutError(f"Timed out waiting for OAuth MCP server: {stderr}") from exc + if not line: + assert process.stderr is not None + stderr = (await process.stderr.read()).decode(errors="replace") + raise RuntimeError(f"OAuth MCP server exited before listening: {stderr}") + text = line.decode().strip() + if text.startswith("Listening: "): + return text.removeprefix("Listening: "), process + + await _stop_process(process) + raise RuntimeError(f"Unexpected OAuth MCP server startup line: {text}") + + +async def _stop_process(process: asyncio.subprocess.Process) -> None: + if process.returncode is not None: + return + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5) + except TimeoutError: + process.kill() + await process.wait() + + +async def _requests(base_url: str) -> list[dict[str, Any]]: + async with httpx.AsyncClient() as client: + response = await client.get(f"{base_url}/__requests") + response.raise_for_status() + return response.json() + + +async def _wait_for_mcp_server_status( + session, server_name: str, expected_status: McpServerStatus = McpServerStatus.CONNECTED +) -> None: + last_status = "" + + async def matches() -> bool: + nonlocal last_status + result = await session.rpc.mcp.list() + server = next((s for s in result.servers if s.name == server_name), None) + last_status = server.status.value if server is not None else "" + return server is not None and server.status == expected_status + + await wait_for_condition( + matches, + timeout=60.0, + poll_interval=0.2, + timeout_message=( + f"{server_name} did not reach {expected_status.value}; last status was {last_status}" + ), + ) + + +class TestMcpOAuth: + async def test_should_satisfy_mcp_oauth_using_host_provided_token(self, ctx: E2ETestContext): + url, process = await _start_oauth_mcp_server() + server_name = "oauth-protected-mcp" + observed_request = None + + def on_mcp_auth_request(request, _invocation): + nonlocal observed_request + observed_request = request + return { + "kind": "token", + "accessToken": EXPECTED_TOKEN, + "tokenType": "Bearer", + "expiresIn": 3600, + } + + try: + mcp_servers: dict[str, MCPServerConfig] = { + server_name: { + "type": "http", + "url": f"{url}/mcp", + "tools": ["*"], + } + } + async with await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=on_mcp_auth_request, + mcp_servers=mcp_servers, + ) as session: + await _wait_for_mcp_server_status(session, server_name) + + tools = await session.rpc.mcp.list_tools( + MCPListToolsRequest(server_name=server_name) + ) + assert [tool.name for tool in tools.tools] == ["whoami"] + + assert observed_request is not None + assert observed_request["serverName"] == server_name + assert observed_request["serverUrl"] == f"{url}/mcp" + assert observed_request["reason"] == "initial" + assert observed_request["wwwAuthenticateParams"] == { + "resourceMetadataUrl": f"{url}/.well-known/oauth-protected-resource", + "scope": "mcp.read", + "error": "invalid_token", + } + assert json.loads(observed_request["resourceMetadata"]) == { + "resource": f"{url}/mcp", + "authorization_servers": [url], + "scopes_supported": ["mcp.read"], + "bearer_methods_supported": ["header"], + } + + requests = await _requests(url) + assert any(request["authorization"] is None for request in requests) + assert any( + request["authorization"] == f"Bearer {EXPECTED_TOKEN}" for request in requests + ) + finally: + await _stop_process(process) + + async def test_should_request_replacement_tokens_across_mcp_oauth_lifecycle( + self, ctx: E2ETestContext + ): + url, process = await _start_oauth_mcp_server() + server_name = "oauth-lifecycle-mcp" + observed_requests: list[dict[str, Any]] = [] + refresh_count = 0 + + def on_mcp_auth_request(request, _invocation): + nonlocal refresh_count + observed_requests.append(request) + if request["reason"] == "refresh": + refresh_count += 1 + assert request["wwwAuthenticateParams"] == {"error": "invalid_token"} + if refresh_count > 1: + return {"kind": "cancelled"} + return {"kind": "token", "accessToken": REFRESH_TOKEN} + if request["reason"] == "upscope": + assert request["wwwAuthenticateParams"] == { + "resourceMetadataUrl": f"{url}/.well-known/oauth-protected-resource", + "scope": "mcp.write", + "error": "insufficient_scope", + } + return {"kind": "token", "accessToken": UPSCOPE_TOKEN} + if request["reason"] == "reauth": + return {"kind": "token", "accessToken": REAUTH_TOKEN} + return {"kind": "token", "accessToken": EXPECTED_TOKEN} + + try: + mcp_servers: dict[str, MCPServerConfig] = { + server_name: { + "type": "http", + "url": f"{url}/mcp", + "tools": ["*"], + } + } + async with await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=on_mcp_auth_request, + mcp_servers=mcp_servers, + enable_mcp_apps=True, + ) as session: + await _wait_for_mcp_server_status(session, server_name) + + for scenario in ("refresh", "upscope", "reauth"): + result = await session.rpc.mcp.apps.call_tool( + MCPAppsCallToolRequest( + origin_server_name=server_name, + server_name=server_name, + tool_name="whoami", + arguments={"scenario": scenario}, + ) + ) + assert result["content"] == [{"type": "text", "text": "oauth-test-user"}] + + assert [request["reason"] for request in observed_requests] == [ + "initial", + "refresh", + "upscope", + "refresh", + "reauth", + ] + requests = await _requests(url) + assert any( + request["authorization"] == f"Bearer {REFRESH_TOKEN}" for request in requests + ) + assert any( + request["authorization"] == f"Bearer {UPSCOPE_TOKEN}" for request in requests + ) + assert any(request["authorization"] == f"Bearer {REAUTH_TOKEN}" for request in requests) + finally: + await _stop_process(process) + + async def test_should_cancel_pending_mcp_oauth_request(self, ctx: E2ETestContext): + url, process = await _start_oauth_mcp_server() + server_name = "oauth-cancelled-mcp" + observed_request = None + + def on_mcp_auth_request(request, _invocation): + nonlocal observed_request + observed_request = request + return {"kind": "cancelled"} + + try: + mcp_servers: dict[str, MCPServerConfig] = { + server_name: { + "type": "http", + "url": f"{url}/mcp", + "tools": ["*"], + } + } + async with await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=on_mcp_auth_request, + mcp_servers=mcp_servers, + ) as session: + await _wait_for_mcp_server_status(session, server_name, McpServerStatus.FAILED) + + assert observed_request is not None + assert observed_request["serverName"] == server_name + assert observed_request["reason"] == "initial" + finally: + await _stop_process(process) diff --git a/python/e2e/testharness/context.py b/python/e2e/testharness/context.py index 735c365c5f..de1bf03292 100644 --- a/python/e2e/testharness/context.py +++ b/python/e2e/testharness/context.py @@ -168,6 +168,8 @@ def get_env(self) -> dict: "XDG_CONFIG_HOME": self.home_dir, "XDG_STATE_HOME": self.home_dir, "GITHUB_TOKEN": DEFAULT_GITHUB_TOKEN, + "COPILOT_MCP_APPS": "true", + "MCP_APPS": "true", } ) return env diff --git a/python/test_client.py b/python/test_client.py index f3f46c4d8b..db6703c3b0 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -4,6 +4,7 @@ This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.py instead. """ +import asyncio from datetime import UTC, datetime from unittest.mock import AsyncMock, Mock, patch @@ -28,6 +29,14 @@ ModelSupports, ) from copilot.session import PermissionHandler +from copilot.session_events import ( + McpOauthRequestReason, + McpOauthRequiredData, + McpOauthRequiredStaticClientConfig, + McpOauthWWWAuthenticateParams, + SessionEvent, + SessionEventType, +) from e2e.testharness import CLI_PATH @@ -139,6 +148,307 @@ async def test_resume_session_allows_none_permission_handler(self): class TestCreateSessionConfig: + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_in_create_session(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + return {} + + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + ) + + create_method, create_payload = captured[0] + interest_method, interest_payload = captured[1] + assert create_method == "session.create" + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload["eventType"] == "mcp.oauth_required" + assert interest_payload["sessionId"] == create_payload["sessionId"] + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_interest_is_not_registered_without_handler(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + if method == "session.resume": + return {"sessionId": params["sessionId"], "workspacePath": None} + return {} + + client._client.request = mock_request + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_event=lambda event: None, + ) + await client.resume_session( + "session-without-auth", + on_permission_request=PermissionHandler.approve_all, + on_event=lambda event: None, + ) + + assert session.session_id + assert not any( + method == "session.eventLog.registerInterest" + and params["eventType"] == "mcp.oauth_required" + for method, params in captured + ) + assert any( + method == "session.create" and params["requestPermission"] is True + for method, params in captured + ) + assert any( + method == "session.resume" and params["requestPermission"] is True + for method, params in captured + ) + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_before_resume(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.resume": + return {"sessionId": params["sessionId"], "workspacePath": None} + return {} + + client._client.request = mock_request + await client.resume_session( + "session-with-auth", + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + ) + + interest_method, interest_payload = captured[0] + resume_method, resume_payload = captured[1] + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload == { + "sessionId": "session-with-auth", + "eventType": "mcp.oauth_required", + } + assert resume_method == "session.resume" + assert resume_payload["requestPermission"] is True + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_handler_registers_interest_after_cloud_create_only_with_handler(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + create_count = 0 + + async def mock_request(method, params, **kwargs): + nonlocal create_count + captured.append((method, params)) + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + if method == "session.create": + create_count += 1 + result = { + "sessionId": f"server-assigned-session-{create_count}", + "workspacePath": None, + } + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + return {} + + cloud = CloudSessionOptions( + repository=CloudSessionRepository( + owner="github", + name="copilot-sdk", + branch="main", + ) + ) + + client._client.request = mock_request + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + cloud=cloud, + ) + + assert not any( + method == "session.eventLog.registerInterest" + and params["eventType"] == "mcp.oauth_required" + for method, params in captured + ) + + captured.clear() + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=lambda request: {"kind": "cancelled"}, + cloud=cloud, + ) + + create_method, _create_payload = captured[0] + interest_method, interest_payload = captured[1] + assert create_method == "session.create" + assert interest_method == "session.eventLog.registerInterest" + assert interest_payload == { + "sessionId": "server-assigned-session-2", + "eventType": "mcp.oauth_required", + } + finally: + await client.force_stop() + + @pytest.mark.asyncio + async def test_mcp_auth_required_event_sends_host_token(self): + client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) + await client.start() + try: + captured: list[tuple[str, dict]] = [] + + async def mock_request(method, params, **kwargs): + if method == "session.mcp.oauth.handlePendingRequest": + captured.append((method, params)) + return {"success": True} + if method == "session.create": + result = {"sessionId": params["sessionId"], "workspacePath": None} + callback = kwargs.get("on_response_inline") + if callback is not None: + callback(result) + return result + if method == "session.eventLog.registerInterest": + return {"id": "interest-1"} + return {} + + client._client.request = mock_request + observed_request = None + + def handle_mcp_auth_request(request, invocation): + nonlocal observed_request + observed_request = request + assert invocation == {"sessionId": session.session_id} + return { + "accessToken": "host-token", + "tokenType": "Bearer", + } + + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + on_mcp_auth_request=handle_mcp_auth_request, + ) + + session._dispatch_event( + SessionEvent( + data=McpOauthRequiredData( + request_id="oauth-request", + server_name="oauth-server", + server_url="https://example.com/mcp", + reason=McpOauthRequestReason.INITIAL, + www_authenticate_params=McpOauthWWWAuthenticateParams( + resource_metadata_url="https://example.com/.well-known/oauth-protected-resource" + ), + resource_metadata='{"resource":"https://example.com/mcp"}', + static_client_config=McpOauthRequiredStaticClientConfig( + client_id="static-client", + client_secret="static-secret", + grant_type="client_credentials", + public_client=False, + ), + ), + id="evt-1", + timestamp="2026-01-01T00:00:00Z", + type=SessionEventType.MCP_OAUTH_REQUIRED, + ephemeral=True, + parent_id=None, + ) + ) + + for _ in range(200): + if captured: + break + await asyncio.sleep(0.005) + + assert observed_request is not None + assert observed_request["resourceMetadata"] == '{"resource":"https://example.com/mcp"}' + assert observed_request["wwwAuthenticateParams"]["resourceMetadataUrl"] == ( + "https://example.com/.well-known/oauth-protected-resource" + ) + assert observed_request["staticClientConfig"] == { + "clientId": "static-client", + "clientSecret": "static-secret", + "grantType": "client_credentials", + "publicClient": False, + } + assert captured == [ + ( + "session.mcp.oauth.handlePendingRequest", + { + "sessionId": session.session_id, + "requestId": "oauth-request", + "result": { + "kind": "token", + "accessToken": "host-token", + "tokenType": "Bearer", + }, + }, + ) + ] + + observed_request = None + session._dispatch_event( + SessionEvent( + data=McpOauthRequiredData( + request_id="oauth-request-without-metadata", + server_name="oauth-server", + server_url="https://example.com/mcp", + reason=McpOauthRequestReason.INITIAL, + ), + id="evt-2", + timestamp="2026-01-01T00:00:00Z", + type=SessionEventType.MCP_OAUTH_REQUIRED, + ephemeral=True, + parent_id=None, + ) + ) + + for _ in range(200): + if observed_request is not None: + break + await asyncio.sleep(0.005) + + assert observed_request is not None + assert "resourceMetadata" not in observed_request + assert "wwwAuthenticateParams" not in observed_request + finally: + await client.force_stop() + @pytest.mark.asyncio async def test_create_session_forwards_cloud_options(self): client = CopilotClient(connection=RuntimeConnection.for_stdio(path=CLI_PATH)) diff --git a/rust/src/handler.rs b/rust/src/handler.rs index dadd1706ff..3287a4f093 100644 --- a/rust/src/handler.rs +++ b/rust/src/handler.rs @@ -19,8 +19,13 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use crate::generated::api_types::{ - PermissionDecision, PermissionDecisionApproveOnce, PermissionDecisionReject, - PermissionDecisionUserNotAvailable, + McpOauthPendingRequestResponse, McpOauthPendingRequestResponseCancelled, + McpOauthPendingRequestResponseCancelledKind, McpOauthPendingRequestResponseToken, + McpOauthPendingRequestResponseTokenKind, PermissionDecision, PermissionDecisionApproveOnce, + PermissionDecisionReject, PermissionDecisionUserNotAvailable, +}; +use crate::session_events::{ + McpOauthRequestReason, McpOauthRequiredStaticClientConfig, McpOauthWWWAuthenticateParams, }; use crate::types::{ ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, @@ -159,6 +164,75 @@ pub trait ElicitationHandler: Send + Sync + 'static { ) -> ElicitationResult; } +/// MCP OAuth request that the SDK host can satisfy with a host-acquired token. +#[derive(Debug, Clone)] +pub struct McpAuthRequest { + /// Identifier for the pending MCP OAuth request. + pub request_id: RequestId, + /// Display name of the MCP server that requires OAuth. + pub server_name: String, + /// URL of the MCP server that requires OAuth. + pub server_url: String, + /// Why the runtime is requesting host-provided OAuth credentials. + pub reason: McpOauthRequestReason, + /// Parsed WWW-Authenticate parameters from the MCP server, if available. + pub www_authenticate_params: Option, + /// Raw RFC 9728 protected-resource metadata JSON fetched by the runtime, if available. + pub resource_metadata: Option, + /// Static OAuth client configuration, if the server specifies one. + pub static_client_config: Option, +} + +/// Result returned by an MCP auth request handler. +#[derive(Debug, Clone)] +pub enum McpAuthResult { + /// Supplies host-acquired OAuth token data. + Token { + /// Access token acquired by the SDK host. + access_token: String, + /// OAuth token type. Defaults to Bearer when omitted. + token_type: Option, + /// Token lifetime in seconds, if known. + expires_in: Option, + }, + /// Declines or cancels the pending OAuth request. + Cancelled, +} + +impl McpAuthResult { + pub(crate) fn into_wire(self) -> McpOauthPendingRequestResponse { + match self { + Self::Token { + access_token, + token_type, + expires_in, + } => McpOauthPendingRequestResponse::Token(McpOauthPendingRequestResponseToken { + access_token, + token_type, + expires_in, + kind: McpOauthPendingRequestResponseTokenKind::Token, + }), + Self::Cancelled => { + McpOauthPendingRequestResponse::Cancelled(McpOauthPendingRequestResponseCancelled { + kind: McpOauthPendingRequestResponseCancelledKind::Cancelled, + }) + } + } + } +} + +/// Handler for MCP server OAuth requests. +#[async_trait] +pub trait McpAuthHandler: Send + Sync + 'static { + /// Resolve an MCP OAuth request with host token data or cancellation. + async fn handle( + &self, + session_id: SessionId, + request_id: RequestId, + request: McpAuthRequest, + ) -> McpAuthResult; +} + /// Handler for `user_input.requested` events from the `ask_user` tool. /// /// When unset, `requestUserInput: false` goes on the wire and the @@ -266,4 +340,23 @@ mod tests { PermissionResult::Decision(PermissionDecision::Reject(_)) )); } + + #[test] + fn mcp_auth_result_token_converts_to_wire_response() { + let wire = McpAuthResult::Token { + access_token: "host-token".to_string(), + token_type: Some("Bearer".to_string()), + expires_in: Some(3600), + } + .into_wire(); + + match wire { + McpOauthPendingRequestResponse::Token(token) => { + assert_eq!(token.access_token, "host-token"); + assert_eq!(token.token_type.as_deref(), Some("Bearer")); + assert_eq!(token.expires_in, Some(3600)); + } + McpOauthPendingRequestResponse::Cancelled(_) => panic!("expected token response"), + } + } } diff --git a/rust/src/session.rs b/rust/src/session.rs index 18b91b4377..b9f17f7dec 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -11,14 +11,17 @@ use tokio_util::sync::CancellationToken; use tracing::{Instrument, warn}; use crate::canvas::CanvasHandler; -use crate::generated::api_types::{LogRequest, ModelSwitchToRequest, OpenCanvasInstance}; +use crate::generated::api_types::{ + LogRequest, ModelSwitchToRequest, OpenCanvasInstance, RegisterEventInterestParams, rpc_methods, +}; use crate::generated::session_events::{ - CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, + CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, McpOauthRequiredData, SessionCanvasClosedData, SessionErrorData, SessionEventType, }; use crate::handler::{ AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, ExitPlanModeHandler, - PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse, + McpAuthHandler, McpAuthRequest, McpAuthResult, PermissionHandler, PermissionResult, + UserInputHandler, UserInputResponse, }; use crate::hooks::SessionHooks; use crate::provider_token::BearerTokenProvider; @@ -49,6 +52,7 @@ use crate::{ pub(crate) struct SessionHandlers { pub permission: Option>, pub elicitation: Option>, + pub mcp_auth: Option>, pub user_input: Option>, pub exit_plan_mode: Option>, pub auto_mode_switch: Option>, @@ -881,6 +885,7 @@ impl Client { let handlers = SessionHandlers { permission: permission_handler, elicitation: runtime.elicitation_handler.take(), + mcp_auth: runtime.mcp_auth_handler.take(), user_input: runtime.user_input_handler.take(), exit_plan_mode: runtime.exit_plan_mode_handler.take(), auto_mode_switch: runtime.auto_mode_switch_handler.take(), @@ -895,6 +900,7 @@ impl Client { let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers); + let has_mcp_auth_handler = handlers.mcp_auth.is_some(); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -1030,6 +1036,9 @@ impl Client { "Client::create_session local setup complete" ); *capabilities.write() = create_result.capabilities.unwrap_or_default(); + if has_mcp_auth_handler { + register_mcp_auth_interest(self, &session_id).await?; + } tracing::debug!( elapsed_ms = total_start.elapsed().as_millis(), @@ -1139,6 +1148,7 @@ impl Client { let handlers = SessionHandlers { permission: permission_handler, elicitation: runtime.elicitation_handler.take(), + mcp_auth: runtime.mcp_auth_handler.take(), user_input: runtime.user_input_handler.take(), exit_plan_mode: runtime.exit_plan_mode_handler.take(), auto_mode_switch: runtime.auto_mode_switch_handler.take(), @@ -1153,6 +1163,7 @@ impl Client { let canvas_handler = runtime.canvas_handler.take(); let session_fs_provider = runtime.session_fs_provider.take(); let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers); + let has_mcp_auth_handler = handlers.mcp_auth.is_some(); if self.inner.session_fs_configured && session_fs_provider.is_none() { return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into()); } @@ -1170,6 +1181,9 @@ impl Client { let mut params = serde_json::to_value(&wire)?; let trace_ctx = self.resolve_trace_context().await; inject_trace_context(&mut params, &trace_ctx); + if has_mcp_auth_handler { + register_mcp_auth_interest(self, &session_id).await?; + } let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default())); let setup_start = Instant::now(); @@ -1477,6 +1491,17 @@ fn notification_permission_payload(result: &PermissionResult) -> Option { } } +async fn register_mcp_auth_interest(client: &Client, session_id: &SessionId) -> Result<(), Error> { + let mut params = serde_json::to_value(RegisterEventInterestParams { + event_type: "mcp.oauth_required".to_string(), + })?; + params["sessionId"] = Value::String(session_id.to_string()); + client + .call(rpc_methods::SESSION_EVENTLOG_REGISTERINTEREST, Some(params)) + .await?; + Ok(()) +} + fn tool_failure_result(message: impl Into) -> ToolResult { let message = message.into(); ToolResult::Expanded(ToolResultExpanded { @@ -1944,6 +1969,91 @@ async fn handle_notification( .instrument(span), ); } + SessionEventType::McpOauthRequired => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let Some(mcp_auth_handler) = handlers.mcp_auth.clone() else { + warn!( + session_id = %session_id, + request_id = %request_id, + "received MCP OAuth request without a registered MCP auth handler" + ); + return; + }; + let data: McpOauthRequiredData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize MCP OAuth request"); + return; + } + }; + let request = McpAuthRequest { + request_id: request_id.clone(), + server_name: data.server_name, + server_url: data.server_url, + reason: data.reason, + www_authenticate_params: data.www_authenticate_params, + resource_metadata: data.resource_metadata, + static_client_config: data.static_client_config, + }; + let client = client.clone(); + let sid = session_id.clone(); + let span = tracing::error_span!( + "mcp_auth_request_handler", + session_id = %sid, + request_id = %request_id + ); + tokio::spawn( + async move { + let cancel = McpAuthResult::Cancelled; + let handler_task = tokio::spawn({ + let sid = sid.clone(); + let request_id = request_id.clone(); + let span = tracing::error_span!( + "mcp_auth_callback", + session_id = %sid, + request_id = %request_id + ); + async move { + let handler_start = Instant::now(); + let response = mcp_auth_handler + .handle(sid.clone(), request_id.clone(), request) + .await; + tracing::debug!( + elapsed_ms = handler_start.elapsed().as_millis(), + session_id = %sid, + request_id = %request_id, + "McpAuthHandler::handle dispatch" + ); + response + } + .instrument(span) + }); + let result = match handler_task.await { + Ok(result) => result, + Err(_) => cancel, + }; + let rpc_start = Instant::now(); + let _ = client + .call( + "session.mcp.oauth.handlePendingRequest", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result.into_wire(), + })), + ) + .await; + tracing::debug!( + elapsed_ms = rpc_start.elapsed().as_millis(), + "Session::handle_notification MCP auth response sent" + ); + } + .instrument(span), + ); + } SessionEventType::CommandExecute => { let data: CommandExecuteData = match serde_json::from_value(notification.event.data.clone()) { diff --git a/rust/src/types.rs b/rust/src/types.rs index 75408db026..290937e392 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -24,8 +24,8 @@ use crate::generated::api_types::OpenCanvasInstance; pub use crate::generated::session_events::ContextTier; use crate::generated::session_events::ReasoningSummary; use crate::handler::{ - AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, PermissionHandler, - UserInputHandler, + AutoModeSwitchHandler, ElicitationHandler, ExitPlanModeHandler, McpAuthHandler, + PermissionHandler, UserInputHandler, }; use crate::hooks::SessionHooks; use crate::provider_token::BearerTokenProvider; @@ -1772,6 +1772,9 @@ pub struct SessionConfig { /// Optional elicitation-request handler. When `None`, /// `requestElicitation: false` goes on the wire. pub elicitation_handler: Option>, + /// Optional MCP OAuth request handler. When set, the SDK can satisfy MCP + /// server OAuth requests with host-acquired token data or cancellation. + pub mcp_auth_handler: Option>, /// Optional user-input handler. When `None`, /// `requestUserInput: false` goes on the wire and the `ask_user` /// tool is disabled. @@ -1901,6 +1904,10 @@ impl std::fmt::Debug for SessionConfig { "elicitation_handler", &self.elicitation_handler.as_ref().map(|_| ""), ) + .field( + "mcp_auth_handler", + &self.mcp_auth_handler.as_ref().map(|_| ""), + ) .field( "user_input_handler", &self.user_input_handler.as_ref().map(|_| ""), @@ -1990,6 +1997,7 @@ impl Default for SessionConfig { session_fs_provider: None, permission_handler: None, elicitation_handler: None, + mcp_auth_handler: None, user_input_handler: None, exit_plan_mode_handler: None, auto_mode_switch_handler: None, @@ -2013,6 +2021,7 @@ pub(crate) struct SessionConfigRuntime { pub permission_handler: Option>, pub permission_policy: Option, pub elicitation_handler: Option>, + pub mcp_auth_handler: Option>, pub user_input_handler: Option>, pub exit_plan_mode_handler: Option>, pub auto_mode_switch_handler: Option>, @@ -2143,6 +2152,7 @@ impl SessionConfig { permission_handler: self.permission_handler, permission_policy: self.permission_policy, elicitation_handler: self.elicitation_handler, + mcp_auth_handler: self.mcp_auth_handler, user_input_handler: self.user_input_handler, exit_plan_mode_handler: self.exit_plan_mode_handler, auto_mode_switch_handler: self.auto_mode_switch_handler, @@ -2173,6 +2183,12 @@ impl SessionConfig { self } + /// Install an [`McpAuthHandler`] for host-provided MCP OAuth tokens. + pub fn with_mcp_auth_handler(mut self, handler: Arc) -> Self { + self.mcp_auth_handler = Some(handler); + self + } + /// Install a [`UserInputHandler`]. Required for the `ask_user` tool /// to be enabled. pub fn with_user_input_handler(mut self, handler: Arc) -> Self { @@ -2851,6 +2867,8 @@ pub struct ResumeSessionConfig { /// Optional elicitation handler. See /// [`SessionConfig::elicitation_handler`]. pub elicitation_handler: Option>, + /// Optional MCP OAuth handler. See [`SessionConfig::mcp_auth_handler`]. + pub mcp_auth_handler: Option>, /// Optional user-input handler. See /// [`SessionConfig::user_input_handler`]. pub user_input_handler: Option>, @@ -3103,6 +3121,7 @@ impl ResumeSessionConfig { permission_handler: self.permission_handler, permission_policy: self.permission_policy, elicitation_handler: self.elicitation_handler, + mcp_auth_handler: self.mcp_auth_handler, user_input_handler: self.user_input_handler, exit_plan_mode_handler: self.exit_plan_mode_handler, auto_mode_switch_handler: self.auto_mode_switch_handler, @@ -3182,6 +3201,7 @@ impl ResumeSessionConfig { continue_pending_work: None, permission_handler: None, elicitation_handler: None, + mcp_auth_handler: None, user_input_handler: None, exit_plan_mode_handler: None, auto_mode_switch_handler: None, @@ -3207,6 +3227,12 @@ impl ResumeSessionConfig { self } + /// Install an [`McpAuthHandler`] for host-provided MCP OAuth tokens. + pub fn with_mcp_auth_handler(mut self, handler: Arc) -> Self { + self.mcp_auth_handler = Some(handler); + self + } + /// Install a [`UserInputHandler`] for the resumed session. pub fn with_user_input_handler(mut self, handler: Arc) -> Self { self.user_input_handler = Some(handler); diff --git a/rust/tests/e2e.rs b/rust/tests/e2e.rs index 59b83ab27c..79059c7f28 100644 --- a/rust/tests/e2e.rs +++ b/rust/tests/e2e.rs @@ -37,6 +37,8 @@ mod hooks; mod hooks_extended; #[path = "e2e/mcp_and_agents.rs"] mod mcp_and_agents; +#[path = "e2e/mcp_oauth.rs"] +mod mcp_oauth; #[path = "e2e/mode_empty.rs"] mod mode_empty; #[path = "e2e/mode_handlers.rs"] diff --git a/rust/tests/e2e/mcp_oauth.rs b/rust/tests/e2e/mcp_oauth.rs new file mode 100644 index 0000000000..b1d932372c --- /dev/null +++ b/rust/tests/e2e/mcp_oauth.rs @@ -0,0 +1,433 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::process::Stdio; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{McpAuthHandler, McpAuthRequest, McpAuthResult}; +use github_copilot_sdk::rpc::{McpAppsCallToolRequest, McpListToolsRequest}; +use github_copilot_sdk::session::Session; +use github_copilot_sdk::session_events::{McpOauthRequestReason, McpServerStatus}; +use github_copilot_sdk::{McpHttpServerConfig, McpServerConfig, RequestId, SessionId}; +use parking_lot::Mutex; +use serde::Deserialize; +use serde_json::Value; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::process::{Child, Command}; + +use super::support::{wait_for_condition, with_e2e_context_no_snapshot}; + +const EXPECTED_TOKEN: &str = "sdk-host-token"; +const REFRESH_TOKEN: &str = "sdk-host-token-refresh"; +const UPSCOPE_TOKEN: &str = "sdk-host-token-upscope"; +const REAUTH_TOKEN: &str = "sdk-host-token-reauth"; + +#[tokio::test] +async fn should_satisfy_mcp_oauth_using_host_provided_token() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let mut oauth_server = OAuthMcpServer::start( + ctx.repo_root() + .join("test/harness/test-mcp-oauth-server.mjs"), + ) + .await; + let server_name = "oauth-protected-mcp"; + let handler = Arc::new(TokenAuthHandler::default()); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_auth_handler(handler.clone()) + .with_mcp_servers(HashMap::from([( + server_name.to_string(), + McpServerConfig::Http(McpHttpServerConfig { + tools: Some(vec!["*".to_string()]), + timeout: None, + url: format!("{}/mcp", oauth_server.url), + headers: HashMap::new(), + }), + )])), + ) + .await + .expect("create session"); + + wait_for_mcp_server_status(&session, server_name, McpServerStatus::Connected).await; + let tools = session + .rpc() + .mcp() + .list_tools(McpListToolsRequest { + server_name: server_name.to_string(), + }) + .await + .expect("list MCP tools"); + assert!(tools.tools.iter().any(|tool| tool.name == "whoami")); + + let request = handler + .request + .lock() + .clone() + .expect("MCP auth handler should be invoked"); + assert_eq!(request.server_name, server_name); + assert_eq!(request.server_url, format!("{}/mcp", oauth_server.url)); + assert_eq!(request.reason, McpOauthRequestReason::Initial); + let www_authenticate = request + .www_authenticate_params + .expect("WWW-Authenticate params"); + assert_eq!( + www_authenticate.resource_metadata_url, + Some(format!( + "{}/.well-known/oauth-protected-resource", + oauth_server.url + )) + ); + assert_eq!(www_authenticate.scope.as_deref(), Some("mcp.read")); + assert_eq!(www_authenticate.error.as_deref(), Some("invalid_token")); + let metadata: Value = serde_json::from_str( + request + .resource_metadata + .as_deref() + .expect("resource metadata"), + ) + .expect("parse resource metadata"); + assert_eq!(metadata["resource"], format!("{}/mcp", oauth_server.url)); + + let requests = oauth_server.requests().await; + assert!( + requests + .iter() + .any(|request| request.authorization.is_none()) + ); + assert!( + requests.iter().any( + |request| request.authorization.as_deref() == Some("Bearer sdk-host-token") + ) + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + oauth_server.stop().await; + }) + }) + .await; +} + +#[tokio::test] +async fn should_request_replacement_tokens_across_mcp_oauth_lifecycle() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let mut oauth_server = OAuthMcpServer::start( + ctx.repo_root() + .join("test/harness/test-mcp-oauth-server.mjs"), + ) + .await; + let server_name = "oauth-lifecycle-mcp"; + let handler = Arc::new(LifecycleAuthHandler::default()); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_enable_mcp_apps(true) + .with_mcp_auth_handler(handler.clone()) + .with_mcp_servers(HashMap::from([( + server_name.to_string(), + McpServerConfig::Http(McpHttpServerConfig { + tools: Some(vec!["*".to_string()]), + timeout: None, + url: format!("{}/mcp", oauth_server.url), + headers: HashMap::new(), + }), + )])), + ) + .await + .expect("create session"); + + wait_for_mcp_server_status(&session, server_name, McpServerStatus::Connected).await; + call_whoami(&session, server_name, "refresh").await; + call_whoami(&session, server_name, "upscope").await; + call_whoami(&session, server_name, "reauth").await; + + assert_eq!( + handler.reasons.lock().as_slice(), + [ + McpOauthRequestReason::Initial, + McpOauthRequestReason::Refresh, + McpOauthRequestReason::Upscope, + McpOauthRequestReason::Refresh, + McpOauthRequestReason::Reauth, + ] + ); + + let requests = oauth_server.requests().await; + assert!( + requests + .iter() + .any(|request| request.authorization.as_deref() + == Some("Bearer sdk-host-token-refresh")) + ); + assert!( + requests + .iter() + .any(|request| request.authorization.as_deref() + == Some("Bearer sdk-host-token-upscope")) + ); + assert!( + requests + .iter() + .any(|request| request.authorization.as_deref() + == Some("Bearer sdk-host-token-reauth")) + ); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + oauth_server.stop().await; + }) + }) + .await; +} + +#[tokio::test] +async fn should_cancel_pending_mcp_oauth_request() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let mut oauth_server = OAuthMcpServer::start( + ctx.repo_root() + .join("test/harness/test-mcp-oauth-server.mjs"), + ) + .await; + let server_name = "oauth-cancelled-mcp"; + let handler = Arc::new(CancelAuthHandler::default()); + let client = ctx.start_client().await; + let session = client + .create_session( + ctx.approve_all_session_config() + .with_mcp_auth_handler(handler.clone()) + .with_mcp_servers(HashMap::from([( + server_name.to_string(), + McpServerConfig::Http(McpHttpServerConfig { + tools: Some(vec!["*".to_string()]), + timeout: None, + url: format!("{}/mcp", oauth_server.url), + headers: HashMap::new(), + }), + )])), + ) + .await + .expect("create session"); + + wait_for_mcp_server_status(&session, server_name, McpServerStatus::Failed).await; + + let request = handler + .request + .lock() + .clone() + .expect("MCP auth handler should be invoked"); + assert_eq!(request.server_name, server_name); + assert_eq!(request.reason, McpOauthRequestReason::Initial); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + oauth_server.stop().await; + }) + }) + .await; +} + +#[derive(Default)] +struct TokenAuthHandler { + request: Mutex>, +} + +#[async_trait] +impl McpAuthHandler for TokenAuthHandler { + async fn handle( + &self, + _session_id: SessionId, + request_id: RequestId, + request: McpAuthRequest, + ) -> McpAuthResult { + assert_eq!(request.request_id, request_id); + *self.request.lock() = Some(request); + McpAuthResult::Token { + access_token: EXPECTED_TOKEN.to_string(), + token_type: Some("Bearer".to_string()), + expires_in: Some(3600), + } + } +} + +#[derive(Default)] +struct LifecycleAuthHandler { + reasons: Mutex>, + refresh_count: Mutex, +} + +#[async_trait] +impl McpAuthHandler for LifecycleAuthHandler { + async fn handle( + &self, + _session_id: SessionId, + request_id: RequestId, + request: McpAuthRequest, + ) -> McpAuthResult { + assert_eq!(request.request_id, request_id); + let reason = request.reason.clone(); + self.reasons.lock().push(reason.clone()); + let token = match reason { + McpOauthRequestReason::Refresh => { + let www_authenticate = request + .www_authenticate_params + .as_ref() + .expect("refresh WWW-Authenticate params"); + assert_eq!(www_authenticate.resource_metadata_url, None); + assert_eq!(www_authenticate.error.as_deref(), Some("invalid_token")); + let mut refresh_count = self.refresh_count.lock(); + *refresh_count += 1; + if *refresh_count > 1 { + return McpAuthResult::Cancelled; + } + REFRESH_TOKEN + } + McpOauthRequestReason::Upscope => { + let www_authenticate = request + .www_authenticate_params + .as_ref() + .expect("upscope WWW-Authenticate params"); + assert!( + www_authenticate + .resource_metadata_url + .as_deref() + .is_some_and(|url| url.ends_with("/.well-known/oauth-protected-resource")) + ); + assert_eq!(www_authenticate.scope.as_deref(), Some("mcp.write")); + assert_eq!( + www_authenticate.error.as_deref(), + Some("insufficient_scope") + ); + UPSCOPE_TOKEN + } + McpOauthRequestReason::Reauth => REAUTH_TOKEN, + _ => EXPECTED_TOKEN, + }; + McpAuthResult::Token { + access_token: token.to_string(), + token_type: None, + expires_in: None, + } + } +} + +#[derive(Default)] +struct CancelAuthHandler { + request: Mutex>, +} + +#[async_trait] +impl McpAuthHandler for CancelAuthHandler { + async fn handle( + &self, + _session_id: SessionId, + request_id: RequestId, + request: McpAuthRequest, + ) -> McpAuthResult { + assert_eq!(request.request_id, request_id); + *self.request.lock() = Some(request); + McpAuthResult::Cancelled + } +} + +#[derive(Deserialize)] +struct OAuthMcpRequest { + authorization: Option, +} + +struct OAuthMcpServer { + child: Child, + url: String, +} + +impl OAuthMcpServer { + async fn start(script: PathBuf) -> Self { + let mut child = Command::new("node") + .arg(script) + .env("EXPECTED_TOKEN", EXPECTED_TOKEN) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .kill_on_drop(true) + .spawn() + .expect("start OAuth MCP server"); + let stdout = child.stdout.take().expect("OAuth MCP stdout"); + let mut lines = BufReader::new(stdout).lines(); + let line = tokio::time::timeout(std::time::Duration::from_secs(10), lines.next_line()) + .await + .expect("OAuth MCP server startup timeout") + .expect("read OAuth MCP startup line") + .expect("OAuth MCP server stdout closed"); + let url = line + .strip_prefix("Listening: ") + .unwrap_or_else(|| panic!("unexpected OAuth MCP startup line: {line}")) + .to_string(); + Self { child, url } + } + + async fn requests(&self) -> Vec { + let text = reqwest::get(format!("{}/__requests", self.url)) + .await + .expect("fetch OAuth MCP requests") + .error_for_status() + .expect("OAuth MCP request status") + .text() + .await + .expect("read OAuth MCP requests"); + serde_json::from_str(&text).expect("decode OAuth MCP requests") + } + + async fn stop(&mut self) { + let _ = self.child.kill().await; + let _ = self.child.wait().await; + } +} + +async fn wait_for_mcp_server_status( + session: &Session, + server_name: &str, + expected_status: McpServerStatus, +) { + wait_for_condition("MCP server status", || async { + session + .rpc() + .mcp() + .list() + .await + .expect("list MCP servers") + .servers + .iter() + .any(|server| server.name == server_name && server.status == expected_status) + }) + .await; +} + +async fn call_whoami(session: &Session, server_name: &str, scenario: &str) { + let result = session + .rpc() + .mcp() + .apps() + .call_tool(McpAppsCallToolRequest { + arguments: Some(HashMap::from([( + "scenario".to_string(), + serde_json::Value::String(scenario.to_string()), + )])), + origin_server_name: server_name.to_string(), + server_name: server_name.to_string(), + tool_name: "whoami".to_string(), + }) + .await + .expect("call whoami"); + let content = result.get("content").expect("whoami content"); + assert_eq!( + content, + &serde_json::json!([{ "type": "text", "text": "oauth-test-user" }]) + ); +} diff --git a/rust/tests/e2e/support.rs b/rust/tests/e2e/support.rs index 1805eb145b..5052ef1be4 100644 --- a/rust/tests/e2e/support.rs +++ b/rust/tests/e2e/support.rs @@ -310,6 +310,8 @@ impl E2eContext { .as_os_str() .to_owned(), ), + ("COPILOT_MCP_APPS".into(), "true".into()), + ("MCP_APPS".into(), "true".into()), ]); if std::env::var("GITHUB_ACTIONS").as_deref() == Ok("true") { env.push(("GH_TOKEN".into(), "fake-token-for-e2e-tests".into())); diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs index 98c6248230..31b0cc2330 100644 --- a/rust/tests/session_test.rs +++ b/rust/tests/session_test.rs @@ -9,17 +9,19 @@ use async_trait::async_trait; use github_copilot_sdk::canvas::{CanvasDeclaration, CanvasHandler, CanvasResult}; use github_copilot_sdk::handler::{ ApproveAllHandler, AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, - ExitPlanModeHandler, ExitPlanModeResult, UserInputHandler, UserInputResponse, + ExitPlanModeHandler, ExitPlanModeResult, McpAuthHandler, McpAuthRequest, McpAuthResult, + UserInputHandler, UserInputResponse, }; use github_copilot_sdk::rpc::{ CanvasProviderInvokeActionRequest, CanvasProviderOpenRequest, CanvasProviderOpenResult, OpenCanvasInstance, }; -use github_copilot_sdk::session_events::ReasoningSummary; +use github_copilot_sdk::session_events::{McpOauthRequiredData, ReasoningSummary}; use github_copilot_sdk::types::{ - CommandContext, CommandDefinition, CommandHandler, DeliveryMode, ElicitationRequest, - ElicitationResult, ExitPlanModeData, ExtensionInfo, MessageOptions, RequestId, SessionConfig, - SessionId, SetModelOptions, Tool, ToolInvocation, ToolResult, + CloudSessionOptions, CloudSessionRepository, CommandContext, CommandDefinition, CommandHandler, + DeliveryMode, ElicitationRequest, ElicitationResult, ExitPlanModeData, ExtensionInfo, + MessageOptions, RequestId, SessionConfig, SessionId, SetModelOptions, Tool, ToolInvocation, + ToolResult, }; use github_copilot_sdk::{Client, ContextTier, tool}; use serde_json::Value; @@ -30,6 +32,20 @@ const TIMEOUT: Duration = Duration::from_secs(2); struct TestCanvasHandler; +struct CancelMcpAuthHandler; + +#[async_trait] +impl McpAuthHandler for CancelMcpAuthHandler { + async fn handle( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: McpAuthRequest, + ) -> McpAuthResult { + McpAuthResult::Cancelled + } +} + #[async_trait] impl CanvasHandler for TestCanvasHandler { async fn on_open( @@ -220,12 +236,294 @@ fn rand_id() -> u64 { COUNTER.fetch_add(1, Ordering::Relaxed) as u64 } +#[test] +fn mcp_oauth_required_data_allows_optional_metadata() { + let with_metadata: McpOauthRequiredData = serde_json::from_value(serde_json::json!({ + "requestId": "oauth-request", + "reason": "initial", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp", + "wwwAuthenticateParams": { + "resourceMetadataUrl": "https://example.com/.well-known/oauth-protected-resource" + }, + "resourceMetadata": "{\"resource\":\"https://example.com/mcp\"}", + "staticClientConfig": { + "clientId": "static-client", + "clientSecret": "static-secret", + "publicClient": false + } + })) + .unwrap(); + assert_eq!( + with_metadata.resource_metadata.as_deref(), + Some("{\"resource\":\"https://example.com/mcp\"}") + ); + assert!(with_metadata.www_authenticate_params.is_some()); + assert_eq!( + with_metadata + .static_client_config + .as_ref() + .and_then(|config| config.client_secret.as_deref()), + Some("static-secret") + ); + + let without_metadata: McpOauthRequiredData = serde_json::from_value(serde_json::json!({ + "requestId": "oauth-request", + "reason": "initial", + "serverName": "oauth-server", + "serverUrl": "https://example.com/mcp" + })) + .unwrap(); + assert!(without_metadata.resource_metadata.is_none()); + assert!(without_metadata.www_authenticate_params.is_none()); +} + fn requested_session_id(request: &Value) -> &str { request["params"]["sessionId"] .as_str() .expect("session request should include sessionId") } +#[tokio::test] +async fn create_session_registers_mcp_auth_interest_only_with_handler() { + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default().with_permission_handler(Arc::new(ApproveAllHandler)), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert_eq!(create_req["params"]["requestPermission"], true); + let session_id = requested_session_id(&create_req).to_string(); + server_respond_create(&mut server_write, &create_req, &session_id).await; + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert_eq!(create_req["params"]["requestPermission"], true); + let session_id = requested_session_id(&create_req).to_string(); + server_respond_create(&mut server_write, &create_req, &session_id).await; + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + + let _session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn cloud_create_session_registers_mcp_auth_interest_after_create_only_with_handler() { + let cloud = || { + CloudSessionOptions::with_repository( + CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"), + ) + }; + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_cloud(cloud()), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert!(create_req["params"].get("sessionId").is_none()); + assert_eq!(create_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &create_req, "server-assigned-session-1").await; + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)) + .with_cloud(cloud()), + ) + .await + .unwrap() + } + }); + + let create_req = read_framed(&mut server_read).await; + assert_eq!(create_req["method"], "session.create"); + assert!(create_req["params"].get("sessionId").is_none()); + assert_eq!(create_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &create_req, "server-assigned-session-2").await; + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!( + interest_req["params"]["sessionId"], + "server-assigned-session-2" + ); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + let _session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn resume_session_registers_mcp_auth_interest_only_with_handler() { + use github_copilot_sdk::types::ResumeSessionConfig; + + let (client, mut server_read, mut server_write) = make_client(); + let resume_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .resume_session( + ResumeSessionConfig::new(SessionId::from("session-without-auth")) + .with_permission_handler(Arc::new(ApproveAllHandler)), + ) + .await + .unwrap() + } + }); + + let resume_req = read_framed(&mut server_read).await; + assert_eq!(resume_req["method"], "session.resume"); + assert_eq!(resume_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &resume_req, "session-without-auth").await; + respond_to_reload(&mut server_read, &mut server_write).await; + let session = timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); + let no_extra_request = timeout(Duration::from_millis(50), read_framed(&mut server_read)).await; + assert!(no_extra_request.is_err()); + drop(session); + + let (client, mut server_read, mut server_write) = make_client(); + let resume_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .resume_session( + ResumeSessionConfig::new(SessionId::from("session-with-auth")) + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_mcp_auth_handler(Arc::new(CancelMcpAuthHandler)), + ) + .await + .unwrap() + } + }); + + let interest_req = read_framed(&mut server_read).await; + assert_eq!(interest_req["method"], "session.eventLog.registerInterest"); + assert_eq!(interest_req["params"]["eventType"], "mcp.oauth_required"); + let id = interest_req["id"].as_u64().unwrap(); + write_framed( + &mut server_write, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "id": "interest-1" }, + })) + .unwrap(), + ) + .await; + + let resume_req = read_framed(&mut server_read).await; + assert_eq!(resume_req["method"], "session.resume"); + assert_eq!(resume_req["params"]["requestPermission"], true); + server_respond_create(&mut server_write, &resume_req, "session-with-auth").await; + respond_to_reload(&mut server_read, &mut server_write).await; + let _session = timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); +} + +async fn server_respond_create( + writer: &mut (impl AsyncWrite + Unpin), + request: &Value, + session_id: &str, +) { + let id = request["id"].as_u64().unwrap(); + write_framed( + writer, + &serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": session_id, "workspacePath": "/tmp/workspace" }, + })) + .unwrap(), + ) + .await; +} + +async fn respond_to_reload( + reader: &mut (impl tokio::io::AsyncRead + Unpin), + writer: &mut (impl AsyncWrite + Unpin), +) { + let reload = read_framed(reader).await; + assert_eq!(reload["method"], "session.skills.reload"); + let id = reload["id"].as_u64().unwrap(); + write_framed( + writer, + &serde_json::to_vec(&serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} })) + .unwrap(), + ) + .await; +} + #[tokio::test] async fn session_subscribe_yields_events_observe_only() { let (session, mut server) = create_session_pair().await; diff --git a/test/harness/test-mcp-oauth-server.mjs b/test/harness/test-mcp-oauth-server.mjs new file mode 100644 index 0000000000..eacd35f304 --- /dev/null +++ b/test/harness/test-mcp-oauth-server.mjs @@ -0,0 +1,325 @@ +#!/usr/bin/env node +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +/** + * Minimal OAuth-protected Streamable HTTP MCP server for SDK E2E tests. + * + * The `/mcp` endpoint returns a WWW-Authenticate challenge until requests include + * an accepted test token, then serves enough JSON-RPC MCP methods for the runtime + * to initialize and list/call one tool. Specific tool-call scenarios trigger + * replacement-token challenges so SDK E2E tests can cover refresh, upscope, and + * reauth flows without relying on a real OAuth server. + */ + +import http from "node:http"; +import path from "node:path"; +import { fileURLToPath } from "node:url"; + +const DEFAULT_EXPECTED_TOKEN = "sdk-host-token"; +const PROTOCOL_VERSION = "2025-03-26"; +const PROTECTED_RESOURCE_PATH = "/.well-known/oauth-protected-resource"; + +export async function startOAuthMcpServer({ + expectedToken = DEFAULT_EXPECTED_TOKEN, + host = "127.0.0.1", + port = 0, +} = {}) { + const requests = []; + const tokens = { + initial: expectedToken, + refresh: `${expectedToken}-refresh`, + upscope: `${expectedToken}-upscope`, + reauth: `${expectedToken}-reauth`, + rejected: `${expectedToken}-rejected`, + }; + const acceptedTokens = new Set([ + tokens.initial, + tokens.refresh, + tokens.upscope, + tokens.reauth, + ]); + + const server = http.createServer(async (req, res) => { + const url = new URL( + req.url ?? "/", + `http://${req.headers.host ?? `${host}:${port}`}`, + ); + const baseUrl = url.origin; + + if (req.method === "GET" && url.pathname === "/__requests") { + respondJson(res, 200, requests); + return; + } + + if ( + req.method === "GET" && + url.pathname === PROTECTED_RESOURCE_PATH + ) { + respondJson(res, 200, { + resource: `${baseUrl}/mcp`, + authorization_servers: [baseUrl], + scopes_supported: ["mcp.read"], + bearer_methods_supported: ["header"], + }); + return; + } + + if ( + req.method === "GET" && + url.pathname === "/.well-known/oauth-authorization-server" + ) { + respondJson(res, 200, { + issuer: baseUrl, + authorization_endpoint: `${baseUrl}/authorize`, + token_endpoint: `${baseUrl}/token`, + response_types_supported: ["code"], + grant_types_supported: ["authorization_code"], + }); + return; + } + + if (url.pathname !== "/mcp") { + respondJson(res, 404, { error: "not_found" }); + return; + } + + const body = await readBody(req); + requests.push({ + method: req.method, + path: url.pathname, + authorization: req.headers.authorization ?? null, + body: body || null, + }); + + const token = parseBearerToken(req.headers.authorization); + if (!token || !acceptedTokens.has(token)) { + challengeInitial(res, baseUrl); + return; + } + + if (req.method !== "POST") { + respondJson(res, 405, { error: "method_not_allowed" }); + return; + } + + const parsedBody = parseJsonBody(body); + if (!parsedBody.ok) { + respondJson(res, 400, { error: "invalid_json" }); + return; + } + + const message = parsedBody.value; + const replacementChallenge = getReplacementChallenge( + message, + token, + tokens, + baseUrl, + ); + if (replacementChallenge) { + res.writeHead(replacementChallenge.statusCode, { + "www-authenticate": replacementChallenge.wwwAuthenticate, + "content-type": "application/json", + }); + res.end(JSON.stringify({ error: replacementChallenge.error })); + return; + } + + const response = Array.isArray(message) + ? message + .map((item) => handleJsonRpcMessage(item)) + .filter((item) => item !== undefined) + : handleJsonRpcMessage(message); + + if ( + response === undefined || + (Array.isArray(response) && response.length === 0) + ) { + res.writeHead(202, { "mcp-session-id": "oauth-test-session" }); + res.end(); + return; + } + + res.writeHead(200, { + "content-type": "application/json", + "mcp-session-id": "oauth-test-session", + }); + res.end(JSON.stringify(response)); + }); + + await new Promise((resolve, reject) => { + server.once("error", reject); + server.listen(port, host, () => { + server.off("error", reject); + resolve(); + }); + }); + + const address = server.address(); + if (!address || typeof address === "string") { + throw new Error("Expected TCP server address"); + } + + return { + url: `http://${host}:${address.port}`, + requests, + close: () => + new Promise((resolve, reject) => + server.close((err) => (err ? reject(err) : resolve())), + ), + }; +} + +function getReplacementChallenge(message, token, tokens, baseUrl) { + const messages = Array.isArray(message) ? message : [message]; + const toolCall = messages.find((item) => item?.method === "tools/call"); + const scenario = toolCall?.params?.arguments?.scenario; + + if (scenario === "refresh" && token !== tokens.refresh) { + return { + statusCode: 401, + wwwAuthenticate: 'Bearer error="invalid_token"', + error: "token_expired", + }; + } + + if (scenario === "upscope" && token !== tokens.upscope) { + return { + statusCode: 403, + wwwAuthenticate: `Bearer resource_metadata="${baseUrl}${PROTECTED_RESOURCE_PATH}", scope="mcp.write", error="insufficient_scope"`, + error: "insufficient_scope", + }; + } + + if (scenario === "reauth" && token !== tokens.reauth) { + return { + statusCode: 401, + wwwAuthenticate: 'Bearer error="invalid_token"', + error: "reauth_required", + }; + } + + if (scenario === "cancel" && token !== tokens.refresh) { + return { + statusCode: 401, + wwwAuthenticate: 'Bearer error="invalid_token"', + error: "token_expired", + }; + } + + return undefined; +} + +function handleJsonRpcMessage(message) { + if (!message || typeof message !== "object" || !("id" in message)) { + return undefined; + } + + switch (message.method) { + case "initialize": + return { + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: message.params?.protocolVersion ?? PROTOCOL_VERSION, + capabilities: { tools: {} }, + serverInfo: { name: "oauth-test-server", version: "1.0.0" }, + }, + }; + case "tools/list": + return { + jsonrpc: "2.0", + id: message.id, + result: { + tools: [ + { + name: "whoami", + description: "Returns the authenticated test principal.", + inputSchema: { + type: "object", + properties: { + scenario: { + type: "string", + enum: ["initial", "refresh", "upscope", "reauth", "cancel"], + }, + }, + additionalProperties: false, + }, + _meta: { "ui.visibility": ["model", "app"] }, + }, + ], + }, + }; + case "tools/call": + return { + jsonrpc: "2.0", + id: message.id, + result: { + content: [{ type: "text", text: "oauth-test-user" }], + isError: false, + }, + }; + default: + return { + jsonrpc: "2.0", + id: message.id, + error: { code: -32601, message: `Method not found: ${message.method}` }, + }; + } +} + +function parseBearerToken(authorization) { + const match = /^Bearer (.+)$/.exec(authorization ?? ""); + return match?.[1]; +} + +function challengeInitial(res, baseUrl) { + const resourceMetadataUrl = `${baseUrl}${PROTECTED_RESOURCE_PATH}`; + res.writeHead(401, { + "www-authenticate": `Bearer resource_metadata="${resourceMetadataUrl}", scope="mcp.read", error="invalid_token"`, + "content-type": "application/json", + }); + res.end(JSON.stringify({ error: "missing_or_invalid_token" })); +} + +function readBody(req) { + return new Promise((resolve, reject) => { + const chunks = []; + req.on("data", (chunk) => chunks.push(chunk)); + req.on("error", reject); + req.on("end", () => resolve(Buffer.concat(chunks).toString("utf8"))); + }); +} + +function parseJsonBody(body) { + if (!body) { + return { ok: true, value: undefined }; + } + + try { + return { ok: true, value: JSON.parse(body) }; + } catch { + return { ok: false, value: undefined }; + } +} + +function respondJson(res, statusCode, body) { + const data = JSON.stringify(body); + res.writeHead(statusCode, { + "content-type": "application/json", + "content-length": Buffer.byteLength(data), + }); + res.end(data); +} + +if (process.argv[1] && path.resolve(process.argv[1]) === fileURLToPath(import.meta.url)) { + const server = await startOAuthMcpServer({ + expectedToken: process.env.EXPECTED_TOKEN ?? DEFAULT_EXPECTED_TOKEN, + }); + console.log(`Listening: ${server.url}`); + process.on("SIGTERM", async () => { + await server.close(); + process.exit(0); + }); +}