From 2b4659079e1669aae589826f8b01d2ddaad7936b Mon Sep 17 00:00:00 2001 From: Ravish Date: Mon, 13 Apr 2026 21:42:02 -0700 Subject: [PATCH] extauth: add ClientCredentialsHandler for OAuth client credentials grant Add an implementation of auth.OAuthHandler that uses the OAuth 2.0 Client Credentials grant (RFC 6749 Section 4.4) for service-to-service authentication with pre-registered credentials. This bypasses both dynamic client registration and the authorization code flow. The handler supports two modes: - Direct token endpoint URL - Metadata discovery via AuthServerURL (RFC 8414) Also adds client_credentials grant type support to the fake authorization server in internal/oauthtest for testing. Refs #627 --- auth/extauth/client_credentials.go | 251 +++++++++++++++ auth/extauth/client_credentials_test.go | 293 ++++++++++++++++++ .../oauthtest/fake_authorization_server.go | 28 ++ 3 files changed, 572 insertions(+) create mode 100644 auth/extauth/client_credentials.go create mode 100644 auth/extauth/client_credentials_test.go diff --git a/auth/extauth/client_credentials.go b/auth/extauth/client_credentials.go new file mode 100644 index 00000000..fefa29b9 --- /dev/null +++ b/auth/extauth/client_credentials.go @@ -0,0 +1,251 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package extauth + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "slices" + "strings" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" +) + +// ClientCredentialsHandlerConfig is the configuration for [ClientCredentialsHandler]. +type ClientCredentialsHandlerConfig struct { + // Credentials contains the pre-registered client ID and secret. + // REQUIRED. Both ClientID and ClientSecretAuth must be set, since the + // client credentials grant requires a confidential client. + Credentials *oauthex.ClientCredentials + + // HTTPClient is an optional HTTP client for customization. + // If nil, http.DefaultClient is used. + // OPTIONAL. + HTTPClient *http.Client +} + +// ClientCredentialsHandler is an implementation of [auth.OAuthHandler] that +// uses the OAuth 2.0 Client Credentials grant (RFC 6749 Section 4.4) to +// obtain access tokens. +// +// This handler is intended for service-to-service authentication where the +// client has pre-registered credentials (client ID and secret) and does not +// require user interaction. It bypasses both dynamic client registration and +// the authorization code flow. +// +// The token endpoint and scopes are discovered automatically via Protected +// Resource Metadata (RFC 9728) and Authorization Server Metadata (RFC 8414), +// following the ext-auth specification SEP-1046. +type ClientCredentialsHandler struct { + config *ClientCredentialsHandlerConfig + tokenSource oauth2.TokenSource +} + +// Compile-time check that ClientCredentialsHandler implements auth.OAuthHandler. +var _ auth.OAuthHandler = (*ClientCredentialsHandler)(nil) + +// NewClientCredentialsHandler creates a new ClientCredentialsHandler. +// It validates the configuration and returns an error if invalid. +func NewClientCredentialsHandler(config *ClientCredentialsHandlerConfig) (*ClientCredentialsHandler, error) { + if config == nil { + return nil, fmt.Errorf("config must be provided") + } + if config.Credentials == nil { + return nil, fmt.Errorf("credentials are required") + } + if err := config.Credentials.Validate(); err != nil { + return nil, fmt.Errorf("invalid credentials: %w", err) + } + if config.Credentials.ClientSecretAuth == nil { + return nil, fmt.Errorf("clientSecretAuth is required for client credentials grant") + } + return &ClientCredentialsHandler{config: config}, nil +} + +// TokenSource returns the token source for outgoing requests. +// Returns nil if authorization has not been performed yet. +func (h *ClientCredentialsHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h.tokenSource, nil +} + +// Authorize performs the Client Credentials grant to obtain an access token. +// It is called when a request fails with 401 or 403. +// +// The flow follows the ext-auth specification SEP-1046: +// 1. Discover Protected Resource Metadata from the request URL +// 2. Discover Authorization Server Metadata from PRM +// 3. Exchange client credentials for an access token at the token endpoint +func (h *ClientCredentialsHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + defer resp.Body.Close() + defer io.Copy(io.Discard, resp.Body) + + httpClient := h.config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + // Step 1: Discover Protected Resource Metadata. + wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) + if err != nil { + return fmt.Errorf("failed to parse WWW-Authenticate header: %v", err) + } + + prm, err := getProtectedResourceMetadata(ctx, wwwChallenges, req.URL.String(), httpClient) + if err != nil { + return err + } + + if len(prm.AuthorizationServers) == 0 { + return fmt.Errorf("protected resource metadata has no authorization servers specified") + } + + // Step 2: Discover Authorization Server Metadata. + asm, err := auth.GetAuthServerMetadata(ctx, prm.AuthorizationServers[0], httpClient) + if err != nil { + return fmt.Errorf("failed to get authorization server metadata: %w", err) + } + if asm == nil { + // Fallback to 2025-03-26 spec: predefined endpoints. + authServerURL := prm.AuthorizationServers[0] + asm = &oauthex.AuthServerMeta{ + Issuer: authServerURL, + TokenEndpoint: authServerURL + "/token", + } + } + + // Determine scopes: use PRM's scopes_supported if available. + scopes := scopesFromChallenges(wwwChallenges) + if len(scopes) == 0 && len(prm.ScopesSupported) > 0 { + scopes = prm.ScopesSupported + } + + // Step 3: Exchange client credentials for an access token. + creds := h.config.Credentials + cfg := &clientcredentials.Config{ + ClientID: creds.ClientID, + ClientSecret: creds.ClientSecretAuth.ClientSecret, + TokenURL: asm.TokenEndpoint, + Scopes: scopes, + AuthStyle: selectTokenAuthMethod(asm.TokenEndpointAuthMethodsSupported), + } + + ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) + h.tokenSource = cfg.TokenSource(ctxWithClient) + + // Eagerly fetch a token to surface errors immediately. + if _, err := h.tokenSource.Token(); err != nil { + h.tokenSource = nil + return fmt.Errorf("client credentials token request failed: %w", err) + } + return nil +} + +// getProtectedResourceMetadata discovers Protected Resource Metadata (RFC 9728) +// from the request URL. This mirrors the logic in AuthorizationCodeHandler. +func getProtectedResourceMetadata(ctx context.Context, wwwChallenges []oauthex.Challenge, mcpServerURL string, httpClient *http.Client) (*oauthex.ProtectedResourceMetadata, error) { + // Use MCP server URL as the resource URI per + // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#canonical-server-uri. + for _, u := range protectedResourceMetadataURLs(resourceMetadataURLFromChallenges(wwwChallenges), mcpServerURL) { + prm, err := oauthex.GetProtectedResourceMetadata(ctx, u.url, u.resource, httpClient) + if err != nil { + continue + } + if prm == nil { + continue + } + if len(prm.AuthorizationServers) == 0 { + // If we found PRM, we enforce the 2025-11-25 spec and not search further. + return nil, fmt.Errorf("protected resource metadata has no authorization servers specified") + } + return prm, nil + } + // Fallback to 2025-03-26 spec: MCP server root is the Authorization Server. + // https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#server-metadata-discovery + u, err := url.Parse(mcpServerURL) + if err != nil { + return nil, fmt.Errorf("failed to parse MCP server URL: %v", err) + } + u.Path = "" + return &oauthex.ProtectedResourceMetadata{ + AuthorizationServers: []string{u.String()}, + Resource: mcpServerURL, + }, nil +} + +type prmURL struct { + url string + resource string +} + +// protectedResourceMetadataURLs returns URLs to try for PRM discovery. +// This mirrors the logic in AuthorizationCodeHandler. +func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL { + var urls []prmURL + if metadataURL != "" { + urls = append(urls, prmURL{url: metadataURL, resource: resourceURL}) + } + ru, err := url.Parse(resourceURL) + if err != nil { + return urls + } + mu := *ru + // At the path of the server's MCP endpoint. + mu.Path = "/.well-known/oauth-protected-resource/" + strings.TrimLeft(ru.Path, "/") + urls = append(urls, prmURL{url: mu.String(), resource: resourceURL}) + // At the root. + mu.Path = "/.well-known/oauth-protected-resource" + ru.Path = "" + urls = append(urls, prmURL{url: mu.String(), resource: ru.String()}) + return urls +} + +// resourceMetadataURLFromChallenges returns a resource metadata URL from +// WWW-Authenticate challenges, or the empty string if there is none. +func resourceMetadataURLFromChallenges(cs []oauthex.Challenge) string { + for _, c := range cs { + if u := c.Params["resource_metadata"]; u != "" { + return u + } + } + return "" +} + +// scopesFromChallenges returns scopes from WWW-Authenticate challenges. +// It only looks at challenges with the "Bearer" scheme. +func scopesFromChallenges(cs []oauthex.Challenge) []string { + for _, c := range cs { + if c.Scheme == "bearer" && c.Params["scope"] != "" { + return strings.Fields(c.Params["scope"]) + } + } + return nil +} + +// selectTokenAuthMethod selects the preferred token endpoint auth method based on +// the authorization server's supported methods. Prefers client_secret_post over +// client_secret_basic per the OAuth 2.1 draft. +func selectTokenAuthMethod(supported []string) oauth2.AuthStyle { + prefOrder := []string{ + "client_secret_post", + "client_secret_basic", + } + for _, method := range prefOrder { + if slices.Contains(supported, method) { + switch method { + case "client_secret_post": + return oauth2.AuthStyleInParams + case "client_secret_basic": + return oauth2.AuthStyleInHeader + } + } + } + return oauth2.AuthStyleAutoDetect +} diff --git a/auth/extauth/client_credentials_test.go b/auth/extauth/client_credentials_test.go new file mode 100644 index 00000000..91ae7c04 --- /dev/null +++ b/auth/extauth/client_credentials_test.go @@ -0,0 +1,293 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package extauth + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/internal/oauthtest" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +func validClientCredentialsConfig() *ClientCredentialsHandlerConfig { + return &ClientCredentialsHandlerConfig{ + Credentials: &oauthex.ClientCredentials{ + ClientID: "test-client", + ClientSecretAuth: &oauthex.ClientSecretAuth{ClientSecret: "test-secret"}, + }, + } +} + +func TestNewClientCredentialsHandler_Validation(t *testing.T) { + tests := []struct { + name string + config *ClientCredentialsHandlerConfig + wantError string + }{ + { + name: "nil config", + config: nil, + wantError: "config must be provided", + }, + { + name: "nil credentials", + config: func() *ClientCredentialsHandlerConfig { + c := validClientCredentialsConfig() + c.Credentials = nil + return c + }(), + wantError: "credentials are required", + }, + { + name: "missing client ID", + config: func() *ClientCredentialsHandlerConfig { + c := validClientCredentialsConfig() + c.Credentials.ClientID = "" + return c + }(), + wantError: "ClientID is required", + }, + { + name: "missing client secret auth", + config: func() *ClientCredentialsHandlerConfig { + c := validClientCredentialsConfig() + c.Credentials.ClientSecretAuth = nil + return c + }(), + wantError: "clientSecretAuth is required", + }, + { + name: "empty client secret", + config: func() *ClientCredentialsHandlerConfig { + c := validClientCredentialsConfig() + c.Credentials.ClientSecretAuth.ClientSecret = "" + return c + }(), + wantError: "ClientSecret is required", + }, + { + name: "valid config", + config: validClientCredentialsConfig(), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewClientCredentialsHandler(tc.config) + if tc.wantError != "" { + if err == nil { + t.Errorf("expected error containing %q, got nil", tc.wantError) + } else if !strings.Contains(err.Error(), tc.wantError) { + t.Errorf("error %q does not contain %q", err.Error(), tc.wantError) + } + } else if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestClientCredentialsHandler_Authorize(t *testing.T) { + authServer := oauthtest.NewFakeAuthorizationServer(oauthtest.Config{ + MetadataEndpointConfig: &oauthtest.MetadataEndpointConfig{ + ServeOAuthInsertedEndpoint: true, + }, + RegistrationConfig: &oauthtest.RegistrationConfig{ + PreregisteredClients: map[string]oauthtest.ClientInfo{ + "test-client": {Secret: "test-secret"}, + }, + }, + ClientCredentialsConfig: &oauthtest.ClientCredentialsConfig{ + Enabled: true, + }, + }) + authServer.Start(t) + + resourceMux := http.NewServeMux() + resourceServer := httptest.NewServer(resourceMux) + t.Cleanup(resourceServer.Close) + resourceURL := resourceServer.URL + "/resource" + + resourceMux.Handle("/.well-known/oauth-protected-resource/resource", auth.ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{authServer.URL()}, + ScopesSupported: []string{"mcp:read", "mcp:write"}, + })) + + t.Run("successful authorization", func(t *testing.T) { + handler, err := NewClientCredentialsHandler(validClientCredentialsConfig()) + if err != nil { + t.Fatal(err) + } + + // TokenSource should be nil before Authorize. + ts, err := handler.TokenSource(t.Context()) + if err != nil { + t.Fatal(err) + } + if ts != nil { + t.Fatal("expected nil TokenSource before Authorize") + } + + // Simulate a 401 response from the MCP server. + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{}, + Body: http.NoBody, + } + req := httptest.NewRequest("GET", resourceURL, nil) + if err := handler.Authorize(t.Context(), req, resp); err != nil { + t.Fatalf("Authorize failed: %v", err) + } + + // TokenSource should now return a valid token. + ts, err = handler.TokenSource(t.Context()) + if err != nil { + t.Fatal(err) + } + if ts == nil { + t.Fatal("expected non-nil TokenSource after Authorize") + } + tok, err := ts.Token() + if err != nil { + t.Fatalf("Token() failed: %v", err) + } + if tok.AccessToken != "test_access_token" { + t.Errorf("got access token %q, want %q", tok.AccessToken, "test_access_token") + } + }) + + t.Run("bad credentials", func(t *testing.T) { + config := validClientCredentialsConfig() + config.Credentials.ClientSecretAuth.ClientSecret = "wrong-secret" + handler, err := NewClientCredentialsHandler(config) + if err != nil { + t.Fatal(err) + } + + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{}, + Body: http.NoBody, + } + req := httptest.NewRequest("GET", resourceURL, nil) + if err := handler.Authorize(t.Context(), req, resp); err == nil { + t.Error("expected Authorize to fail with bad credentials") + } + + // TokenSource should still be nil after failed Authorize. + ts, err := handler.TokenSource(t.Context()) + if err != nil { + t.Fatal(err) + } + if ts != nil { + t.Error("expected nil TokenSource after failed Authorize") + } + }) + + t.Run("scopes from WWW-Authenticate challenge", func(t *testing.T) { + handler, err := NewClientCredentialsHandler(validClientCredentialsConfig()) + if err != nil { + t.Fatal(err) + } + + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{ + "WWW-Authenticate": []string{`Bearer scope="read write"`}, + }, + Body: http.NoBody, + } + req := httptest.NewRequest("GET", resourceURL, nil) + if err := handler.Authorize(t.Context(), req, resp); err != nil { + t.Fatalf("Authorize failed: %v", err) + } + + ts, err := handler.TokenSource(t.Context()) + if err != nil { + t.Fatal(err) + } + if ts == nil { + t.Fatal("expected non-nil TokenSource") + } + }) + + t.Run("PRM via resource_metadata in challenge", func(t *testing.T) { + prmMux := http.NewServeMux() + prmMux.Handle("/custom-prm", auth.ProtectedResourceMetadataHandler(&oauthex.ProtectedResourceMetadata{ + Resource: resourceURL, + AuthorizationServers: []string{authServer.URL()}, + })) + prmServer := httptest.NewServer(prmMux) + t.Cleanup(prmServer.Close) + + handler, err := NewClientCredentialsHandler(validClientCredentialsConfig()) + if err != nil { + t.Fatal(err) + } + + resp := &http.Response{ + StatusCode: http.StatusUnauthorized, + Header: http.Header{ + "WWW-Authenticate": []string{`Bearer resource_metadata="` + prmServer.URL + `/custom-prm"`}, + }, + Body: http.NoBody, + } + req := httptest.NewRequest("GET", resourceURL, nil) + if err := handler.Authorize(t.Context(), req, resp); err != nil { + t.Fatalf("Authorize with resource_metadata challenge failed: %v", err) + } + + ts, err := handler.TokenSource(t.Context()) + if err != nil { + t.Fatal(err) + } + if ts == nil { + t.Fatal("expected non-nil TokenSource") + } + }) +} + +func TestSelectTokenAuthMethod(t *testing.T) { + tests := []struct { + name string + supported []string + want oauth2.AuthStyle + }{ + { + name: "prefers client_secret_post", + supported: []string{"client_secret_basic", "client_secret_post"}, + want: oauth2.AuthStyleInParams, + }, + { + name: "falls back to client_secret_basic", + supported: []string{"client_secret_basic"}, + want: oauth2.AuthStyleInHeader, + }, + { + name: "auto detect when none supported", + supported: []string{"private_key_jwt"}, + want: oauth2.AuthStyleAutoDetect, + }, + { + name: "auto detect when empty", + supported: nil, + want: oauth2.AuthStyleAutoDetect, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := selectTokenAuthMethod(tc.supported) + if got != tc.want { + t.Errorf("selectTokenAuthMethod(%v) = %v, want %v", tc.supported, got, tc.want) + } + }) + } +} diff --git a/internal/oauthtest/fake_authorization_server.go b/internal/oauthtest/fake_authorization_server.go index 44f937f6..33c25e85 100644 --- a/internal/oauthtest/fake_authorization_server.go +++ b/internal/oauthtest/fake_authorization_server.go @@ -57,6 +57,15 @@ type JWTBearerConfig struct { ValidAssertions []string } +// ClientCredentialsConfig configures support for the client_credentials +// grant type (RFC 6749 Section 4.4) on a [FakeAuthorizationServer]. +type ClientCredentialsConfig struct { + // Enabled controls whether the /token endpoint accepts + // grant_type=client_credentials and returns an access token + // if client authentication succeeds. + Enabled bool +} + // Config holds configuration for FakeAuthorizationServer. type Config struct { // The optional path component of the issuer URL. @@ -70,6 +79,9 @@ type Config struct { // JWTBearerConfig enables RFC 7523 JWT Bearer grant at the /token endpoint. // If non-nil, the server accepts grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer. JWTBearerConfig *JWTBearerConfig + // ClientCredentialsConfig enables RFC 6749 Section 4.4 client credentials + // grant at the /token endpoint. + ClientCredentialsConfig *ClientCredentialsConfig } // FakeAuthorizationServer is a fake OAuth 2.0 Authorization Server for testing. @@ -271,6 +283,8 @@ func (s *FakeAuthorizationServer) handleToken(w http.ResponseWriter, r *http.Req s.handleAuthorizationCodeGrant(w, r) case "urn:ietf:params:oauth:grant-type:jwt-bearer": s.handleJWTBearerGrant(w, r) + case "client_credentials": + s.handleClientCredentialsGrant(w, r) default: http.Error(w, fmt.Sprintf("unsupported grant_type: %s", grantType), http.StatusBadRequest) } @@ -332,6 +346,20 @@ func (s *FakeAuthorizationServer) handleJWTBearerGrant(w http.ResponseWriter, r }) } +func (s *FakeAuthorizationServer) handleClientCredentialsGrant(w http.ResponseWriter, r *http.Request) { + if s.config.ClientCredentialsConfig == nil || !s.config.ClientCredentialsConfig.Enabled { + http.Error(w, "client_credentials grant not supported", http.StatusBadRequest) + return + } + // Client was already authenticated in handleToken. + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test_access_token", + "token_type": "Bearer", + "expires_in": 3600, + }) +} + func (s *FakeAuthorizationServer) authenticateClient(r *http.Request) error { clientID, clientSecret, ok := r.BasicAuth() if !ok {