diff --git a/examples/server/custom-method/main.go b/examples/server/custom-method/main.go new file mode 100644 index 00000000..6b2861fe --- /dev/null +++ b/examples/server/custom-method/main.go @@ -0,0 +1,110 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +// The custom-method example demonstrates registering and calling a custom +// JSON-RPC method that is not part of the standard MCP spec. +// +// The server registers a "latin/translate" method that translates simple +// English phrases into Latin. A client connects over an in-memory transport, +// calls the custom method, and prints the result. +package main + +import ( + "context" + "fmt" + "log" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type TranslateParams struct { + mcp.ParamsBase + Text string `json:"text"` +} + +type TranslateResult struct { + mcp.ResultBase + Latin string `json:"latin"` +} + +var translations = map[string]string{ + "hello": "salve", + "goodbye": "vale", + "thank you": "gratias tibi ago", + "how are you": "quid agis", + "good morning": "bonum mane", + "good night": "bonam noctem", + "friend": "amicus", + "water": "aqua", + "love": "amor", + "war": "bellum", + "peace": "pax", + "truth": "veritas", + "light": "lux", + "time": "tempus", + "life": "vita", + "death": "mors", + "star": "stella", + "earth": "terra", + "sea": "mare", + "the die is cast": "alea iacta est", + "i came i saw i conquered": "veni vidi vici", + "seize the day": "carpe diem", +} + +func main() { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "latin-server", Version: "v1.0.0"}, nil) + + if err := mcp.AddReceivingCustomMethod(server, "latin/translate", + func(ctx context.Context, ss *mcp.ServerSession, params *TranslateParams) (*TranslateResult, error) { + key := strings.ToLower(strings.TrimSpace(params.Text)) + latin, ok := translations[key] + if !ok { + latin = fmt.Sprintf("[unknown: %q — try: %s]", params.Text, knownPhrases()) + } + return &TranslateResult{Latin: latin}, nil + }); err != nil { + log.Fatal(err) + } + + ct, st := mcp.NewInMemoryTransports() + + ss, err := server.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "latin-client", Version: "v1.0.0"}, nil) + if err := mcp.AddSendingCustomMethod[*TranslateParams, *TranslateResult](client, "latin/translate"); err != nil { + log.Fatal(err) + } + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + phrases := []string{"Hello", "Seize the day", "Peace", "Truth", "I came I saw I conquered"} + for _, phrase := range phrases { + result, err := mcp.CallCustomMethod[*TranslateParams, *TranslateResult]( + ctx, cs, "latin/translate", &TranslateParams{Text: phrase}) + if err != nil { + log.Fatalf("translate %q: %v", phrase, err) + } + fmt.Printf("%-35s → %s\n", phrase, result.Latin) + } +} + +func knownPhrases() string { + phrases := make([]string, 0, len(translations)) + for k := range translations { + phrases = append(phrases, fmt.Sprintf("%q", k)) + } + return strings.Join(phrases, ", ") +} diff --git a/mcp/client.go b/mcp/client.go index 552b3d60..0a6c906f 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -10,6 +10,8 @@ import ( "fmt" "iter" "log/slog" + "maps" + "reflect" "slices" "strings" "sync" @@ -32,6 +34,11 @@ type Client struct { sessions []*ClientSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler + // sendMethods is the list of methods this client may send to a + // server: it always contains the standard server methods (from + // serverMethodInfos) plus any custom methods registered via + // [AddSendingCustomMethod]. + sendMethods map[string]methodInfo } // NewClient creates a new [Client]. @@ -58,12 +65,16 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { opts.Logger = ensureLogger(nil) } + sendMethods := make(map[string]methodInfo, len(serverMethodInfos)) + maps.Copy(sendMethods, serverMethodInfos) + c := &Client{ impl: impl, opts: opts, roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], + sendMethods: sendMethods, } if opts.MultiRoundTrip == nil || !opts.MultiRoundTrip.Disabled { c.AddSendingMiddleware(clientMultiRoundTripMiddleware()) @@ -504,7 +515,7 @@ func injectRequestMeta[T any, P interface { Params }](cs *ClientSession, params P) P { res := cs.state.InitializeResult - if params.isNil() { + if params == nil { params = new(T) } m := params.GetMeta() @@ -1153,7 +1164,9 @@ var clientMethodInfos = map[string]methodInfo{ } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { - return serverMethodInfos + cs.client.mu.Lock() + defer cs.client.mu.Unlock() + return cs.client.sendMethods } func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { @@ -1597,3 +1610,67 @@ func paginate[P listParams, R listResult[T], T any](ctx context.Context, params } } } + +// AddSendingCustomMethod registers a custom JSON-RPC method +// that the client may send to the server. +// +// Registration is decoupled from invocation: extensions typically call this +// during setup, while the actual call site uses [CallCustomMethod]. +// +// if err := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search"); err != nil { +// return err +// } +// // ... later, anywhere a *ClientSession is available: +// result, err := mcp.CallCustomMethod[*SearchParams, *SearchResult]( +// ctx, cs, "acme/search", &SearchParams{Query: "hello"}) +// +// AddSendingCustomMethod returns an error if method is the name of a standard +// MCP method. Registering the same method twice replaces the previous +// registration. +func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( + c *Client, + method string, +) error { + if _, ok := serverMethodInfos[method]; ok { + return fmt.Errorf("mcp: AddSendingCustomMethod: %q shadows a standard MCP method", method) + } + + mi := methodInfo{ + newResult: func() Result { + return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) + }, + } + + c.mu.Lock() + defer c.mu.Unlock() + c.sendMethods[method] = mi + return nil +} + +// CallCustomMethod sends a custom (non-standard) JSON-RPC method to the +// server and decodes the response into R. +// +// The method must have been registered on the session's client via +// [AddSendingCustomMethod]. +func CallCustomMethod[P paramsPtr[PT], R Result, PT any]( + ctx context.Context, + cs *ClientSession, + method string, + params P, +) (R, error) { + c := cs.client + c.mu.Lock() + _, ok := c.sendMethods[method] + c.mu.Unlock() + if !ok { + var zero R + return zero, fmt.Errorf("mcp: CallCustomMethod: %q is not registered; call AddSendingCustomMethod first", method) + } + if cs.usesNewProtocol() { + params = injectRequestMeta(cs, params) + } + return handleSend[R](ctx, method, &ClientRequest[P]{ + Session: cs, + Params: params, + }) +} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 82a919d5..c0a9cfef 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -3233,3 +3233,145 @@ func TestSubscriptionsListen_DisconnectScrubsMaps(t *testing.T) { time.Sleep(10 * time.Millisecond) } } + +func TestCustomMethods(t *testing.T) { + type searchParams struct { + ParamsBase + Query string `json:"query"` + Limit int `json:"limit,omitempty"` + } + + type searchResult struct { + ResultBase + Hits []string `json:"hits"` + Total int `json:"total"` + } + + callCustom := func(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + return conn.Call(ctx, method, params).Await(ctx, result) + } + + ctx := context.Background() + s := NewServer(testImpl, nil) + + if err := AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + hits := []string{"result for " + params.Query} + return &searchResult{ + Hits: hits, + Total: len(hits), + }, nil + }); err != nil { + t.Fatal(err) + } + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + if err := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search"); err != nil { + t.Fatal(err) + } + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + // Test raw JSON-RPC call. + var result1 searchResult + if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result1); err != nil { + t.Fatal(err) + } + if len(result1.Hits) != 1 || result1.Hits[0] != "result for hello" { + t.Errorf("raw call: unexpected hits: %v", result1.Hits) + } + if result1.Total != 1 { + t.Errorf("raw call: unexpected total: %d", result1.Total) + } + + // Test the typed CallCustomMethod helper. + result2, err := CallCustomMethod[*searchParams, *searchResult]( + ctx, cs, "acme/search", &searchParams{Query: "world"}) + if err != nil { + t.Fatal(err) + } + if len(result2.Hits) != 1 || result2.Hits[0] != "result for world" { + t.Errorf("CallCustomMethod: unexpected hits: %v", result2.Hits) + } + if result2.Total != 1 { + t.Errorf("CallCustomMethod: unexpected total: %d", result2.Total) + } + + // CallCustomMethod must reject methods that were never registered. + if _, err := CallCustomMethod[*searchParams, *searchResult]( + ctx, cs, "acme/unregistered", &searchParams{Query: "x"}); err == nil { + t.Error("CallCustomMethod: expected error for unregistered method, got nil") + } +} + +func TestAddCustomMethodRejectsStandardMethods(t *testing.T) { + type params struct{ ParamsBase } + type result struct{ ResultBase } + + t.Run("server", func(t *testing.T) { + s := NewServer(testImpl, nil) + err := AddReceivingCustomMethod(s, "tools/call", + func(ctx context.Context, ss *ServerSession, p *params) (*result, error) { + return &result{}, nil + }) + if err == nil { + t.Fatal("AddReceivingCustomMethod: expected error when shadowing a standard method, got nil") + } + }) + + t.Run("client", func(t *testing.T) { + c := NewClient(testImpl, nil) + if err := AddSendingCustomMethod[*params, *result](c, "tools/call"); err == nil { + t.Fatal("AddSendingCustomMethod: expected error when shadowing a standard method, got nil") + } + }) +} + +// TestCallCustomMethodTypedNilParams exercises the typed-nil params path. +// User param types embed ParamsBase, so the inherited isNil forwarder would +// dereference a typed-nil outer if injectRequestMeta were called with it. +// CallCustomMethod must allocate a fresh value before the meta-injection step. +func TestCallCustomMethodTypedNilParams(t *testing.T) { + type pingParams struct{ ParamsBase } + type pingResult struct{ ResultBase } + + ctx := context.Background() + s := NewServer(testImpl, nil) + if err := AddReceivingCustomMethod(s, "acme/ping", + func(ctx context.Context, ss *ServerSession, p *pingParams) (*pingResult, error) { + return &pingResult{}, nil + }); err != nil { + t.Fatal(err) + } + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + if err := AddSendingCustomMethod[*pingParams, *pingResult](c, "acme/ping"); err != nil { + t.Fatal(err) + } + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var typedNil *pingParams + if _, err := CallCustomMethod[*pingParams, *pingResult](ctx, cs, "acme/ping", typedNil); err != nil { + t.Fatalf("CallCustomMethod with typed-nil params: %v", err) + } +} diff --git a/mcp/server.go b/mcp/server.go index e07d3c25..21b5a722 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -57,6 +57,11 @@ type Server struct { resourceChangeSubscriptions map[*ServerSession]jsonrpc.ID // session -> requestID for "resources/changed" resourceSubscriptions map[string]map[*ServerSession]jsonrpc.ID // uri -> session -> requestID pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + // receiveMethods is the merged map of methods this server may receive + // from a client: it always contains the standard server methods (from + // serverMethodInfos) plus any custom methods registered via + // [AddReceivingCustomMethod]. + receiveMethods map[string]methodInfo } // ServerOptions is used to configure behavior of the server. @@ -203,6 +208,9 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { opts.Logger = ensureLogger(nil) } + receiveMethods := make(map[string]methodInfo, len(serverMethodInfos)) + maps.Copy(receiveMethods, serverMethodInfos) + s := &Server{ impl: impl, opts: opts, @@ -217,6 +225,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { resourceChangeSubscriptions: make(map[*ServerSession]jsonrpc.ID), resourceSubscriptions: make(map[string]map[*ServerSession]jsonrpc.ID), pendingNotifications: make(map[string]*time.Timer), + receiveMethods: receiveMethods, } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) return s @@ -1761,7 +1770,15 @@ func initializeMethodInfo() methodInfo { func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } -func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } +func (s *Server) receivingMethodInfos() map[string]methodInfo { + s.mu.Lock() + defer s.mu.Unlock() + return s.receiveMethods +} + +func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { + return ss.server.receivingMethodInfos() +} func (ss *ServerSession) sendingMethodHandler() MethodHandler { s := ss.server @@ -2016,3 +2033,55 @@ func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageS *res.nextCursorPtr() = nextCursor return res, nil } + +// AddReceivingCustomMethod registers a handler for a custom (non-standard) +// JSON-RPC method on the server. +// +// When a client sends a request with the given method name, the params will be +// unmarshaled into P, the handler will be called, and the returned R will be +// marshaled as the JSON-RPC result. +// +// Custom methods go through the server's middleware chain just like standard +// MCP methods (tools/call, prompts/list, etc.). +// +// P and R must implement [Params] and [Result] respectively, which is most +// easily done by embedding [ParamsBase] and [ResultBase]: +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +// +// if err := mcp.AddReceivingCustomMethod(server, "acme/search", +// func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { +// return &SearchResult{Hits: []string{"result"}}, nil +// }); err != nil { +// return err +// } +// +// AddReceivingCustomMethod returns an error if method is the name of a +// standard MCP method. Registering the same custom method twice replaces the +// previous handler. +func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( + s *Server, + method string, + handler func(ctx context.Context, ss *ServerSession, params P) (R, error), +) error { + if _, ok := serverMethodInfos[method]; ok { + return fmt.Errorf("mcp: AddReceivingCustomMethod: %q shadows a standard MCP method", method) + } + + typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return handler(ctx, req.Session, req.Params) + }) + + s.mu.Lock() + defer s.mu.Unlock() + s.receiveMethods[method] = newServerMethodInfo(typed, missingParamsOK) + return nil +} diff --git a/mcp/shared.go b/mcp/shared.go index 2cf37b15..6711d304 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -727,6 +727,22 @@ type Params interface { isNil() bool } +// ParamsBase can be embedded in custom parameter structs to satisfy the +// [Params] interface. It provides the required [Meta] field and the unexported +// isParams marker method. +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +type ParamsBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ParamsBase) isParams() {} + +func (p *ParamsBase) isNil() bool { return p == nil } + // RequestParams is a parameter (input) type for an MCP request. type RequestParams interface { Params @@ -751,6 +767,20 @@ type Result interface { SetMeta(map[string]any) } +// ResultBase can be embedded in custom result structs to satisfy the +// [Result] interface. It provides the required [Meta] field and the unexported +// isResult marker method. +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +type ResultBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ResultBase) isResult() {} + // emptyResult is returned by methods that have no result, like ping. // Those methods cannot return nil, because jsonrpc2 cannot handle nils. type emptyResult struct{} diff --git a/mcp/streamable.go b/mcp/streamable.go index b71bc774..c047c156 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -486,6 +486,7 @@ func connectStreamable(ctx context.Context, server *Server, transport *Streamabl if err != nil { return nil, err } + transport.connection.server = server transport.connection.toolLookup = server.getServerTool return s, nil } @@ -846,6 +847,7 @@ type streamableServerConn struct { logger *slog.Logger + server *Server toolLookup func(name string) (*serverTool, bool) incoming chan jsonrpc.Message // messages from the client to the server @@ -1415,7 +1417,15 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // Preemptively check that this is a valid request, so that we can fail // the HTTP request. If we didn't do this, a request with a bad method or // missing ID could be silently swallowed. - if _, err := checkRequest(jreq, serverMethodInfos); err != nil { + // Use the server's receiving method infos (which include any custom + // methods registered via AddReceivingCustomMethod) when available; + // fall back to the standard methods otherwise, e.g. in tests that + // exercise streamableServerConn directly without a server. + methodInfos := serverMethodInfos + if c.server != nil { + methodInfos = c.server.receivingMethodInfos() + } + if _, err := checkRequest(jreq, methodInfos); err != nil { if protocolVersion >= protocolVersion20260728 && errors.Is(err, jsonrpc2.ErrNotHandled) && jreq.IsCall() { writeJSONRPCError(w, http.StatusNotFound, jreq.ID, &jsonrpc.Error{ Code: jsonrpc.CodeMethodNotFound,