diff --git a/mcp/client.go b/mcp/client.go index c82e189d..a3a66132 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -160,6 +160,12 @@ type ClientOptions struct { KeepAlive time.Duration } +// toolContextKeyType is the context key type for passing tool definitions +// from CallTool to the transport layer. +type toolContextKeyType struct{} + +var toolContextKey = toolContextKeyType{} + // bind implements the binder[*ClientSession] interface, so that Clients can // be connected using [connect]. func (c *Client) bind(mcpConn Connection, conn *jsonrpc2.Connection, state *clientSessionState, onClose func()) *ClientSession { @@ -318,6 +324,13 @@ type ClientSession struct { // Pending URL elicitations waiting for completion notifications. pendingElicitationsMu sync.Mutex pendingElicitations map[string]chan struct{} + + // toolCacheMu guards toolCache. + toolCacheMu sync.RWMutex + // toolCache stores tool definitions keyed by name. + // It is used to look up x-mcp-header annotations when + // constructing Mcp-Param-* headers for tools/call requests. + toolCache map[string]*Tool } type clientSessionState struct { @@ -363,6 +376,23 @@ func (cs *ClientSession) Wait() error { return cs.conn.Wait() } +func (cs *ClientSession) cacheTools(tools []*Tool) { + cs.toolCacheMu.Lock() + defer cs.toolCacheMu.Unlock() + if cs.toolCache == nil { + cs.toolCache = make(map[string]*Tool, len(tools)) + } + for _, tool := range tools { + cs.toolCache[tool.Name] = tool + } +} + +func (cs *ClientSession) getCachedTool(name string) *Tool { + cs.toolCacheMu.RLock() + defer cs.toolCacheMu.RUnlock() + return cs.toolCache[name] +} + // registerElicitationWaiter registers a waiter for an elicitation complete // notification with the given elicitation ID. It returns two functions: an await // function that waits for the notification or context cancellation, and a cleanup @@ -981,7 +1011,13 @@ func (cs *ClientSession) GetPrompt(ctx context.Context, params *GetPromptParams) // ListTools lists tools that are currently available on the server. func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams) (*ListToolsResult, error) { - return handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) + result, err := handleSend[*ListToolsResult](ctx, methodListTools, newClientRequest(cs, orZero[Params](params))) + if err != nil { + return nil, err + } + result.Tools = filterValidTools(result.Tools) + cs.cacheTools(result.Tools) + return result, nil } // CallTool calls the tool with the given parameters. @@ -995,6 +1031,9 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) ( // Avoid sending nil over the wire. params.Arguments = map[string]any{} } + if tool := cs.getCachedTool(params.Name); tool != nil { + ctx = context.WithValue(ctx, toolContextKey, tool) + } return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params))) } diff --git a/mcp/client_test.go b/mcp/client_test.go index fc37c3eb..8c840d9d 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -440,6 +440,80 @@ func TestClientCapabilities(t *testing.T) { } } +func TestToolCache(t *testing.T) { + tool1 := &Tool{Name: "tool1", Description: "first"} + tool2 := &Tool{Name: "tool2", Description: "second"} + tool1Updated := &Tool{Name: "tool1", Description: "updated"} + + testCases := []struct { + name string + cacheBatches [][]*Tool + lookup string + want *Tool + }{ + { + name: "empty cache", + lookup: "tool1", + want: nil, + }, + { + name: "single tool found", + cacheBatches: [][]*Tool{{tool1}}, + lookup: "tool1", + want: tool1, + }, + { + name: "unknown tool", + cacheBatches: [][]*Tool{{tool1}}, + lookup: "nonexistent", + want: nil, + }, + { + name: "multiple tools single batch", + cacheBatches: [][]*Tool{{tool1, tool2}}, + lookup: "tool2", + want: tool2, + }, + { + name: "additive first tool retained", + cacheBatches: [][]*Tool{{tool1}, {tool2}}, + lookup: "tool1", + want: tool1, + }, + { + name: "additive second tool added", + cacheBatches: [][]*Tool{{tool1}, {tool2}}, + lookup: "tool2", + want: tool2, + }, + { + name: "overwrite existing entry", + cacheBatches: [][]*Tool{{tool1}, {tool1Updated}}, + lookup: "tool1", + want: tool1Updated, + }, + { + name: "empty batch no-op", + cacheBatches: [][]*Tool{{}}, + lookup: "tool1", + want: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cs := &ClientSession{} + for _, batch := range tc.cacheBatches { + cs.cacheTools(batch) + } + got := cs.getCachedTool(tc.lookup) + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("getCachedTool(%q) mismatch (-want +got):\n%s", tc.lookup, diff) + } + }) + } +} + func TestClientCapabilitiesOverWire(t *testing.T) { testCases := []struct { name string diff --git a/mcp/header_encoding.go b/mcp/header_encoding.go new file mode 100644 index 00000000..77f414f5 --- /dev/null +++ b/mcp/header_encoding.go @@ -0,0 +1,86 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "encoding/base64" + "fmt" + "strings" +) + +const ( + base64Prefix = "=?base64?" + base64Suffix = "?=" +) + +// encodeHeaderValue converts a parameter value to an HTTP header-safe string +// per the SEP-2243 encoding rules: +// - string: used as-is if safe ASCII, otherwise Base64 encoded +// - number (float64): decimal string representation +// - bool: lowercase "true" or "false" +// - nil: returns "", false +// +// Values that contain non-ASCII characters, control characters, or +// leading/trailing whitespace are Base64-encoded with the =?base64?...?= wrapper. +func encodeHeaderValue(value any) (string, bool) { + var s string + switch v := value.(type) { + case string: + s = v + case float64: + s = fmt.Sprintf("%g", v) + case bool: + if v { + s = "true" + } else { + s = "false" + } + default: + return "", false + } + + if requiresBase64Encoding(s) { + return encodeBase64(s), true + } + return s, true +} + +// decodeHeaderValue decodes a header value that may be Base64-encoded +// with the =?base64?...?= wrapper. +func decodeHeaderValue(headerValue string) (string, bool) { + if len(headerValue) == 0 { + return headerValue, true + } + + if strings.HasPrefix(strings.ToLower(headerValue), strings.ToLower(base64Prefix)) && + strings.HasSuffix(headerValue, base64Suffix) { + encoded := headerValue[len(base64Prefix) : len(headerValue)-len(base64Suffix)] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return "", false + } + return string(decoded), true + } + return headerValue, true +} + +func requiresBase64Encoding(s string) bool { + if len(s) == 0 { + return false + } + if s[0] == ' ' || s[0] == '\t' || s[len(s)-1] == ' ' || s[len(s)-1] == '\t' { + return true + } + for _, c := range s { + if c < 0x20 || c > 0x7E { + return true + } + } + return false +} + +func encodeBase64(s string) string { + return base64Prefix + base64.StdEncoding.EncodeToString([]byte(s)) + base64Suffix +} diff --git a/mcp/header_encoding_test.go b/mcp/header_encoding_test.go new file mode 100644 index 00000000..811ab336 --- /dev/null +++ b/mcp/header_encoding_test.go @@ -0,0 +1,110 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import "testing" + +func TestEncodeHeaderValue(t *testing.T) { + tests := []struct { + name string + value any + want string + wantOK bool + }{ + // Strings + {"plain ASCII", "us-west1", "us-west1", true}, + {"empty string", "", "", true}, + {"string with internal spaces", "us west 1", "us west 1", true}, + {"string with leading space", " us-west1", "=?base64?IHVzLXdlc3Qx?=", true}, + {"string with trailing space", "us-west1 ", "=?base64?dXMtd2VzdDEg?=", true}, + {"string with both spaces", " us-west1 ", "=?base64?IHVzLXdlc3QxIA==?=", true}, + {"non-ASCII", "日本語", "=?base64?5pel5pys6Kqe?=", true}, + {"mixed ASCII and non-ASCII", "Hello, 世界", "=?base64?SGVsbG8sIOS4lueVjA==?=", true}, + {"string with newline", "line1\nline2", "=?base64?bGluZTEKbGluZTI=?=", true}, + {"string with carriage return", "line1\r\nline2", "=?base64?bGluZTENCmxpbmUy?=", true}, + {"string with leading tab", "\tindented", "=?base64?CWluZGVudGVk?=", true}, + + // Numbers + {"integer", float64(42), "42", true}, + {"float", float64(3.14159), "3.14159", true}, + + // Booleans + {"true", true, "true", true}, + {"false", false, "false", true}, + + // Unsupported types + {"nil", nil, "", false}, + {"slice", []string{"a"}, "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := encodeHeaderValue(tt.value) + if ok != tt.wantOK { + t.Fatalf("encodeHeaderValue(%v) ok = %v, want %v", tt.value, ok, tt.wantOK) + } + if got != tt.want { + t.Errorf("encodeHeaderValue(%v) = %q, want %q", tt.value, got, tt.want) + } + }) + } +} + +func TestDecodeHeaderValue(t *testing.T) { + tests := []struct { + name string + input string + want string + wantOK bool + }{ + {"plain value", "us-west1", "us-west1", true}, + {"empty value", "", "", true}, + {"valid base64", "=?base64?SGVsbG8=?=", "Hello", true}, + {"non-ASCII decoded", "=?base64?5pel5pys6Kqe?=", "日本語", true}, + {"leading space decoded", "=?base64?IHVzLXdlc3Qx?=", " us-west1", true}, + {"case-insensitive prefix", "=?BASE64?SGVsbG8=?=", "Hello", true}, + {"invalid base64 chars", "=?base64?SGVs!!!bG8=?=", "", false}, + // Missing prefix or suffix: treated as literal values, not base64 + {"missing prefix", "SGVsbG8=", "SGVsbG8=", true}, + {"missing suffix", "=?base64?SGVsbG8=", "=?base64?SGVsbG8=", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := decodeHeaderValue(tt.input) + if ok != tt.wantOK { + t.Fatalf("decodeHeaderValue(%q) ok = %v, want %v", tt.input, ok, tt.wantOK) + } + if got != tt.want { + t.Errorf("decodeHeaderValue(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestEncodeDecodeRoundTrip(t *testing.T) { + values := []string{ + "us-west1", + "", + " leading", + "trailing ", + "Hello, 世界", + "line1\nline2", + "\ttab", + } + for _, v := range values { + encoded, ok := encodeHeaderValue(v) + if !ok { + t.Fatalf("encodeHeaderValue(%q) failed", v) + } + decoded, ok := decodeHeaderValue(encoded) + if !ok { + t.Fatalf("decodeHeaderValue(%q) failed", encoded) + } + if decoded != v { + t.Errorf("round-trip failed: %q -> %q -> %q", v, encoded, decoded) + } + } +} diff --git a/mcp/streamable.go b/mcp/streamable.go index 708b1326..6f7c1866 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -490,6 +490,14 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque http.Error(w, "failed connection", http.StatusInternalServerError) return } + transport.connection.toolLookup = func(name string) *Tool { + server.mu.Lock() + defer server.mu.Unlock() + if st, ok := server.tools.get(name); ok { + return st.tool + } + return nil + } // Capture the user ID from the token info to enable session hijacking // prevention on subsequent requests. var userID string @@ -668,6 +676,8 @@ type streamableServerConn struct { logger *slog.Logger + toolLookup func(name string) *Tool + incoming chan jsonrpc.Message // messages from the client to the server mu sync.Mutex // guards all fields below @@ -1185,9 +1195,15 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques } } - // Validate MCP standard headers (Mcp-Method, Mcp-Name) + // Validate MCP standard headers (Mcp-Method, Mcp-Name, Mcp-Param-*) if !isBatch && len(incoming) == 1 { - if err := validateMcpHeaders(req.Header, incoming[0]); err != nil { + var tool *Tool + if jreq, ok := incoming[0].(*jsonrpc.Request); ok && jreq.Method == "tools/call" && c.toolLookup != nil { + if name, ok := extractName(jreq.Method, jreq.Params); ok { + tool = c.toolLookup(name) + } + } + if err := validateMcpHeaders(req.Header, incoming[0], tool); err != nil { resp := &jsonrpc.Response{ Error: jsonrpc2.NewError(CodeHeaderMismatch, err.Error()), } @@ -1785,6 +1801,11 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e if msg.IsCall() { forCall = msg } + if msg.Method == "tools/call" { + if tool, ok := ctx.Value(toolContextKey).(*Tool); ok { + msg.Extra = tool + } + } case *jsonrpc.Response: requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID) default: diff --git a/mcp/streamable_headers.go b/mcp/streamable_headers.go index ac9a4121..6383439a 100644 --- a/mcp/streamable_headers.go +++ b/mcp/streamable_headers.go @@ -8,7 +8,9 @@ import ( "encoding/json" "errors" "fmt" + "log" "net/http" + "strings" internaljson "github.com/modelcontextprotocol/go-sdk/internal/json" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -20,7 +22,9 @@ const ( lastEventIDHeader = "Last-Event-ID" methodHeader = "Mcp-Method" nameHeader = "Mcp-Name" + paramHeaderPrefix = "Mcp-Param-" minVersionForStandardHeaders = protocolVersion20260630 + mcpHeaderExtension = "x-mcp-header" ) func extractName(method string, params json.RawMessage) (string, bool) { @@ -45,6 +49,75 @@ func extractName(method string, params json.RawMessage) (string, bool) { return "", false } +func extractSchemaProperties(schema any) map[string]any { + s, ok := schema.(map[string]any) + if !ok { + return nil + } + props, ok := s["properties"].(map[string]any) + if !ok { + return nil + } + return props +} + +// extractToolParamHeaders returns a map of parameter name to header name +// for all properties in the tool's InputSchema that have an x-mcp-header +// annotation. +func extractToolParamHeaders(tool *Tool) map[string]string { + props := extractSchemaProperties(tool.InputSchema) + if props == nil { + return nil + } + result := make(map[string]string) + for propName, propSchema := range props { + ps, ok := propSchema.(map[string]any) + if !ok { + continue + } + headerName, ok := ps[mcpHeaderExtension].(string) + if !ok || headerName == "" { + continue + } + result[propName] = headerName + } + if len(result) == 0 { + return nil + } + return result +} + +func primitiveToString(value any) string { + switch v := value.(type) { + case string: + return v + case float64: + return fmt.Sprintf("%g", v) + case bool: + if v { + return "true" + } + return "false" + default: + return fmt.Sprintf("%v", v) + } +} + +// unmarshalPrimitive unmarshals a JSON value into a Go primitive +// (string, float64, or bool). Returns nil for non-primitive types. +func unmarshalPrimitive(raw json.RawMessage) any { + var val any + if err := internaljson.Unmarshal(raw, &val); err != nil { + return nil + } + switch val.(type) { + case string, float64, bool: + return val + default: + return nil + } +} + // setStandardHeaders populates standard MCP headers. // It requires the protocol version header to be set. func setStandardHeaders(header http.Header, msg jsonrpc.Message) { @@ -61,10 +134,141 @@ func setStandardHeaders(header http.Header, msg jsonrpc.Message) { if name, ok := extractName(msg.Method, msg.Params); ok { header.Set(nameHeader, name) } + if msg.Method == "tools/call" { + if tool, ok := msg.Extra.(*Tool); ok && tool != nil { + for k, v := range extractParamHeaders(tool, msg.Params) { + header.Set(k, v) + } + } + } } } -func validateMcpHeaders(header http.Header, msg jsonrpc.Message) error { +// extractParamHeaders reads x-mcp-header annotations from the tool's InputSchema +// and returns the Mcp-Param-{Name} headers to be set on the HTTP request. +func extractParamHeaders(tool *Tool, params json.RawMessage) map[string]string { + paramHeaders := extractToolParamHeaders(tool) + if len(paramHeaders) == 0 { + return nil + } + + var raw struct { + Arguments map[string]json.RawMessage `json:"arguments"` + } + if err := internaljson.Unmarshal(params, &raw); err != nil || raw.Arguments == nil { + return nil + } + + res := make(map[string]string) + for paramName, headerName := range paramHeaders { + argRaw, ok := raw.Arguments[paramName] + if !ok { + continue + } + if string(argRaw) == "null" { + continue + } + val := unmarshalPrimitive(argRaw) + if val == nil { + continue + } + encoded, ok := encodeHeaderValue(val) + if !ok { + continue + } + res[paramHeaderPrefix+headerName] = encoded + } + return res +} + +// filterValidTools returns only tools that have valid +// x-mcp-header annotations. Invalid tools are logged and excluded. +func filterValidTools(tools []*Tool) []*Tool { + result := make([]*Tool, 0, len(tools)) + for _, tool := range tools { + if err := validateToolParamHeaders(tool); err != nil { + log.Printf("mcp: excluding tool %q from tools/list: %v", tool.Name, err) + continue + } + result = append(result, tool) + } + return result +} + +// validateToolParamHeaders checks that a tool's x-mcp-header annotations +// are valid. +func validateToolParamHeaders(tool *Tool) error { + props := extractSchemaProperties(tool.InputSchema) + if props == nil { + return nil + } + + seen := make(map[string]bool) + for propName, propSchema := range props { + ps, ok := propSchema.(map[string]any) + if !ok { + continue + } + if err := checkForNestedHeaders(ps, propName); err != nil { + return err + } + headerNameRaw, exists := ps[mcpHeaderExtension] + if !exists { + continue + } + headerName, ok := headerNameRaw.(string) + if !ok || headerName == "" { + return fmt.Errorf("property %q: x-mcp-header must be a non-empty string", propName) + } + if err := validateHeaderName(headerName); err != nil { + return fmt.Errorf("property %q: %w", propName, err) + } + lower := strings.ToLower(headerName) + if seen[lower] { + return fmt.Errorf("property %q: duplicate x-mcp-header value %q (case-insensitive)", propName, headerName) + } + seen[lower] = true + + propType, _ := ps["type"].(string) + if propType != "" && propType != "string" && propType != "number" && propType != "integer" && propType != "boolean" { + return fmt.Errorf("property %q: x-mcp-header can only be applied to primitive types, got %q", propName, propType) + } + } + return nil +} + +func checkForNestedHeaders(schema map[string]any, path string) error { + nestedProps := extractSchemaProperties(schema) + if nestedProps == nil { + return nil + } + for propName, propSchema := range nestedProps { + ps, ok := propSchema.(map[string]any) + if !ok { + continue + } + if _, exists := ps[mcpHeaderExtension]; exists { + return fmt.Errorf("property %q: x-mcp-header cannot be applied to nested properties", path+"."+propName) + } + if err := checkForNestedHeaders(ps, path+"."+propName); err != nil { + return err + } + } + return nil +} + +// validateHeaderName checks that a header name contains only valid +// ASCII characters (excluding space and ':'). +func validateHeaderName(name string) error { + for _, c := range name { + if c <= 0x20 || c > 0x7E || c == ':' { + return fmt.Errorf("x-mcp-header value %q contains invalid character %q", name, c) + } + } + return nil +} + +func validateMcpHeaders(header http.Header, msg jsonrpc.Message, tool *Tool) error { protocolVersion := header.Get(protocolVersionHeader) if protocolVersion == "" || protocolVersion < minVersionForStandardHeaders { return nil @@ -93,6 +297,59 @@ func validateMcpHeaders(header http.Header, msg jsonrpc.Message) error { return fmt.Errorf("header mismatch: Mcp-Name header value '%s' does not match body value '%s'", nameInHeader, nameInBody) } } + + if msg.Method == "tools/call" && tool != nil { + if err := validateParamHeaders(header, msg, tool); err != nil { + return err + } + } + } + return nil +} + +func validateParamHeaders(header http.Header, msg *jsonrpc.Request, tool *Tool) error { + paramHeaders := extractToolParamHeaders(tool) + if len(paramHeaders) == 0 { + return nil + } + + var raw struct { + Arguments map[string]json.RawMessage `json:"arguments"` + } + if err := internaljson.Unmarshal(msg.Params, &raw); err != nil { + return nil + } + + for paramName, headerName := range paramHeaders { + fullHeader := paramHeaderPrefix + headerName + headerVal := header.Get(fullHeader) + argRaw, argExists := raw.Arguments[paramName] + + if !argExists || string(argRaw) == "null" { + if headerVal != "" { + return fmt.Errorf("header mismatch: unexpected %s header for absent or null parameter %q", fullHeader, paramName) + } + continue + } + + if headerVal == "" { + return fmt.Errorf("header mismatch: missing %s header for parameter %q", fullHeader, paramName) + } + + decoded, ok := decodeHeaderValue(headerVal) + if !ok { + return fmt.Errorf("header mismatch: %s header contains invalid Base64 encoding", fullHeader) + } + + bodyVal := unmarshalPrimitive(argRaw) + if bodyVal == nil { + continue + } + expected := primitiveToString(bodyVal) + + if decoded != expected { + return fmt.Errorf("header mismatch: %s header value '%s' does not match body value", fullHeader, headerVal) + } } return nil } diff --git a/mcp/streamable_headers_test.go b/mcp/streamable_headers_test.go index 0abf1b28..5bf3309b 100644 --- a/mcp/streamable_headers_test.go +++ b/mcp/streamable_headers_test.go @@ -6,6 +6,7 @@ package mcp import ( "encoding/json" + "fmt" "net/http" "strings" "testing" @@ -399,7 +400,7 @@ func TestValidateMcpHeaders(t *testing.T) { header.Set(nameHeader, tt.nameHeader) } - err := validateMcpHeaders(header, tt.msg) + err := validateMcpHeaders(header, tt.msg, nil) if tt.wantErr { if err == nil { t.Fatalf("validateMcpHeaders() = nil, want error containing %q", tt.wantErrContain) @@ -413,3 +414,587 @@ func TestValidateMcpHeaders(t *testing.T) { }) } } + +func TestValidateToolParamHeaders(t *testing.T) { + tests := []struct { + name string + tool *Tool + wantErr bool + wantErrSub string + }{ + { + name: "valid tool with x-mcp-header", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, + }, + { + name: "tool with no x-mcp-header annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }, + }, + { + name: "empty x-mcp-header value", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "non-empty string", + }, + { + name: "x-mcp-header with space", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "My Region", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "x-mcp-header with colon", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region:Primary", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "x-mcp-header with non-ASCII", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Région", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "invalid character", + }, + { + name: "duplicate header names same case", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "a": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "b": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + }, + wantErr: true, + wantErrSub: "duplicate", + }, + { + name: "duplicate header names different case", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "a": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "b": map[string]any{"type": "string", "x-mcp-header": "REGION"}, + }, + }, + }, + wantErr: true, + wantErrSub: "duplicate", + }, + { + name: "x-mcp-header on array type", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "items": map[string]any{ + "type": "array", + "x-mcp-header": "Items", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "primitive types", + }, + { + name: "x-mcp-header on object type", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "nested": map[string]any{ + "type": "object", + "x-mcp-header": "Nested", + }, + }, + }, + }, + wantErr: true, + wantErrSub: "primitive types", + }, + { + name: "x-mcp-header on number type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "number", + "x-mcp-header": "Count", + }, + }, + }, + }, + }, + { + name: "x-mcp-header on integer type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "integer", + "x-mcp-header": "Count", + }, + }, + }, + }, + }, + { + name: "x-mcp-header on boolean type is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{ + "type": "boolean", + "x-mcp-header": "Flag", + }, + }, + }, + }, + }, + { + name: "x-mcp-header on nested property inside object", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "nested", + }, + { + name: "x-mcp-header on deeply nested property", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "outer": map[string]any{ + "type": "object", + "properties": map[string]any{ + "inner": map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{ + "type": "string", + "x-mcp-header": "Value", + }, + }, + }, + }, + }, + }, + }, + }, + wantErr: true, + wantErrSub: "nested", + }, + { + name: "object property without nested x-mcp-header is valid", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + }, + }, + }, + "flag": map[string]any{ + "type": "boolean", + "x-mcp-header": "Flag", + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateToolParamHeaders(tt.tool) + if tt.wantErr { + if err == nil { + t.Fatal("validateToolParamHeaders() = nil, want error") + } + if !strings.Contains(err.Error(), tt.wantErrSub) { + t.Errorf("error = %q, want substring %q", err.Error(), tt.wantErrSub) + } + } else if err != nil { + t.Errorf("validateToolParamHeaders() = %v, want nil", err) + } + }) + } +} + +func TestFilterValidTools(t *testing.T) { + valid := &Tool{ + Name: "valid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + } + invalid := &Tool{ + Name: "invalid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": ""}, + }, + }, + } + noAnnotation := &Tool{ + Name: "plain", + InputSchema: map[string]any{"type": "object", "properties": map[string]any{"q": map[string]any{"type": "string"}}}, + } + nestedInvalid := &Tool{ + Name: "nested-invalid", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "config": map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + }, + }, + }, + }, + } + + result := filterValidTools([]*Tool{valid, invalid, noAnnotation, nestedInvalid}) + if len(result) != 2 { + t.Fatalf("filterValidTools returned %d tools, want 2", len(result)) + } + if result[0].Name != "valid" || result[1].Name != "plain" { + t.Errorf("filterValidTools returned [%s, %s], want [valid, plain]", result[0].Name, result[1].Name) + } +} + +func TestSetStandardHeadersWithParamHeaders(t *testing.T) { + toolSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + "priority": map[string]any{ + "type": "string", + "x-mcp-header": "Priority", + }, + }, + } + tool := &Tool{Name: "execute_sql", InputSchema: toolSchema} + + tests := []struct { + name string + tool *Tool + params any + wantHeaders map[string]string + }{ + { + name: "sets param headers from arguments", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1", "priority": "high"}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Region": "us-west1", + "Mcp-Param-Priority": "high", + }, + }, + { + name: "omits header when argument is missing", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"query": "SELECT 1"}, + }, + wantHeaders: map[string]string{}, + }, + { + name: "omits header when argument is null", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": nil, "query": "SELECT 1"}, + }, + wantHeaders: map[string]string{}, + }, + { + name: "encodes non-ASCII value", + tool: tool, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "日本", "query": "SELECT 1"}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Region": "=?base64?5pel5pys?=", + }, + }, + { + name: "handles boolean argument", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "flag": map[string]any{"type": "boolean", "x-mcp-header": "Flag"}, + }, + }, + }, + params: &CallToolParams{ + Name: "test", + Arguments: map[string]any{"flag": true}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Flag": "true", + }, + }, + { + name: "handles number argument", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{"type": "number", "x-mcp-header": "Count"}, + }, + }, + }, + params: &CallToolParams{ + Name: "test", + Arguments: map[string]any{"count": float64(42)}, + }, + wantHeaders: map[string]string{ + "Mcp-Param-Count": "42", + }, + }, + { + name: "no tool in extra does not add param headers", + tool: nil, + params: &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + }, + wantHeaders: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header := http.Header{} + header.Set(protocolVersionHeader, minVersionForStandardHeaders) + + msg := &jsonrpc.Request{ + Method: "tools/call", + Params: mustMarshal(tt.params), + Extra: tt.tool, + } + + setStandardHeaders(header, msg) + + if got := header.Get(methodHeader); got != "tools/call" { + t.Errorf("MethodHeader = %q, want %q", got, "tools/call") + } + + for h, want := range tt.wantHeaders { + if got := header.Get(h); got != want { + t.Errorf("%s = %q, want %q", h, got, want) + } + } + + if got := header.Get("Mcp-Param-query"); got != "" { + t.Errorf("non-annotated param got header: Mcp-Param-query = %q", got) + } + }) + } +} + +func TestExtractToolParamHeaders(t *testing.T) { + tests := []struct { + name string + tool *Tool + want map[string]string + }{ + { + name: "extracts x-mcp-header annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{"type": "string", "x-mcp-header": "Region"}, + "query": map[string]any{"type": "string"}, + "tenant_id": map[string]any{"type": "string", "x-mcp-header": "TenantId"}, + }, + }, + }, + want: map[string]string{"region": "Region", "tenant_id": "TenantId"}, + }, + { + name: "returns nil for tool without properties", + tool: &Tool{Name: "test", InputSchema: map[string]any{"type": "object"}}, + want: nil, + }, + { + name: "returns nil for non-map schema", + tool: &Tool{Name: "test", InputSchema: "not a map"}, + want: nil, + }, + { + name: "returns nil when no annotations", + tool: &Tool{ + Name: "test", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"q": map[string]any{"type": "string"}}, + }, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractToolParamHeaders(tt.tool) + if tt.want == nil { + if got != nil { + t.Errorf("extractToolParamHeaders() = %v, want nil", got) + } + return + } + if len(got) != len(tt.want) { + t.Fatalf("extractToolParamHeaders() returned %d entries, want %d", len(got), len(tt.want)) + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("extractToolParamHeaders()[%q] = %q, want %q", k, got[k], v) + } + } + }) + } +} + +func TestUnmarshalPrimitive(t *testing.T) { + tests := []struct { + name string + raw string + want any + }{ + {"string", `"hello"`, "hello"}, + {"number", `42`, float64(42)}, + {"float", `3.14`, float64(3.14)}, + {"true", `true`, true}, + {"false", `false`, false}, + {"null", `null`, nil}, + {"array", `[1,2]`, nil}, + {"object", `{"a":1}`, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := unmarshalPrimitive(json.RawMessage(tt.raw)) + if fmt.Sprintf("%v", got) != fmt.Sprintf("%v", tt.want) { + t.Errorf("unmarshalPrimitive(%s) = %v (%T), want %v (%T)", tt.raw, got, got, tt.want, tt.want) + } + }) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 80fb4baa..2876c4de 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1915,9 +1915,9 @@ func TestStreamableGET(t *testing.T) { } } -// TestStreamableMcpHeaderValidation tests the server-side Mcp-Method and -// Mcp-Name header validation through the full HTTP handler, as specified -// in SEP-2243. +// TestStreamableMcpHeaderValidation tests the server-side Mcp-Method, +// Mcp-Name, and Mcp-Param header validation through the full HTTP handler, +// as specified in SEP-2243. func TestStreamableMcpHeaderValidation(t *testing.T) { // Temporarily register the future version so the handler accepts it. orig := supportedProtocolVersions @@ -1930,6 +1930,25 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { return &CallToolResult{}, nil }) + server.AddTool( + &Tool{ + Name: "execute_sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + }, + }, + }, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + }) handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) defer handler.closeAll() @@ -2019,6 +2038,50 @@ func TestStreamableMcpHeaderValidation(t *testing.T) { wantStatusCode: http.StatusOK, wantMessages: []jsonrpc.Message{resp(6, &CallToolResult{Content: []Content{}}, nil)}, }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + paramHeaderPrefix + "Region": {"us-west1"}, + }, + messages: []jsonrpc.Message{req(7, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, + })}, + wantStatusCode: http.StatusOK, + wantMessages: []jsonrpc.Message{resp(7, &CallToolResult{Content: []Content{}}, nil)}, + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + paramHeaderPrefix + "Region": {"eu-central1"}, + }, + messages: []jsonrpc.Message{req(8, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + })}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "header mismatch", + }, + { + method: "POST", + headers: http.Header{ + protocolVersionHeader: {minVersionForStandardHeaders}, + methodHeader: {"tools/call"}, + nameHeader: {"execute_sql"}, + }, + messages: []jsonrpc.Message{req(9, "tools/call", &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1"}, + })}, + wantStatusCode: http.StatusBadRequest, + wantBodyContaining: "missing", + }, }) } @@ -2170,6 +2233,166 @@ func TestStreamableMcpHeaderVersionGating(t *testing.T) { }) } +// TestStreamableParamHeadersClientSetsHeaders verifies that the client sets +// Mcp-Param-* headers on tool calls when the tool has x-mcp-header annotations. +func TestStreamableParamHeadersClientSetsHeaders(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + server.AddTool( + &Tool{ + Name: "execute_sql", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + "query": map[string]any{ + "type": "string", + }, + }, + }, + }, + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "ok"}}}, nil + }) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + var capturedHeaders http.Header + customClient := &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header.Get(methodHeader) == "tools/call" { + capturedHeaders = req.Header.Clone() + } + return http.DefaultTransport.RoundTrip(req) + }), + } + + clientTransport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + HTTPClient: customClient, + } + + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + ctx := context.Background() + session, err := client.Connect(ctx, clientTransport, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + if err != nil { + t.Fatal(err) + } + defer session.Close() + + // ListTools to populate the tool cache (needed for param headers). + if _, err := session.ListTools(ctx, nil); err != nil { + t.Fatal(err) + } + + _, err = session.CallTool(ctx, &CallToolParams{ + Name: "execute_sql", + Arguments: map[string]any{"region": "us-west1", "query": "SELECT 1"}, + }) + if err != nil { + t.Fatal(err) + } + + if capturedHeaders == nil { + t.Fatal("no tool call headers captured") + } + if got := capturedHeaders.Get(methodHeader); got != "tools/call" { + t.Errorf("Mcp-Method = %q, want %q", got, "tools/call") + } + if got := capturedHeaders.Get(nameHeader); got != "execute_sql" { + t.Errorf("Mcp-Name = %q, want %q", got, "execute_sql") + } + if got := capturedHeaders.Get(paramHeaderPrefix + "Region"); got != "us-west1" { + t.Errorf("Mcp-Param-Region = %q, want %q", got, "us-west1") + } +} + +// TestStreamableFilterValidToolsIntegration verifies that invalid tools +// (with bad x-mcp-header annotations) are filtered out when the client +// calls ListTools. +func TestStreamableFilterValidToolsIntegration(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append(slices.Clone(orig), minVersionForStandardHeaders) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) + noop := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + return &CallToolResult{}, nil + } + + // Valid tool with correct x-mcp-header annotation. + server.AddTool(&Tool{ + Name: "valid-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "string", + "x-mcp-header": "Region", + }, + }, + }, + }, noop) + + // Invalid tool: x-mcp-header on an array type. + server.AddTool(&Tool{ + Name: "invalid-tool", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "items": map[string]any{ + "type": "array", + "x-mcp-header": "Items", + }, + }, + }, + }, noop) + + // Tool with no x-mcp-header annotations (always valid). + server.AddTool(&Tool{ + Name: "plain-tool", + InputSchema: &jsonschema.Schema{Type: "object"}, + }, noop) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + client := NewClient(&Implementation{Name: "testClient", Version: "v1.0.0"}, nil) + ctx := context.Background() + session, err := client.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, &ClientSessionOptions{protocolVersion: minVersionForStandardHeaders}) + if err != nil { + t.Fatal(err) + } + defer session.Close() + + result, err := session.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + + toolNames := make([]string, len(result.Tools)) + for i, tool := range result.Tools { + toolNames[i] = tool.Name + } + sort.Strings(toolNames) + + wantNames := []string{"plain-tool", "valid-tool"} + if !slices.Equal(toolNames, wantNames) { + t.Errorf("ListTools returned %v, want %v", toolNames, wantNames) + } +} + // TestStreamable405AllowHeader verifies RFC 9110 §15.5.6 compliance: // 405 Method Not Allowed responses MUST include an Allow header. func TestStreamable405AllowHeader(t *testing.T) {