diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 5da253566..6e2337978 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -3,11 +3,13 @@ package oauth import ( + "context" "fmt" "net/http" "strings" "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" @@ -43,8 +45,13 @@ type Config struct { // This is used to construct the OAuth resource URL. BaseURL string + // APIHost is the GitHub API host resolver that provides OAuth URL. + // If set, this takes precedence over AuthorizationServer. + APIHost utils.APIHostResolver + // AuthorizationServer is the OAuth authorization server URL. // Defaults to GitHub's OAuth server if not specified. + // This field is ignored if APIHost is set. AuthorizationServer string // ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp"). @@ -64,8 +71,15 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) { cfg = &Config{} } - // Default authorization server to GitHub - if cfg.AuthorizationServer == "" { + // Resolve authorization server from APIHost if provided + if cfg.APIHost != nil { + oauthURL, err := cfg.APIHost.OAuthURL(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get OAuth URL from API host: %w", err) + } + cfg.AuthorizationServer = oauthURL.String() + } else if cfg.AuthorizationServer == "" { + // Default authorization server to GitHub if not provided cfg.AuthorizationServer = DefaultAuthorizationServer } diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go index 9133e8331..78e156649 100644 --- a/pkg/http/oauth/oauth_test.go +++ b/pkg/http/oauth/oauth_test.go @@ -1,18 +1,49 @@ package oauth import ( + "context" "crypto/tls" "encoding/json" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// mockAPIHostResolver is a test implementation of utils.APIHostResolver +type mockAPIHostResolver struct { + oauthURL string +} + +func (m mockAPIHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + +func (m mockAPIHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + +func (m mockAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + +func (m mockAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + +func (m mockAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) { + return url.Parse(m.oauthURL) +} + +// Ensure mockAPIHostResolver implements utils.APIHostResolver +var _ utils.APIHostResolver = mockAPIHostResolver{} + func TestNewAuthHandler(t *testing.T) { t.Parallel() @@ -51,6 +82,31 @@ func TestNewAuthHandler(t *testing.T) { expectedAuthServer: DefaultAuthorizationServer, expectedResourcePath: "/mcp", }, + { + name: "APIHost with HTTPS GHES", + cfg: &Config{ + APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"}, + }, + expectedAuthServer: "https://ghes.example.com/login/oauth", + expectedResourcePath: "", + }, + { + name: "APIHost with HTTP GHES", + cfg: &Config{ + APIHost: mockAPIHostResolver{oauthURL: "http://ghes.local/login/oauth"}, + }, + expectedAuthServer: "http://ghes.local/login/oauth", + expectedResourcePath: "", + }, + { + name: "APIHost takes precedence over AuthorizationServer", + cfg: &Config{ + APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"}, + AuthorizationServer: "https://should-be-ignored.example.com/oauth", + }, + expectedAuthServer: "https://ghes.example.com/login/oauth", + expectedResourcePath: "", + }, } for _, tc := range tests { diff --git a/pkg/http/server.go b/pkg/http/server.go index 7397e54a8..1de400783 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -126,6 +126,7 @@ func RunHTTPServer(cfg ServerConfig) error { oauthCfg := &oauth.Config{ BaseURL: cfg.BaseURL, ResourcePath: cfg.ResourcePath, + APIHost: apiHost, } serverOptions := []HandlerOption{} diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 2d887d7a8..08520489c 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -28,6 +28,9 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { return nil, nil } +func (t testAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) { + return nil, nil +} func TestParseScopeHeader(t *testing.T) { tests := []struct { diff --git a/pkg/utils/api.go b/pkg/utils/api.go index a523917de..6e02530fb 100644 --- a/pkg/utils/api.go +++ b/pkg/utils/api.go @@ -14,6 +14,7 @@ type APIHostResolver interface { GraphqlURL(ctx context.Context) (*url.URL, error) UploadURL(ctx context.Context) (*url.URL, error) RawURL(ctx context.Context) (*url.URL, error) + OAuthURL(ctx context.Context) (*url.URL, error) } type APIHost struct { @@ -21,6 +22,7 @@ type APIHost struct { gqlURL *url.URL uploadURL *url.URL rawURL *url.URL + oauthURL *url.URL } var _ APIHostResolver = APIHost{} @@ -52,6 +54,10 @@ func (a APIHost) RawURL(_ context.Context) (*url.URL, error) { return a.rawURL, nil } +func (a APIHost) OAuthURL(_ context.Context) (*url.URL, error) { + return a.oauthURL, nil +} + func newDotcomHost() (APIHost, error) { baseRestURL, err := url.Parse("https://api.github.com/") if err != nil { @@ -73,11 +79,17 @@ func newDotcomHost() (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) } + oauthURL, err := url.Parse("https://github.com/login/oauth") + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse dotcom OAuth URL: %w", err) + } + return APIHost{ restURL: baseRestURL, gqlURL: gqlURL, uploadURL: uploadURL, rawURL: rawURL, + oauthURL: oauthURL, }, nil } @@ -112,11 +124,17 @@ func newGHECHost(hostname string) (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) } + oauthURL, err := url.Parse(fmt.Sprintf("https://%s/login/oauth", u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHEC OAuth URL: %w", err) + } + return APIHost{ restURL: restURL, gqlURL: gqlURL, uploadURL: uploadURL, rawURL: rawURL, + oauthURL: oauthURL, }, nil } @@ -164,11 +182,17 @@ func newGHESHost(hostname string) (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) } + oauthURL, err := url.Parse(fmt.Sprintf("%s://%s/login/oauth", u.Scheme, u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHES OAuth URL: %w", err) + } + return APIHost{ restURL: restURL, gqlURL: gqlURL, uploadURL: uploadURL, rawURL: rawURL, + oauthURL: oauthURL, }, nil } diff --git a/pkg/utils/api_test.go b/pkg/utils/api_test.go new file mode 100644 index 000000000..8fa9279ab --- /dev/null +++ b/pkg/utils/api_test.go @@ -0,0 +1,140 @@ +package utils //nolint:revive //TODO: figure out a better name for this package + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOAuthURL(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + host string + expectedOAuth string + expectError bool + errorSubstring string + }{ + { + name: "dotcom (empty host)", + host: "", + expectedOAuth: "https://github.com/login/oauth", + }, + { + name: "dotcom (explicit github.com)", + host: "https://github.com", + expectedOAuth: "https://github.com/login/oauth", + }, + { + name: "GHEC with HTTPS", + host: "https://acme.ghe.com", + expectedOAuth: "https://acme.ghe.com/login/oauth", + }, + { + name: "GHEC with HTTP (should error)", + host: "http://acme.ghe.com", + expectError: true, + errorSubstring: "GHEC URL must be HTTPS", + }, + { + name: "GHES with HTTPS", + host: "https://ghes.example.com", + expectedOAuth: "https://ghes.example.com/login/oauth", + }, + { + name: "GHES with HTTP", + host: "http://ghes.example.com", + expectedOAuth: "http://ghes.example.com/login/oauth", + }, + { + name: "GHES with HTTP and custom port (port stripped - not supported yet)", + host: "http://ghes.local:8080", + expectedOAuth: "http://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment + }, + { + name: "GHES with HTTPS and custom port (port stripped - not supported yet)", + host: "https://ghes.local:8443", + expectedOAuth: "https://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment + }, + { + name: "host without scheme (should error)", + host: "ghes.example.com", + expectError: true, + errorSubstring: "host must have a scheme", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiHost, err := NewAPIHost(tt.host) + + if tt.expectError { + require.Error(t, err) + if tt.errorSubstring != "" { + assert.Contains(t, err.Error(), tt.errorSubstring) + } + return + } + + require.NoError(t, err) + require.NotNil(t, apiHost) + + oauthURL, err := apiHost.OAuthURL(ctx) + require.NoError(t, err) + require.NotNil(t, oauthURL) + + assert.Equal(t, tt.expectedOAuth, oauthURL.String()) + }) + } +} + +func TestAPIHost_AllURLsHaveConsistentScheme(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + host string + expectedScheme string + }{ + { + name: "GHES with HTTPS", + host: "https://ghes.example.com", + expectedScheme: "https", + }, + { + name: "GHES with HTTP", + host: "http://ghes.example.com", + expectedScheme: "http", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiHost, err := NewAPIHost(tt.host) + require.NoError(t, err) + + restURL, err := apiHost.BaseRESTURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, restURL.Scheme, "REST URL scheme should match") + + gqlURL, err := apiHost.GraphqlURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, gqlURL.Scheme, "GraphQL URL scheme should match") + + uploadURL, err := apiHost.UploadURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, uploadURL.Scheme, "Upload URL scheme should match") + + rawURL, err := apiHost.RawURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, rawURL.Scheme, "Raw URL scheme should match") + + oauthURL, err := apiHost.OAuthURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, oauthURL.Scheme, "OAuth URL scheme should match") + }) + } +}