Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions examples/server/custom-method/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by the license
// that can be found in the LICENSE file.

// The custom-method example demonstrates registering and calling a custom
// JSON-RPC method that is not part of the standard MCP spec.
//
// The server registers a "latin/translate" method that translates simple
// English phrases into Latin. A client connects over an in-memory transport,
// calls the custom method, and prints the result.
package main

import (
"context"
"fmt"
"log"
"strings"

"github.com/modelcontextprotocol/go-sdk/mcp"
)

type TranslateParams struct {
mcp.ParamsBase
Text string `json:"text"`
}

type TranslateResult struct {
mcp.ResultBase
Latin string `json:"latin"`
}

var translations = map[string]string{
"hello": "salve",
"goodbye": "vale",
"thank you": "gratias tibi ago",
"how are you": "quid agis",
"good morning": "bonum mane",
"good night": "bonam noctem",
"friend": "amicus",
"water": "aqua",
"love": "amor",
"war": "bellum",
"peace": "pax",
"truth": "veritas",
"light": "lux",
"time": "tempus",
"life": "vita",
"death": "mors",
"star": "stella",
"earth": "terra",
"sea": "mare",
"the die is cast": "alea iacta est",
"i came i saw i conquered": "veni vidi vici",
"seize the day": "carpe diem",
}

func main() {
ctx := context.Background()

server := mcp.NewServer(&mcp.Implementation{Name: "latin-server", Version: "v1.0.0"}, nil)

if err := mcp.AddReceivingCustomMethod(server, "latin/translate",
func(ctx context.Context, ss *mcp.ServerSession, params *TranslateParams) (*TranslateResult, error) {
key := strings.ToLower(strings.TrimSpace(params.Text))
latin, ok := translations[key]
if !ok {
latin = fmt.Sprintf("[unknown: %q — try: %s]", params.Text, knownPhrases())
}
return &TranslateResult{Latin: latin}, nil
}); err != nil {
log.Fatal(err)
}

ct, st := mcp.NewInMemoryTransports()

ss, err := server.Connect(ctx, st, nil)
if err != nil {
log.Fatal(err)
}
defer ss.Close()

client := mcp.NewClient(&mcp.Implementation{Name: "latin-client", Version: "v1.0.0"}, nil)
if err := mcp.AddSendingCustomMethod[*TranslateParams, *TranslateResult](client, "latin/translate"); err != nil {
log.Fatal(err)
}

cs, err := client.Connect(ctx, ct, nil)
if err != nil {
log.Fatal(err)
}
defer cs.Close()

phrases := []string{"Hello", "Seize the day", "Peace", "Truth", "I came I saw I conquered"}
for _, phrase := range phrases {
result, err := mcp.CallCustomMethod[*TranslateParams, *TranslateResult](
ctx, cs, "latin/translate", &TranslateParams{Text: phrase})
if err != nil {
log.Fatalf("translate %q: %v", phrase, err)
}
fmt.Printf("%-35s → %s\n", phrase, result.Latin)
}
}

func knownPhrases() string {
phrases := make([]string, 0, len(translations))
for k := range translations {
phrases = append(phrases, fmt.Sprintf("%q", k))
}
return strings.Join(phrases, ", ")
}
81 changes: 79 additions & 2 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"fmt"
"iter"
"log/slog"
"maps"
"reflect"
"slices"
"strings"
"sync"
Expand All @@ -32,6 +34,11 @@ type Client struct {
sessions []*ClientSession
sendingMethodHandler_ MethodHandler
receivingMethodHandler_ MethodHandler
// sendMethods is the list of methods this client may send to a
// server: it always contains the standard server methods (from
// serverMethodInfos) plus any custom methods registered via
// [AddSendingCustomMethod].
sendMethods map[string]methodInfo
}

// NewClient creates a new [Client].
Expand All @@ -58,12 +65,16 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client {
opts.Logger = ensureLogger(nil)
}

sendMethods := make(map[string]methodInfo, len(serverMethodInfos))
maps.Copy(sendMethods, serverMethodInfos)

c := &Client{
impl: impl,
opts: opts,
roots: newFeatureSet(func(r *Root) string { return r.URI }),
sendingMethodHandler_: defaultSendingMethodHandler,
receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession],
sendMethods: sendMethods,
}
if opts.MultiRoundTrip == nil || !opts.MultiRoundTrip.Disabled {
c.AddSendingMiddleware(clientMultiRoundTripMiddleware())
Expand Down Expand Up @@ -504,7 +515,7 @@ func injectRequestMeta[T any, P interface {
Params
}](cs *ClientSession, params P) P {
res := cs.state.InitializeResult
if params.isNil() {
if params == nil {
params = new(T)
}
m := params.GetMeta()
Expand Down Expand Up @@ -1153,7 +1164,9 @@ var clientMethodInfos = map[string]methodInfo{
}

func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo {
return serverMethodInfos
cs.client.mu.Lock()
defer cs.client.mu.Unlock()
return cs.client.sendMethods
}

func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo {
Expand Down Expand Up @@ -1597,3 +1610,67 @@ func paginate[P listParams, R listResult[T], T any](ctx context.Context, params
}
}
}

// AddSendingCustomMethod registers a custom JSON-RPC method
// that the client may send to the server.
//
// Registration is decoupled from invocation: extensions typically call this
// during setup, while the actual call site uses [CallCustomMethod].
//
// if err := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search"); err != nil {
// return err
// }
// // ... later, anywhere a *ClientSession is available:
// result, err := mcp.CallCustomMethod[*SearchParams, *SearchResult](
// ctx, cs, "acme/search", &SearchParams{Query: "hello"})
//
// AddSendingCustomMethod returns an error if method is the name of a standard
// MCP method. Registering the same method twice replaces the previous
// registration.
func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any](
c *Client,
method string,
) error {
if _, ok := serverMethodInfos[method]; ok {
return fmt.Errorf("mcp: AddSendingCustomMethod: %q shadows a standard MCP method", method)
}

mi := methodInfo{
newResult: func() Result {
return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R)
},
}

c.mu.Lock()
defer c.mu.Unlock()
c.sendMethods[method] = mi
return nil
}

