Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions pkg/gateway/mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ type Server struct {
client *hookdeck.Client
cfg *config.Config
mcpServer *mcpsdk.Server

// sessionCtx is the context passed to RunStdio. It is cancelled when the
// MCP transport closes (stdin EOF). Background goroutines (e.g. login
// polling) should select on this — NOT on the per-request ctx passed to
// tool handlers, which is cancelled when the handler returns.
sessionCtx context.Context
}

// NewServer creates an MCP server with all Hookdeck tools registered.
Expand Down Expand Up @@ -56,7 +62,7 @@ func (s *Server) registerTools() {
"reauth": {Type: "boolean", Desc: "If true, clear stored credentials and start a new browser login. Use when project listing fails — complete login in the browser, then retry hookdeck_projects."},
}),
},
s.wrapWithTelemetry("hookdeck_login", handleLogin(s.client, s.cfg)),
s.wrapWithTelemetry("hookdeck_login", handleLogin(s)),
)
}

Expand Down Expand Up @@ -126,5 +132,13 @@ func extractAction(req *mcpsdk.CallToolRequest) string {
// RunStdio starts the MCP server on stdin/stdout and blocks until the
// connection is closed (i.e. stdin reaches EOF).
func (s *Server) RunStdio(ctx context.Context) error {
return s.mcpServer.Run(ctx, &mcpsdk.StdioTransport{})
return s.Run(ctx, &mcpsdk.StdioTransport{})
}

// Run starts the MCP server on the given transport. It stores ctx as the
// session-level context so background goroutines (e.g. login polling) can
// detect when the session ends.
func (s *Server) Run(ctx context.Context, transport mcpsdk.Transport) error {
s.sessionCtx = ctx
return s.mcpServer.Run(ctx, transport)
}
73 changes: 69 additions & 4 deletions pkg/gateway/mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"strings"
"testing"
"time"

mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -45,7 +46,7 @@ func connectInMemory(t *testing.T, client *hookdeck.Client) *mcpsdk.ClientSessio
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() {
_ = srv.mcpServer.Run(ctx, serverTransport)
_ = srv.Run(ctx, serverTransport)
}()

mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{
Expand Down Expand Up @@ -1143,7 +1144,7 @@ func TestLoginTool_ReauthStartsFreshLogin(t *testing.T) {
serverTransport, clientTransport := mcpsdk.NewInMemoryTransports()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = srv.mcpServer.Run(ctx, serverTransport) }()
go func() { _ = srv.Run(ctx, serverTransport) }()

mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "0.0.1"}, nil)
session, err := mcpClient.Connect(ctx, clientTransport, nil)
Expand Down Expand Up @@ -1182,7 +1183,7 @@ func TestLoginTool_ReturnsURLImmediately(t *testing.T) {
serverTransport, clientTransport := mcpsdk.NewInMemoryTransports()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = srv.mcpServer.Run(ctx, serverTransport) }()
go func() { _ = srv.Run(ctx, serverTransport) }()

mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "0.0.1"}, nil)
session, err := mcpClient.Connect(ctx, clientTransport, nil)
Expand Down Expand Up @@ -1218,7 +1219,7 @@ func TestLoginTool_InProgressShowsURL(t *testing.T) {
serverTransport, clientTransport := mcpsdk.NewInMemoryTransports()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = srv.mcpServer.Run(ctx, serverTransport) }()
go func() { _ = srv.Run(ctx, serverTransport) }()

mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "0.0.1"}, nil)
session, err := mcpClient.Connect(ctx, clientTransport, nil)
Expand All @@ -1236,6 +1237,70 @@ func TestLoginTool_InProgressShowsURL(t *testing.T) {
assert.Contains(t, text, "https://hookdeck.com/auth?code=xyz")
}

