From b5e4148dd6e900b1cead8585f79f94c7b81c92e1 Mon Sep 17 00:00:00 2001 From: Scott Leggett Date: Thu, 30 Apr 2026 11:55:58 +0800 Subject: [PATCH] mcp: handle oauth2.RetrieveError during token refresh Previously, an expired refresh token in the oauth2.Token returned from OAuthHandler.TokenSource() would cause the connection to fail. From the client perspective, this meant that the MCP connection was in a hard-failed state with no way to re-authorize. The change in this commit causes Authorize() to be called in the event of both an oauth2.RetrieveError from the token endpoint, as well as in the pre-existing case of a 401/403 HTTP response from the MCP server. Clients will handle this in their existing Authorize() flows to get a new valid token for the connection. --- mcp/streamable.go | 19 ++++++++++-- mcp/streamable_client_test.go | 56 +++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 708b1326..b49822e6 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -37,6 +37,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/util" "github.com/modelcontextprotocol/go-sdk/internal/xcontext" "github.com/modelcontextprotocol/go-sdk/jsonrpc" + "golang.org/x/oauth2" ) // A StreamableHTTPHandler is an http.Handler that serves streamable MCP @@ -1803,6 +1804,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json, text/event-stream") + if err := c.setMCPHeaders(req); err != nil { // Failure to set headers means that the request was not sent. // Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr @@ -1934,9 +1936,20 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { if ts != nil { token, err := ts.Token() if err != nil { - return err - } - if token != nil { + // If the error is an invalid_grant oauth2.RetrieveError it indicates + // that the token source doesn't have valid authorization for the token + // endpoint, per RFC 6749 section 5.2. For example, the refresh token + // may be expired or invalid. + // + // In that case, ignore the error, skip setting the Authorization + // header, and proceed with the request. Callers that support + // authorization flows get a 401/403 response and trigger the + // Authorize() flow to refresh their token. + var retrieveErr *oauth2.RetrieveError + if !errors.As(err, &retrieveErr) || retrieveErr.ErrorCode != "invalid_grant" { + return err + } + } else if token != nil { req.Header.Set("Authorization", "Bearer "+token.AccessToken) } } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 96c5e75f..517e51af 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -6,6 +6,7 @@ package mcp import ( "context" + "errors" "fmt" "io" "net/http" @@ -1156,3 +1157,58 @@ func TestTokenInfo(t *testing.T) { t.Errorf("got %q, want %q", g, w) } } + +// errTestAuthorizeFailed is a sentinel error returned by +// retrieveErrorOAuthHandler.Authorize(). +var errTestAuthorizeFailed = errors.New("authorize intentionally failed for test") + +// retrieveErrorOAuthHandler is a mock OAuthHandler that always returns +// an oauth2.RetrieveError from its TokenSource's Token() method. +type retrieveErrorOAuthHandler struct{} + +func (h *retrieveErrorOAuthHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h, nil +} + +func (h *retrieveErrorOAuthHandler) Token() (*oauth2.Token, error) { + return nil, &oauth2.RetrieveError{ + Response: &http.Response{StatusCode: http.StatusBadRequest}, + Body: []byte("test retrieve error"), + ErrorCode: "invalid_grant", + } +} + +func (h *retrieveErrorOAuthHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + return errTestAuthorizeFailed +} + +// TestStreamableClientOAuth_RetrieveError verifies that an invalid_grant RetrieveError +// from the OAuth token source correctly skips sending Authorization header and relies on +// the server's 401 response to trigger the Authorize fallback flow. +func TestStreamableClientOAuth_RetrieveError(t *testing.T) { + ctx := context.Background() + oauthHandler := &retrieveErrorOAuthHandler{} + + // Mock MCP server returns 401 Unauthorized to simulate a server rejecting + // the request that omitted the Authorization header. + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + t.Cleanup(httpServer.Close) + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + OAuthHandler: oauthHandler, + } + client := NewClient(testImpl, nil) + + // Attempt to connect. The Connect call will trigger the initialization request, + // which will fail to retrieve the token and proceed without auth header, receive 401, + // and invoke Authorize(). + _, err := client.Connect(ctx, transport, nil) + + // Expect the connection to fail with the sentinel error, not the RetrieveError. + if !errors.Is(err, errTestAuthorizeFailed) { + t.Fatalf("client.Connect() error = %v, want %v", err, errTestAuthorizeFailed) + } +}