Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f00d0f4
refactor: centralize MCP headers and add support for validating stand…
guglielmo-san Apr 24, 2026
d00288d
fix
guglielmo-san Apr 24, 2026
57659c0
formatter
guglielmo-san Apr 24, 2026
017e0fc
Merge branch 'main' into guglielmoc/SEP-2243_http_standardization
guglielmo-san Apr 24, 2026
604f2d4
fix after merge
guglielmo-san Apr 24, 2026
7df5ab6
refactor: decouple standard header population from setMCPHeaders in s…
guglielmo-san Apr 24, 2026
9de3bec
feat: implement SEP-2243 header encoding and support Mcp-Param- heade…
guglielmo-san Apr 26, 2026
f429bc5
feat: implement validation for tool-specific parameter headers in MCP…
guglielmo-san Apr 26, 2026
ad17562
feat: implement base64 header decoding and standardize primitive valu…
guglielmo-san Apr 26, 2026
a4e1e23
refactor: simplify MCP header logic and remove redundant tool cache a…
guglielmo-san Apr 27, 2026
aeada36
refactor: prohibit x-mcp-header annotations on nested object properti…
guglielmo-san Apr 27, 2026
b223143
refactor: rename mcp_http_headers to streamable_headers
guglielmo-san Apr 29, 2026
005d33d
Merge remote-tracking branch 'origin/main' into guglielmoc/SEP-2243_h…
guglielmo-san Apr 29, 2026
24cc607
feat: enable MCP parameter headers and add validation tests using int…
guglielmo-san Apr 29, 2026
9dd6907
fix: add mutex protection to toolCache and provide thread-safe access…
guglielmo-san Apr 30, 2026
8d4e94d
test: remove assertion for non-annotated query parameter headers in s…
guglielmo-san Apr 30, 2026
d8c5a76
Merge branch 'main' into guglielmoc/SEP-2243_http_standardization_2
guglielmo-san Apr 30, 2026
e754ff7
refactor: remove redundant blank line in streamable_headers.go
guglielmo-san Apr 30, 2026
87093cb
Merge remote-tracking branch 'refs/remotes/origin/guglielmoc/SEP-2243…
guglielmo-san Apr 30, 2026
04652ee
merge tests
guglielmo-san Apr 30, 2026
d7ef2c1
minor fix
guglielmo-san Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ type ClientOptions struct {
KeepAlive time.Duration
}

// toolContextKeyType is the context key type for passing tool definitions
// from CallTool to the transport layer.
type toolContextKeyType struct{}

var toolContextKey = toolContextKeyType{}

// bind implements the binder[*ClientSession] interface, so that Clients can
// be connected using [connect].
func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession {
Expand Down Expand Up @@ -318,6 +324,13 @@ type ClientSession struct {
// Pending URL elicitations waiting for completion notifications.
pendingElicitationsMu sync.Mutex
pendingElicitations map[string]chan struct{}

// toolCacheMu guards toolCache.
toolCacheMu sync.RWMutex
// toolCache stores tool definitions keyed by name.
// It is used to look up x-mcp-header annotations when
// constructing Mcp-Param-* headers for tools/call requests.
toolCache map[string]*Tool
}

type clientSessionState struct {
Expand Down Expand Up @@ -363,6 +376,23 @@ func (cs *ClientSession) Wait() error {
return cs.conn.Wait()
}

func (cs *ClientSession) cacheTools(tools []*Tool) {
cs.toolCacheMu.Lock()
defer cs.toolCacheMu.Unlock()
if cs.toolCache == nil {
cs.toolCache = make(map[string]*Tool, len(tools))
}
for _, tool := range tools {
cs.toolCache[tool.Name] = tool
}
}

func (cs *ClientSession) getCachedTool(name string) *Tool {
cs.toolCacheMu.RLock()
defer cs.toolCacheMu.RUnlock()
return cs.toolCache[name]
}

// registerElicitationWaiter registers a waiter for an elicitation complete
// notification with the given elicitation ID. It returns two functions: an await
// function that waits for the notification or context cancellation, and a cleanup
Expand Down Expand Up @@ -981,7 +1011,13 @@ func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams)

// ListTools lists tools that are currently available on the server.
func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) {
return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params)))
result, err := handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params)))
if err != nil {
return nil, err
}
result.Tools = filterValidTools(result.Tools)
cs.cacheTools(result.Tools)
return result, nil
}

