Skip to content
36 changes: 26 additions & 10 deletions pkg/http/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/github/github-mcp-server/pkg/http/headers"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/go-chi/chi/v5"
"github.com/modelcontextprotocol/go-sdk/auth"
"github.com/modelcontextprotocol/go-sdk/oauthex"
Expand All @@ -16,9 +17,6 @@ import (
const (
// OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata.
OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource"

// DefaultAuthorizationServer is GitHub's OAuth authorization server.
DefaultAuthorizationServer = "https://github.com/login/oauth"
)

// SupportedScopes lists all OAuth scopes that may be required by MCP tools.
Expand Down Expand Up @@ -55,22 +53,27 @@ type Config struct {

// AuthHandler handles OAuth-related HTTP endpoints.
type AuthHandler struct {
cfg *Config
cfg *Config
apiHost utils.APIHostResolver
}

// NewAuthHandler creates a new OAuth auth handler.
func NewAuthHandler(cfg *Config) (*AuthHandler, error) {
func NewAuthHandler(cfg *Config, apiHost utils.APIHostResolver) (*AuthHandler, error) {
if cfg == nil {
cfg = &Config{}
}

// Default authorization server to GitHub
if cfg.AuthorizationServer == "" {
cfg.AuthorizationServer = DefaultAuthorizationServer
if apiHost == nil {
var err error
apiHost, err = utils.NewAPIHost("https://api.github.com")
if err != nil {
return nil, fmt.Errorf("failed to create default API host: %w", err)
}
}

return &AuthHandler{
cfg: cfg,
cfg: cfg,
apiHost: apiHost,
}, nil
}

Expand All @@ -95,15 +98,28 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) {

func (h *AuthHandler) metadataHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
resourcePath := resolveResourcePath(
strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix),
h.cfg.ResourcePath,
)
resourceURL := h.buildResourceURL(r, resourcePath)

var authorizationServerURL string
if h.cfg.AuthorizationServer != "" {
authorizationServerURL = h.cfg.AuthorizationServer
} else {
authURL, err := h.apiHost.AuthorizationServerURL(ctx)
if err != nil {
http.Error(w, fmt.Sprintf("failed to resolve authorization server URL: %v", err), http.StatusInternalServerError)
return
}
authorizationServerURL = authURL.String()
}

metadata := &oauthex.ProtectedResourceMetadata{
Resource: resourceURL,
AuthorizationServers: []string{h.cfg.AuthorizationServer},
AuthorizationServers: []string{authorizationServerURL},
ResourceName: "GitHub MCP Server",
ScopesSupported: SupportedScopes,
BearerMethodsSupported: []string{"header"},
Expand Down
162 changes: 142 additions & 20 deletions pkg/http/oauth/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,28 @@ import (
"testing"

"github.com/github/github-mcp-server/pkg/http/headers"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
defaultAuthorizationServer = "https://github.com/login/oauth"
)

func TestNewAuthHandler(t *testing.T) {
t.Parallel()

dotcomHost, err := utils.NewAPIHost("https://api.github.com")
require.NoError(t, err)

tests := []struct {
name string
cfg *Config
expectedAuthServer string
expectedResourcePath string
}{
{
name: "nil config uses defaults",
cfg: nil,
expectedAuthServer: DefaultAuthorizationServer,
expectedResourcePath: "",
},
{
name: "empty config uses defaults",
cfg: &Config{},
expectedAuthServer: DefaultAuthorizationServer,
expectedResourcePath: "",
},
{
name: "custom authorization server",
cfg: &Config{
Expand All @@ -48,7 +44,7 @@ func TestNewAuthHandler(t *testing.T) {
BaseURL: "https://example.com",
ResourcePath: "/mcp",
},
expectedAuthServer: DefaultAuthorizationServer,
expectedAuthServer: "",
expectedResourcePath: "/mcp",
},
}
Expand All @@ -57,11 +53,12 @@ func TestNewAuthHandler(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

handler, err := NewAuthHandler(tc.cfg)
handler, err := NewAuthHandler(tc.cfg, dotcomHost)
require.NoError(t, err)
require.NotNil(t, handler)

assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer)
assert.Equal(t, tc.expectedResourcePath, handler.cfg.ResourcePath)
})
}
}
Expand Down Expand Up @@ -372,7 +369,7 @@ func TestHandleProtectedResource(t *testing.T) {
authServers, ok := body["authorization_servers"].([]any)
require.True(t, ok)
require.Len(t, authServers, 1)
assert.Equal(t, DefaultAuthorizationServer, authServers[0])
assert.Equal(t, defaultAuthorizationServer, authServers[0])
},
},
{
Expand Down Expand Up @@ -451,7 +448,10 @@ func TestHandleProtectedResource(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

handler, err := NewAuthHandler(tc.cfg)
dotcomHost, err := utils.NewAPIHost("https://api.github.com")
require.NoError(t, err)

handler, err := NewAuthHandler(tc.cfg, dotcomHost)
require.NoError(t, err)

router := chi.NewRouter()
Expand Down Expand Up @@ -493,9 +493,12 @@ func TestHandleProtectedResource(t *testing.T) {
func TestRegisterRoutes(t *testing.T) {
t.Parallel()

dotcomHost, err := utils.NewAPIHost("https://api.github.com")
require.NoError(t, err)

handler, err := NewAuthHandler(&Config{
BaseURL: "https://api.example.com",
})
}, dotcomHost)
require.NoError(t, err)

router := chi.NewRouter()
Expand Down Expand Up @@ -559,9 +562,12 @@ func TestSupportedScopes(t *testing.T) {
func TestProtectedResourceResponseFormat(t *testing.T) {
t.Parallel()

dotcomHost, err := utils.NewAPIHost("https://api.github.com")
require.NoError(t, err)

handler, err := NewAuthHandler(&Config{
BaseURL: "https://api.example.com",
})
}, dotcomHost)
require.NoError(t, err)

router := chi.NewRouter()
Expand Down Expand Up @@ -598,7 +604,7 @@ func TestProtectedResourceResponseFormat(t *testing.T) {
authServers, ok := response["authorization_servers"].([]any)
require.True(t, ok)
assert.Len(t, authServers, 1)
assert.Equal(t, DefaultAuthorizationServer, authServers[0])
assert.Equal(t, defaultAuthorizationServer, authServers[0])
}

func TestOAuthProtectedResourcePrefix(t *testing.T) {
Expand All @@ -611,5 +617,121 @@ func TestOAuthProtectedResourcePrefix(t *testing.T) {
func TestDefaultAuthorizationServer(t *testing.T) {
t.Parallel()

assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer)
assert.Equal(t, "https://github.com/login/oauth", defaultAuthorizationServer)
}

func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) {
t.Parallel()

tests := []struct {
name string
host string
oauthConfig *Config
expectedURL string
expectedError bool
expectedStatusCode int
errorContains string
}{
{
name: "valid host returns authorization server URL",
host: "https://github.com",
expectedURL: "https://github.com/login/oauth",
expectedStatusCode: http.StatusOK,
},
{
name: "invalid host returns error",
host: "://invalid-url",
expectedURL: "",
expectedError: true,
errorContains: "could not parse host as URL",
},
{
name: "host without scheme returns error",
host: "github.com",
expectedURL: "",
expectedError: true,
errorContains: "host must have a scheme",
},
{
name: "GHEC host returns correct authorization server URL",
host: "https://test.ghe.com",
expectedURL: "https://test.ghe.com/login/oauth",
expectedStatusCode: http.StatusOK,
},
{
name: "GHES host returns correct authorization server URL",
host: "https://ghe.example.com",
expectedURL: "https://ghe.example.com/login/oauth",
expectedStatusCode: http.StatusOK,
},
{
name: "GHES with http scheme returns the correct authorization server URL",
host: "http://ghe.example.com",
expectedURL: "http://ghe.example.com/login/oauth",
expectedStatusCode: http.StatusOK,
},
{
name: "custom authorization server in config takes precedence",
host: "https://github.com",
oauthConfig: &Config{
AuthorizationServer: "https://custom.auth.example.com/oauth",
},
expectedURL: "https://custom.auth.example.com/oauth",
expectedStatusCode: http.StatusOK,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

apiHost, err := utils.NewAPIHost(tc.host)
if tc.expectedError {
require.Error(t, err)
if tc.errorContains != "" {
assert.Contains(t, err.Error(), tc.errorContains)
}
return
}
require.NoError(t, err)

config := tc.oauthConfig
if config == nil {
config = &Config{}
}
config.BaseURL = tc.host

handler, err := NewAuthHandler(config, apiHost)
require.NoError(t, err)

router := chi.NewRouter()
handler.RegisterRoutes(router)

req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil)
req.Host = "api.example.com"

rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)

require.Equal(t, http.StatusOK, rec.Code)

var response map[string]any
err = json.Unmarshal(rec.Body.Bytes(), &response)
require.NoError(t, err)

assert.Contains(t, response, "authorization_servers")
if tc.expectedStatusCode != http.StatusOK {
require.Equal(t, tc.expectedStatusCode, rec.Code)
if tc.errorContains != "" {
assert.Contains(t, rec.Body.String(), tc.errorContains)
}
return
}

responseAuthServers, ok := response["authorization_servers"].([]any)
require.True(t, ok)
require.Len(t, responseAuthServers, 1)
assert.Equal(t, tc.expectedURL, responseAuthServers[0])
})
}
}
2 changes: 1 addition & 1 deletion pkg/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func RunHTTPServer(cfg ServerConfig) error {

r := chi.NewRouter()
handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...)
oauthHandler, err := oauth.NewAuthHandler(oauthCfg)
oauthHandler, err := oauth.NewAuthHandler(oauthCfg, apiHost)
if err != nil {
return fmt.Errorf("failed to create OAuth handler: %w", err)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/scopes/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) {
func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) {
return nil, nil
}
func (t testAPIHostResolver) AuthorizationServerURL(_ context.Context) (*url.URL, error) {
return nil, nil
}

func TestParseScopeHeader(t *testing.T) {
tests := []struct {
Expand Down
Loading