diff --git a/mcp/streamable.go b/mcp/streamable.go index c047c156..60c5ffb1 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -2432,21 +2432,36 @@ func (c *streamableClientConn) handleSSE(ctx context.Context, requestSummary str } // The stream was interrupted or ended by the server. Attempt to reconnect. - newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false) - if err != nil { - // If the client didn't cancel this request, any failure to execute it - // breaks the logical MCP session. - if ctx.Err() == nil { - // All reconnection attempts failed: fail the connection. - c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + for { + newResp, err := c.connectSSE(ctx, lastEventID, reconnectDelay, false) + if err != nil { + // If the client didn't cancel this request, any failure to execute it + // breaks the logical MCP session. + if ctx.Err() == nil { + // All reconnection attempts failed: fail the connection. + c.fail(fmt.Errorf("%s: failed to reconnect (session ID: %v): %v", requestSummary, c.sessionID, err)) + } + return } - return - } - resp = newResp - if err := c.checkResponse(ctx, requestSummary, resp); err != nil { - c.fail(err) - return + if err := c.checkResponse(ctx, requestSummary, newResp); err != nil { + if errors.Is(err, jsonrpc2.ErrRejected) { + retriesWithoutProgress++ + if retriesWithoutProgress > c.maxRetries { + if ctx.Err() == nil { + c.fail(fmt.Errorf("%s: exceeded %d retries without progress (session ID: %v)", requestSummary, c.maxRetries, c.sessionID)) + } + return + } + reconnectDelay = 0 + continue + } + c.fail(err) + return + } + + resp = newResp + break } } } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 59bf1bd2..2ed677fd 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -740,6 +740,83 @@ func TestStreamableClientTransientErrors(t *testing.T) { } } +// TestStreamableClientReconnectTransientErrors verifies that a transient HTTP +// error on an SSE reconnect is retried rather than breaking the session (#683). +func TestStreamableClientReconnectTransientErrors(t *testing.T) { + const tick = 10 * time.Millisecond + defer func(delay int64) { + reconnectInitialDelay.Store(delay) + }(reconnectInitialDelay.Load()) + reconnectInitialDelay.Store(int64(tick)) + + ctx := context.Background() + + var callID atomic.Int64 + var sentTransient atomic.Bool + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: protocolVersion20251125, + }, + {"GET", "123", "", ""}: { + status: http.StatusMethodNotAllowed, + }, + {"POST", "123", methodCallTool, ""}: { + header: header{ + "Content-Type": "text/event-stream", + }, + responseFunc: func(r *jsonrpc.Request) (string, int) { + callID.Store(r.ID.Raw().(int64)) + return "id: evt_1\n" + + `data: {"jsonrpc":"2.0","method":"notifications/message","params":{"level":"info","data":"progress"}}` + + "\n\n", http.StatusOK + }, + }, + {"GET", "123", "", "evt_1"}: { + header: header{ + "Content-Type": "text/event-stream", + }, + responseFunc: func(r *jsonrpc.Request) (string, int) { + if !sentTransient.Swap(true) { + return "", http.StatusServiceUnavailable + } + body := jsonBody(t, resp(callID.Load(), &CallToolResult{Content: []Content{}}, nil)) + return "id: evt_2\ndata: " + body + "\n\n", http.StatusOK + }, + }, + {"DELETE", "123", "", ""}: {optional: true}, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{Endpoint: httpServer.URL} + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, &ClientSessionOptions{protocolVersion: protocolVersion20251125}) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer session.Close() + + if _, err := session.CallTool(ctx, &CallToolParams{Name: "test"}); err != nil { + t.Fatalf("CallTool failed after transient reconnect error: %v (session should survive)", err) + } + if !sentTransient.Load() { + t.Error("expected a transient error to be exercised on reconnect") + } +} + // TestStreamableClientRetryWithoutProgress verifies that the client fails after // exceeding the retry limit when no progress is made (Last-Event-ID does not advance). // This tests the fix for issue #679.