Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,17 @@ type StreamableClientTransport struct {
// - You want to avoid maintaining a persistent connection
DisableStandaloneSSE bool

// Headers contains additional HTTP headers to include with every request to
// the MCP endpoint. Headers that the transport itself sets (Content-Type,
// Accept, Authorization, Mcp-Protocol-Version, Mcp-Session-Id) take
// precedence over any conflicting entries here.
Headers http.Header

// MaxResponseBytes, if positive, limits the size of response bodies read
// from the MCP endpoint for outgoing POST requests. It does not apply to
// the standalone SSE stream. A value of zero means no limit.
MaxResponseBytes int64

// OAuthHandler is an optional field that, if provided, will be used to authorize the requests.
OAuthHandler auth.OAuthHandler

Expand Down Expand Up @@ -1605,6 +1616,8 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
cancel: cancel,
failed: make(chan struct{}),
disableStandaloneSSE: t.DisableStandaloneSSE,
headers: t.Headers,
maxResponseBytes: t.MaxResponseBytes,
oauthHandler: t.OAuthHandler,
}
return conn, nil
Expand All @@ -1624,6 +1637,9 @@ type streamableClientConn struct {
// for receiving server-to-client notifications when no request is in flight.
disableStandaloneSSE bool // from [StreamableClientTransport.DisableStandaloneSSE]

headers http.Header // from [StreamableClientTransport.Headers]
maxResponseBytes int64 // from [StreamableClientTransport.MaxResponseBytes]

// oauthHandler is the OAuth handler for the connection.
oauthHandler auth.OAuthHandler // from [StreamableClientTransport.OAuthHandler]

Expand Down Expand Up @@ -1819,6 +1835,9 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
// and permanently break the connection.
err = fmt.Errorf("%s: %w: %w", requestSummary, jsonrpc2.ErrRejected, err)
}
if err == nil && c.maxResponseBytes > 0 {
resp.Body = http.MaxBytesReader(nil, resp.Body, c.maxResponseBytes)
}
return req, resp, err
}

Expand Down Expand Up @@ -1947,7 +1966,12 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) error {
if c.sessionID != "" {
req.Header.Set(sessionIDHeader, c.sessionID)
}

for k, v := range c.headers {
ck := http.CanonicalHeaderKey(k)
if _, ok := req.Header[ck]; !ok {
req.Header[ck] = v
}
}
return nil
}

Expand Down
100 changes: 100 additions & 0 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1156,3 +1156,103 @@ func TestTokenInfo(t *testing.T) {
t.Errorf("got %q, want %q", g, w)
}
}

func TestStreamableClientHeaders(t *testing.T) {
ctx := context.Background()

var mu sync.Mutex
got := make(map[string]http.Header) // method -> received headers
capture := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
got[r.Method] = r.Header.Clone()
mu.Unlock()
next.ServeHTTP(w, r)
})
}

server := NewServer(testImpl, nil)
streamHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)
httpServer := httptest.NewServer(capture(streamHandler))
defer httpServer.Close()

transport := &StreamableClientTransport{
Endpoint: httpServer.URL,
// Mix canonical and non-canonical keys; the non-canonical
// "content-type" must not override the transport-set value.
Headers: http.Header{
"X-Custom": {"custom-value"},
"content-type": {"text/plain"},
},
}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
if err := session.Close(); err != nil {
t.Errorf("closing session: %v", err)
}

mu.Lock()
defer mu.Unlock()
for _, method := range []string{"POST", "GET", "DELETE"} {
h, ok := got[method]
if !ok {
t.Errorf("%s: no request received", method)
continue
}
if g, w := h.Get("X-Custom"), "custom-value"; g != w {
t.Errorf("%s: X-Custom = %q, want %q", method, g, w)
}
}
// Headers set by the transport must take precedence even when
// the user supplies a non-canonical key. POST is the only method
// where the transport sets Content-Type.
if g := got["POST"].Values("Content-Type"); len(g) != 1 || g[0] != "application/json" {
t.Errorf("POST: Content-Type = %q, want [\"application/json\"]", g)
}
}

func TestStreamableClientMaxResponseBytes(t *testing.T) {
ctx := context.Background()

server := NewServer(testImpl, nil)
big := func(context.Context, *CallToolRequest, struct{}) (*CallToolResult, any, error) {
return &CallToolResult{Content: []Content{&TextContent{Text: strings.Repeat("a", 10000)}}}, nil, nil
}
AddTool(server, &Tool{Name: "big", Description: "return big result"}, big)

streamHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)
httpServer := httptest.NewServer(streamHandler)
defer httpServer.Close()

connect := func(max int64) (*ClientSession, error) {
transport := &StreamableClientTransport{
Endpoint: httpServer.URL,
MaxResponseBytes: max,
}
return NewClient(testImpl, nil).Connect(ctx, transport, nil)
}

// A limit large enough for the response should succeed.
session, err := connect(1 << 20)
if err != nil {
t.Fatalf("connect with large limit failed: %v", err)
}
if _, err := session.CallTool(ctx, &CallToolParams{Name: "big"}); err != nil {
t.Errorf("CallTool with large limit failed: %v", err)
}
session.Close()

// A limit large enough for initialize but smaller than the tool
// response should fail the tool call.
session, err = connect(2000)
if err != nil {
t.Fatalf("connect with small limit failed: %v", err)
}
defer session.Close()
if _, err := session.CallTool(ctx, &CallToolParams{Name: "big"}); err == nil {
t.Error("CallTool with small limit succeeded, want error")
}
}
Loading