// CallTool calls the tool with the given parameters.
Expand All @@ -995,6 +1031,9 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (
// Avoid sending nil over the wire.
params.Arguments = map[string]any{}
}
if tool := cs.getCachedTool(params.Name); tool != nil {
ctx = context.WithValue(ctx, toolContextKey, tool)
}
return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params)))
}

Expand Down
74 changes: 74 additions & 0 deletions mcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,80 @@ func TestClientCapabilities(t *testing.T) {
}
}

func TestToolCache(t *testing.T) {
tool1 := &Tool{Name: "tool1", Description: "first"}
tool2 := &Tool{Name: "tool2", Description: "second"}
tool1Updated := &Tool{Name: "tool1", Description: "updated"}

testCases := []struct {
name string
cacheBatches [][]*Tool
lookup string
want *Tool
}{
{
name: "empty cache",
lookup: "tool1",
want: nil,
},
{
name: "single tool found",
cacheBatches: [][]*Tool{{tool1}},
lookup: "tool1",
want: tool1,
},
{
name: "unknown tool",
cacheBatches: [][]*Tool{{tool1}},
lookup: "nonexistent",
want: nil,
},
{
name: "multiple tools single batch",
cacheBatches: [][]*Tool{{tool1, tool2}},
lookup: "tool2",
want: tool2,
},
{
name: "additive first tool retained",
cacheBatches: [][]*Tool{{tool1}, {tool2}},
lookup: "tool1",
want: tool1,
},
{
name: "additive second tool added",
cacheBatches: [][]*Tool{{tool1}, {tool2}},
lookup: "tool2",
want: tool2,
},
{
name: "overwrite existing entry",
cacheBatches: [][]*Tool{{tool1}, {tool1Updated}},
lookup: "tool1",
want: tool1Updated,
},
{
name: "empty batch no-op",
cacheBatches: [][]*Tool{{}},
lookup: "tool1",
want: nil,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cs := &ClientSession{}
for _, batch := range tc.cacheBatches {
cs.cacheTools(batch)
}
got := cs.getCachedTool(tc.lookup)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("getCachedTool(%q) mismatch (-want +got):\n%s", tc.lookup, diff)
}
})
}
}

func TestClientCapabilitiesOverWire(t *testing.T) {
testCases := []struct {
name string
Expand Down
86 changes: 86 additions & 0 deletions mcp/header_encoding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by the license
// that can be found in the LICENSE file.

package mcp

import (
"encoding/base64"
"fmt"
"strings"
)

const (
base64Prefix = "=?base64?"
base64Suffix = "?="
)

// encodeHeaderValue converts a parameter value to an HTTP header-safe string
// per the SEP-2243 encoding rules:
// - string: used as-is if safe ASCII, otherwise Base64 encoded
// - number (float64): decimal string representation
// - bool: lowercase "true" or "false"
// - nil: returns "", false
//
// Values that contain non-ASCII characters, control characters, or
// leading/trailing whitespace are Base64-encoded with the =?base64?...?= wrapper.
func encodeHeaderValue(value any) (string, bool) {
var s string
switch v := value.(type) {
case string:
s = v
case float64:
s = fmt.Sprintf("%g", v)
case bool:
if v {
s = "true"
} else {
s = "false"
}
default:
return "", false
}

if requiresBase64Encoding(s) {
return encodeBase64(s), true
}
return s, true
}

// decodeHeaderValue decodes a header value that may be Base64-encoded
// with the =?base64?...?= wrapper.
func decodeHeaderValue(headerValue string) (string, bool) {
if len(headerValue) == 0 {
return headerValue, true
}

if strings.HasPrefix(strings.ToLower(headerValue), strings.ToLower(base64Prefix)) &&
strings.HasSuffix(headerValue, base64Suffix) {
encoded := headerValue[len(base64Prefix) : len(headerValue)-len(base64Suffix)]
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return "", false
}
return string(decoded), true
}
return headerValue, true
}

func requiresBase64Encoding(s string) bool {
if len(s) == 0 {
return false
}
if s[0] == ' ' || s[0] == '\t' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' {
return true
}
for _, c := range s {
if c < 0x20 || c > 0x7E {
return true
}
}
return false
}

