diff --git a/mcp/streamable.go b/mcp/streamable.go index 708b1326..ecf4a0ff 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1523,6 +1523,17 @@ type StreamableClientTransport struct { // - You want to avoid maintaining a persistent connection DisableStandaloneSSE bool + // Headers contains additional HTTP headers to include with every request to + // the MCP endpoint. Headers that the transport itself sets (Content-Type, + // Accept, Authorization, Mcp-Protocol-Version, Mcp-Session-Id) take + // precedence over any conflicting entries here. + Headers http.Header + + // MaxResponseBytes, if positive, limits the size of response bodies read + // from the MCP endpoint for outgoing POST requests. It does not apply to + // the standalone SSE stream. A value of zero means no limit. + MaxResponseBytes int64 + // OAuthHandler is an optional field that, if provided, will be used to authorize the requests. OAuthHandler auth.OAuthHandler @@ -1605,6 +1616,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er cancel: cancel, failed: make(chan struct{}), disableStandaloneSSE: t.DisableStandaloneSSE, + headers: t.Headers, + maxResponseBytes: t.MaxResponseBytes, oauthHandler: t.OAuthHandler, } return conn, nil @@ -1624,6 +1637,9 @@ type streamableClientConn struct { // for receiving server-to-client notifications when no request is in flight. disableStandaloneSSE bool // from [StreamableClientTransport.DisableStandaloneSSE] + headers http.Header // from [StreamableClientTransport.Headers] + maxResponseBytes int64 // from [StreamableClientTransport.MaxResponseBytes] + // oauthHandler is the OAuth handler for the connection. oauthHandler auth.OAuthHandler // from [StreamableClientTransport.OAuthHandler] @@ -1819,6 +1835,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e // and permanently break the connection. err = fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err) } + if err == nil && c.maxResponseBytes > 0 { + resp.Body = http.MaxBytesReader(nil, resp.Body, c.maxResponseBytes) + } return req, resp, err } @@ -1947,7 +1966,12 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error { if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } - + for k, v := range c.headers { + ck := http.CanonicalHeaderKey(k) + if _, ok := req.Header[ck]; !ok { + req.Header[ck] = v + } + } return nil } diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index 96c5e75f..a726e265 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -1156,3 +1156,103 @@ func TestTokenInfo(t *testing.T) { t.Errorf("got %q, want %q", g, w) } } + +func TestStreamableClientHeaders(t *testing.T) { + ctx := context.Background() + + var mu sync.Mutex + got := make(map[string]http.Header) // method -> received headers + capture := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + got[r.Method] = r.Header.Clone() + mu.Unlock() + next.ServeHTTP(w, r) + }) + } + + server := NewServer(testImpl, nil) + streamHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(capture(streamHandler)) + defer httpServer.Close() + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + // Mix canonical and non-canonical keys; the non-canonical + // "content-type" must not override the transport-set value. + Headers: http.Header{ + "X-Custom": {"custom-value"}, + "content-type": {"text/plain"}, + }, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + + mu.Lock() + defer mu.Unlock() + for _, method := range []string{"POST", "GET", "DELETE"} { + h, ok := got[method] + if !ok { + t.Errorf("%s: no request received", method) + continue + } + if g, w := h.Get("X-Custom"), "custom-value"; g != w { + t.Errorf("%s: X-Custom = %q, want %q", method, g, w) + } + } + // Headers set by the transport must take precedence even when + // the user supplies a non-canonical key. POST is the only method + // where the transport sets Content-Type. + if g := got["POST"].Values("Content-Type"); len(g) != 1 || g[0] != "application/json" { + t.Errorf("POST: Content-Type = %q, want [\"application/json\"]", g) + } +} + +func TestStreamableClientMaxResponseBytes(t *testing.T) { + ctx := context.Background() + + server := NewServer(testImpl, nil) + big := func(context.Context, *CallToolRequest, struct{}) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: strings.Repeat("a", 10000)}}}, nil, nil + } + AddTool(server, &Tool{Name: "big", Description: "return big result"}, big) + + streamHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(streamHandler) + defer httpServer.Close() + + connect := func(max int64) (*ClientSession, error) { + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + MaxResponseBytes: max, + } + return NewClient(testImpl, nil).Connect(ctx, transport, nil) + } + + // A limit large enough for the response should succeed. + session, err := connect(1 << 20) + if err != nil { + t.Fatalf("connect with large limit failed: %v", err) + } + if _, err := session.CallTool(ctx, &CallToolParams{Name: "big"}); err != nil { + t.Errorf("CallTool with large limit failed: %v", err) + } + session.Close() + + // A limit large enough for initialize but smaller than the tool + // response should fail the tool call. + session, err = connect(2000) + if err != nil { + t.Fatalf("connect with small limit failed: %v", err) + } + defer session.Close() + if _, err := session.CallTool(ctx, &CallToolParams{Name: "big"}); err == nil { + t.Error("CallTool with small limit succeeded, want error") + } +}