Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions auth/authorization_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Requ
Scopes: scps,
}

authRes, err := h.getAuthorizationCode(ctx, cfg, req.URL.String())
authRes, err := h.getAuthorizationCode(ctx, cfg, prm.Resource)
if err != nil {
// Purposefully leaving the error unwrappable so it can be handled by the caller.
return err
Expand Down Expand Up @@ -289,9 +289,24 @@ func (h *AuthorizationCodeHandler) getProtectedResourceMetadata(ctx context.Cont
errs = append(errs, fmt.Errorf("protected resource metadata is nil"))
continue
}
if len(prm.AuthorizationServers) == 0 {
// If we found PRM, we enforce the 2025-11-25 spec and not search further.
return nil, fmt.Errorf("protected resource metadata has no authorization servers specified")
}
return prm, nil
}
return nil, fmt.Errorf("failed to get protected resource metadata: %v", errors.Join(errs...))
// Fallback to 2025-03-26 spec MCP server root is the Authorization Server:
// https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#server-metadata-discovery
u, err := url.Parse(mcpServerURL)
if err != nil {
return nil, fmt.Errorf("failed to parse MCP server URL: %v", err)
}
u.Path = ""
prm := &oauthex.ProtectedResourceMetadata{
AuthorizationServers: []string{u.String()},
Resource: mcpServerURL,
}
return prm, nil
}