// CallCustomMethod sends a custom (non-standard) JSON-RPC method to the
// server and decodes the response into R.
//
// The method must have been registered on the session's client via
// [AddSendingCustomMethod].
func CallCustomMethod[P paramsPtr[PT], R Result, PT any](
ctx context.Context,
cs *ClientSession,
method string,
params P,
) (R, error) {
c := cs.client
c.mu.Lock()
_, ok := c.sendMethods[method]
c.mu.Unlock()
if !ok {
var zero R
return zero, fmt.Errorf("mcp: CallCustomMethod: %q is not registered; call AddSendingCustomMethod first", method)
}
if cs.usesNewProtocol() {
params = injectRequestMeta(cs, params)
}
return handleSend[R](ctx, method, &ClientRequest[P]{
Session: cs,
Params: params,
})
}
142 changes: 142 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3233,3 +3233,145 @@ func TestSubscriptionsListen_DisconnectScrubsMaps(t *testing.T) {
time.Sleep(10 * time.Millisecond)
}
}

func TestCustomMethods(t *testing.T) {
type searchParams struct {
ParamsBase
Query string `json:"query"`
Limit int `json:"limit,omitempty"`
}

type searchResult struct {
ResultBase
Hits []string `json:"hits"`
Total int `json:"total"`
}

callCustom := func(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error {
return conn.Call(ctx, method, params).Await(ctx, result)
}

ctx := context.Background()
s := NewServer(testImpl, nil)

if err := AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) {
hits := []string{"result for " + params.Query}
return &searchResult{
Hits: hits,
Total: len(hits),
}, nil
}); err != nil {
t.Fatal(err)
}

ct, st := NewInMemoryTransports()
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = ss.Close() })

c := NewClient(testImpl, nil)
if err := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search"); err != nil {
t.Fatal(err)
}

cs, err := c.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = cs.Close() })

// Test raw JSON-RPC call.
var result1 searchResult
if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result1); err != nil {
t.Fatal(err)
}
if len(result1.Hits) != 1 || result1.Hits[0] != "result for hello" {
t.Errorf("raw call: unexpected hits: %v", result1.Hits)
}
if result1.Total != 1 {
t.Errorf("raw call: unexpected total: %d", result1.Total)
}

// Test the typed CallCustomMethod helper.
result2, err := CallCustomMethod[*searchParams, *searchResult](
ctx, cs, "acme/search", &searchParams{Query: "world"})
if err != nil {
t.Fatal(err)
}
if len(result2.Hits) != 1 || result2.Hits[0] != "result for world" {
t.Errorf("CallCustomMethod: unexpected hits: %v", result2.Hits)
}
if result2.Total != 1 {
t.Errorf("CallCustomMethod: unexpected total: %d", result2.Total)
}

// CallCustomMethod must reject methods that were never registered.
if _, err := CallCustomMethod[*searchParams, *searchResult](
ctx, cs, "acme/unregistered", &searchParams{Query: "x"}); err == nil {
t.Error("CallCustomMethod: expected error for unregistered method, got nil")
}
}

func TestAddCustomMethodRejectsStandardMethods(t *testing.T) {
type params struct{ ParamsBase }
type result struct{ ResultBase }

t.Run("server", func(t *testing.T) {
s := NewServer(testImpl, nil)
err := AddReceivingCustomMethod(s, "tools/call",
func(ctx context.Context, ss *ServerSession, p *params) (*result, error) {
return &result{}, nil
})
if err == nil {
t.Fatal("AddReceivingCustomMethod: expected error when shadowing a standard method, got nil")
}
})

t.Run("client", func(t *testing.T) {
c := NewClient(testImpl, nil)
if err := AddSendingCustomMethod[*params, *result](c, "tools/call"); err == nil {
t.Fatal("AddSendingCustomMethod: expected error when shadowing a standard method, got nil")
}
})
}

// TestCallCustomMethodTypedNilParams exercises the typed-nil params path.
// User param types embed ParamsBase, so the inherited isNil forwarder would
// dereference a typed-nil outer if injectRequestMeta were called with it.
// CallCustomMethod must allocate a fresh value before the meta-injection step.
func TestCallCustomMethodTypedNilParams(t *testing.T) {
type pingParams struct{ ParamsBase }
type pingResult struct{ ResultBase }

ctx := context.Background()
s := NewServer(testImpl, nil)
if err := AddReceivingCustomMethod(s, "acme/ping",
func(ctx context.Context, ss *ServerSession, p *pingParams) (*pingResult, error) {
return &pingResult{}, nil
}); err != nil {
t.Fatal(err)
}
ct, st := NewInMemoryTransports()
ss, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = ss.Close() })

c := NewClient(testImpl, nil)
if err := AddSendingCustomMethod[*pingParams, *pingResult](c, "acme/ping"); err != nil {
t.Fatal(err)
}
cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728})
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = cs.Close() })

var typedNil *pingParams
if _, err := CallCustomMethod[*pingParams, *pingResult](ctx, cs, "acme/ping", typedNil); err != nil {
t.Fatalf("CallCustomMethod with typed-nil params: %v", err)
}
}
Loading
Loading