From 370304adc84d6166d02dae35634f7dfcbcccaa6d Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 12 May 2026 16:12:11 +0000 Subject: [PATCH 1/8] feat: add support for registering and calling custom JSON-RPC methods with type safety --- examples/server/custom-method/main.go | 105 ++++++++++++ mcp/client.go | 11 +- mcp/custom.go | 109 ++++++++++++ mcp/custom_test.go | 235 ++++++++++++++++++++++++++ mcp/server.go | 16 +- mcp/streamable.go | 4 +- 6 files changed, 477 insertions(+), 3 deletions(-) create mode 100644 examples/server/custom-method/main.go create mode 100644 mcp/custom.go create mode 100644 mcp/custom_test.go diff --git a/examples/server/custom-method/main.go b/examples/server/custom-method/main.go new file mode 100644 index 00000000..d36bc6eb --- /dev/null +++ b/examples/server/custom-method/main.go @@ -0,0 +1,105 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +// The custom-method example demonstrates registering and calling a custom +// JSON-RPC method that is not part of the standard MCP spec. +// +// The server registers a "latin/translate" method that translates simple +// English phrases into Latin. A client connects over an in-memory transport, +// calls the custom method, and prints the result. +package main + +import ( + "context" + "fmt" + "log" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type TranslateParams struct { + mcp.ParamsBase + Text string `json:"text"` +} + +type TranslateResult struct { + mcp.ResultBase + Latin string `json:"latin"` +} + +var translations = map[string]string{ + "hello": "salve", + "goodbye": "vale", + "thank you": "gratias tibi ago", + "how are you": "quid agis", + "good morning": "bonum mane", + "good night": "bonam noctem", + "friend": "amicus", + "water": "aqua", + "love": "amor", + "war": "bellum", + "peace": "pax", + "truth": "veritas", + "light": "lux", + "time": "tempus", + "life": "vita", + "death": "mors", + "star": "stella", + "earth": "terra", + "sea": "mare", + "the die is cast": "alea iacta est", + "i came i saw i conquered": "veni vidi vici", + "seize the day": "carpe diem", +} + +func main() { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "latin-server", Version: "v1.0.0"}, nil) + + mcp.AddReceivingCustomMethod(server, "latin/translate", + func(ctx context.Context, ss *mcp.ServerSession, params *TranslateParams) (*TranslateResult, error) { + key := strings.ToLower(strings.TrimSpace(params.Text)) + latin, ok := translations[key] + if !ok { + latin = fmt.Sprintf("[unknown: %q — try: %s]", params.Text, knownPhrases()) + } + return &TranslateResult{Latin: latin}, nil + }) + + ct, st := mcp.NewInMemoryTransports() + + ss, err := server.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "latin-client", Version: "v1.0.0"}, nil) + translate := mcp.AddSendingCustomMethod[*TranslateParams, *TranslateResult](client, "latin/translate") + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + phrases := []string{"Hello", "Seize the day", "Peace", "Truth", "I came I saw I conquered"} + for _, phrase := range phrases { + result, err := translate(ctx, cs, &TranslateParams{Text: phrase}) + if err != nil { + log.Fatalf("translate %q: %v", phrase, err) + } + fmt.Printf("%-35s → %s\n", phrase, result.Latin) + } +} + +func knownPhrases() string { + phrases := make([]string, 0, len(translations)) + for k := range translations { + phrases = append(phrases, fmt.Sprintf("%q", k)) + } + return strings.Join(phrases, ", ") +} diff --git a/mcp/client.go b/mcp/client.go index 6e24c5a3..47686481 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -10,6 +10,7 @@ import ( "fmt" "iter" "log/slog" + "maps" "slices" "strings" "sync" @@ -32,6 +33,7 @@ type Client struct { sessions []*ClientSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler + customSendMethods map[string]methodInfo } // NewClient creates a new [Client]. @@ -64,6 +66,7 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], + customSendMethods: make(map[string]methodInfo), } } @@ -945,7 +948,13 @@ var clientMethodInfos = map[string]methodInfo{ } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { - return serverMethodInfos + if len(cs.client.customSendMethods) == 0 { + return serverMethodInfos + } + infos := make(map[string]methodInfo, len(serverMethodInfos)+len(cs.client.customSendMethods)) + maps.Copy(infos, serverMethodInfos) + maps.Copy(infos, cs.client.customSendMethods) + return infos } func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { diff --git a/mcp/custom.go b/mcp/custom.go new file mode 100644 index 00000000..70b451ba --- /dev/null +++ b/mcp/custom.go @@ -0,0 +1,109 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "reflect" +) + +// ParamsBase can be embedded in custom parameter structs to satisfy the +// [Params] interface. It provides the required [Meta] field and the unexported +// isParams marker method. +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +type ParamsBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ParamsBase) isParams() {} + +// ResultBase can be embedded in custom result structs to satisfy the +// [Result] interface. It provides the required [Meta] field and the unexported +// isResult marker method. +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +type ResultBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ResultBase) isResult() {} + +// AddReceivingCustomMethod registers a handler for a custom (non-standard) +// JSON-RPC method on the server. +// +// When a client sends a request with the given method name, the params will be +// unmarshaled into P, the handler will be called, and the returned R will be +// marshaled as the JSON-RPC result. +// +// Custom methods go through the server's middleware chain just like standard +// MCP methods (tools/call, prompts/list, etc.). +// +// P and R must implement [Params] and [Result] respectively, which is most +// easily done by embedding [ParamsBase] and [ResultBase]: +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +// +// mcp.AddReceivingCustomMethod(server, "acme/search", +// func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { +// return &SearchResult{Hits: []string{"result"}}, nil +// }) +func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( + s *Server, + method string, + handler func(ctx context.Context, ss *ServerSession, params P) (R, error), +) { + typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return handler(ctx, req.Session, req.Params) + }) + + s.mu.Lock() + defer s.mu.Unlock() + s.customMethods[method] = newServerMethodInfo(typed, missingParamsOK) +} + +// AddSendingCustomMethod registers a custom method that the client can send +// to the server and returns a typed caller function. +// +// The returned function calls the custom method through the client's sending +// middleware chain, with full type safety on both params and result. +// +// callSearch := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search") +// result, err := callSearch(ctx, cs, &SearchParams{Query: "hello"}) +func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( + c *Client, + method string, +) func(ctx context.Context, cs *ClientSession, params P) (R, error) { + mi := methodInfo{ + newResult: func() Result { + return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) + }, + } + + c.mu.Lock() + defer c.mu.Unlock() + c.customSendMethods[method] = mi + + return func(ctx context.Context, cs *ClientSession, params P) (R, error) { + return handleSend[R](ctx, method, &ClientRequest[P]{ + Session: cs, + Params: params, + }) + } +} diff --git a/mcp/custom_test.go b/mcp/custom_test.go new file mode 100644 index 00000000..5a6d76d8 --- /dev/null +++ b/mcp/custom_test.go @@ -0,0 +1,235 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "net/http" + "testing" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" +) + +type searchParams struct { + ParamsBase + Query string `json:"query"` + Limit int `json:"limit,omitempty"` +} + +type searchResult struct { + ResultBase + Hits []string `json:"hits"` + Total int `json:"total"` +} + +// callCustom calls a custom JSON-RPC method via the raw jsonrpc2 connection, +// bypassing the SDK's typed method dispatch. +func callCustom(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + return conn.Call(ctx, method, params).Await(ctx, result) +} + +func TestAddReceivingCustomMethod(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + hits := []string{"result for " + params.Query} + return &searchResult{ + Hits: hits, + Total: len(hits), + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var result searchResult + if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result); err != nil { + t.Fatal(err) + } + + if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { + t.Errorf("unexpected hits: %v", result.Hits) + } + if result.Total != 1 { + t.Errorf("unexpected total: %d", result.Total) + } +} + +func TestCustomMethodGoesThoughMiddleware(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + var middlewareCalled bool + s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == "acme/ping" { + middlewareCalled = true + } + return next(ctx, method, req) + } + }) + + type pingParams struct { + ParamsBase + } + type pingResult struct { + ResultBase + Pong bool `json:"pong"` + } + AddReceivingCustomMethod(s, "acme/ping", func(ctx context.Context, ss *ServerSession, params *pingParams) (*pingResult, error) { + return &pingResult{Pong: true}, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var result pingResult + if err := callCustom(ctx, cs.getConn(), "acme/ping", &pingParams{}, &result); err != nil { + t.Fatal(err) + } + + if !result.Pong { + t.Error("expected Pong to be true") + } + if !middlewareCalled { + t.Error("middleware was not called for custom method") + } +} + +func TestCustomMethodNotOnOtherServers(t *testing.T) { + ctx := context.Background() + + type emptyParams struct{ ParamsBase } + type emptyResult struct{ ResultBase } + + // Server 1 has the custom method. + s1 := NewServer(testImpl, nil) + AddReceivingCustomMethod(s1, "acme/custom", func(ctx context.Context, ss *ServerSession, params *emptyParams) (*emptyResult, error) { + return &emptyResult{}, nil + }) + + // Server 2 does NOT have the custom method. + s2 := NewServer(testImpl, nil) + + // Test s1: custom method should work. + ct1, st1 := NewInMemoryTransports() + ss1, err := s1.Connect(ctx, st1, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss1.Close() }) + + c1 := NewClient(testImpl, nil) + cs1, err := c1.Connect(ctx, ct1, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs1.Close() }) + + if err := callCustom(ctx, cs1.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}); err != nil { + t.Fatalf("custom method on s1 should work: %v", err) + } + + // Test s2: custom method should fail. + ct2, st2 := NewInMemoryTransports() + ss2, err := s2.Connect(ctx, st2, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss2.Close() }) + + c2 := NewClient(testImpl, nil) + cs2, err := c2.Connect(ctx, ct2, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs2.Close() }) + + err = callCustom(ctx, cs2.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}) + if err == nil { + t.Fatal("expected error calling custom method on s2") + } +} + +func TestCallCustomMethod(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + return &searchResult{ + Hits: []string{"result for " + params.Query}, + Total: 1, + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + result, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) + if err != nil { + t.Fatal(err) + } + if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { + t.Errorf("unexpected hits: %v", result.Hits) + } + if result.Total != 1 { + t.Errorf("unexpected total: %d", result.Total) + } +} + +func TestCustomMethodStreamableHTTP(t *testing.T) { + s := NewServer(testImpl, nil) + + type echoParams struct { + ParamsBase + Msg string `json:"msg"` + } + type echoResult struct { + ResultBase + Reply string `json:"reply"` + } + AddReceivingCustomMethod(s, "acme/echo", func(ctx context.Context, ss *ServerSession, params *echoParams) (*echoResult, error) { + return &echoResult{Reply: "echo: " + params.Msg}, nil + }) + + handler := NewStreamableHTTPHandler(func(r *http.Request) *Server { return s }, nil) + _ = handler +} diff --git a/mcp/server.go b/mcp/server.go index d25c7922..98c0c985 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -52,6 +52,7 @@ type Server struct { sessions []*ServerSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler + customMethods map[string]methodInfo resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } @@ -195,6 +196,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), + customMethods: make(map[string]methodInfo), } } @@ -1422,7 +1424,19 @@ func initializeMethodInfo() methodInfo { func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } -func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } +func (s *Server) receivingMethodInfos() map[string]methodInfo { + if len(s.customMethods) == 0 { + return serverMethodInfos + } + infos := make(map[string]methodInfo, len(serverMethodInfos)+len(s.customMethods)) + maps.Copy(infos, serverMethodInfos) + maps.Copy(infos, s.customMethods) + return infos +} + +func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { + return ss.server.receivingMethodInfos() +} func (ss *ServerSession) sendingMethodHandler() MethodHandler { s := ss.server diff --git a/mcp/streamable.go b/mcp/streamable.go index b8e36553..a4a34a52 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -396,6 +396,7 @@ func connectStreamable(ctx context.Context, server *Server, transport *Streamabl if err != nil { return nil, err } + transport.connection.server = server transport.connection.toolLookup = server.getServerTool return s, nil } @@ -734,6 +735,7 @@ type streamableServerConn struct { logger *slog.Logger + server *Server toolLookup func(name string) (*serverTool, bool) incoming chan jsonrpc.Message // messages from the client to the server @@ -1241,7 +1243,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // Preemptively check that this is a valid request, so that we can fail // the HTTP request. If we didn't do this, a request with a bad method or // missing ID could be silently swallowed. - if _, err := checkRequest(jreq, serverMethodInfos); err != nil { + if _, err := checkRequest(jreq, c.server.receivingMethodInfos()); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } From c4b8630abd1ce2b02cbc17ee679d1401d8356e49 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 13 May 2026 12:34:12 +0000 Subject: [PATCH 2/8] feat: add support for registering and calling type-safe custom MCP methods --- mcp/client.go | 31 ++++++ mcp/client_test.go | 39 ++++++++ mcp/custom.go | 104 -------------------- mcp/custom_test.go | 230 --------------------------------------------- mcp/server.go | 41 ++++++++ mcp/server_test.go | 184 ++++++++++++++++++++++++++++++++++++ mcp/shared.go | 28 ++++++ 7 files changed, 323 insertions(+), 334 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 47686481..093249ef 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -11,6 +11,7 @@ import ( "iter" "log/slog" "maps" + "reflect" "slices" "strings" "sync" @@ -1227,3 +1228,33 @@ func paginate[P listParams, R listResult[T], T any](ctx context.Context, params } } } + +// AddSendingCustomMethod registers a custom method that the client can send +// to the server and returns a typed caller function. +// +// The returned function calls the custom method through the client's sending +// middleware chain, with full type safety on both params and result. +// +// callSearch := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search") +// result, err := callSearch(ctx, cs, &SearchParams{Query: "hello"}) +func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( + c *Client, + method string, +) func(ctx context.Context, cs *ClientSession, params P) (R, error) { + mi := methodInfo{ + newResult: func() Result { + return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) + }, + } + + c.mu.Lock() + defer c.mu.Unlock() + c.customSendMethods[method] = mi + + return func(ctx context.Context, cs *ClientSession, params P) (R, error) { + return handleSend[R](ctx, method, &ClientRequest[P]{ + Session: cs, + Params: params, + }) + } +} diff --git a/mcp/client_test.go b/mcp/client_test.go index 609fd501..6bb5ee8e 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -617,3 +617,42 @@ func TestClientCapabilitiesOverWire(t *testing.T) { }) } } + +func TestCallCustomMethod(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + return &searchResult{ + Hits: []string{"result for " + params.Query}, + Total: 1, + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + result, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) + if err != nil { + t.Fatal(err) + } + if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { + t.Errorf("unexpected hits: %v", result.Hits) + } + if result.Total != 1 { + t.Errorf("unexpected total: %d", result.Total) + } +} diff --git a/mcp/custom.go b/mcp/custom.go index 70b451ba..e5aef305 100644 --- a/mcp/custom.go +++ b/mcp/custom.go @@ -3,107 +3,3 @@ // that can be found in the LICENSE file. package mcp - -import ( - "context" - "reflect" -) - -// ParamsBase can be embedded in custom parameter structs to satisfy the -// [Params] interface. It provides the required [Meta] field and the unexported -// isParams marker method. -// -// type SearchParams struct { -// mcp.ParamsBase -// Query string `json:"query"` -// } -type ParamsBase struct { - Meta `json:"_meta,omitempty"` -} - -func (*ParamsBase) isParams() {} - -// ResultBase can be embedded in custom result structs to satisfy the -// [Result] interface. It provides the required [Meta] field and the unexported -// isResult marker method. -// -// type SearchResult struct { -// mcp.ResultBase -// Hits []string `json:"hits"` -// } -type ResultBase struct { - Meta `json:"_meta,omitempty"` -} - -func (*ResultBase) isResult() {} - -// AddReceivingCustomMethod registers a handler for a custom (non-standard) -// JSON-RPC method on the server. -// -// When a client sends a request with the given method name, the params will be -// unmarshaled into P, the handler will be called, and the returned R will be -// marshaled as the JSON-RPC result. -// -// Custom methods go through the server's middleware chain just like standard -// MCP methods (tools/call, prompts/list, etc.). -// -// P and R must implement [Params] and [Result] respectively, which is most -// easily done by embedding [ParamsBase] and [ResultBase]: -// -// type SearchParams struct { -// mcp.ParamsBase -// Query string `json:"query"` -// } -// -// type SearchResult struct { -// mcp.ResultBase -// Hits []string `json:"hits"` -// } -// -// mcp.AddReceivingCustomMethod(server, "acme/search", -// func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { -// return &SearchResult{Hits: []string{"result"}}, nil -// }) -func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( - s *Server, - method string, - handler func(ctx context.Context, ss *ServerSession, params P) (R, error), -) { - typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { - return handler(ctx, req.Session, req.Params) - }) - - s.mu.Lock() - defer s.mu.Unlock() - s.customMethods[method] = newServerMethodInfo(typed, missingParamsOK) -} - -// AddSendingCustomMethod registers a custom method that the client can send -// to the server and returns a typed caller function. -// -// The returned function calls the custom method through the client's sending -// middleware chain, with full type safety on both params and result. -// -// callSearch := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search") -// result, err := callSearch(ctx, cs, &SearchParams{Query: "hello"}) -func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( - c *Client, - method string, -) func(ctx context.Context, cs *ClientSession, params P) (R, error) { - mi := methodInfo{ - newResult: func() Result { - return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) - }, - } - - c.mu.Lock() - defer c.mu.Unlock() - c.customSendMethods[method] = mi - - return func(ctx context.Context, cs *ClientSession, params P) (R, error) { - return handleSend[R](ctx, method, &ClientRequest[P]{ - Session: cs, - Params: params, - }) - } -} diff --git a/mcp/custom_test.go b/mcp/custom_test.go index 5a6d76d8..e5aef305 100644 --- a/mcp/custom_test.go +++ b/mcp/custom_test.go @@ -3,233 +3,3 @@ // that can be found in the LICENSE file. package mcp - -import ( - "context" - "net/http" - "testing" - - "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" -) - -type searchParams struct { - ParamsBase - Query string `json:"query"` - Limit int `json:"limit,omitempty"` -} - -type searchResult struct { - ResultBase - Hits []string `json:"hits"` - Total int `json:"total"` -} - -// callCustom calls a custom JSON-RPC method via the raw jsonrpc2 connection, -// bypassing the SDK's typed method dispatch. -func callCustom(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { - return conn.Call(ctx, method, params).Await(ctx, result) -} - -func TestAddReceivingCustomMethod(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { - hits := []string{"result for " + params.Query} - return &searchResult{ - Hits: hits, - Total: len(hits), - }, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - var result searchResult - if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result); err != nil { - t.Fatal(err) - } - - if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { - t.Errorf("unexpected hits: %v", result.Hits) - } - if result.Total != 1 { - t.Errorf("unexpected total: %d", result.Total) - } -} - -func TestCustomMethodGoesThoughMiddleware(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - var middlewareCalled bool - s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { - return func(ctx context.Context, method string, req Request) (Result, error) { - if method == "acme/ping" { - middlewareCalled = true - } - return next(ctx, method, req) - } - }) - - type pingParams struct { - ParamsBase - } - type pingResult struct { - ResultBase - Pong bool `json:"pong"` - } - AddReceivingCustomMethod(s, "acme/ping", func(ctx context.Context, ss *ServerSession, params *pingParams) (*pingResult, error) { - return &pingResult{Pong: true}, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - var result pingResult - if err := callCustom(ctx, cs.getConn(), "acme/ping", &pingParams{}, &result); err != nil { - t.Fatal(err) - } - - if !result.Pong { - t.Error("expected Pong to be true") - } - if !middlewareCalled { - t.Error("middleware was not called for custom method") - } -} - -func TestCustomMethodNotOnOtherServers(t *testing.T) { - ctx := context.Background() - - type emptyParams struct{ ParamsBase } - type emptyResult struct{ ResultBase } - - // Server 1 has the custom method. - s1 := NewServer(testImpl, nil) - AddReceivingCustomMethod(s1, "acme/custom", func(ctx context.Context, ss *ServerSession, params *emptyParams) (*emptyResult, error) { - return &emptyResult{}, nil - }) - - // Server 2 does NOT have the custom method. - s2 := NewServer(testImpl, nil) - - // Test s1: custom method should work. - ct1, st1 := NewInMemoryTransports() - ss1, err := s1.Connect(ctx, st1, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss1.Close() }) - - c1 := NewClient(testImpl, nil) - cs1, err := c1.Connect(ctx, ct1, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs1.Close() }) - - if err := callCustom(ctx, cs1.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}); err != nil { - t.Fatalf("custom method on s1 should work: %v", err) - } - - // Test s2: custom method should fail. - ct2, st2 := NewInMemoryTransports() - ss2, err := s2.Connect(ctx, st2, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss2.Close() }) - - c2 := NewClient(testImpl, nil) - cs2, err := c2.Connect(ctx, ct2, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs2.Close() }) - - err = callCustom(ctx, cs2.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}) - if err == nil { - t.Fatal("expected error calling custom method on s2") - } -} - -func TestCallCustomMethod(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { - return &searchResult{ - Hits: []string{"result for " + params.Query}, - Total: 1, - }, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") - - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - result, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) - if err != nil { - t.Fatal(err) - } - if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { - t.Errorf("unexpected hits: %v", result.Hits) - } - if result.Total != 1 { - t.Errorf("unexpected total: %d", result.Total) - } -} - -func TestCustomMethodStreamableHTTP(t *testing.T) { - s := NewServer(testImpl, nil) - - type echoParams struct { - ParamsBase - Msg string `json:"msg"` - } - type echoResult struct { - ResultBase - Reply string `json:"reply"` - } - AddReceivingCustomMethod(s, "acme/echo", func(ctx context.Context, ss *ServerSession, params *echoParams) (*echoResult, error) { - return &echoResult{Reply: "echo: " + params.Msg}, nil - }) - - handler := NewStreamableHTTPHandler(func(r *http.Request) *Server { return s }, nil) - _ = handler -} diff --git a/mcp/server.go b/mcp/server.go index 98c0c985..549645a6 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1644,3 +1644,44 @@ func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageS *res.nextCursorPtr() = nextCursor return res, nil } + +// AddReceivingCustomMethod registers a handler for a custom (non-standard) +// JSON-RPC method on the server. +// +// When a client sends a request with the given method name, the params will be +// unmarshaled into P, the handler will be called, and the returned R will be +// marshaled as the JSON-RPC result. +// +// Custom methods go through the server's middleware chain just like standard +// MCP methods (tools/call, prompts/list, etc.). +// +// P and R must implement [Params] and [Result] respectively, which is most +// easily done by embedding [ParamsBase] and [ResultBase]: +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +// +// mcp.AddReceivingCustomMethod(server, "acme/search", +// func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { +// return &SearchResult{Hits: []string{"result"}}, nil +// }) +func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( + s *Server, + method string, + handler func(ctx context.Context, ss *ServerSession, params P) (R, error), +) { + typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return handler(ctx, req.Session, req.Params) + }) + + s.mu.Lock() + defer s.mu.Unlock() + s.customMethods[method] = newServerMethodInfo(typed, missingParamsOK) +} diff --git a/mcp/server_test.go b/mcp/server_test.go index 2937ea2b..c3e7a4b6 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -11,6 +11,7 @@ import ( "fmt" "log" "log/slog" + "net/http" "slices" "strings" "testing" @@ -1007,3 +1008,186 @@ func TestServerCapabilitiesOverWire(t *testing.T) { }) } } + +type searchParams struct { + ParamsBase + Query string `json:"query"` + Limit int `json:"limit,omitempty"` +} + +type searchResult struct { + ResultBase + Hits []string `json:"hits"` + Total int `json:"total"` +} + +// callCustom calls a custom JSON-RPC method via the raw jsonrpc2 connection, +// bypassing the SDK's typed method dispatch. +func callCustom(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + return conn.Call(ctx, method, params).Await(ctx, result) +} + +func TestAddReceivingCustomMethod(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + hits := []string{"result for " + params.Query} + return &searchResult{ + Hits: hits, + Total: len(hits), + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var result searchResult + if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result); err != nil { + t.Fatal(err) + } + + if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { + t.Errorf("unexpected hits: %v", result.Hits) + } + if result.Total != 1 { + t.Errorf("unexpected total: %d", result.Total) + } +} + +func TestCustomMethodGoesThoughMiddleware(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + var middlewareCalled bool + s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == "acme/ping" { + middlewareCalled = true + } + return next(ctx, method, req) + } + }) + + type pingParams struct { + ParamsBase + } + type pingResult struct { + ResultBase + Pong bool `json:"pong"` + } + AddReceivingCustomMethod(s, "acme/ping", func(ctx context.Context, ss *ServerSession, params *pingParams) (*pingResult, error) { + return &pingResult{Pong: true}, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var result pingResult + if err := callCustom(ctx, cs.getConn(), "acme/ping", &pingParams{}, &result); err != nil { + t.Fatal(err) + } + + if !result.Pong { + t.Error("expected Pong to be true") + } + if !middlewareCalled { + t.Error("middleware was not called for custom method") + } +} + +func TestCustomMethodNotOnOtherServers(t *testing.T) { + ctx := context.Background() + + type emptyParams struct{ ParamsBase } + type emptyResult struct{ ResultBase } + + // Server 1 has the custom method. + s1 := NewServer(testImpl, nil) + AddReceivingCustomMethod(s1, "acme/custom", func(ctx context.Context, ss *ServerSession, params *emptyParams) (*emptyResult, error) { + return &emptyResult{}, nil + }) + + // Server 2 does NOT have the custom method. + s2 := NewServer(testImpl, nil) + + // Test s1: custom method should work. + ct1, st1 := NewInMemoryTransports() + ss1, err := s1.Connect(ctx, st1, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss1.Close() }) + + c1 := NewClient(testImpl, nil) + cs1, err := c1.Connect(ctx, ct1, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs1.Close() }) + + if err := callCustom(ctx, cs1.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}); err != nil { + t.Fatalf("custom method on s1 should work: %v", err) + } + + // Test s2: custom method should fail. + ct2, st2 := NewInMemoryTransports() + ss2, err := s2.Connect(ctx, st2, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss2.Close() }) + + c2 := NewClient(testImpl, nil) + cs2, err := c2.Connect(ctx, ct2, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs2.Close() }) + + err = callCustom(ctx, cs2.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}) + if err == nil { + t.Fatal("expected error calling custom method on s2") + } +} + +func TestCustomMethodStreamableHTTP(t *testing.T) { + s := NewServer(testImpl, nil) + + type echoParams struct { + ParamsBase + Msg string `json:"msg"` + } + type echoResult struct { + ResultBase + Reply string `json:"reply"` + } + AddReceivingCustomMethod(s, "acme/echo", func(ctx context.Context, ss *ServerSession, params *echoParams) (*echoResult, error) { + return &echoResult{Reply: "echo: " + params.Msg}, nil + }) + + handler := NewStreamableHTTPHandler(func(r *http.Request) *Server { return s }, nil) + _ = handler +} diff --git a/mcp/shared.go b/mcp/shared.go index 078b401b..42472704 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -544,6 +544,20 @@ type Params interface { isParams() } +// ParamsBase can be embedded in custom parameter structs to satisfy the +// [Params] interface. It provides the required [Meta] field and the unexported +// isParams marker method. +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +type ParamsBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ParamsBase) isParams() {} + // RequestParams is a parameter (input) type for an MCP request. type RequestParams interface { Params @@ -568,6 +582,20 @@ type Result interface { SetMeta(map[string]any) } +// ResultBase can be embedded in custom result structs to satisfy the +// [Result] interface. It provides the required [Meta] field and the unexported +// isResult marker method. +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +type ResultBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ResultBase) isResult() {} + // emptyResult is returned by methods that have no result, like ping. // Those methods cannot return nil, because jsonrpc2 cannot handle nils. type emptyResult struct{} From 7edd01cff5949d27a18bbdb220dd2c8f7a5eef5e Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 13 May 2026 12:38:55 +0000 Subject: [PATCH 3/8] refactor: remove unused custom MCP implementation files --- mcp/custom.go | 5 ----- mcp/custom_test.go | 5 ----- 2 files changed, 10 deletions(-) delete mode 100644 mcp/custom.go delete mode 100644 mcp/custom_test.go diff --git a/mcp/custom.go b/mcp/custom.go deleted file mode 100644 index e5aef305..00000000 --- a/mcp/custom.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by the license -// that can be found in the LICENSE file. - -package mcp diff --git a/mcp/custom_test.go b/mcp/custom_test.go deleted file mode 100644 index e5aef305..00000000 --- a/mcp/custom_test.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by the license -// that can be found in the LICENSE file. - -package mcp From 51f3a2b26438645601424b75491fe03f5a7b2303 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 13 May 2026 13:19:59 +0000 Subject: [PATCH 4/8] refactor: consolidate custom method tests into mcp_test.go --- mcp/client_test.go | 37 --------- mcp/mcp_test.go | 69 +++++++++++++++++ mcp/server_test.go | 186 +-------------------------------------------- 3 files changed, 70 insertions(+), 222 deletions(-) diff --git a/mcp/client_test.go b/mcp/client_test.go index 6bb5ee8e..f36d7d00 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -618,41 +618,4 @@ func TestClientCapabilitiesOverWire(t *testing.T) { } } -func TestCallCustomMethod(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { - return &searchResult{ - Hits: []string{"result for " + params.Query}, - Total: 1, - }, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - result, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) - if err != nil { - t.Fatal(err) - } - if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { - t.Errorf("unexpected hits: %v", result.Hits) - } - if result.Total != 1 { - t.Errorf("unexpected total: %d", result.Total) - } -} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 14173231..810b1e58 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -2372,3 +2372,72 @@ func TestSetErrorPreservesContent(t *testing.T) { } var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} + +func TestCustomMethods(t *testing.T) { + type searchParams struct { + ParamsBase + Query string `json:"query"` + Limit int `json:"limit,omitempty"` + } + + type searchResult struct { + ResultBase + Hits []string `json:"hits"` + Total int `json:"total"` + } + + callCustom := func(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + return conn.Call(ctx, method, params).Await(ctx, result) + } + + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + hits := []string{"result for " + params.Query} + return &searchResult{ + Hits: hits, + Total: len(hits), + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + // Test raw JSON-RPC call + var result1 searchResult + if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result1); err != nil { + t.Fatal(err) + } + if len(result1.Hits) != 1 || result1.Hits[0] != "result for hello" { + t.Errorf("raw call: unexpected hits: %v", result1.Hits) + } + if result1.Total != 1 { + t.Errorf("raw call: unexpected total: %d", result1.Total) + } + + // Test typed caller + result2, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) + if err != nil { + t.Fatal(err) + } + if len(result2.Hits) != 1 || result2.Hits[0] != "result for hello" { + t.Errorf("typed call: unexpected hits: %v", result2.Hits) + } + if result2.Total != 1 { + t.Errorf("typed call: unexpected total: %d", result2.Total) + } +} diff --git a/mcp/server_test.go b/mcp/server_test.go index c3e7a4b6..609f8e76 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -11,7 +11,6 @@ import ( "fmt" "log" "log/slog" - "net/http" "slices" "strings" "testing" @@ -1007,187 +1006,4 @@ func TestServerCapabilitiesOverWire(t *testing.T) { } }) } -} - -type searchParams struct { - ParamsBase - Query string `json:"query"` - Limit int `json:"limit,omitempty"` -} - -type searchResult struct { - ResultBase - Hits []string `json:"hits"` - Total int `json:"total"` -} - -// callCustom calls a custom JSON-RPC method via the raw jsonrpc2 connection, -// bypassing the SDK's typed method dispatch. -func callCustom(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { - return conn.Call(ctx, method, params).Await(ctx, result) -} - -func TestAddReceivingCustomMethod(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { - hits := []string{"result for " + params.Query} - return &searchResult{ - Hits: hits, - Total: len(hits), - }, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - var result searchResult - if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result); err != nil { - t.Fatal(err) - } - - if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { - t.Errorf("unexpected hits: %v", result.Hits) - } - if result.Total != 1 { - t.Errorf("unexpected total: %d", result.Total) - } -} - -func TestCustomMethodGoesThoughMiddleware(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - var middlewareCalled bool - s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { - return func(ctx context.Context, method string, req Request) (Result, error) { - if method == "acme/ping" { - middlewareCalled = true - } - return next(ctx, method, req) - } - }) - - type pingParams struct { - ParamsBase - } - type pingResult struct { - ResultBase - Pong bool `json:"pong"` - } - AddReceivingCustomMethod(s, "acme/ping", func(ctx context.Context, ss *ServerSession, params *pingParams) (*pingResult, error) { - return &pingResult{Pong: true}, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - var result pingResult - if err := callCustom(ctx, cs.getConn(), "acme/ping", &pingParams{}, &result); err != nil { - t.Fatal(err) - } - - if !result.Pong { - t.Error("expected Pong to be true") - } - if !middlewareCalled { - t.Error("middleware was not called for custom method") - } -} - -func TestCustomMethodNotOnOtherServers(t *testing.T) { - ctx := context.Background() - - type emptyParams struct{ ParamsBase } - type emptyResult struct{ ResultBase } - - // Server 1 has the custom method. - s1 := NewServer(testImpl, nil) - AddReceivingCustomMethod(s1, "acme/custom", func(ctx context.Context, ss *ServerSession, params *emptyParams) (*emptyResult, error) { - return &emptyResult{}, nil - }) - - // Server 2 does NOT have the custom method. - s2 := NewServer(testImpl, nil) - - // Test s1: custom method should work. - ct1, st1 := NewInMemoryTransports() - ss1, err := s1.Connect(ctx, st1, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss1.Close() }) - - c1 := NewClient(testImpl, nil) - cs1, err := c1.Connect(ctx, ct1, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs1.Close() }) - - if err := callCustom(ctx, cs1.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}); err != nil { - t.Fatalf("custom method on s1 should work: %v", err) - } - - // Test s2: custom method should fail. - ct2, st2 := NewInMemoryTransports() - ss2, err := s2.Connect(ctx, st2, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss2.Close() }) - - c2 := NewClient(testImpl, nil) - cs2, err := c2.Connect(ctx, ct2, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs2.Close() }) - - err = callCustom(ctx, cs2.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}) - if err == nil { - t.Fatal("expected error calling custom method on s2") - } -} - -func TestCustomMethodStreamableHTTP(t *testing.T) { - s := NewServer(testImpl, nil) - - type echoParams struct { - ParamsBase - Msg string `json:"msg"` - } - type echoResult struct { - ResultBase - Reply string `json:"reply"` - } - AddReceivingCustomMethod(s, "acme/echo", func(ctx context.Context, ss *ServerSession, params *echoParams) (*echoResult, error) { - return &echoResult{Reply: "echo: " + params.Msg}, nil - }) - - handler := NewStreamableHTTPHandler(func(r *http.Request) *Server { return s }, nil) - _ = handler -} +} \ No newline at end of file From e5748ee6ea53b19647d05429ace4b5cd7b6336b8 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 13 May 2026 13:21:31 +0000 Subject: [PATCH 5/8] refactor: remove trailing whitespace and empty lines from client and server test files --- mcp/client_test.go | 2 -- mcp/server_test.go | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/mcp/client_test.go b/mcp/client_test.go index f36d7d00..609fd501 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -617,5 +617,3 @@ func TestClientCapabilitiesOverWire(t *testing.T) { }) } } - - diff --git a/mcp/server_test.go b/mcp/server_test.go index 609f8e76..2937ea2b 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -1006,4 +1006,4 @@ func TestServerCapabilitiesOverWire(t *testing.T) { } }) } -} \ No newline at end of file +} From 1b3b509105c29a648fe2635eeb42f2da58616ee4 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 25 Jun 2026 11:46:18 +0000 Subject: [PATCH 6/8] refactor: decouple custom method registration from invocation and add validation to prevent shadowing standard MCP methods --- examples/server/custom-method/main.go | 13 +++-- mcp/client.go | 78 +++++++++++++++++++-------- mcp/mcp_test.go | 52 ++++++++++++++---- mcp/server.go | 44 +++++++++------ 4 files changed, 137 insertions(+), 50 deletions(-) diff --git a/examples/server/custom-method/main.go b/examples/server/custom-method/main.go index d36bc6eb..6b2861fe 100644 --- a/examples/server/custom-method/main.go +++ b/examples/server/custom-method/main.go @@ -59,7 +59,7 @@ func main() { server := mcp.NewServer(&mcp.Implementation{Name: "latin-server", Version: "v1.0.0"}, nil) - mcp.AddReceivingCustomMethod(server, "latin/translate", + if err := mcp.AddReceivingCustomMethod(server, "latin/translate", func(ctx context.Context, ss *mcp.ServerSession, params *TranslateParams) (*TranslateResult, error) { key := strings.ToLower(strings.TrimSpace(params.Text)) latin, ok := translations[key] @@ -67,7 +67,9 @@ func main() { latin = fmt.Sprintf("[unknown: %q — try: %s]", params.Text, knownPhrases()) } return &TranslateResult{Latin: latin}, nil - }) + }); err != nil { + log.Fatal(err) + } ct, st := mcp.NewInMemoryTransports() @@ -78,7 +80,9 @@ func main() { defer ss.Close() client := mcp.NewClient(&mcp.Implementation{Name: "latin-client", Version: "v1.0.0"}, nil) - translate := mcp.AddSendingCustomMethod[*TranslateParams, *TranslateResult](client, "latin/translate") + if err := mcp.AddSendingCustomMethod[*TranslateParams, *TranslateResult](client, "latin/translate"); err != nil { + log.Fatal(err) + } cs, err := client.Connect(ctx, ct, nil) if err != nil { @@ -88,7 +92,8 @@ func main() { phrases := []string{"Hello", "Seize the day", "Peace", "Truth", "I came I saw I conquered"} for _, phrase := range phrases { - result, err := translate(ctx, cs, &TranslateParams{Text: phrase}) + result, err := mcp.CallCustomMethod[*TranslateParams, *TranslateResult]( + ctx, cs, "latin/translate", &TranslateParams{Text: phrase}) if err != nil { log.Fatalf("translate %q: %v", phrase, err) } diff --git a/mcp/client.go b/mcp/client.go index 093249ef..e5ebf759 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -34,7 +34,11 @@ type Client struct { sessions []*ClientSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler - customSendMethods map[string]methodInfo + // sendMethods is the list of methods this client may send to a + // server: it always contains the standard server methods (from + // serverMethodInfos) plus any custom methods registered via + // [AddSendingCustomMethod]. + sendMethods map[string]methodInfo } // NewClient creates a new [Client]. @@ -61,13 +65,16 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { opts.Logger = ensureLogger(nil) } + sendMethods := make(map[string]methodInfo, len(serverMethodInfos)) + maps.Copy(sendMethods, serverMethodInfos) + return &Client{ impl: impl, opts: opts, roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], - customSendMethods: make(map[string]methodInfo), + sendMethods: sendMethods, } } @@ -949,13 +956,9 @@ var clientMethodInfos = map[string]methodInfo{ } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { - if len(cs.client.customSendMethods) == 0 { - return serverMethodInfos - } - infos := make(map[string]methodInfo, len(serverMethodInfos)+len(cs.client.customSendMethods)) - maps.Copy(infos, serverMethodInfos) - maps.Copy(infos, cs.client.customSendMethods) - return infos + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.sendMethods } func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { @@ -1229,18 +1232,30 @@ func paginate[P listParams, R listResult[T], T any](ctx context.Context, params } } -// AddSendingCustomMethod registers a custom method that the client can send -// to the server and returns a typed caller function. +// AddSendingCustomMethod registers a custom JSON-RPC method +// that the client may send to the server. // -// The returned function calls the custom method through the client's sending -// middleware chain, with full type safety on both params and result. +// Registration is decoupled from invocation: extensions typically call this +// during setup, while the actual call site uses [CallCustomMethod]. // -// callSearch := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search") -// result, err := callSearch(ctx, cs, &SearchParams{Query: "hello"}) +// if err := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search"); err != nil { +// return err +// } +// // ... later, anywhere a *ClientSession is available: +// result, err := mcp.CallCustomMethod[*SearchParams, *SearchResult]( +// ctx, cs, "acme/search", &SearchParams{Query: "hello"}) +// +// AddSendingCustomMethod returns an error if method is the name of a standard +// MCP method. Registering the same method twice replaces the previous +// registration. func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( c *Client, method string, -) func(ctx context.Context, cs *ClientSession, params P) (R, error) { +) error { + if _, ok := serverMethodInfos[method]; ok { + return fmt.Errorf("mcp: AddSendingCustomMethod: %q shadows a standard MCP method", method) + } + mi := methodInfo{ newResult: func() Result { return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) @@ -1249,12 +1264,31 @@ func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( c.mu.Lock() defer c.mu.Unlock() - c.customSendMethods[method] = mi + c.sendMethods[method] = mi + return nil +} - return func(ctx context.Context, cs *ClientSession, params P) (R, error) { - return handleSend[R](ctx, method, &ClientRequest[P]{ - Session: cs, - Params: params, - }) +// CallCustomMethod sends a custom (non-standard) JSON-RPC method to the +// server and decodes the response into R. +// +// The method must have been registered on the session's client via +// [AddSendingCustomMethod]. +func CallCustomMethod[P paramsPtr[PT], R Result, PT any]( + ctx context.Context, + cs *ClientSession, + method string, + params P, +) (R, error) { + c := cs.client + c.mu.Lock() + _, ok := c.sendMethods[method] + c.mu.Unlock() + if !ok { + var zero R + return zero, fmt.Errorf("mcp: CallCustomMethod: %q is not registered; call AddSendingCustomMethod first", method) } + return handleSend[R](ctx, method, &ClientRequest[P]{ + Session: cs, + Params: params, + }) } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 810b1e58..f0c4f0a9 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -2393,13 +2393,15 @@ func TestCustomMethods(t *testing.T) { ctx := context.Background() s := NewServer(testImpl, nil) - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + if err := AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { hits := []string{"result for " + params.Query} return &searchResult{ Hits: hits, Total: len(hits), }, nil - }) + }); err != nil { + t.Fatal(err) + } ct, st := NewInMemoryTransports() ss, err := s.Connect(ctx, st, nil) @@ -2409,7 +2411,9 @@ func TestCustomMethods(t *testing.T) { t.Cleanup(func() { _ = ss.Close() }) c := NewClient(testImpl, nil) - callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") + if err := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search"); err != nil { + t.Fatal(err) + } cs, err := c.Connect(ctx, ct, nil) if err != nil { @@ -2417,7 +2421,7 @@ func TestCustomMethods(t *testing.T) { } t.Cleanup(func() { _ = cs.Close() }) - // Test raw JSON-RPC call + // Test raw JSON-RPC call. var result1 searchResult if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result1); err != nil { t.Fatal(err) @@ -2429,15 +2433,45 @@ func TestCustomMethods(t *testing.T) { t.Errorf("raw call: unexpected total: %d", result1.Total) } - // Test typed caller - result2, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) + // Test the typed CallCustomMethod helper. + result2, err := CallCustomMethod[*searchParams, *searchResult]( + ctx, cs, "acme/search", &searchParams{Query: "world"}) if err != nil { t.Fatal(err) } - if len(result2.Hits) != 1 || result2.Hits[0] != "result for hello" { - t.Errorf("typed call: unexpected hits: %v", result2.Hits) + if len(result2.Hits) != 1 || result2.Hits[0] != "result for world" { + t.Errorf("CallCustomMethod: unexpected hits: %v", result2.Hits) } if result2.Total != 1 { - t.Errorf("typed call: unexpected total: %d", result2.Total) + t.Errorf("CallCustomMethod: unexpected total: %d", result2.Total) + } + + // CallCustomMethod must reject methods that were never registered. + if _, err := CallCustomMethod[*searchParams, *searchResult]( + ctx, cs, "acme/unregistered", &searchParams{Query: "x"}); err == nil { + t.Error("CallCustomMethod: expected error for unregistered method, got nil") } } + +func TestAddCustomMethodRejectsStandardMethods(t *testing.T) { + type params struct{ ParamsBase } + type result struct{ ResultBase } + + t.Run("server", func(t *testing.T) { + s := NewServer(testImpl, nil) + err := AddReceivingCustomMethod(s, "tools/call", + func(ctx context.Context, ss *ServerSession, p *params) (*result, error) { + return &result{}, nil + }) + if err == nil { + t.Fatal("AddReceivingCustomMethod: expected error when shadowing a standard method, got nil") + } + }) + + t.Run("client", func(t *testing.T) { + c := NewClient(testImpl, nil) + if err := AddSendingCustomMethod[*params, *result](c, "tools/call"); err == nil { + t.Fatal("AddSendingCustomMethod: expected error when shadowing a standard method, got nil") + } + }) +} diff --git a/mcp/server.go b/mcp/server.go index 549645a6..6df9a4b5 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -52,9 +52,13 @@ type Server struct { sessions []*ServerSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler - customMethods map[string]methodInfo - resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + // receiveMethods is the merged map of methods this server may receive + // from a client: it always contains the standard server methods (from + // serverMethodInfos) plus any custom methods registered via + // [AddReceivingCustomMethod]. + receiveMethods map[string]methodInfo + resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -185,6 +189,9 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { opts.Logger = ensureLogger(nil) } + receiveMethods := make(map[string]methodInfo, len(serverMethodInfos)) + maps.Copy(receiveMethods, serverMethodInfos) + return &Server{ impl: impl, opts: opts, @@ -196,7 +203,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), - customMethods: make(map[string]methodInfo), + receiveMethods: receiveMethods, } } @@ -1425,13 +1432,9 @@ func initializeMethodInfo() methodInfo { func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } func (s *Server) receivingMethodInfos() map[string]methodInfo { - if len(s.customMethods) == 0 { - return serverMethodInfos - } - infos := make(map[string]methodInfo, len(serverMethodInfos)+len(s.customMethods)) - maps.Copy(infos, serverMethodInfos) - maps.Copy(infos, s.customMethods) - return infos + s.mu.Lock() + defer s.mu.Unlock() + return s.receiveMethods } func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { @@ -1668,20 +1671,31 @@ func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageS // Hits []string `json:"hits"` // } // -// mcp.AddReceivingCustomMethod(server, "acme/search", +// if err := mcp.AddReceivingCustomMethod(server, "acme/search", // func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { // return &SearchResult{Hits: []string{"result"}}, nil -// }) +// }); err != nil { +// return err +// } +// +// AddReceivingCustomMethod returns an error if method is the name of a +// standard MCP method. Registering the same custom method twice replaces the +// previous handler. func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( s *Server, method string, handler func(ctx context.Context, ss *ServerSession, params P) (R, error), -) { +) error { + if _, ok := serverMethodInfos[method]; ok { + return fmt.Errorf("mcp: AddReceivingCustomMethod: %q shadows a standard MCP method", method) + } + typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { return handler(ctx, req.Session, req.Params) }) s.mu.Lock() defer s.mu.Unlock() - s.customMethods[method] = newServerMethodInfo(typed, missingParamsOK) + s.receiveMethods[method] = newServerMethodInfo(typed, missingParamsOK) + return nil } From 5ffcb6097dac54522715c566b9f903e3f6e49207 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 25 Jun 2026 12:01:23 +0000 Subject: [PATCH 7/8] refactor: remove redundant documentation for ParamsBase.isNil --- mcp/shared.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mcp/shared.go b/mcp/shared.go index 153df165..6711d304 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -741,13 +741,6 @@ type ParamsBase struct { func (*ParamsBase) isParams() {} -// isNil reports whether the embedded ParamsBase pointer is nil. Custom -// parameter types that embed [ParamsBase] inherit this method, which is -// sufficient for the Params interface as long as callers already nil-check -// the outer pointer before invoking it (see getRequestMeta). -// -// Concrete protocol params types in this package override isNil with their -// own implementation that compares the outer pointer to nil. func (p *ParamsBase) isNil() bool { return p == nil } // RequestParams is a parameter (input) type for an MCP request. From 98e86e6fd5eb2b6f3a63c0cafc19278459c24ffc Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 25 Jun 2026 15:37:44 +0000 Subject: [PATCH 8/8] fix: prevent dereferencing typed-nil parameters in CallCustomMethod by allocating fresh values before meta-injection --- mcp/client.go | 5 ++++- mcp/mcp_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/mcp/client.go b/mcp/client.go index e679c52a..0a6c906f 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -515,7 +515,7 @@ func injectRequestMeta[T any, P interface { Params }](cs *ClientSession, params P) P { res := cs.state.InitializeResult - if params.isNil() { + if params == nil { params = new(T) } m := params.GetMeta() @@ -1666,6 +1666,9 @@ func CallCustomMethod[P paramsPtr[PT], R Result, PT any]( var zero R return zero, fmt.Errorf("mcp: CallCustomMethod: %q is not registered; call AddSendingCustomMethod first", method) } + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } return handleSend[R](ctx, method, &ClientRequest[P]{ Session: cs, Params: params, diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 08d0c897..c0a9cfef 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -3336,3 +3336,42 @@ func TestAddCustomMethodRejectsStandardMethods(t *testing.T) { } }) } + +// TestCallCustomMethodTypedNilParams exercises the typed-nil params path. +// User param types embed ParamsBase, so the inherited isNil forwarder would +// dereference a typed-nil outer if injectRequestMeta were called with it. +// CallCustomMethod must allocate a fresh value before the meta-injection step. +func TestCallCustomMethodTypedNilParams(t *testing.T) { + type pingParams struct{ ParamsBase } + type pingResult struct{ ResultBase } + + ctx := context.Background() + s := NewServer(testImpl, nil) + if err := AddReceivingCustomMethod(s, "acme/ping", + func(ctx context.Context, ss *ServerSession, p *pingParams) (*pingResult, error) { + return &pingResult{}, nil + }); err != nil { + t.Fatal(err) + } + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + if err := AddSendingCustomMethod[*pingParams, *pingResult](c, "acme/ping"); err != nil { + t.Fatal(err) + } + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var typedNil *pingParams + if _, err := CallCustomMethod[*pingParams, *pingResult](ctx, cs, "acme/ping", typedNil); err != nil { + t.Fatalf("CallCustomMethod with typed-nil params: %v", err) + } +}