diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 074ceab5..4994c63b 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -208,12 +208,30 @@ type ConnectionConfig struct { Bind func(*Connection) Handler // required OnDone func() // optional OnInternalError func(error) // optional + + // PropagateCancellation controls whether cancellation of the context + // passed to [NewConnection] is observable by request handlers. + // + // By default (false), the connection wraps that context (see [notDone]) + // so handlers' Done channels do not fire when the connection's root + // context is cancelled. Cancellation of an in-flight handler is then + // expected to flow only through the jsonrpc2 layer's explicit channels + // (the [Preempter] reacting to the peer's cancel notification, or a + // transport read/write failure cancelling every in-flight request). + // + // Set this true when cancellation of the connection's context is itself + // a meaningful signal that handlers should react to (for example, when + // the connection is tied to a carrier that owns it and whose end means + // the request is cancelled). + PropagateCancellation bool // optional } // NewConnection creates a new [Connection] object and starts processing // incoming messages. func NewConnection(ctx context.Context, cfg ConnectionConfig) *Connection { - ctx = notDone{ctx} + if !cfg.PropagateCancellation { + ctx = notDone{ctx} + } c := &Connection{ state: inFlightState{closer: cfg.Closer}, @@ -530,6 +548,14 @@ func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter ac.retire(&Response{ID: id, Error: err}) } s.outgoingCalls = nil + + // Cancel any incoming requests still in flight: with the reader gone we + // cannot receive cancellation notifications, and likely cannot write a + // response either, so parked handlers have nothing useful left to do. + // Mirrors the equivalent cleanup on write failure. + for _, r := range s.incomingByID { + r.cancel() + } }) } @@ -724,9 +750,12 @@ func (c *Connection) write(ctx context.Context, msg Message) error { var err error // Fail writes immediately if the connection is shutting down. // - // TODO(rfindley): should we allow cancellation notifications through? It - // could be the case that writes can still succeed. + // Allow outgoing "notifications" forwarded by the Notify method. + // This will allow to send the cancelled notification when the client is shutting down. c.updateInFlight(func(s *inFlightState) { + if req, ok := msg.(*Request); ok && !req.IsCall() && s.outgoingNotifications > 0 { + return + } err = s.shuttingDown(ErrServerClosing) }) if err == nil { @@ -771,6 +800,15 @@ func (c *Connection) internalErrorf(format string, args ...any) error { } // notDone is a context.Context wrapper that returns a nil Done channel. +// +// Request handlers' contexts are derived from the connection's root context, +// which by default is wrapped in notDone so a transport-level cancellation +// does not implicitly cancel every in-flight handler. Cancellation of an +// in-flight handler is instead expected to flow only through the jsonrpc2 +// layer's explicit channels: the [Preempter] calling [Connection.Cancel] in +// response to the peer's cancel notification, or the transport itself +// failing (the read loop exits on EOF or a write fails) — both of which +// cancel every in-flight incoming request in turn. type notDone struct{ ctx context.Context } func (ic notDone) Value(key any) any { diff --git a/mcp/client.go b/mcp/client.go index e64b8ff3..d8fd2c11 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -300,6 +300,27 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp if hc, ok := cs.mcpConn.(clientConnection); ok { hc.sessionUpdated(cs.state) } + subscribeParams := &SubscriptionsListenParams{} + if c.opts.ToolListChangedHandler != nil { + subscribeParams.Notifications.ToolsListChanged = true + } + if c.opts.PromptListChangedHandler != nil { + subscribeParams.Notifications.PromptsListChanged = true + } + if c.opts.ResourceListChangedHandler != nil { + subscribeParams.Notifications.ResourcesListChanged = true + } + if subscribeParams.Notifications.ToolsListChanged || + subscribeParams.Notifications.PromptsListChanged || + subscribeParams.Notifications.ResourcesListChanged { + // ClientSession.Close cancels the listenCtx context to send notifications/cancelled. + listenCtx, cancelListen := context.WithCancel(context.Background()) + cs.listenCancel = cancelListen + if err := cs.subscriptionsListen(listenCtx, subscribeParams); err != nil { + cancelListen() + return nil, fmt.Errorf("opening subscriptions/listen: %w", err) + } + } return cs, nil } @@ -414,6 +435,7 @@ type ClientSession struct { conn *jsonrpc2.Connection client *Client keepaliveCancel context.CancelFunc + listenCancel context.CancelFunc mcpConn Connection // No mutex is (currently) required to guard the session state, because it is @@ -430,6 +452,15 @@ type ClientSession struct { // Pending URL elicitations waiting for completion notifications. pendingElicitationsMu sync.Mutex pendingElicitations map[string]chan struct{} + + // resourceSubsMu guards resourceSubs. + resourceSubsMu sync.Mutex + // resourceSubs maps a subscribed resource URI to the cancel func of the + // goroutine running its dedicated subscriptions/listen stream. Populated + // only under SEP-2575; the legacy protocol routes Subscribe and + // Unsubscribe straight to the resources/subscribe and resources/unsubscribe + // RPCs and leaves this map untouched. + resourceSubs map[string]context.CancelFunc } type clientSessionState struct { @@ -496,6 +527,10 @@ func (cs *ClientSession) Close() error { if cs.keepaliveCancel != nil { cs.keepaliveCancel() } + if cs.listenCancel != nil { + cs.listenCancel() + } + cs.cancelAllResourceSubscriptions() err := cs.conn.Close() if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) { @@ -1083,6 +1118,7 @@ var clientMethodInfos = map[string]methodInfo{ notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK), + notificationSubscriptionsAck: newClientMethodInfo(clientMethod((*Client).callSubscriptionsAckHandler), notification|missingParamsOK), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { @@ -1279,17 +1315,91 @@ func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) ( // Subscribe sends a "resources/subscribe" request to the server, asking for // notifications when the specified resource changes. func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { - _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) - return err + if !cs.usesNewProtocol() { + _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) + return err + } + if params == nil || params.URI == "" { + return fmt.Errorf("Subscribe: missing URI") + } + uri := params.URI + + var listenCtx context.Context + cs.resourceSubsMu.Lock() + if _, exists := cs.resourceSubs[uri]; !exists { + var cancel context.CancelFunc + listenCtx, cancel = context.WithCancel(context.Background()) + if cs.resourceSubs == nil { + cs.resourceSubs = make(map[string]context.CancelFunc) + } + cs.resourceSubs[uri] = cancel + } + cs.resourceSubsMu.Unlock() + if listenCtx == nil { + // Already subscribed to this URI + return nil + } + + return cs.subscriptionsListen(listenCtx, &SubscriptionsListenParams{ + Notifications: NotificationSubscriptions{ + ResourceSubscriptions: []string{uri}, + }, + }) } -// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling -// a previous subscription. +// Unsubscribe cancels a previous [ClientSession.Subscribe] for params.URI. +// +// Under the legacy protocol it sends a "resources/unsubscribe" request. +// +// Under SEP-2575 it cancels the background "subscriptions/listen" stream +// opened by Subscribe for the URI. Unsubscribe is idempotent: calling it for +// a URI that is not currently subscribed is a no-op. func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { - _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) + if !cs.usesNewProtocol() { + _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) + return err + } + if params == nil || params.URI == "" { + return fmt.Errorf("Unsubscribe: missing URI") + } + cs.resourceSubsMu.Lock() + cancel, ok := cs.resourceSubs[params.URI] + delete(cs.resourceSubs, params.URI) + cs.resourceSubsMu.Unlock() + if ok { + cancel() + } + return nil +} + +// cancelAllResourceSubscriptions cancels every active SEP-2575 resource +// subscription opened via Subscribe. The listen goroutines exit +// asynchronously as their contexts unwind. Called from Close. +func (cs *ClientSession) cancelAllResourceSubscriptions() { + cs.resourceSubsMu.Lock() + subs := cs.resourceSubs + cs.resourceSubs = nil + cs.resourceSubsMu.Unlock() + for _, cancel := range subs { + cancel() + } +} + +// SubscriptionsListen opens a SEP-2575 "subscriptions/listen" stream. +// +// The server's first message on the stream is "notifications/subscriptions/acknowledged"; +// subsequent opted-in notifications (e.g. tools/list_changed) are delivered through the +// usual handlers registered in [ClientOptions]. +func (cs *ClientSession) subscriptionsListen(ctx context.Context, params *SubscriptionsListenParams) error { + params = injectRequestMeta(cs, params) + _, err := handleSend[*emptyResult](ctx, methodSubscriptionsListen, newClientRequest(cs, orZero[Params](params))) return err } +func (c *Client) callSubscriptionsAckHandler(context.Context, *ClientRequest[*SubscriptionsAcknowledgedParams]) (Result, error) { + return nil, nil +} + func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { if cs, ok := req.GetSession().(*ClientSession); ok { cs.toolsCache.invalidate() diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5c8e7d12..91cc57f9 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -12,6 +12,8 @@ import ( "fmt" "io" "log/slog" + "net/http" + "net/http/httptest" "net/url" "path/filepath" "runtime" @@ -2449,3 +2451,791 @@ func TestSetErrorPreservesContent(t *testing.T) { } var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(CallToolResult{}, GetPromptResult{}, ReadResourceResult{})} + +// runSubscriptionsListenTest exercises the SEP-2575 auto-listen flow end-to-end +// against the supplied transport pair. It captures every notification and the +// acknowledgment the client sees, then asserts: +// +// - the auto-listen issued by Client.Connect is acknowledged with a tagged +// subscription ID; +// - tool and prompt list-changed notifications are delivered to the matching +// handlers, each carrying the same subscription ID as the ack; +// - the subscription persists across multiple unrelated changes; +// - cs.Close() ends the subscription and further changes don't deliver. +func runSubscriptionsListenTest(t *testing.T, client *Client, server *Server, ct Transport, events chan subListenEvent) { + t.Helper() + + ctx, topCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer topCancel() + + cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + waitFor := func(kind string) subListenEvent { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return subListenEvent{} + } + } + expectNoEvent := func(d time.Duration) { + t.Helper() + select { + case e := <-events: + t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) + case <-time.After(d): + } + } + + ack := waitFor("ack") + if ack.id == "" { + t.Fatalf("acknowledgment missing subscription ID") + } + + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + if e := waitFor("tool"); e.id != ack.id { + t.Errorf("first tool notif id = %s, want %s", e.id, ack.id) + } + + server.AddPrompt(&Prompt{Name: "p2"}, nil) + if e := waitFor("prompt"); e.id != ack.id { + t.Errorf("first prompt notif id = %s, want %s", e.id, ack.id) + } + + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + if e := waitFor("tool"); e.id != ack.id { + t.Errorf("second tool notif id = %s, want %s", e.id, ack.id) + } + server.AddPrompt(&Prompt{Name: "p3"}, nil) + if e := waitFor("prompt"); e.id != ack.id { + t.Errorf("second prompt notif id = %s, want %s", e.id, ack.id) + } + expectNoEvent(notificationDelay * 5) + + cs.Close() + time.Sleep(50 * time.Millisecond) + + server.AddTool(&Tool{Name: "t4", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + server.AddPrompt(&Prompt{Name: "p4"}, nil) + expectNoEvent(notificationDelay * 20) +} + +type subListenEvent struct { + kind string // "ack", "tool", "prompt" + id string // subscription ID from _meta, stringified for cross-encoding equality +} + +// newSubListenClient returns a client wired to push every ack and every +// list-changed notification it receives into events, tagged with the kind +// and the subscription ID extracted from _meta. +func newSubListenClient(events chan subListenEvent) *Client { + asEvent := func(kind string, raw any) subListenEvent { + return subListenEvent{kind, fmt.Sprint(raw)} + } + c := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { + events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) + }, + PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { + events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) + }, + }) + c.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + return c +} + +func newSubListenServer() *Server { + s := NewServer(testImpl, nil) + AddTool(s, &Tool{Name: "t1"}, sayHi) + s.AddPrompt(&Prompt{Name: "p1"}, nil) + return s +} + +func enableNewProtocol(t *testing.T) { + t.Helper() + orig := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260728}, slices.Clone(orig)...) + t.Cleanup(func() { supportedProtocolVersions = orig }) +} + +// TestSubscriptionsListen_InMemory exercises the listen flow over the +// session-shared in-memory transport (semantically equivalent to STDIO). +// Cancellation here propagates via notifications/cancelled. +func TestSubscriptionsListen_InMemory(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 64) + server := newSubListenServer() + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + runSubscriptionsListenTest(t, newSubListenClient(events), server, ct, events) +} + +// TestSubscriptionsListen_Streamable exercises the listen flow over a +// stateless HTTP server (SEP-2575). Each listen rides its own SSE response +// stream; cs.Close() tears it down. +func TestSubscriptionsListen_Streamable(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 64) + server := newSubListenServer() + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + runSubscriptionsListenTest(t, newSubListenClient(events), server, + &StreamableClientTransport{Endpoint: httpServer.URL}, events) +} + +// TestSubscriptionsListen_NoHandlersNoListen verifies that a new-protocol +// client without any list-changed handlers registered does not open an +// auto-listen on connect, and therefore does not receive any acknowledgment +// or downstream notifications. +func TestSubscriptionsListen_NoHandlersNoListen(t *testing.T) { + enableNewProtocol(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + events := make(chan subListenEvent, 8) + server := newSubListenServer() + ct, st := NewInMemoryTransports() + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + c := NewClient(testImpl, nil) + c.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + events <- subListenEvent{"ack", ""} + } + return next(ctx, method, req) + } + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + + select { + case e := <-events: + t.Fatalf("unexpected event %q on no-handler client", e.kind) + case <-time.After(notificationDelay * 10): + } +} + +// resourceSubServer builds a server that advertises resource subscriptions +// and records every Subscribe/Unsubscribe handler invocation through chans. +func resourceSubServer(t *testing.T, subCh, unsubCh chan string) *Server { + t.Helper() + s := NewServer(testImpl, &ServerOptions{ + SubscribeHandler: func(_ context.Context, r *SubscribeRequest) error { + subCh <- r.Params.URI + return nil + }, + UnsubscribeHandler: func(_ context.Context, r *UnsubscribeRequest) error { + unsubCh <- r.Params.URI + return nil + }, + }) + s.AddResource(&Resource{Name: "r1", URI: "file:///r1"}, nil) + return s +} + +// resourceSubEvent is one delivered notifications/resources/updated. +type resourceSubEvent struct { + uri string + id string // _meta subscription ID, stringified +} + +// TestResourceSubscriptionsSEP2575_Streamable verifies the Subscribe -> +// ResourceUpdated path on a stateless Streamable HTTP server. +// +// Caveat: per-subscription Unsubscribe is intentionally NOT verified here. +// In stateless Streamable HTTP mode the subscriptions/listen handler blocks +// on its request context, and neither the HTTP POST disconnect nor the +// separate notifications/cancelled POST currently propagates to that +// handler's context. The handler only unwinds when the server next attempts +// a write to the (now-dead) SSE stream and the writeErr branch in the +// jsonrpc2 layer cancels the in-flight request. To keep the test +// hermetic we therefore trigger a write at the end by adding a resource, +// which fires notifications/resources/list_changed on the auto-listen path +// (if any) and on the per-URI listen, causing the listen handler to unwind. +// The spec-correct fix is to plumb the POST's request context down to the +// subscriptionsListen handler so HTTP disconnect is observed directly; this +// is tracked separately. +func TestResourceSubscriptions_Streamable(t *testing.T) { + enableNewProtocol(t) + + subCh := make(chan string, 8) + unsubCh := make(chan string, 8) + events := make(chan resourceSubEvent, 16) + + server := resourceSubServer(t, subCh, unsubCh) + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + c := NewClient(testImpl, &ClientOptions{ + ResourceUpdatedHandler: func(_ context.Context, req *ResourceUpdatedNotificationRequest) { + id := "" + if req.Params != nil && req.Params.Meta != nil { + id = fmt.Sprint(req.Params.Meta[MetaKeySubscriptionID]) + } + events <- resourceSubEvent{uri: req.Params.URI, id: id} + }, + }) + cs, err := c.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, + &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("subscribe r1: %v", err) + } + select { + case got := <-subCh: + if got != "file:///r1" { + t.Fatalf("got URI %q, want %q", got, "file:///r1") + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for SubscribeHandler") + } + + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + select { + case e := <-events: + if e.uri != "file:///r1" { + t.Fatalf("got URI %q, want %q", e.uri, "file:///r1") + } + if e.id == "" || e.id == "" { + t.Fatalf("missing subscription ID on update (got %q)", e.id) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for resource update") + } + + // See test header comment for the explanation of this teardown ritual: + // close the client, then drop server-side TCP, then drive a write that + // will fail (any extra ResourceUpdated for our URI), to unblock the + // in-flight listen handler so httpServer.Close can return. + _ = cs.Close() + httpServer.CloseClientConnections() + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + httpServer.Close() +} + +// TestResourceSubscriptions_InMemory mirrors TestResourceSubscriptions_Streamable +// over an in-memory (stdio-equivalent) transport: the per-URI Subscribe path +// uses notifications/cancelled rather than HTTP disconnect for teardown. +func TestResourceSubscriptions_InMemory(t *testing.T) { + enableNewProtocol(t) + + subCh := make(chan string, 8) + unsubCh := make(chan string, 8) + events := make(chan resourceSubEvent, 16) + + server := resourceSubServer(t, subCh, unsubCh) + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + c := NewClient(testImpl, &ClientOptions{ + ResourceUpdatedHandler: func(_ context.Context, req *ResourceUpdatedNotificationRequest) { + id := "" + if req.Params != nil && req.Params.Meta != nil { + id = fmt.Sprint(req.Params.Meta[MetaKeySubscriptionID]) + } + events <- resourceSubEvent{uri: req.Params.URI, id: id} + }, + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("subscribe r1: %v", err) + } + waitURI := func(ch chan string, want string) { + t.Helper() + select { + case got := <-ch: + if got != want { + t.Fatalf("got URI %q, want %q", got, want) + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for URI %q", want) + } + } + waitURI(subCh, "file:///r1") + + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + select { + case e := <-events: + if e.uri != "file:///r1" { + t.Fatalf("got URI %q, want %q", e.uri, "file:///r1") + } + if e.id == "" || e.id == "" { + t.Fatalf("missing subscription ID on update (got %q)", e.id) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for resource update") + } + + // On stdio/in-memory, Unsubscribe sends notifications/cancelled which + // reaches the server preempter, cancels the listen handler ctx, and + // triggers the UnsubscribeHandler. + if err := cs.Unsubscribe(ctx, &UnsubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("unsubscribe r1: %v", err) + } + waitURI(unsubCh, "file:///r1") + + // Updates for an unsubscribed URI MUST NOT be delivered. Give the server + // a moment to process the cancellation first. + time.Sleep(50 * time.Millisecond) + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + select { + case e := <-events: + t.Fatalf("unexpected resource update after unsubscribe: %q (id=%s)", e.uri, e.id) + case <-time.After(notificationDelay * 10): + } +} + +// TestResourceSubscriptions_Subscribe_Idempotent verifies that calling +// Subscribe twice for the same URI in the same session is a no-op for the +// second call: it returns nil without invoking SubscribeHandler again and +// without opening a second listen stream. +func TestResourceSubscriptions_Subscribe_Idempotent(t *testing.T) { + enableNewProtocol(t) + + subCh := make(chan string, 8) + unsubCh := make(chan string, 8) + + server := resourceSubServer(t, subCh, unsubCh) + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + c := NewClient(testImpl, &ClientOptions{ + ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) {}, + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("first subscribe: %v", err) + } + select { + case got := <-subCh: + if got != "file:///r1" { + t.Fatalf("got URI %q, want %q", got, "file:///r1") + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for first SubscribeHandler") + } + + // Second Subscribe for the same URI returns nil and does NOT fire + // SubscribeHandler again. + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("second subscribe should be no-op, got error: %v", err) + } + select { + case got := <-subCh: + t.Fatalf("duplicate Subscribe should not re-invoke SubscribeHandler (got %q)", got) + case <-time.After(notificationDelay * 10): + } + + // Subsequent Unsubscribe still works (verifies the URI is tracked + // correctly even though the second Subscribe was a no-op). + if err := cs.Unsubscribe(ctx, &UnsubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("unsubscribe: %v", err) + } + select { + case <-unsubCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for UnsubscribeHandler") + } +} + +// TestResourceSubscriptions_MultipleURIs verifies that two concurrent +// Subscribe calls on the same session each open their own independent listen +// stream with a distinct subscription ID. Unsubscribing one does not affect +// the other. +func TestResourceSubscriptions_MultipleURIs(t *testing.T) { + enableNewProtocol(t) + + subCh := make(chan string, 8) + unsubCh := make(chan string, 8) + events := make(chan resourceSubEvent, 16) + + server := resourceSubServer(t, subCh, unsubCh) + server.AddResource(&Resource{Name: "r2", URI: "file:///r2"}, nil) + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + c := NewClient(testImpl, &ClientOptions{ + ResourceUpdatedHandler: func(_ context.Context, req *ResourceUpdatedNotificationRequest) { + id := "" + if req.Params != nil && req.Params.Meta != nil { + id = fmt.Sprint(req.Params.Meta[MetaKeySubscriptionID]) + } + events <- resourceSubEvent{uri: req.Params.URI, id: id} + }, + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("subscribe r1: %v", err) + } + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r2"}); err != nil { + t.Fatalf("subscribe r2: %v", err) + } + + // Each Subscribe MUST fire its own SubscribeHandler invocation. + gotURIs := map[string]bool{} + for range 2 { + select { + case got := <-subCh: + gotURIs[got] = true + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for SubscribeHandler") + } + } + if !gotURIs["file:///r1"] || !gotURIs["file:///r2"] { + t.Fatalf("missing SubscribeHandler invocation; got %v", gotURIs) + } + + // Each update delivers exactly one event tagged with that URI's distinct + // subscription ID. + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + ev1 := <-events + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r2"}) + ev2 := <-events + if ev1.uri != "file:///r1" || ev2.uri != "file:///r2" { + t.Fatalf("got URIs %q and %q, want r1 and r2", ev1.uri, ev2.uri) + } + if ev1.id == "" || ev2.id == "" { + t.Fatalf("missing subscription IDs: r1=%q r2=%q", ev1.id, ev2.id) + } + if ev1.id == ev2.id { + t.Fatalf("r1 and r2 should have distinct subscription IDs, both = %q", ev1.id) + } + + // Unsubscribe r1 only. r2 keeps working. + if err := cs.Unsubscribe(ctx, &UnsubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("unsubscribe r1: %v", err) + } + select { + case got := <-unsubCh: + if got != "file:///r1" { + t.Fatalf("got unsubscribe URI %q, want %q", got, "file:///r1") + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for UnsubscribeHandler") + } + time.Sleep(50 * time.Millisecond) // let server-side cancellation settle + + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r2"}) + + // Only the r2 update should arrive. + select { + case e := <-events: + if e.uri != "file:///r2" || e.id != ev2.id { + t.Fatalf("post-unsubscribe got %q (id=%s), want r2 (id=%s)", e.uri, e.id, ev2.id) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for r2 update after r1 unsubscribe") + } + select { + case e := <-events: + t.Fatalf("unexpected event after r1 unsubscribe: %q (id=%s)", e.uri, e.id) + case <-time.After(notificationDelay * 10): + } + + // Explicit Unsubscribe r2 to verify it still works independently. + if err := cs.Unsubscribe(ctx, &UnsubscribeParams{URI: "file:///r2"}); err != nil { + t.Fatalf("unsubscribe r2: %v", err) + } + select { + case <-unsubCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for r2 UnsubscribeHandler") + } +} + +// TestSubscriptionsListen_MultipleSessions verifies that two concurrent +// client sessions on the same server are isolated: a list-changed +// notification is fanned out to BOTH sessions, each with its own subscription +// ID; closing one session does not affect deliveries to the other. +func TestSubscriptionsListen_MultipleSessions(t *testing.T) { + enableNewProtocol(t) + + server := newSubListenServer() + + // Open two clients on the same server. + open := func(t *testing.T) (*ClientSession, chan subListenEvent, *ServerSession) { + t.Helper() + events := make(chan subListenEvent, 16) + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + c := newSubListenClient(events) + cs, err := c.Connect(context.Background(), ct, + &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + return cs, events, ss + } + csA, evA, ssA := open(t) + defer ssA.Close() + csB, evB, ssB := open(t) + defer ssB.Close() + + waitFor := func(ch chan subListenEvent, kind string) subListenEvent { + t.Helper() + select { + case e := <-ch: + if e.kind != kind { + t.Fatalf("got event %q, want %q", e.kind, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return subListenEvent{} + } + } + ackA := waitFor(evA, "ack") + ackB := waitFor(evB, "ack") + // Request IDs are per-connection, so both sessions may legitimately use + // the same ID number; we only require that each session's own + // notifications carry its own ack ID. + + // A single change fans out to BOTH sessions, each tagged with that + // session's own ack ID. + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + gotA := waitFor(evA, "tool") + gotB := waitFor(evB, "tool") + if gotA.id != ackA.id { + t.Errorf("session A: tool notif id=%s, want %s", gotA.id, ackA.id) + } + if gotB.id != ackB.id { + t.Errorf("session B: tool notif id=%s, want %s", gotB.id, ackB.id) + } + + // Close A; B's subscription must keep delivering. + csA.Close() + time.Sleep(50 * time.Millisecond) + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + gotB2 := waitFor(evB, "tool") + if gotB2.id != ackB.id { + t.Errorf("session B after A closed: tool notif id=%s, want %s", gotB2.id, ackB.id) + } + select { + case e := <-evA: + t.Fatalf("session A unexpected event after Close: %q", e.kind) + case <-time.After(notificationDelay * 5): + } + + csB.Close() +} + +// TestSubscriptionsListen_ResourceListChanged covers the resources/list_changed +// auto-listen branch (the other two list-changed types are covered by +// runSubscriptionsListenTest, but resources is not). +func TestSubscriptionsListen_ResourceListChanged(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 16) + + server := NewServer(testImpl, nil) + server.AddResource(&Resource{Name: "r1", URI: "file:///r1"}, nil) + + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + asEvent := func(kind string, raw any) subListenEvent { + return subListenEvent{kind, fmt.Sprint(raw)} + } + c := NewClient(testImpl, &ClientOptions{ + ResourceListChangedHandler: func(_ context.Context, req *ResourceListChangedRequest) { + id := any(nil) + if req.Params != nil && req.Params.Meta != nil { + id = req.Params.Meta[MetaKeySubscriptionID] + } + events <- asEvent("resource", id) + }, + }) + c.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + waitFor := func(kind string) subListenEvent { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q, want %q", e.kind, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q", kind) + return subListenEvent{} + } + } + ack := waitFor("ack") + if ack.id == "" { + t.Fatal("acknowledgment missing subscription ID") + } + server.AddResource(&Resource{Name: "r2", URI: "file:///r2"}, nil) + got := waitFor("resource") + if got.id != ack.id { + t.Errorf("resource notif id=%s, want %s", got.id, ack.id) + } + cs.Close() +} + +// TestSubscriptionsListen_DisconnectScrubsMaps verifies that closing a +// session removes its entries from the server's three per-type subscription +// maps via Server.disconnect. +func TestSubscriptionsListen_DisconnectScrubsMaps(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 16) + server := newSubListenServer() + + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + c := newSubListenClient(events) + cs, err := c.Connect(context.Background(), ct, + &ClientSessionOptions{protocolVersion: protocolVersion20260728}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + // Wait for the auto-listen to actually register on the server side. + select { + case e := <-events: + if e.kind != "ack" { + t.Fatalf("got %q, want ack", e.kind) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for ack") + } + + // Server should now have this session in tools+prompts maps (the fixture + // client opts in to both via newSubListenClient). + server.mu.Lock() + if _, ok := server.toolChangeSubscriptions[ss]; !ok { + server.mu.Unlock() + t.Fatal("session missing from toolChangeSubscriptions before close") + } + if _, ok := server.promptChangeSubscriptions[ss]; !ok { + server.mu.Unlock() + t.Fatal("session missing from promptChangeSubscriptions before close") + } + server.mu.Unlock() + + cs.Close() + ss.Close() + // Closures complete asynchronously on the server side. + deadline := time.Now().Add(5 * time.Second) + for { + server.mu.Lock() + _, inTool := server.toolChangeSubscriptions[ss] + _, inPrompt := server.promptChangeSubscriptions[ss] + _, inResource := server.resourceChangeSubscriptions[ss] + server.mu.Unlock() + if !inTool && !inPrompt && !inResource { + return + } + if time.Now().After(deadline) { + t.Fatalf("subscription maps not scrubbed after Close: tool=%v prompt=%v resource=%v", + inTool, inPrompt, inResource) + } + time.Sleep(10 * time.Millisecond) + } +} diff --git a/mcp/protocol.go b/mcp/protocol.go index 4b0872be..35539259 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -1931,6 +1931,49 @@ type ResourceUpdatedNotificationParams struct { func (x *ResourceUpdatedNotificationParams) isParams() {} func (x *ResourceUpdatedNotificationParams) isNil() bool { return x == nil } +// NotificationSubscriptions describes the set of server-to-client +// notifications a client wishes to receive on a [SubscriptionsListenParams] +// stream. Each field is an explicit opt-in: a server MUST NOT push +// notifications of a type the client did not request. +type NotificationSubscriptions struct { + // ToolsListChanged opts in to "notifications/tools/list_changed". + ToolsListChanged bool `json:"toolsListChanged,omitempty"` + // PromptsListChanged opts in to "notifications/prompts/list_changed". + PromptsListChanged bool `json:"promptsListChanged,omitempty"` + // ResourcesListChanged opts in to "notifications/resources/list_changed". + ResourcesListChanged bool `json:"resourcesListChanged,omitempty"` + // ResourceSubscriptions enumerates the resource URIs for which the client + // wants "notifications/resources/updated". Replaces the legacy + // resources/subscribe RPC. + ResourceSubscriptions []string `json:"resourceSubscriptions,omitempty"` +} + +// SubscriptionsListenParams are the parameters for the +// "subscriptions/listen" RPC. +type SubscriptionsListenParams struct { + // Meta carries the per-request `_meta` triple. + Meta `json:"_meta,omitempty"` + // Notifications declares which notification types the client wants to + // receive on this stream. + Notifications NotificationSubscriptions `json:"notifications"` +} + +func (x *SubscriptionsListenParams) isParams() {} +func (x *SubscriptionsListenParams) isNil() bool { return x == nil } + +// SubscriptionsAcknowledgedParams are the parameters for the +// "notifications/subscriptions/acknowledged" notification, which the server +// MUST send as the first message on a subscriptions/listen stream. It carries +// the subset of the requested [NotificationSubscriptions] that the server has +// agreed to honor. +type SubscriptionsAcknowledgedParams struct { + Meta `json:"_meta,omitempty"` + Notifications NotificationSubscriptions `json:"notifications"` +} + +func (x *SubscriptionsAcknowledgedParams) isParams() {} +func (x *SubscriptionsAcknowledgedParams) isNil() bool { return x == nil } + // TODO(jba): add CompleteRequest and related types. // A request from the server to elicit additional information from the user via the client. @@ -2134,8 +2177,10 @@ const ( notificationRootsListChanged = "notifications/roots/list_changed" methodSetLevel = "logging/setLevel" methodSubscribe = "resources/subscribe" + methodSubscriptionsListen = "subscriptions/listen" notificationToolListChanged = "notifications/tools/list_changed" methodUnsubscribe = "resources/unsubscribe" + notificationSubscriptionsAck = "notifications/subscriptions/acknowledged" ) // Per-request _meta field names for the >= 2026-07-28 protocol version. @@ -2152,6 +2197,9 @@ const ( MetaKeyClientCapabilities = "io.modelcontextprotocol/clientCapabilities" // MetaKeyLogLevel identifies the desired log level for the request. MetaKeyLogLevel = "io.modelcontextprotocol/logLevel" + // MetaKeySubscriptionID identifies the subscriptions/listen request that an + // out-of-band notification belongs to. + MetaKeySubscriptionID = "io.modelcontextprotocol/subscriptionId" ) // UnsupportedProtocolVersionData is the SEP-2575 payload carried in the diff --git a/mcp/requests.go b/mcp/requests.go index 36368c99..64414caa 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -19,6 +19,7 @@ type ( ReadResourceRequest = ServerRequest[*ReadResourceParams] RootsListChangedRequest = ServerRequest[*RootsListChangedParams] SubscribeRequest = ServerRequest[*SubscribeParams] + SubscriptionsListenRequest = ServerRequest[*SubscriptionsListenParams] UnsubscribeRequest = ServerRequest[*UnsubscribeParams] ) diff --git a/mcp/server.go b/mcp/server.go index af59b40b..011ea56e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -44,16 +44,19 @@ type Server struct { impl *Implementation opts ServerOptions - mu sync.Mutex - prompts *featureSet[*serverPrompt] - tools *featureSet[*serverTool] - resources *featureSet[*serverResource] - resourceTemplates *featureSet[*serverResourceTemplate] - sessions []*ServerSession - sendingMethodHandler_ MethodHandler - receivingMethodHandler_ MethodHandler - resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + mu sync.Mutex + prompts *featureSet[*serverPrompt] + tools *featureSet[*serverTool] + resources *featureSet[*serverResource] + resourceTemplates *featureSet[*serverResourceTemplate] + sessions []*ServerSession + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler + toolChangeSubscriptions map[*ServerSession]jsonrpc.ID // session -> requestID for "tools/changed" + promptChangeSubscriptions map[*ServerSession]jsonrpc.ID // session -> requestID for "prompts/changed" + resourceChangeSubscriptions map[*ServerSession]jsonrpc.ID // session -> requestID for "resources/changed" + resourceSubscriptions map[string]map[*ServerSession]jsonrpc.ID // uri -> session -> requestID + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -195,16 +198,19 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { } s := &Server{ - impl: impl, - opts: opts, - prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), - tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), - resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), - resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), - sendingMethodHandler_: defaultSendingMethodHandler, - receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], - resourceSubscriptions: make(map[string]map[*ServerSession]bool), - pendingNotifications: make(map[string]*time.Timer), + impl: impl, + opts: opts, + prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), + tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), + resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), + resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), + sendingMethodHandler_: defaultSendingMethodHandler, + receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + toolChangeSubscriptions: make(map[*ServerSession]jsonrpc.ID), + promptChangeSubscriptions: make(map[*ServerSession]jsonrpc.ID), + resourceChangeSubscriptions: make(map[*ServerSession]jsonrpc.ID), + resourceSubscriptions: make(map[string]map[*ServerSession]jsonrpc.ID), + pendingNotifications: make(map[string]*time.Timer), } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) return s @@ -647,12 +653,13 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR return s.opts.CompletionHandler(ctx, req) } -// Map from notification name to its corresponding params. The params have no fields, -// so a single struct can be reused. -var changeNotificationParams = map[string]Params{ - notificationToolListChanged: &ToolListChangedParams{}, - notificationPromptListChanged: &PromptListChangedParams{}, - notificationResourceListChanged: &ResourceListChangedParams{}, +// Map from notification name to a function creating its corresponding Params. +// We need to create a fresh one each time to add the jsonrpc ID. See +// [injectMetaSubscriptionID]. +var changeNotificationParams = map[string]func() Params{ + notificationToolListChanged: func() Params { return &ToolListChangedParams{} }, + notificationPromptListChanged: func() Params { return &PromptListChangedParams{} }, + notificationResourceListChanged: func() Params { return &ResourceListChangedParams{} }, } // How long to wait before sending a change notification. @@ -686,12 +693,62 @@ func (s *Server) changeAndNotify(notification string, change func() bool) { // notifySessions sends the notification n to all existing sessions. // It is called asynchronously by changeAndNotify. +// +// Legacy (pre-SEP-2575) sessions receive the notification on the shared +// session channel. Sessions speaking the new protocol receive it only if they +// have an active subscriptions/listen stream that opted in to this +// notification type. func (s *Server) notifySessions(n string) { s.mu.Lock() - sessions := slices.Clone(s.sessions) s.pendingNotifications[n] = nil + // Legacy (pre-SEP-2575) sessions receive list-changed notifications on the + // shared session channel without opt-in; collect them while we hold the lock. + var legacySessions []*ServerSession + for _, sess := range s.sessions { + if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260728 { + legacySessions = append(legacySessions, sess) + } + } + var subscribers map[*ServerSession]jsonrpc.ID + switch n { + case notificationToolListChanged: + subscribers = maps.Clone(s.toolChangeSubscriptions) + case notificationPromptListChanged: + subscribers = maps.Clone(s.promptChangeSubscriptions) + case notificationResourceListChanged: + subscribers = maps.Clone(s.resourceChangeSubscriptions) + } s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. - notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger) + + notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) + + // Sessions receive the notification only if they opened a + // subscriptions/listen stream that opted in to this notification type. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for sess, reqID := range subscribers { + params := changeNotificationParams[n]() + injectMetaSubscriptionID(params, reqID) + req := newRequest(sess, params) + if err := handleNotify(ctx, n, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) + } + } +} + +// injectMetaSubscriptionID stamps the listen request's JSON-RPC ID into the +// params' _meta under [MetaKeySubscriptionID], so the receiving client can +// correlate the notification with the [subscriptions/listen] stream it +// originated. +// +// [subscriptions/listen]: https://modelcontextprotocol.io/seps/2575-stateless-mcp#multiple-concurrent-subscriptions +func injectMetaSubscriptionID(params Params, reqID jsonrpc.ID) { + m := params.GetMeta() + if m == nil { + m = map[string]any{} + } + m[MetaKeySubscriptionID] = reqID.Raw() + params.SetMeta(m) } // shouldSendListChangedNotification checks if the server's capabilities allow @@ -1006,12 +1063,38 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot subscribedSessions := s.resourceSubscriptions[params.URI] sessions := slices.Collect(maps.Keys(subscribedSessions)) s.mu.Unlock() - notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger) + // Only add legacy sessions for the notification, new ones use the new notification mechanism. + var legacySessions []*ServerSession + var newSessions []*ServerSession + for _, sess := range sessions { + if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260728 { + legacySessions = append(legacySessions, sess) + } else { + newSessions = append(newSessions, sess) + } + } + notifySessions(legacySessions, notificationResourceUpdated, params, s.opts.Logger) s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + for _, sess := range newSessions { + reqID := subscribedSessions[sess] + p := *params + injectMetaSubscriptionID(&p, reqID) + req := newRequest(sess, &p) + if err := handleNotify(ctx, notificationResourceUpdated, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", notificationResourceUpdated, err)) + } + } return nil } func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyResult, error) { + requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) + if !ok || !requestID.IsValid() { + return nil, fmt.Errorf("%w: subscribe requires a request ID", jsonrpc2.ErrInvalidRequest) + } if s.opts.SubscribeHandler == nil { return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) } @@ -1022,10 +1105,10 @@ func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyRe s.mu.Lock() defer s.mu.Unlock() if s.resourceSubscriptions[req.Params.URI] == nil { - s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) + s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]jsonrpc.ID) } - s.resourceSubscriptions[req.Params.URI][req.Session] = true - s.opts.Logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) + s.resourceSubscriptions[req.Params.URI][req.Session] = requestID + s.opts.Logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID(), "request_id", requestID) return &emptyResult{}, nil } @@ -1052,6 +1135,82 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp return &emptyResult{}, nil } +func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { + requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) + if !ok || !requestID.IsValid() { + return nil, fmt.Errorf("%w: subscriptions/listen requires a request ID", jsonrpc2.ErrInvalidRequest) + } + + allowed := s.allowedSubscriptions(req.Params.Notifications) + s.mu.Lock() + if allowed.ToolsListChanged { + s.toolChangeSubscriptions[req.Session] = requestID + } + if allowed.PromptsListChanged { + s.promptChangeSubscriptions[req.Session] = requestID + } + if allowed.ResourcesListChanged { + s.resourceChangeSubscriptions[req.Session] = requestID + } + s.mu.Unlock() + defer func() { + s.mu.Lock() + delete(s.toolChangeSubscriptions, req.Session) + delete(s.promptChangeSubscriptions, req.Session) + delete(s.resourceChangeSubscriptions, req.Session) + s.mu.Unlock() + }() + + for _, uri := range allowed.ResourceSubscriptions { + _, err := s.subscribe(ctx, &SubscribeRequest{ + Session: req.Session, + Params: &SubscribeParams{ + URI: uri, + Meta: req.Params.GetMeta(), + }, + }) + if err != nil { + return nil, err + } + defer s.unsubscribe(ctx, &UnsubscribeRequest{ + Session: req.Session, + Params: &UnsubscribeParams{ + URI: uri, + Meta: req.Params.GetMeta(), + }, + }) + } + + ackParams := &SubscriptionsAcknowledgedParams{ + Notifications: allowed, + Meta: Meta{MetaKeySubscriptionID: requestID.Raw()}, + } + if err := req.Session.notifySubscriptionAcked(ctx, ackParams); err != nil { + return nil, fmt.Errorf("sending subscriptions/acknowledged: %w", err) + } + + <-ctx.Done() + return &emptyResult{}, nil +} + +func (s *Server) allowedSubscriptions(want NotificationSubscriptions) NotificationSubscriptions { + caps := s.capabilities() + agreed := NotificationSubscriptions{} + if want.ToolsListChanged && caps.Tools != nil && caps.Tools.ListChanged { + agreed.ToolsListChanged = true + } + if want.PromptsListChanged && caps.Prompts != nil && caps.Prompts.ListChanged { + agreed.PromptsListChanged = true + } + if want.ResourcesListChanged && caps.Resources != nil && caps.Resources.ListChanged { + agreed.ResourcesListChanged = true + } + if len(want.ResourceSubscriptions) > 0 && caps.Resources != nil && caps.Resources.Subscribe { + agreed.ResourceSubscriptions = slices.Clone(want.ResourceSubscriptions) + } + return agreed +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection or the provided @@ -1121,6 +1280,10 @@ func (s *Server) disconnect(cc *ServerSession) { for _, subscribedSessions := range s.resourceSubscriptions { delete(subscribedSessions, cc) } + delete(s.toolChangeSubscriptions, cc) + delete(s.promptChangeSubscriptions, cc) + delete(s.resourceChangeSubscriptions, cc) + s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) } @@ -1221,6 +1384,13 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) } +// notifySubscriptionAcked sends a "notifications/subscriptions/acknowledged" +// notification on the listen stream represented by this session, indicating +// the subscription filter the server accepted (SEP-2575). +func (ss *ServerSession) notifySubscriptionAcked(ctx context.Context, params *SubscriptionsAcknowledgedParams) error { + return handleNotify(ctx, notificationSubscriptionsAck, newServerRequest(ss, orZero[Params](params))) +} + func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { return &ServerRequest[P]{Session: ss, Params: params} } @@ -1522,6 +1692,7 @@ var serverMethodInfos = map[string]methodInfo{ methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), + methodSubscriptionsListen: newServerMethodInfo(serverMethod((*Server).subscriptionsListen), 0), methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), @@ -1598,7 +1769,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, } switch req.Method { - case methodInitialize, methodPing, notificationInitialized, methodSetLevel: + case methodInitialize, methodPing, notificationInitialized, methodSetLevel, methodSubscribe, methodUnsubscribe: if validatedMeta.usesNewProtocol { ss.server.opts.Logger.Error("method removed in the new protocol", "method", req.Method) return nil, &jsonrpc.Error{ diff --git a/mcp/shared.go b/mcp/shared.go index a15ec201..ac07ee46 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -123,8 +123,12 @@ func defaultSendingMethodHandler(ctx context.Context, method string, req Request // Create the result to unmarshal into. // The concrete type of the result is the return type of the receiving function. res := info.newResult() - if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { - return nil, err + if method == methodSubscriptionsListen { + callSubscriptionsListen(ctx, req.GetSession().getConn(), method, params) + } else { + if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { + return nil, err + } } return res, nil } diff --git a/mcp/streamable.go b/mcp/streamable.go index f01266a8..b71bc774 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -366,14 +366,14 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. return } - connectOpts, usesNewProtocol, err := h.ephemeralConnectOpts(req) + info, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } var sessionID string - if legacySessions && !usesNewProtocol { + if legacySessions && !info.usesNewProtocol { sessionID = req.Header.Get(sessionIDHeader) if sessionID == "" { sessionID = server.opts.GetSessionID() @@ -381,14 +381,15 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. } transport := &StreamableServerTransport{ - SessionID: sessionID, - Stateless: true, - EventStore: h.opts.EventStore, - jsonResponse: h.opts.JSONResponse, - logger: h.opts.Logger, + SessionID: sessionID, + Stateless: true, + EventStore: h.opts.EventStore, + jsonResponse: h.opts.JSONResponse, + logger: h.opts.Logger, + shouldPropagateCancellation: info.isSubscriptionsListen && info.usesNewProtocol, } - session, err := connectStreamable(req.Context(), server, transport, connectOpts) + session, err := connectStreamable(req.Context(), server, transport, info.opts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) http.Error(w, "failed connection", http.StatusInternalServerError) @@ -412,27 +413,29 @@ func (h *StreamableHTTPHandler) serveStatelessLegacyDELETE(w http.ResponseWriter w.WriteHeader(http.StatusNoContent) } -// ephemeralConnectOpts peeks at the request body to determine whether it -// contains an initialize or initialized message or whether the protocol version -// header indicates a protocol version >= 2026-07-28 (SEP-2575). +type ephemeralConnectInfo struct { + opts *ServerSessionOptions + usesNewProtocol bool + isSubscriptionsListen bool +} + +// ephemeralConnectOpts peeks at the request body to determine connection +// parameters and whether protocol version >= 2026-06-30 (SEP-2575). // // For old-protocol requests, default session state is synthesized so that // the session's init gate doesn't reject the request. // // It is used for both stateless servers and stateful servers with no session ID. -// -// The returned usesNewProtocol bool reports whether the protocol version -// header indicates a protocol version >= 2026-07-28 (SEP-2575). -func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *ServerSessionOptions, usesNewProtocol bool, err error) { +func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*ephemeralConnectInfo, error) { protocolVersion := protocolVersionFromContext(req.Context()) if protocolVersion == "" { protocolVersion = protocolVersion20250326 } - var hasInitialize, hasInitialized bool + var hasInitialize, hasInitialized, usesNewProtocol, isSubscriptionsListen bool body, err := io.ReadAll(req.Body) if err != nil { - return nil, false, fmt.Errorf("failed to read body") + return nil, fmt.Errorf("failed to read body") } req.Body.Close() req.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -445,6 +448,8 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *S hasInitialize = true case notificationInitialized: hasInitialized = true + case methodSubscriptionsListen: + isSubscriptionsListen = true } if protocolVersion >= protocolVersion20260728 { usesNewProtocol = true @@ -467,9 +472,13 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *S if !usesNewProtocol { state.LogLevel = "info" } - return &ServerSessionOptions{ - State: state, - }, usesNewProtocol, nil + return &ephemeralConnectInfo{ + opts: &ServerSessionOptions{ + State: state, + }, + usesNewProtocol: usesNewProtocol, + isSubscriptionsListen: isSubscriptionsListen, + }, nil } func connectStreamable(ctx context.Context, server *Server, transport *StreamableServerTransport, opts *ServerSessionOptions) (*ServerSession, error) { @@ -614,12 +623,12 @@ func (h *StreamableHTTPHandler) serveStatefulPOST(w http.ResponseWriter, req *ht // that arrives before a session exists (e.g. initialize or ping) on a // server configured this way. if sessionID == "" { - connectOpts, _, err := h.ephemeralConnectOpts(req) + info, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - session, err := connectStreamable(req.Context(), server, transport, connectOpts) + session, err := connectStreamable(req.Context(), server, transport, info.opts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) http.Error(w, "failed connection", http.StatusInternalServerError) @@ -775,6 +784,10 @@ type StreamableServerTransport struct { // to write their own streamable HTTP handler. logger *slog.Logger + // shouldPropagateCancellation is forwarded to the underlying + // [streamableServerConn]. See its docstring. + shouldPropagateCancellation bool + // connection is non-nil if and only if the transport has been connected. connection *streamableServerConn } @@ -785,15 +798,16 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er return nil, fmt.Errorf("transport already connected") } t.connection = &streamableServerConn{ - sessionID: t.SessionID, - stateless: t.Stateless, - eventStore: t.EventStore, - jsonResponse: t.jsonResponse, - logger: ensureLogger(t.logger), // see #556: must be non-nil - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - streams: make(map[string]*stream), - requestStreams: make(map[jsonrpc.ID]string), + sessionID: t.SessionID, + stateless: t.Stateless, + eventStore: t.EventStore, + jsonResponse: t.jsonResponse, + logger: ensureLogger(t.logger), // see #556: must be non-nil + shouldPropagateCancellation: t.shouldPropagateCancellation, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + streams: make(map[string]*stream), + requestStreams: make(map[jsonrpc.ID]string), } // Stream 0 corresponds to the standalone SSE stream. // @@ -823,6 +837,13 @@ type streamableServerConn struct { jsonResponse bool eventStore EventStore + // shouldPropagateCancellation is true when the underlying HTTP request's + // lifetime IS the connection's cancellation signal (e.g., a stateless + // POST that owns a long-lived subscriptions/listen stream). It is read + // by the [cancellationPropagator] interface so the jsonrpc2 layer wires + // handler contexts to observe the carrier's cancellation. + shouldPropagateCancellation bool + logger *slog.Logger toolLookup func(name string) (*serverTool, bool) @@ -860,6 +881,15 @@ func (c *streamableServerConn) SessionID() string { return c.sessionID } +// propagateCancellation implements [cancellationPropagator]. It returns true +// when this connection is bound to a single HTTP request whose lifetime +// should drive request-handler cancellation — for example, a stateless POST +// carrying a long-lived subscriptions/listen stream that must unwind when +// the client TCP-disconnects. +func (c *streamableServerConn) propagateCancellation() bool { + return c.shouldPropagateCancellation +} + // A stream is a single logical stream of SSE events within a server session. // A stream begins with a client request, or with a client GET that has // no Last-Event-ID header. @@ -914,6 +944,12 @@ type stream struct { // the spec and earlier there was a concept of batching, in which POST // payloads could hold multiple requests or responses. requests map[jsonrpc.ID]struct{} + + // isListen reports whether this stream was opened by a + // subscriptions/listen request. Listen streams are always SSE, live for + // the duration of the subscription, and act as the target for + // out-of-band notifications routed through this connection. + isListen bool } // close sends a 'close' event to the client (if protocolVersion >= 2025-11-25 @@ -1372,6 +1408,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques calls := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false + isSubscriptionsListen := false var initializeProtocolVersion string for _, msg := range incoming { if jreq, ok := msg.(*jsonrpc.Request); ok { @@ -1397,6 +1434,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques initializeProtocolVersion = params.ProtocolVersion } } + if jreq.Method == methodSubscriptionsListen { + isSubscriptionsListen = true + } // SEP-2575: requests carrying `_meta.protocolVersion` require the // Mcp-Protocol-Version HTTP header to be present and to match the // per-request `_meta.protocolVersion` value. @@ -1542,14 +1582,22 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) return } + stream.isListen = isSubscriptionsListen + + // subscriptions/listen is inherently a long-lived SSE endpoint (SEP-2575): + // it has no synchronous result, the response stream stays open until the + // client cancels, and the server pushes notifications on it as they occur. + // Force SSE mode (bypassing JSONResponse) so the buffered application/json + // path doesn't deadlock waiting for a completion that won't come. + useSSE := !c.jsonResponse || isSubscriptionsListen // Set response headers. Accept was checked in [StreamableHTTPHandler]. w.Header().Set("Cache-Control", "no-cache, no-transform") - if c.jsonResponse { - w.Header().Set("Content-Type", "application/json") - } else { + if useSSE { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Connection", "keep-alive") + } else { + w.Header().Set("Content-Type", "application/json") } if c.sessionID != "" && isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) @@ -1584,7 +1632,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // remaining requests here, since the client will never get the results. defer stream.release() - if c.jsonResponse { + if !useSSE { // JSON mode: collect messages in pendingJSONMessages until done. // Set pendingJSONMessages to a non-nil value to signal that this is an // application/json stream. @@ -1724,7 +1772,18 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e s = c.streams[streamID] } } else { - s = c.streams[""] // standalone SSE stream + // In stateless mode there will always be only one stream per connection. + // If that stream was open to listen for subscription notifications, + // automatically select as the one to write the notification to. + for _, stream := range c.streams { + if stream.isListen { + s = stream + break + } + } + if s == nil { + s = c.streams[""] // standalone SSE stream + } } if responseTo.IsValid() { // Once we've responded to a request, disallow related messages by removing diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 539789b9..b5c0eb96 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3319,20 +3319,20 @@ func TestEphemeralConnectOpts(t *testing.T) { } req.Header.Set(protocolVersionHeader, pver) req = req.WithContext(context.WithValue(req.Context(), protocolVersionContextKey{}, pver)) - opts, usesNew, err := h.ephemeralConnectOpts(req) + info, err := h.ephemeralConnectOpts(req) if err != nil { t.Fatal(err) } - if usesNew != tt.wantUsesNew { - t.Errorf("usesNewProtocol = %v, want %v", usesNew, tt.wantUsesNew) + if info.usesNewProtocol != tt.wantUsesNew { + t.Errorf("usesNewProtocol = %v, want %v", info.usesNewProtocol, tt.wantUsesNew) } - if got := opts.State.InitializeParams != nil; got != tt.wantInitializeParams { + if got := info.opts.State.InitializeParams != nil; got != tt.wantInitializeParams { t.Errorf("InitializeParams non-nil = %v, want %v (value = %+v)", - got, tt.wantInitializeParams, opts.State.InitializeParams) + got, tt.wantInitializeParams, info.opts.State.InitializeParams) } - if got := opts.State.InitializedParams != nil; got != tt.wantInitializedParams { + if got := info.opts.State.InitializedParams != nil; got != tt.wantInitializedParams { t.Errorf("InitializedParams non-nil = %v, want %v (value = %+v)", - got, tt.wantInitializedParams, opts.State.InitializedParams) + got, tt.wantInitializedParams, info.opts.State.InitializedParams) } }) } diff --git a/mcp/transport.go b/mcp/transport.go index f67c43d0..55c73d74 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -191,6 +191,13 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) } + // Transports may opt in to propagating cancellation of ctx into request + // handler contexts when their own lifecycle IS the cancellation signal + // (e.g., a connection bound to a single HTTP request). + var propagateCancellation bool + if cp, ok := mcpConn.(cancellationPropagator); ok { + propagateCancellation = cp.propagateCancellation() + } _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ Reader: reader, Writer: writer, @@ -203,11 +210,22 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, OnInternalError: func(err error) { logger.Error("jsonrpc2 internal error", "error", err) }, + PropagateCancellation: propagateCancellation, }) assert(preempter.conn != nil, "unbound preempter") return h, nil } +// cancellationPropagator is an optional interface implemented by a +// [Connection] whose own lifecycle should propagate cancellation into request +// handler contexts. The default jsonrpc2 behavior is to suppress propagation; +// transports that bind a connection to a single short-lived carrier (such as +// a one-shot HTTP request) should return true here so that handlers unwind +// when the carrier observes the peer going away. +type cancellationPropagator interface { + propagateCancellation() bool +} + // A canceller is a jsonrpc2.Preempter that cancels in-flight requests on MCP // cancelled notifications. type canceller struct { @@ -230,6 +248,24 @@ func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result a return nil, jsonrpc2.ErrNotHandled } +// callSubscriptionsListen issues a "subscriptions/listen" call (SEP-2575) +// without awaiting its JSON-RPC response. The call's logical lifetime is the +// stream of notifications that follow on the same channel — the empty +// response, if ever delivered, only marks subscription teardown — so the +// caller has nothing useful to block on. +// +// Cancellation is driven by ctx: when it is cancelled, a background goroutine +// sends a "notifications/cancelled" notification referencing the listen's +// request ID and retires the call from the connection's outgoing-calls map. +func callSubscriptionsListen(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params) { + call := conn.Call(ctx, method, params) + + go func() { + <-ctx.Done() + _ = cancelCall(ctx, conn, call) + }() +} + // call executes and awaits a jsonrpc2 call on the given connection, // translating errors into the mcp domain. func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params, result Result) error { @@ -240,24 +276,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) case ctx.Err() != nil: - notifyCtx, cancelNotify := context.WithTimeout(context.WithoutCancel(ctx), notifyCancellationTimeout) - defer cancelNotify() - err := conn.Notify(notifyCtx, notificationCancelled, &CancelledParams{ - Reason: ctx.Err().Error(), - RequestID: call.ID().Raw(), - }) - // By default, the jsonrpc2 library waits for graceful shutdown when the - // connection is closed, meaning it expects all outgoing and incoming - // requests to complete. However, for MCP this expectation is unrealistic, - // and can lead to hanging shutdown. For example, if a streamable client is - // killed, the server will not be able to detect this event, except via - // keepalive pings (if they are configured), and so outgoing calls may hang - // indefinitely. - // - // Therefore, we choose to eagerly retire calls, removing them from the - // outgoingCalls map, when the caller context is cancelled: if the caller - // will never receive the response, there's no need to track it. - conn.Retire(call, ctx.Err()) + err := cancelCall(ctx, conn, call) return errors.Join(ctx.Err(), err) case err != nil: return fmt.Errorf("calling %q: %w", method, err) @@ -265,6 +284,30 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params return nil } +// cancelCall sends a "notifications/cancelled" notification for call and eagerly +// retires it from conn. +// +// By default, the jsonrpc2 library waits for graceful shutdown when the +// connection is closed, meaning it expects all outgoing and incoming requests +// to complete. However, for MCP this expectation is unrealistic, and can lead +// to hanging shutdown. For example, if a streamable client is killed, the +// server will not be able to detect this event, except via keepalive pings (if +// they are configured), and so outgoing calls may hang indefinitely. +// +// Therefore, we choose to eagerly retire calls, removing them from the +// outgoingCalls map, when the caller context is cancelled: if the caller will +// never receive the response, there's no need to track it. +func cancelCall(ctx context.Context, conn *jsonrpc2.Connection, call *jsonrpc2.AsyncCall) error { + notifyCtx, cancelNotify := context.WithTimeout(context.WithoutCancel(ctx), notifyCancellationTimeout) + defer cancelNotify() + err := conn.Notify(notifyCtx, notificationCancelled, &CancelledParams{ + Reason: ctx.Err().Error(), + RequestID: call.ID().Raw(), + }) + conn.Retire(call, ctx.Err()) + return err +} + // A LoggingTransport is a [Transport] that delegates to another transport, // writing RPC logs to an io.Writer. type LoggingTransport struct {