func encodeBase64(s string) string {
return base64Prefix + base64.StdEncoding.EncodeToString([]byte(s)) + base64Suffix
}
110 changes: 110 additions & 0 deletions mcp/header_encoding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by the license
// that can be found in the LICENSE file.

package mcp

import "testing"

func TestEncodeHeaderValue(t *testing.T) {
tests := []struct {
name string
value any
want string
wantOK bool
}{
// Strings
{"plain ASCII", "us-west1", "us-west1", true},
{"empty string", "", "", true},
{"string with internal spaces", "us west 1", "us west 1", true},
{"string with leading space", " us-west1", "=?base64?IHVzLXdlc3Qx?=", true},
{"string with trailing space", "us-west1 ", "=?base64?dXMtd2VzdDEg?=", true},
{"string with both spaces", " us-west1 ", "=?base64?IHVzLXdlc3QxIA==?=", true},
{"non-ASCII", "日本語", "=?base64?5pel5pys6Kqe?=", true},
{"mixed ASCII and non-ASCII", "Hello, 世界", "=?base64?SGVsbG8sIOS4lueVjA==?=", true},
{"string with newline", "line1\nline2", "=?base64?bGluZTEKbGluZTI=?=", true},
{"string with carriage return", "line1\r\nline2", "=?base64?bGluZTENCmxpbmUy?=", true},
{"string with leading tab", "\tindented", "=?base64?CWluZGVudGVk?=", true},

// Numbers
{"integer", float64(42), "42", true},
{"float", float64(3.14159), "3.14159", true},

// Booleans
{"true", true, "true", true},
{"false", false, "false", true},

// Unsupported types
{"nil", nil, "", false},
{"slice", []string{"a"}, "", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := encodeHeaderValue(tt.value)
if ok != tt.wantOK {
t.Fatalf("encodeHeaderValue(%v) ok = %v, want %v", tt.value, ok, tt.wantOK)
}
if got != tt.want {
t.Errorf("encodeHeaderValue(%v) = %q, want %q", tt.value, got, tt.want)
}
})
}
}

func TestDecodeHeaderValue(t *testing.T) {
tests := []struct {
name string
input string
want string
wantOK bool
}{
{"plain value", "us-west1", "us-west1", true},
{"empty value", "", "", true},
{"valid base64", "=?base64?SGVsbG8=?=", "Hello", true},
{"non-ASCII decoded", "=?base64?5pel5pys6Kqe?=", "日本語", true},
{"leading space decoded", "=?base64?IHVzLXdlc3Qx?=", " us-west1", true},
{"case-insensitive prefix", "=?BASE64?SGVsbG8=?=", "Hello", true},
{"invalid base64 chars", "=?base64?SGVs!!!bG8=?=", "", false},
// Missing prefix or suffix: treated as literal values, not base64
{"missing prefix", "SGVsbG8=", "SGVsbG8=", true},
{"missing suffix", "=?base64?SGVsbG8=", "=?base64?SGVsbG8=", true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := decodeHeaderValue(tt.input)
if ok != tt.wantOK {
t.Fatalf("decodeHeaderValue(%q) ok = %v, want %v", tt.input, ok, tt.wantOK)
}
if got != tt.want {
t.Errorf("decodeHeaderValue(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}

func TestEncodeDecodeRoundTrip(t *testing.T) {
values := []string{
"us-west1",
"",
" leading",
"trailing ",
"Hello, 世界",
"line1\nline2",
"\ttab",
}
for _, v := range values {
encoded, ok := encodeHeaderValue(v)
if !ok {
t.Fatalf("encodeHeaderValue(%q) failed", v)
}
decoded, ok := decodeHeaderValue(encoded)
if !ok {
t.Fatalf("decodeHeaderValue(%q) failed", encoded)
}
if decoded != v {
t.Errorf("round-trip failed: %q -> %q -> %q", v, encoded, decoded)
}
}
}
Loading
Loading