type prmURL struct {
Expand Down Expand Up @@ -335,26 +350,14 @@ func protectedResourceMetadataURLs(metadataURL, resourceURL string) []prmURL {
}

// getAuthServerMetadata returns the authorization server metadata.
// The provided Protected Resource Metadata must not be nil.
// The provided Protected Resource Metadata must not be nil and must contain
// at least one authorization server.
// It returns an error if the metadata request fails with non-4xx HTTP status code
// or the fetched metadata fails security checks.
// If no metadata was found, it returns a minimal set of endpoints
// as a fallback to 2025-03-26 spec.
func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, prm *oauthex.ProtectedResourceMetadata) (*oauthex.AuthServerMeta, error) {
var authServerURL string
if len(prm.AuthorizationServers) > 0 {
// Use the first authorization server, similarly to other SDKs.
authServerURL = prm.AuthorizationServers[0]
} else {
// Fallback to 2025-03-26 spec: MCP server base URL acts as Authorization Server.
authURL, err := url.Parse(prm.Resource)
if err != nil {
return nil, fmt.Errorf("failed to parse resource URL: %v", err)
}
authURL.Path = ""
authServerURL = authURL.String()
}

authServerURL := prm.AuthorizationServers[0]
for _, u := range authorizationServerMetadataURLs(authServerURL) {
asm, err := oauthex.GetAuthServerMeta(ctx, u, authServerURL, http.DefaultClient)
if err != nil {
Expand Down
120 changes: 68 additions & 52 deletions auth/authorization_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -322,43 +323,40 @@ func TestNewAuthorizationCodeHandler_Error(t *testing.T) {
}
}

func TestGetProtectedResourceMetadata(t *testing.T) {
func TestGetProtectedResourceMetadata_Success(t *testing.T) {
handler := &AuthorizationCodeHandler{} // No config needed for this method
pathForChallenge := "/protected-resource"

tests := []struct {
name string
challengesProvided bool
prmPath string
resourcePath string
wantError bool
// Path of the PRM endpoint.
prmPath string
// Path of the MCP server that is accessed.
mcpServerPath string
// Path for the Resource expected in the returned PRM.
resourcePath string
}{
{
name: "FromChallenges",
challengesProvided: true,
prmPath: pathForChallenge,
mcpServerPath: "/resource",
resourcePath: "/resource",
wantError: false,
},
{
name: "FallbackToEndpoint",
challengesProvided: false,
prmPath: "/.well-known/oauth-protected-resource/resource",
mcpServerPath: "/resource",
resourcePath: "/resource",
wantError: false,
},
{
name: "FallbackToRoot",
challengesProvided: false,
prmPath: "/.well-known/oauth-protected-resource",
mcpServerPath: "/resource",
resourcePath: "",
wantError: false,
},
{
name: "NoMetadata",
challengesProvided: false,
prmPath: "/incorrect",
wantError: true,
},
}

Expand All @@ -367,10 +365,10 @@ func TestGetProtectedResourceMetadata(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
resourceURL := server.URL + tt.resourcePath
metadata := &oauthex.ProtectedResourceMetadata{
Resource: resourceURL,
ScopesSupported: []string{"read", "write"},
Resource: server.URL + tt.resourcePath,
AuthorizationServers: []string{"https://oauth.example.com"},
ScopesSupported: []string{"read", "write"},
}
mux.Handle(tt.prmPath, ProtectedResourceMetadataHandler(metadata))
var challenges []oauthex.Challenge
Expand All @@ -385,12 +383,9 @@ func TestGetProtectedResourceMetadata(t *testing.T) {
}
}

got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, resourceURL)
got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, server.URL+tt.mcpServerPath)
if err != nil {
if !tt.wantError {
t.Fatalf("getProtectedResourceMetadata() error = %v, want nil", err)
}
return
t.Fatalf("getProtectedResourceMetadata() error = %v", err)
}
if got == nil {
t.Fatal("getProtectedResourceMetadata() got nil, want metadata")
Expand All @@ -402,59 +397,86 @@ func TestGetProtectedResourceMetadata(t *testing.T) {
}
}

func TestGetProtectedResourceMetadata_Backcompat(t *testing.T) {
var challenges []oauthex.Challenge
handler := &AuthorizationCodeHandler{} // No config needed for this method
got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, "http://localhost:1234/resource")
if err != nil {
t.Fatalf("getProtectedResourceMetadata() error = %v", err)
}
wantPRM := &oauthex.ProtectedResourceMetadata{
Resource: "http://localhost:1234/resource",
AuthorizationServers: []string{"http://localhost:1234"},
}
if diff := cmp.Diff(wantPRM, got); diff != "" {
t.Errorf("getProtectedResourceMetadata() metadata mismatch (-want +got):\n%s", diff)
}
}

func TestGetProtectedResourceMetadata_Error(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
t.Cleanup(server.Close)
metadata := &oauthex.ProtectedResourceMetadata{
Resource: server.URL + "/resource",
AuthorizationServers: nil, // Empty list is invalid
ScopesSupported: []string{"read", "write"},
}
mux.Handle("/.well-known/oauth-protected-resource/resource", ProtectedResourceMetadataHandler(metadata))
var challenges []oauthex.Challenge
handler := &AuthorizationCodeHandler{} // No config needed for this method
got, err := handler.getProtectedResourceMetadata(t.Context(), challenges, server.URL+"/resource")
if err == nil || !strings.Contains(err.Error(), "authorization servers") {
t.Errorf("getProtectedResourceMetadata() = %v, want error containing \"authorization servers\"", err)
}
if got != nil {
t.Errorf("getProtectedResourceMetadata() = %+v, want nil", got)
}
}

func TestGetAuthServerMetadata(t *testing.T) {
handler := &AuthorizationCodeHandler{} // No config needed for this method

tests := []struct {
name string
authorizationAtMCPServer bool
issuerPath string
endpointConfig *oauthtest.MetadataEndpointConfig
name string
issuerPath string
endpointConfig *oauthtest.MetadataEndpointConfig
}{
{
name: "OAuthEndpoint_Root",
authorizationAtMCPServer: false,
issuerPath: "",
name: "OAuthEndpoint_Root",
issuerPath: "",
endpointConfig: &oauthtest.MetadataEndpointConfig{
ServeOAuthInsertedEndpoint: true,
},
},
{
name: "OpenIDEndpoint_Root",
authorizationAtMCPServer: false,
issuerPath: "",
name: "OpenIDEndpoint_Root",
issuerPath: "",
endpointConfig: &oauthtest.MetadataEndpointConfig{
ServeOpenIDInsertedEndpoint: true,
},
},
{
name: "OAuthEndpoint_Path",
authorizationAtMCPServer: false,
issuerPath: "/oauth",
name: "OAuthEndpoint_Path",
issuerPath: "/oauth",
endpointConfig: &oauthtest.MetadataEndpointConfig{
ServeOAuthInsertedEndpoint: true,
},
},
{
name: "OpenIDEndpoint_Path",
authorizationAtMCPServer: false,
issuerPath: "/openid",
name: "OpenIDEndpoint_Path",
issuerPath: "/openid",
endpointConfig: &oauthtest.MetadataEndpointConfig{
ServeOpenIDInsertedEndpoint: true,
},
},
{
name: "OpenIDAppendedEndpoint_Path",
authorizationAtMCPServer: false,
issuerPath: "/openid",
name: "OpenIDAppendedEndpoint_Path",
issuerPath: "/openid",
endpointConfig: &oauthtest.MetadataEndpointConfig{
ServeOpenIDAppendedEndpoint: true,
},
},
{
name: "FallbackToMCPServer",
authorizationAtMCPServer: true,
},
{
name: "NoMetadata",
issuerPath: "",
Expand All @@ -474,15 +496,9 @@ func TestGetAuthServerMetadata(t *testing.T) {
})
s.Start(t)
issuerURL := s.URL() + tt.issuerPath
resourceURL := "https://example.com/resource"
authServers := []string{issuerURL}
if tt.authorizationAtMCPServer {
resourceURL = issuerURL
authServers = nil
}
prm := &oauthex.ProtectedResourceMetadata{
Resource: resourceURL,
AuthorizationServers: authServers,
Resource: "https://example.com/resource",
AuthorizationServers: []string{issuerURL},
}

got, err := handler.getAuthServerMetadata(t.Context(), prm)
Expand Down