diff --git a/pkg/cmd/root.go b/pkg/cmd/root.go index 0270436..db766f4 100644 --- a/pkg/cmd/root.go +++ b/pkg/cmd/root.go @@ -146,6 +146,8 @@ func argvContainsGatewayMCP(argv []string) bool { return false } +// flagNeedsNextArg lists global flags that consume the next argv token as their value. +// Keep in sync with the PersistentFlags registered in init() below. var flagNeedsNextArg = map[string]bool{ "profile": true, "p": true, diff --git a/pkg/cmd/root_argv_test.go b/pkg/cmd/root_argv_test.go index dffac21..57524b8 100644 --- a/pkg/cmd/root_argv_test.go +++ b/pkg/cmd/root_argv_test.go @@ -21,6 +21,10 @@ func TestArgvContainsGatewayMCP(t *testing.T) { {"wrong order", []string{"hookdeck", "mcp", "gateway"}, false}, {"too short", []string{"hookdeck", "gateway"}, false}, {"api key flag before", []string{"hookdeck", "--api-key", "k", "gateway", "mcp"}, true}, + // Boolean flags (--insecure, --version) are not in flagNeedsNextArg, so + // globalPositionalArgs treats them as single-token flags and skips them. + {"bool flag before gateway", []string{"hookdeck", "--insecure", "gateway", "mcp"}, true}, + {"bool flag between gateway and mcp", []string{"hookdeck", "gateway", "--insecure", "mcp"}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/gateway/mcp/tool_login.go b/pkg/gateway/mcp/tool_login.go index 5ba1b39..0bc5792 100644 --- a/pkg/gateway/mcp/tool_login.go +++ b/pkg/gateway/mcp/tool_login.go @@ -26,6 +26,10 @@ const ( // loginState tracks a background login poll so that repeated calls to // hookdeck_login don't start duplicate auth flows. +// +// Synchronization: err is written by the goroutine before close(done). +// The handler only reads err after receiving from done, so the channel +// close provides the happens-before guarantee — no separate mutex needed. type loginState struct { browserURL string // URL the user must open done chan struct{} // closed when polling finishes @@ -37,7 +41,6 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler var state *loginState return func(ctx context.Context, req *mcpsdk.CallToolRequest) (*mcpsdk.CallToolResult, error) { - _ = ctx in, err := parseInput(req.Params.Arguments) if err != nil { return ErrorResult(err.Error()), nil @@ -118,14 +121,34 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler } // Poll in the background so we return the URL to the agent immediately. - go func(s *loginState) { + // WaitForAPIKey blocks with time.Sleep; run it in a goroutine and + // select on ctx so we abandon the attempt when the session closes. + go func(s *loginState, ctx context.Context) { defer close(s.done) - response, err := session.WaitForAPIKey(loginPollInterval, loginMaxAttempts) - if err != nil { - s.err = err - log.WithError(err).Debug("Login polling failed") + type pollResult struct { + resp *hookdeck.PollAPIKeyResponse + err error + } + ch := make(chan pollResult, 1) + go func() { + resp, err := session.WaitForAPIKey(loginPollInterval, loginMaxAttempts) + ch <- pollResult{resp, err} + }() + + var response *hookdeck.PollAPIKeyResponse + select { + case <-ctx.Done(): + s.err = fmt.Errorf("login cancelled: MCP session closed") + log.Debug("Login polling cancelled — MCP session closed") return + case r := <-ch: + if r.err != nil { + s.err = r.err + log.WithError(r.err).Debug("Login polling failed") + return + } + response = r.resp } if err := validators.APIKey(response.APIKey); err != nil { @@ -164,7 +187,7 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler "user": response.UserName, "project": response.ProjectName, }).Info("MCP login completed successfully") - }(state) + }(state, ctx) // Return the URL immediately so the agent can show it to the user. return TextResult(fmt.Sprintf( diff --git a/pkg/gateway/mcp/tool_projects_errors.go b/pkg/gateway/mcp/tool_projects_errors.go index 231bc7e..921fd8a 100644 --- a/pkg/gateway/mcp/tool_projects_errors.go +++ b/pkg/gateway/mcp/tool_projects_errors.go @@ -26,9 +26,11 @@ func shouldSuggestReauthAfterListProjectsFailure(err error) bool { } return strings.Contains(strings.ToLower(apiErr.Message), "fatal") } - // ListProjects wraps some failures as plain fmt.Errorf with status in the text. + // ListProjects wraps some failures as plain fmt.Errorf with the status code + // in the text (e.g. "unexpected http status code: 403 "). + // Match "status code: 4xx" to avoid false positives on IDs containing "401"/"403". msg := strings.ToLower(err.Error()) - if strings.Contains(msg, "403") || strings.Contains(msg, "401") { + if strings.Contains(msg, "status code: 403") || strings.Contains(msg, "status code: 401") { return true } return strings.Contains(msg, "fatal") diff --git a/pkg/gateway/mcp/tool_projects_errors_test.go b/pkg/gateway/mcp/tool_projects_errors_test.go new file mode 100644 index 0000000..836bc6b --- /dev/null +++ b/pkg/gateway/mcp/tool_projects_errors_test.go @@ -0,0 +1,63 @@ +package mcp + +import ( + "fmt" + "testing" + + "github.com/hookdeck/hookdeck-cli/pkg/hookdeck" + "github.com/stretchr/testify/assert" +) + +func TestShouldSuggestReauthAfterListProjectsFailure(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "APIError 403", + err: &hookdeck.APIError{StatusCode: 403, Message: "not allowed"}, + want: true, + }, + { + name: "APIError 401", + err: &hookdeck.APIError{StatusCode: 401, Message: "unauthorized"}, + want: true, + }, + { + name: "APIError 500 no match", + err: &hookdeck.APIError{StatusCode: 500, Message: "internal error"}, + want: false, + }, + { + name: "APIError with fatal message", + err: &hookdeck.APIError{StatusCode: 500, Message: "fatal: cannot proceed"}, + want: true, + }, + { + name: "plain error with status code 403", + err: fmt.Errorf("unexpected http status code: 403 "), + want: true, + }, + { + name: "plain error with status code 401", + err: fmt.Errorf("unexpected http status code: 401 "), + want: true, + }, + { + name: "plain error without status code", + err: fmt.Errorf("network timeout"), + want: false, + }, + { + name: "plain error with 403 in ID should not match", + err: fmt.Errorf("project proj_403_abc not found"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, shouldSuggestReauthAfterListProjectsFailure(tt.err)) + }) + } +} diff --git a/pkg/hookdeck/projects.go b/pkg/hookdeck/projects.go index 3fd9ce2..5a46b4d 100644 --- a/pkg/hookdeck/projects.go +++ b/pkg/hookdeck/projects.go @@ -2,8 +2,6 @@ package hookdeck import ( "context" - "fmt" - "net/http" ) type Project struct { @@ -17,8 +15,8 @@ func (c *Client) ListProjects() ([]Project, error) { if err != nil { return []Project{}, err } - if res.StatusCode != http.StatusOK { - return []Project{}, fmt.Errorf("unexpected http status code: %d %s", res.StatusCode, err) + if err := checkAndPrintError(res); err != nil { + return []Project{}, err } projects := []Project{} postprocessJsonResponse(res, &projects) diff --git a/test/acceptance/helpers.go b/test/acceptance/helpers.go index b352a42..85a8e8d 100644 --- a/test/acceptance/helpers.go +++ b/test/acceptance/helpers.go @@ -759,10 +759,10 @@ type Request struct { // Attempt represents a Hookdeck attempt for testing type Attempt struct { - ID string `json:"id"` - EventID string `json:"event_id"` - AttemptNumber int `json:"attempt_number"` - Status string `json:"status"` + ID string `json:"id"` + EventID string `json:"event_id"` + AttemptNumber int `json:"attempt_number"` + Status string `json:"status"` } // createTestConnection creates a basic test connection and returns its ID