func TestLoginTool_PollSurvivesAcrossToolCalls(t *testing.T) {
// Regression: the login polling goroutine must use the session-level
// context, not the per-request ctx (which is cancelled when the handler
// returns). If the goroutine selected on per-request ctx, it would be
// cancelled immediately and the second hookdeck_login call would see a
// "login cancelled" error instead of "Already authenticated".
pollCount := 0
api := mockAPI(t, map[string]http.HandlerFunc{
"/2025-07-01/cli-auth": func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"browser_url": "https://hookdeck.com/auth?code=survive",
"poll_url": "http://" + r.Host + "/2025-07-01/cli-auth/poll?key=survive",
})
},
"/2025-07-01/cli-auth/poll": func(w http.ResponseWriter, r *http.Request) {
pollCount++
if pollCount >= 2 {
// Simulate user completing browser auth on 2nd poll.
json.NewEncoder(w).Encode(map[string]any{
"claimed": true,
"key": "sk_test_survive12345",
"team_id": "proj_survive",
"team_name": "Survive Project",
"team_mode": "console",
"user_name": "test-user",
"organization_name": "test-org",
})
return
}
json.NewEncoder(w).Encode(map[string]any{"claimed": false})
},
})

unauthClient := newTestClient(api.URL, "")
cfg := &config.Config{APIBaseURL: api.URL}
srv := NewServer(unauthClient, cfg)

serverTransport, clientTransport := mcpsdk.NewInMemoryTransports()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() { _ = srv.Run(ctx, serverTransport) }()

mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "0.0.1"}, nil)
session, err := mcpClient.Connect(ctx, clientTransport, nil)
require.NoError(t, err)
t.Cleanup(func() { _ = session.Close() })

// First call initiates the flow — handler returns immediately.
result := callTool(t, session, "hookdeck_login", map[string]any{})
assert.False(t, result.IsError)
assert.Contains(t, textContent(t, result), "https://hookdeck.com/auth?code=survive")

// Wait briefly for the polling goroutine to complete (poll interval is 2s
// in production, but the mock returns instantly so it completes quickly).
time.Sleep(500 * time.Millisecond)

// Second call — if the goroutine survived, the client is now authenticated.
result2 := callTool(t, session, "hookdeck_login", map[string]any{})
assert.False(t, result2.IsError)
text := textContent(t, result2)
assert.Contains(t, text, "Already authenticated")
assert.Equal(t, "sk_test_survive12345", unauthClient.APIKey)
}

// ---------------------------------------------------------------------------
// API error scenarios (shared across tools)
// ---------------------------------------------------------------------------
Expand Down
16 changes: 10 additions & 6 deletions pkg/gateway/mcp/tool_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ type loginState struct {
err error // non-nil if polling failed
}

func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler {
func handleLogin(srv *Server) mcpsdk.ToolHandler {
client := srv.client
cfg := srv.cfg
var stateMu sync.Mutex
var state *loginState

Expand Down Expand Up @@ -121,9 +123,11 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler
}

// Poll in the background so we return the URL to the agent immediately.
// WaitForAPIKey blocks with time.Sleep; run it in a goroutine and
// select on ctx so we abandon the attempt when the session closes.
go func(s *loginState, ctx context.Context) {
// WaitForAPIKey blocks with time.Sleep internally, so we run it in an
// inner goroutine and select on the session-level context (not the
// per-request ctx, which is cancelled when this handler returns).
sessionCtx := srv.sessionCtx
go func(s *loginState) {
defer close(s.done)

type pollResult struct {
Expand All @@ -138,7 +142,7 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler

var response *hookdeck.PollAPIKeyResponse
select {
case <-ctx.Done():
case <-sessionCtx.Done():
s.err = fmt.Errorf("login cancelled: MCP session closed")
log.Debug("Login polling cancelled — MCP session closed")
return
Expand Down Expand Up @@ -187,7 +191,7 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler
"user": response.UserName,
"project": response.ProjectName,
}).Info("MCP login completed successfully")
}(state, ctx)
}(state)

// Return the URL immediately so the agent can show it to the user.
return TextResult(fmt.Sprintf(
Expand Down