From c99ca347f6d70988c155d6e75550f182938cf8de Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 11 Jun 2026 15:32:52 +0000 Subject: [PATCH 01/18] feat: implement subscriptions/listen with SSE support for long-lived streams and end-to-end testing --- mcp/client.go | 36 +++++++++ mcp/protocol.go | 48 ++++++++++++ mcp/requests.go | 1 + mcp/server.go | 162 ++++++++++++++++++++++++++++++++++++++--- mcp/streamable.go | 33 +++++++-- mcp/streamable_test.go | 134 ++++++++++++++++++++++++++++++++++ mcp/transport_test.go | 151 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 550 insertions(+), 15 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 1e0e18a4..e40f148a 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -300,6 +300,26 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp if hc, ok := cs.mcpConn.(clientConnection); ok { hc.sessionUpdated(cs.state) } + subscribeParams := &SubscriptionsListenParams{} + if c.opts.ToolListChangedHandler != nil { + subscribeParams.Notifications.ToolsListChanged = true + } + if c.opts.PromptListChangedHandler != nil { + subscribeParams.Notifications.PromptsListChanged = true + } + if c.opts.ResourceListChangedHandler != nil { + subscribeParams.Notifications.ResourcesListChanged = true + } + if subscribeParams.Notifications.ToolsListChanged || + subscribeParams.Notifications.PromptsListChanged || + subscribeParams.Notifications.ResourcesListChanged { + // The listen blocks until the server cancels it (or the + // underlying connection ends). Run it in a goroutine so + // Connect can return; the SDK retires it on cs.Close(). + go func() { + _ = cs.subscriptionsListen(context.Background(), subscribeParams) + }() + } return cs, nil } @@ -1079,6 +1099,7 @@ var clientMethodInfos = map[string]methodInfo{ notificationLoggingMessage: newClientMethodInfo(clientMethod((*Client).callLoggingHandler), notification), notificationProgress: newClientMethodInfo(clientSessionMethod((*ClientSession).callProgressNotificationHandler), notification), notificationElicitationComplete: newClientMethodInfo(clientMethod((*Client).callElicitationCompleteHandler), notification|missingParamsOK), + notificationSubscriptionsAck: newClientMethodInfo(clientMethod((*Client).callSubscriptionsAckHandler), notification|missingParamsOK), } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { @@ -1234,6 +1255,21 @@ func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribePar return err } +// SubscriptionsListen opens a SEP-2575 "subscriptions/listen" stream and +// blocks for the lifetime of the subscription. The server's first message on +// the stream is "notifications/subscriptions/acknowledged"; subsequent +// opted-in notifications (e.g. tools/list_changed) are delivered through the +// usual handlers registered in [ClientOptions]. +func (cs *ClientSession) subscriptionsListen(ctx context.Context, params *SubscriptionsListenParams) error { + params = injectRequestMeta(cs, params) + _, err := handleSend[*emptyResult](ctx, methodSubscriptionsListen, newClientRequest(cs, orZero[Params](params))) + return err +} + +func (c *Client) callSubscriptionsAckHandler(context.Context, *ClientRequest[*SubscriptionsAcknowledgedParams]) (Result, error) { + return nil, nil +} + func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { if h := c.opts.ToolListChangedHandler; h != nil { h(ctx, req) diff --git a/mcp/protocol.go b/mcp/protocol.go index 4ba8600e..6eca3698 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -1883,6 +1883,49 @@ type ResourceUpdatedNotificationParams struct { func (x *ResourceUpdatedNotificationParams) isParams() {} func (x *ResourceUpdatedNotificationParams) isNil() bool { return x == nil } +// NotificationSubscriptions describes the set of server-to-client +// notifications a client wishes to receive on a [SubscriptionsListenParams] +// stream. Each field is an explicit opt-in: a server MUST NOT push +// notifications of a type the client did not request. +type NotificationSubscriptions struct { + // ToolsListChanged opts in to "notifications/tools/list_changed". + ToolsListChanged bool `json:"toolsListChanged,omitempty"` + // PromptsListChanged opts in to "notifications/prompts/list_changed". + PromptsListChanged bool `json:"promptsListChanged,omitempty"` + // ResourcesListChanged opts in to "notifications/resources/list_changed". + ResourcesListChanged bool `json:"resourcesListChanged,omitempty"` + // ResourceSubscriptions enumerates the resource URIs for which the client + // wants "notifications/resources/updated". Replaces the legacy + // resources/subscribe RPC. + ResourceSubscriptions []string `json:"resourceSubscriptions,omitempty"` +} + +// SubscriptionsListenParams are the parameters for the +// "subscriptions/listen" RPC. +type SubscriptionsListenParams struct { + // Meta carries the per-request `_meta` triple. + Meta `json:"_meta,omitempty"` + // Notifications declares which notification types the client wants to + // receive on this stream. + Notifications NotificationSubscriptions `json:"notifications"` +} + +func (x *SubscriptionsListenParams) isParams() {} +func (x *SubscriptionsListenParams) isNil() bool { return x == nil } + +// SubscriptionsAcknowledgedParams are the parameters for the +// "notifications/subscriptions/acknowledged" notification, which the server +// MUST send as the first message on a subscriptions/listen stream. It carries +// the subset of the requested [NotificationSubscriptions] that the server has +// agreed to honor. +type SubscriptionsAcknowledgedParams struct { + Meta `json:"_meta,omitempty"` + Notifications NotificationSubscriptions `json:"notifications"` +} + +func (x *SubscriptionsAcknowledgedParams) isParams() {} +func (x *SubscriptionsAcknowledgedParams) isNil() bool { return x == nil } + // TODO(jba): add CompleteRequest and related types. // A request from the server to elicit additional information from the user via the client. @@ -2086,8 +2129,10 @@ const ( notificationRootsListChanged = "notifications/roots/list_changed" methodSetLevel = "logging/setLevel" methodSubscribe = "resources/subscribe" + methodSubscriptionsListen = "subscriptions/listen" notificationToolListChanged = "notifications/tools/list_changed" methodUnsubscribe = "resources/unsubscribe" + notificationSubscriptionsAck = "notifications/subscriptions/acknowledged" ) // Per-request _meta field names for the >= 2026-06-30 protocol version. @@ -2104,6 +2149,9 @@ const ( MetaKeyClientCapabilities = "io.modelcontextprotocol/clientCapabilities" // MetaKeyLogLevel identifies the desired log level for the request. MetaKeyLogLevel = "io.modelcontextprotocol/logLevel" + // MetaKeySubscriptionID identifies the subscriptions/listen request that an + // out-of-band notification belongs to. + MetaKeySubscriptionID = "io.modelcontextprotocol/subscriptionId" ) // UnsupportedProtocolVersionData is the SEP-2575 payload carried in the diff --git a/mcp/requests.go b/mcp/requests.go index 36368c99..64414caa 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -19,6 +19,7 @@ type ( ReadResourceRequest = ServerRequest[*ReadResourceParams] RootsListChangedRequest = ServerRequest[*RootsListChangedParams] SubscribeRequest = ServerRequest[*SubscribeParams] + SubscriptionsListenRequest = ServerRequest[*SubscriptionsListenParams] UnsubscribeRequest = ServerRequest[*UnsubscribeParams] ) diff --git a/mcp/server.go b/mcp/server.go index adcf698f..d4af3761 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -53,7 +53,10 @@ type Server struct { sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + toolsSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + promptsSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + resourcesSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -204,6 +207,9 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), + toolsSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), + promptsSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), + resourcesSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), pendingNotifications: make(map[string]*time.Timer), } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) @@ -643,12 +649,11 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR return s.opts.CompletionHandler(ctx, req) } -// Map from notification name to its corresponding params. The params have no fields, -// so a single struct can be reused. -var changeNotificationParams = map[string]Params{ - notificationToolListChanged: &ToolListChangedParams{}, - notificationPromptListChanged: &PromptListChangedParams{}, - notificationResourceListChanged: &ResourceListChangedParams{}, +// Map from notification name to its corresponding params. +var changeNotificationParams = map[string]func() Params{ + notificationToolListChanged: func() Params { return &ToolListChangedParams{} }, + notificationPromptListChanged: func() Params { return &PromptListChangedParams{} }, + notificationResourceListChanged: func() Params { return &ResourceListChangedParams{} }, } // How long to wait before sending a change notification. @@ -687,7 +692,47 @@ func (s *Server) notifySessions(n string) { sessions := slices.Clone(s.sessions) s.pendingNotifications[n] = nil s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. - notifySessions(sessions, n, changeNotificationParams[n], s.opts.Logger) + + // Only add legacy sessions for the notification, new ones use the new notification mechanism. + var legacySessions []*ServerSession + for _, s := range sessions { + if s.InitializeParams().isNil() || s.InitializeParams().ProtocolVersion < protocolVersion20260630 { + legacySessions = append(legacySessions, s) + } + } + notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) + + var activeSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + switch n { + case notificationToolListChanged: + activeSubscriptions = s.toolsSubscriptions + case notificationPromptListChanged: + activeSubscriptions = s.promptsSubscriptions + case notificationResourceListChanged: + activeSubscriptions = s.resourcesSubscriptions + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for session := range activeSubscriptions { + for reqID := range activeSubscriptions[session] { + params := changeNotificationParams[n]() + setSubscriptionID(params, reqID) + req := newRequest(session, params) + ctx = context.WithValue(ctx, idContextKey{}, reqID) + if err := handleNotify(ctx, n, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) + } + } + } +} + +func setSubscriptionID(params Params, reqID jsonrpc.ID) { + m := params.GetMeta() + if m == nil { + m = map[string]any{} + } + m[MetaKeySubscriptionID] = reqID.Raw() + params.SetMeta(m) } // shouldSendListChangedNotification checks if the server's capabilities allow @@ -1027,6 +1072,80 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp return &emptyResult{}, nil } +func (s *Server) addSubscription(index map[*ServerSession]map[jsonrpc.ID]bool, ss *ServerSession, id jsonrpc.ID) { + s.mu.Lock() + defer s.mu.Unlock() + ids := index[ss] + if ids == nil { + ids = make(map[jsonrpc.ID]bool) + index[ss] = ids + } + ids[id] = true +} + +func (s *Server) removeSubscription(index map[*ServerSession]map[jsonrpc.ID]bool, ss *ServerSession, id jsonrpc.ID) { + s.mu.Lock() + defer s.mu.Unlock() + ids, ok := index[ss] + if !ok { + return + } + delete(ids, id) + if len(ids) == 0 { + delete(index, ss) + } +} + +func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { + requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) + if !ok || !requestID.IsValid() { + return nil, fmt.Errorf("%w: subscriptions/listen requires a request ID", jsonrpc2.ErrInvalidRequest) + } + + allowed := s.allowedSubscriptions(req.Params.Notifications) + if allowed.ToolsListChanged { + s.addSubscription(s.toolsSubscriptions, req.Session, requestID) + } + if allowed.PromptsListChanged { + s.addSubscription(s.promptsSubscriptions, req.Session, requestID) + } + if allowed.ResourcesListChanged { + s.addSubscription(s.resourcesSubscriptions, req.Session, requestID) + } + + ackParams := &SubscriptionsAcknowledgedParams{ + Notifications: allowed, + Meta: Meta{MetaKeySubscriptionID: requestID.Raw()}, + } + if err := req.Session.NotifySubscriptionAcked(ctx, ackParams); err != nil { + return nil, fmt.Errorf("sending subscriptions/acknowledged: %w", err) + } + + defer func() { + s.removeSubscription(s.toolsSubscriptions, req.Session, requestID) + s.removeSubscription(s.promptsSubscriptions, req.Session, requestID) + s.removeSubscription(s.resourcesSubscriptions, req.Session, requestID) + }() + + <-ctx.Done() + return &emptyResult{}, nil +} + +func (s *Server) allowedSubscriptions(want NotificationSubscriptions) NotificationSubscriptions { + caps := s.capabilities() + agreed := NotificationSubscriptions{} + if want.ToolsListChanged && caps.Tools != nil && caps.Tools.ListChanged { + agreed.ToolsListChanged = true + } + if want.PromptsListChanged && caps.Prompts != nil && caps.Prompts.ListChanged { + agreed.PromptsListChanged = true + } + if want.ResourcesListChanged && caps.Resources != nil && caps.Resources.ListChanged { + agreed.ResourcesListChanged = true + } + return agreed +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection or the provided @@ -1096,6 +1215,10 @@ func (s *Server) disconnect(cc *ServerSession) { for _, subscribedSessions := range s.resourceSubscriptions { delete(subscribedSessions, cc) } + delete(s.toolsSubscriptions, cc) + delete(s.promptsSubscriptions, cc) + delete(s.resourcesSubscriptions, cc) + s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) } @@ -1196,6 +1319,12 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) } +// NotifySubscriptionAcked sends a subscription acknowledged notification from the server to the client +// associated with this session. +func (ss *ServerSession) NotifySubscriptionAcked(ctx context.Context, params *SubscriptionsAcknowledgedParams) error { + return handleNotify(ctx, notificationSubscriptionsAck, newServerRequest(ss, orZero[Params](params))) +} + func newServerRequest[P Params](ss *ServerSession, params P) *ServerRequest[P] { return &ServerRequest[P]{Session: ss, Params: params} } @@ -1497,6 +1626,7 @@ var serverMethodInfos = map[string]methodInfo{ methodReadResource: newServerMethodInfo(serverMethod((*Server).readResource), 0), methodSetLevel: newServerMethodInfo(serverSessionMethod((*ServerSession).setLevel), 0), methodSubscribe: newServerMethodInfo(serverMethod((*Server).subscribe), 0), + methodSubscriptionsListen: newServerMethodInfo(serverMethod((*Server).subscriptionsListen), 0), methodUnsubscribe: newServerMethodInfo(serverMethod((*Server).unsubscribe), 0), notificationCancelled: newServerMethodInfo(serverSessionMethod((*ServerSession).cancel), notification|missingParamsOK), notificationInitialized: newServerMethodInfo(serverSessionMethod((*ServerSession).initialized), notification|missingParamsOK), @@ -1577,7 +1707,7 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, } switch req.Method { - case methodInitialize, methodPing, notificationInitialized, methodSetLevel: + case methodInitialize, methodPing, notificationInitialized, methodSetLevel, methodSubscribe, methodUnsubscribe: if validatedMeta.usesNewProtocol { ss.server.opts.Logger.Error("method removed in the new protocol", "method", req.Method) return nil, &jsonrpc.Error{ @@ -1663,7 +1793,11 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error // It should never be invoked in practice because cancellation is preempted, // but having its signature here facilitates the construction of methodInfo // that can be used to validate incoming cancellation notifications. -func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { +func (ss *ServerSession) cancel(ctx context.Context, param *CancelledParams) (Result, error) { + server := ss.server + server.removeSubscription(server.toolsSubscriptions, ss, param.RequestID.(jsonrpc.ID)) + server.removeSubscription(server.promptsSubscriptions, ss, param.RequestID.(jsonrpc.ID)) + server.removeSubscription(server.resourcesSubscriptions, ss, param.RequestID.(jsonrpc.ID)) return nil, nil } @@ -1696,6 +1830,14 @@ func (ss *ServerSession) Close() error { ss.onClose() } + // Clean up session subscriptions + server := ss.server + server.mu.Lock() + delete(server.toolsSubscriptions, ss) + delete(server.promptsSubscriptions, ss) + delete(server.resourcesSubscriptions, ss) + server.mu.Unlock() + return err } diff --git a/mcp/streamable.go b/mcp/streamable.go index 5bc31771..2c851dd3 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1365,6 +1365,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques calls := make(map[jsonrpc.ID]struct{}) tokenInfo := auth.TokenInfoFromContext(req.Context()) isInitialize := false + isSubscriptionsListen := false var initializeProtocolVersion string headerVersion := protocolVersionFromContext(req.Context()) for _, msg := range incoming { @@ -1391,6 +1392,9 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques initializeProtocolVersion = params.ProtocolVersion } } + if jreq.Method == methodSubscriptionsListen { + isSubscriptionsListen = true + } // SEP-2575: requests carrying `_meta.protocolVersion` require the // Mcp-Protocol-Version HTTP header to be present and to match the // per-request `_meta.protocolVersion` value. @@ -1528,13 +1532,20 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques return } + // subscriptions/listen is inherently a long-lived SSE endpoint (SEP-2575): + // it has no synchronous result, the response stream stays open until the + // client cancels, and the server pushes notifications on it as they occur. + // Force SSE mode (bypassing JSONResponse) so the buffered application/json + // path doesn't deadlock waiting for a completion that won't come. + useSSE := !c.jsonResponse || isSubscriptionsListen + // Set response headers. Accept was checked in [StreamableHTTPHandler]. w.Header().Set("Cache-Control", "no-cache, no-transform") - if c.jsonResponse { - w.Header().Set("Content-Type", "application/json") - } else { + if useSSE { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Connection", "keep-alive") + } else { + w.Header().Set("Content-Type", "application/json") } if c.sessionID != "" && isInitialize { w.Header().Set(sessionIDHeader, c.sessionID) @@ -1545,7 +1556,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques done := make(chan struct{}) stream.done = done stream.protocolVersion = effectiveVersion - if c.jsonResponse { + if !useSSE { // JSON mode: collect messages in pendingJSONMessages until done. // Set pendingJSONMessages to a non-nil value to signal that this is an // application/json stream. @@ -1571,6 +1582,17 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques c.logger.Warn(fmt.Sprintf("Writing priming event: %v", err)) } } + // For subscriptions/listen, flush headers eagerly so the client sees + // the open SSE stream even when an HTTP/2-aware reverse proxy is + // buffering the HEADERS frame waiting for a DATA frame to coalesce + // with. The acknowledgment notification will follow shortly, but the + // client can begin reading event-stream framing immediately. See the + // equivalent comment in [streamableServerConn.acquireStream]. + if isSubscriptionsListen { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, ": ok\n\n") + _ = http.NewResponseController(w).Flush() + } } // TODO(rfindley): if we have no event store, we should really cancel all @@ -2460,7 +2482,8 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. - if jsonResp.ID == forCall.ID { + // The subscriptions/listen is now returning a response in the SSE stream. + if jsonResp.ID == forCall.ID && forCall.Method != methodSubscriptionsListen { return "", 0, true } } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2efee20e..6e19336f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3716,3 +3716,137 @@ func TestStreamableHTTP_E2E_DiscoverSuccess(t *testing.T) { t.Errorf("CallTool result[0] = %+v, want TextContent{Text:\"hello\"}", res.Content[0]) } } + +// TestSubscriptionsListen_Streamable opens two concurrent subscriptions/listen +// streams on the same client session against a stateless HTTP server and +// verifies the full SEP-2575 contract end-to-end as observed by the client: +// +// - both listens are acknowledged, each tagged with its own request ID; +// - opt-in filtering: tool-list-changed reaches only the tools listen and +// prompt-list-changed reaches only the prompts listen; +// - the underlying SSE streams persist across multiple unrelated changes so +// notifications keep arriving until the session is closed; +// - subscription-ID tagging: every notification carries the originating +// listen's request ID in `_meta`, matching the ack's tag for that listen; +// - closing the client session ends all deliveries. +func TestSubscriptionsListen_Streamable(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "t1"}, sayHi) + server.AddPrompt(&Prompt{Name: "p1"}, nil) + + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + + type event struct { + kind string + id string + } + events := make(chan event, 64) + asEvent := func(kind string, raw any) event { return event{kind, fmt.Sprint(raw)} } + + client := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { + events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) + }, + PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { + events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) + }, + }) + client.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + + cs, err := client.Connect(context.Background(), &StreamableClientTransport{Endpoint: httpServer.URL}, + &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + startListen := func(notifs NotificationSubscriptions) { + go func() { + _ = cs.SubscriptionsListen(context.Background(), &SubscriptionsListenParams{Notifications: notifs}) + }() + } + waitFor := func(kind string) event { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return event{} + } + } + expectNoEvent := func(d time.Duration) { + t.Helper() + select { + case e := <-events: + t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) + case <-time.After(d): + } + } + + startListen(NotificationSubscriptions{ToolsListChanged: true}) + startListen(NotificationSubscriptions{PromptsListChanged: true}) + + ack1 := waitFor("ack") + ack2 := waitFor("ack") + if ack1.id == ack2.id { + t.Fatalf("both acks share subscription ID %s; expected distinct per-listen IDs", ack1.id) + } + + // Trigger a tool change. The notification's subscription ID identifies + // which of the two acks belongs to the tools listen; the other ack + // therefore belongs to the prompts listen. + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + toolEv := waitFor("tool") + if toolEv.id != ack1.id && toolEv.id != ack2.id { + t.Fatalf("tool notif id %s matches neither ack (%s, %s)", toolEv.id, ack1.id, ack2.id) + } + toolSubID := toolEv.id + promptSubID := ack2.id + if ack1.id != toolSubID { + promptSubID = ack1.id + } + expectNoEvent(notificationDelay * 5) + + server.AddPrompt(&Prompt{Name: "p2"}, nil) + if e := waitFor("prompt"); e.id != promptSubID { + t.Errorf("prompt notif id = %s, want %s", e.id, promptSubID) + } + expectNoEvent(notificationDelay * 5) + + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + if e := waitFor("tool"); e.id != toolSubID { + t.Errorf("second tool notif id = %s, want %s", e.id, toolSubID) + } + server.AddPrompt(&Prompt{Name: "p3"}, nil) + if e := waitFor("prompt"); e.id != promptSubID { + t.Errorf("second prompt notif id = %s, want %s", e.id, promptSubID) + } + expectNoEvent(notificationDelay * 5) + + cs.Close() + time.Sleep(50 * time.Millisecond) + server.AddTool(&Tool{Name: "t4", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + server.AddPrompt(&Prompt{Name: "p4"}, nil) + expectNoEvent(notificationDelay * 20) +} diff --git a/mcp/transport_test.go b/mcp/transport_test.go index 515b8c19..d882998c 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -6,10 +6,14 @@ package mcp import ( "context" + "fmt" "io" + "slices" "strings" "testing" + "time" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -124,3 +128,150 @@ func TestIOConnRead(t *testing.T) { }) } } + +// TestSubscriptionsListen_InMemory verifies SEP-2575 subscriptions/listen +// over a single shared session (in-memory transport, semantically equivalent +// to STDIO). It exercises behavior that is harder to observe over streamable +// HTTP, where each listen lives in its own ephemeral session: +// +// - two concurrent listens on the SAME session both deliver notifications; +// - opt-in filtering: each listen receives only its opted-in notification +// types, tagged with its own subscription ID; +// - per-listen cancellation propagates over notifications/cancelled: when +// the client cancels one listen's context, the server stops fanning out +// notifications for that listen but keeps delivering to the other; +// - the remaining listen continues to work after the first cancellation. +func TestSubscriptionsListen_InMemory(t *testing.T) { + orig := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) + t.Cleanup(func() { supportedProtocolVersions = orig }) + + ctx, topCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer topCancel() + + server := NewServer(testImpl, nil) + AddTool(server, &Tool{Name: "t1"}, sayHi) + server.AddPrompt(&Prompt{Name: "p1"}, nil) + + ct, st := NewInMemoryTransports() + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + type event struct { + kind string + id string + } + events := make(chan event, 64) + asEvent := func(kind string, raw any) event { return event{kind, fmt.Sprint(raw)} } + + client := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { + events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) + }, + PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { + events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) + }, + }) + client.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + + cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + startListen := func(notifs NotificationSubscriptions) context.CancelFunc { + lctx, cancel := context.WithCancel(ctx) + go func() { + _ = cs.SubscriptionsListen(lctx, &SubscriptionsListenParams{Notifications: notifs}) + }() + return cancel + } + waitFor := func(kind string) event { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return event{} + } + } + expectNoEvent := func(d time.Duration) { + t.Helper() + select { + case e := <-events: + t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) + case <-time.After(d): + } + } + + cancelTools := startListen(NotificationSubscriptions{ToolsListChanged: true}) + cancelPrompts := startListen(NotificationSubscriptions{PromptsListChanged: true}) + + ack1 := waitFor("ack") + ack2 := waitFor("ack") + if ack1.id == ack2.id { + t.Fatalf("both acks share subscription ID %s; expected distinct per-listen IDs", ack1.id) + } + + // Identify which ack belongs to which listen by triggering a tool change + // and observing the tagged subscription ID on the notification. + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + toolEv := waitFor("tool") + if toolEv.id != ack1.id && toolEv.id != ack2.id { + t.Fatalf("tool notif id %s matches neither ack (%s, %s)", toolEv.id, ack1.id, ack2.id) + } + toolSubID := toolEv.id + promptSubID := ack2.id + if ack1.id != toolSubID { + promptSubID = ack1.id + } + expectNoEvent(notificationDelay * 5) + + server.AddPrompt(&Prompt{Name: "p2"}, nil) + if e := waitFor("prompt"); e.id != promptSubID { + t.Errorf("prompt notif id = %s, want %s", e.id, promptSubID) + } + expectNoEvent(notificationDelay * 5) + + // Cancel the tools listen. The SDK sends a notifications/cancelled to the + // server, which flips the listen handler's ctx, which unblocks the + // goroutine that removes the subscription from the server's index. + cancelTools() + + // Give the cancellation a moment to propagate (notifications/cancelled + // → server-side cancel → cleanup goroutine). + time.Sleep(50 * time.Millisecond) + + // A new tool change must NOT reach the (cancelled) tools listen, while + // the prompts listen continues to receive its notifications. + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + expectNoEvent(notificationDelay * 20) + + server.AddPrompt(&Prompt{Name: "p3"}, nil) + if e := waitFor("prompt"); e.id != promptSubID { + t.Errorf("prompt notif after tools-cancel id = %s, want %s", e.id, promptSubID) + } + + cancelPrompts() + time.Sleep(50 * time.Millisecond) + + server.AddPrompt(&Prompt{Name: "p4"}, nil) + expectNoEvent(notificationDelay * 20) +} From 0066ac920db4f16ef3a1b2501f02a499641f08e5 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 11 Jun 2026 15:55:27 +0000 Subject: [PATCH 02/18] refactor: remove subscription-based notification tests and cleanup unused imports --- mcp/client.go | 14 ++- mcp/mcp_test.go | 201 +++++++++++++++++++++++++++++++++++++++++ mcp/server.go | 12 ++- mcp/streamable_test.go | 134 --------------------------- mcp/transport_test.go | 151 ------------------------------- 5 files changed, 219 insertions(+), 293 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index e40f148a..fff96859 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -313,11 +313,13 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp if subscribeParams.Notifications.ToolsListChanged || subscribeParams.Notifications.PromptsListChanged || subscribeParams.Notifications.ResourcesListChanged { - // The listen blocks until the server cancels it (or the - // underlying connection ends). Run it in a goroutine so - // Connect can return; the SDK retires it on cs.Close(). + // The listen blocks until the server cancels it. Run it in + // a goroutine so Connect can return; ClientSession.Close + // cancels its context to send notifications/cancelled. + listenCtx, cancelListen := context.WithCancel(context.Background()) + cs.listenCancel = cancelListen go func() { - _ = cs.subscriptionsListen(context.Background(), subscribeParams) + _ = cs.subscriptionsListen(listenCtx, subscribeParams) }() } return cs, nil @@ -436,6 +438,7 @@ type ClientSession struct { conn *jsonrpc2.Connection client *Client keepaliveCancel context.CancelFunc + listenCancel context.CancelFunc mcpConn Connection // No mutex is (currently) required to guard the session state, because it is @@ -518,6 +521,9 @@ func (cs *ClientSession) Close() error { if cs.keepaliveCancel != nil { cs.keepaliveCancel() } + if cs.listenCancel != nil { + cs.listenCancel() + } err := cs.conn.Close() if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) { diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5c8e7d12..02f05618 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -12,6 +12,8 @@ import ( "fmt" "io" "log/slog" + "net/http" + "net/http/httptest" "net/url" "path/filepath" "runtime" @@ -2449,3 +2451,202 @@ func TestSetErrorPreservesContent(t *testing.T) { } var ctrCmpOpts = []cmp.Option{cmpopts.IgnoreUnexported(CallToolResult{}, GetPromptResult{}, ReadResourceResult{})} + +// runSubscriptionsListenTest exercises the SEP-2575 auto-listen flow end-to-end +// against the supplied transport pair. It captures every notification and the +// acknowledgment the client sees, then asserts: +// +// - the auto-listen issued by Client.Connect is acknowledged with a tagged +// subscription ID; +// - tool and prompt list-changed notifications are delivered to the matching +// handlers, each carrying the same subscription ID as the ack; +// - the subscription persists across multiple unrelated changes; +// - cs.Close() ends the subscription and further changes don't deliver. +func runSubscriptionsListenTest(t *testing.T, client *Client, server *Server, ct Transport, events chan subListenEvent) { + t.Helper() + + ctx, topCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer topCancel() + + cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + waitFor := func(kind string) subListenEvent { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return subListenEvent{} + } + } + expectNoEvent := func(d time.Duration) { + t.Helper() + select { + case e := <-events: + t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) + case <-time.After(d): + } + } + + ack := waitFor("ack") + if ack.id == "" { + t.Fatalf("acknowledgment missing subscription ID") + } + + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + if e := waitFor("tool"); e.id != ack.id { + t.Errorf("first tool notif id = %s, want %s", e.id, ack.id) + } + + server.AddPrompt(&Prompt{Name: "p2"}, nil) + if e := waitFor("prompt"); e.id != ack.id { + t.Errorf("first prompt notif id = %s, want %s", e.id, ack.id) + } + + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + if e := waitFor("tool"); e.id != ack.id { + t.Errorf("second tool notif id = %s, want %s", e.id, ack.id) + } + server.AddPrompt(&Prompt{Name: "p3"}, nil) + if e := waitFor("prompt"); e.id != ack.id { + t.Errorf("second prompt notif id = %s, want %s", e.id, ack.id) + } + expectNoEvent(notificationDelay * 5) + + cs.Close() + time.Sleep(50 * time.Millisecond) + + server.AddTool(&Tool{Name: "t4", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + server.AddPrompt(&Prompt{Name: "p4"}, nil) + expectNoEvent(notificationDelay * 20) +} + +type subListenEvent struct { + kind string // "ack", "tool", "prompt" + id string // subscription ID from _meta, stringified for cross-encoding equality +} + +// newSubListenClient returns a client wired to push every ack and every +// list-changed notification it receives into events, tagged with the kind +// and the subscription ID extracted from _meta. +func newSubListenClient(events chan subListenEvent) *Client { + asEvent := func(kind string, raw any) subListenEvent { + return subListenEvent{kind, fmt.Sprint(raw)} + } + c := NewClient(testImpl, &ClientOptions{ + ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { + events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) + }, + PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { + events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) + }, + }) + c.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + return c +} + +func newSubListenServer() *Server { + s := NewServer(testImpl, nil) + AddTool(s, &Tool{Name: "t1"}, sayHi) + s.AddPrompt(&Prompt{Name: "p1"}, nil) + return s +} + +func enableNewProtocol(t *testing.T) { + t.Helper() + orig := supportedProtocolVersions + supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) + t.Cleanup(func() { supportedProtocolVersions = orig }) +} + +// TestSubscriptionsListen_InMemory exercises the listen flow over the +// session-shared in-memory transport (semantically equivalent to STDIO). +// Cancellation here propagates via notifications/cancelled. +func TestSubscriptionsListen_InMemory(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 64) + server := newSubListenServer() + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + runSubscriptionsListenTest(t, newSubListenClient(events), server, ct, events) +} + +// TestSubscriptionsListen_Streamable exercises the listen flow over a +// stateless HTTP server (SEP-2575). Each listen rides its own SSE response +// stream; cs.Close() tears it down. +func TestSubscriptionsListen_Streamable(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 64) + server := newSubListenServer() + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + defer httpServer.Close() + runSubscriptionsListenTest(t, newSubListenClient(events), server, + &StreamableClientTransport{Endpoint: httpServer.URL}, events) +} + +// TestSubscriptionsListen_NoHandlersNoListen verifies that a new-protocol +// client without any list-changed handlers registered does not open an +// auto-listen on connect, and therefore does not receive any acknowledgment +// or downstream notifications. +func TestSubscriptionsListen_NoHandlersNoListen(t *testing.T) { + enableNewProtocol(t) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + events := make(chan subListenEvent, 8) + server := newSubListenServer() + ct, st := NewInMemoryTransports() + ss, err := server.Connect(ctx, st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + c := NewClient(testImpl, nil) + c.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + events <- subListenEvent{"ack", ""} + } + return next(ctx, method, req) + } + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + + select { + case e := <-events: + t.Fatalf("unexpected event %q on no-handler client", e.kind) + case <-time.After(notificationDelay * 10): + } +} diff --git a/mcp/server.go b/mcp/server.go index d4af3761..7cf2a9c5 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1793,11 +1793,15 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error // It should never be invoked in practice because cancellation is preempted, // but having its signature here facilitates the construction of methodInfo // that can be used to validate incoming cancellation notifications. -func (ss *ServerSession) cancel(ctx context.Context, param *CancelledParams) (Result, error) { +func (ss *ServerSession) cancel(_ context.Context, param *CancelledParams) (Result, error) { + id, err := jsonrpc.MakeID(param.RequestID) + if err != nil { + return nil, nil + } server := ss.server - server.removeSubscription(server.toolsSubscriptions, ss, param.RequestID.(jsonrpc.ID)) - server.removeSubscription(server.promptsSubscriptions, ss, param.RequestID.(jsonrpc.ID)) - server.removeSubscription(server.resourcesSubscriptions, ss, param.RequestID.(jsonrpc.ID)) + server.removeSubscription(server.toolsSubscriptions, ss, id) + server.removeSubscription(server.promptsSubscriptions, ss, id) + server.removeSubscription(server.resourcesSubscriptions, ss, id) return nil, nil } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 6e19336f..2efee20e 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3716,137 +3716,3 @@ func TestStreamableHTTP_E2E_DiscoverSuccess(t *testing.T) { t.Errorf("CallTool result[0] = %+v, want TextContent{Text:\"hello\"}", res.Content[0]) } } - -// TestSubscriptionsListen_Streamable opens two concurrent subscriptions/listen -// streams on the same client session against a stateless HTTP server and -// verifies the full SEP-2575 contract end-to-end as observed by the client: -// -// - both listens are acknowledged, each tagged with its own request ID; -// - opt-in filtering: tool-list-changed reaches only the tools listen and -// prompt-list-changed reaches only the prompts listen; -// - the underlying SSE streams persist across multiple unrelated changes so -// notifications keep arriving until the session is closed; -// - subscription-ID tagging: every notification carries the originating -// listen's request ID in `_meta`, matching the ack's tag for that listen; -// - closing the client session ends all deliveries. -func TestSubscriptionsListen_Streamable(t *testing.T) { - orig := supportedProtocolVersions - supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) - t.Cleanup(func() { supportedProtocolVersions = orig }) - - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "t1"}, sayHi) - server.AddPrompt(&Prompt{Name: "p1"}, nil) - - handler := NewStreamableHTTPHandler( - func(*http.Request) *Server { return server }, - &StreamableHTTPOptions{Stateless: true}, - ) - httpServer := httptest.NewServer(mustNotPanic(t, handler)) - defer httpServer.Close() - - type event struct { - kind string - id string - } - events := make(chan event, 64) - asEvent := func(kind string, raw any) event { return event{kind, fmt.Sprint(raw)} } - - client := NewClient(testImpl, &ClientOptions{ - ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { - events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) - }, - PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { - events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) - }, - }) - client.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { - return func(ctx context.Context, method string, req Request) (Result, error) { - if method == notificationSubscriptionsAck { - if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { - events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) - } - } - return next(ctx, method, req) - } - }) - - cs, err := client.Connect(context.Background(), &StreamableClientTransport{Endpoint: httpServer.URL}, - &ClientSessionOptions{protocolVersion: protocolVersion20260630}) - if err != nil { - t.Fatalf("client connect: %v", err) - } - - startListen := func(notifs NotificationSubscriptions) { - go func() { - _ = cs.SubscriptionsListen(context.Background(), &SubscriptionsListenParams{Notifications: notifs}) - }() - } - waitFor := func(kind string) event { - t.Helper() - select { - case e := <-events: - if e.kind != kind { - t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) - } - return e - case <-time.After(5 * time.Second): - t.Fatalf("timed out waiting for %q event", kind) - return event{} - } - } - expectNoEvent := func(d time.Duration) { - t.Helper() - select { - case e := <-events: - t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) - case <-time.After(d): - } - } - - startListen(NotificationSubscriptions{ToolsListChanged: true}) - startListen(NotificationSubscriptions{PromptsListChanged: true}) - - ack1 := waitFor("ack") - ack2 := waitFor("ack") - if ack1.id == ack2.id { - t.Fatalf("both acks share subscription ID %s; expected distinct per-listen IDs", ack1.id) - } - - // Trigger a tool change. The notification's subscription ID identifies - // which of the two acks belongs to the tools listen; the other ack - // therefore belongs to the prompts listen. - server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - toolEv := waitFor("tool") - if toolEv.id != ack1.id && toolEv.id != ack2.id { - t.Fatalf("tool notif id %s matches neither ack (%s, %s)", toolEv.id, ack1.id, ack2.id) - } - toolSubID := toolEv.id - promptSubID := ack2.id - if ack1.id != toolSubID { - promptSubID = ack1.id - } - expectNoEvent(notificationDelay * 5) - - server.AddPrompt(&Prompt{Name: "p2"}, nil) - if e := waitFor("prompt"); e.id != promptSubID { - t.Errorf("prompt notif id = %s, want %s", e.id, promptSubID) - } - expectNoEvent(notificationDelay * 5) - - server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - if e := waitFor("tool"); e.id != toolSubID { - t.Errorf("second tool notif id = %s, want %s", e.id, toolSubID) - } - server.AddPrompt(&Prompt{Name: "p3"}, nil) - if e := waitFor("prompt"); e.id != promptSubID { - t.Errorf("second prompt notif id = %s, want %s", e.id, promptSubID) - } - expectNoEvent(notificationDelay * 5) - - cs.Close() - time.Sleep(50 * time.Millisecond) - server.AddTool(&Tool{Name: "t4", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - server.AddPrompt(&Prompt{Name: "p4"}, nil) - expectNoEvent(notificationDelay * 20) -} diff --git a/mcp/transport_test.go b/mcp/transport_test.go index d882998c..515b8c19 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -6,14 +6,10 @@ package mcp import ( "context" - "fmt" "io" - "slices" "strings" "testing" - "time" - "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -128,150 +124,3 @@ func TestIOConnRead(t *testing.T) { }) } } - -// TestSubscriptionsListen_InMemory verifies SEP-2575 subscriptions/listen -// over a single shared session (in-memory transport, semantically equivalent -// to STDIO). It exercises behavior that is harder to observe over streamable -// HTTP, where each listen lives in its own ephemeral session: -// -// - two concurrent listens on the SAME session both deliver notifications; -// - opt-in filtering: each listen receives only its opted-in notification -// types, tagged with its own subscription ID; -// - per-listen cancellation propagates over notifications/cancelled: when -// the client cancels one listen's context, the server stops fanning out -// notifications for that listen but keeps delivering to the other; -// - the remaining listen continues to work after the first cancellation. -func TestSubscriptionsListen_InMemory(t *testing.T) { - orig := supportedProtocolVersions - supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) - t.Cleanup(func() { supportedProtocolVersions = orig }) - - ctx, topCancel := context.WithTimeout(context.Background(), 30*time.Second) - defer topCancel() - - server := NewServer(testImpl, nil) - AddTool(server, &Tool{Name: "t1"}, sayHi) - server.AddPrompt(&Prompt{Name: "p1"}, nil) - - ct, st := NewInMemoryTransports() - ss, err := server.Connect(ctx, st, nil) - if err != nil { - t.Fatalf("server connect: %v", err) - } - defer ss.Close() - - type event struct { - kind string - id string - } - events := make(chan event, 64) - asEvent := func(kind string, raw any) event { return event{kind, fmt.Sprint(raw)} } - - client := NewClient(testImpl, &ClientOptions{ - ToolListChangedHandler: func(_ context.Context, req *ToolListChangedRequest) { - events <- asEvent("tool", req.Params.Meta[MetaKeySubscriptionID]) - }, - PromptListChangedHandler: func(_ context.Context, req *PromptListChangedRequest) { - events <- asEvent("prompt", req.Params.Meta[MetaKeySubscriptionID]) - }, - }) - client.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { - return func(ctx context.Context, method string, req Request) (Result, error) { - if method == notificationSubscriptionsAck { - if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { - events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) - } - } - return next(ctx, method, req) - } - }) - - cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) - if err != nil { - t.Fatalf("client connect: %v", err) - } - defer cs.Close() - - startListen := func(notifs NotificationSubscriptions) context.CancelFunc { - lctx, cancel := context.WithCancel(ctx) - go func() { - _ = cs.SubscriptionsListen(lctx, &SubscriptionsListenParams{Notifications: notifs}) - }() - return cancel - } - waitFor := func(kind string) event { - t.Helper() - select { - case e := <-events: - if e.kind != kind { - t.Fatalf("got event %q (id=%s), want kind %q", e.kind, e.id, kind) - } - return e - case <-time.After(5 * time.Second): - t.Fatalf("timed out waiting for %q event", kind) - return event{} - } - } - expectNoEvent := func(d time.Duration) { - t.Helper() - select { - case e := <-events: - t.Fatalf("unexpected event %q (id=%s)", e.kind, e.id) - case <-time.After(d): - } - } - - cancelTools := startListen(NotificationSubscriptions{ToolsListChanged: true}) - cancelPrompts := startListen(NotificationSubscriptions{PromptsListChanged: true}) - - ack1 := waitFor("ack") - ack2 := waitFor("ack") - if ack1.id == ack2.id { - t.Fatalf("both acks share subscription ID %s; expected distinct per-listen IDs", ack1.id) - } - - // Identify which ack belongs to which listen by triggering a tool change - // and observing the tagged subscription ID on the notification. - server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - toolEv := waitFor("tool") - if toolEv.id != ack1.id && toolEv.id != ack2.id { - t.Fatalf("tool notif id %s matches neither ack (%s, %s)", toolEv.id, ack1.id, ack2.id) - } - toolSubID := toolEv.id - promptSubID := ack2.id - if ack1.id != toolSubID { - promptSubID = ack1.id - } - expectNoEvent(notificationDelay * 5) - - server.AddPrompt(&Prompt{Name: "p2"}, nil) - if e := waitFor("prompt"); e.id != promptSubID { - t.Errorf("prompt notif id = %s, want %s", e.id, promptSubID) - } - expectNoEvent(notificationDelay * 5) - - // Cancel the tools listen. The SDK sends a notifications/cancelled to the - // server, which flips the listen handler's ctx, which unblocks the - // goroutine that removes the subscription from the server's index. - cancelTools() - - // Give the cancellation a moment to propagate (notifications/cancelled - // → server-side cancel → cleanup goroutine). - time.Sleep(50 * time.Millisecond) - - // A new tool change must NOT reach the (cancelled) tools listen, while - // the prompts listen continues to receive its notifications. - server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) - expectNoEvent(notificationDelay * 20) - - server.AddPrompt(&Prompt{Name: "p3"}, nil) - if e := waitFor("prompt"); e.id != promptSubID { - t.Errorf("prompt notif after tools-cancel id = %s, want %s", e.id, promptSubID) - } - - cancelPrompts() - time.Sleep(50 * time.Millisecond) - - server.AddPrompt(&Prompt{Name: "p4"}, nil) - expectNoEvent(notificationDelay * 20) -} From 8974819e587c0535cb49fdb7172917fd2d7885cb Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 12 Jun 2026 07:21:07 +0000 Subject: [PATCH 03/18] wip --- mcp/server.go | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 7cf2a9c5..7cb383c2 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1793,15 +1793,7 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error // It should never be invoked in practice because cancellation is preempted, // but having its signature here facilitates the construction of methodInfo // that can be used to validate incoming cancellation notifications. -func (ss *ServerSession) cancel(_ context.Context, param *CancelledParams) (Result, error) { - id, err := jsonrpc.MakeID(param.RequestID) - if err != nil { - return nil, nil - } - server := ss.server - server.removeSubscription(server.toolsSubscriptions, ss, id) - server.removeSubscription(server.promptsSubscriptions, ss, id) - server.removeSubscription(server.resourcesSubscriptions, ss, id) +func (ss *ServerSession) cancel(_ context.Context, _ *CancelledParams) (Result, error) { return nil, nil } From d2803b70e8c7ffbd998346e330125c7a672dc8a4 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 12 Jun 2026 14:03:49 +0000 Subject: [PATCH 04/18] refactor: encapsulate subscription management with activeSubscriptions type and dedicated methods --- mcp/server.go | 106 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 68 insertions(+), 38 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 7cb383c2..ab7b6878 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -35,6 +35,9 @@ import ( // DefaultPageSize is the default for [ServerOptions.PageSize]. const DefaultPageSize = 1000 +// A map from sessions to the set of subscription IDs they have active for a given feature. +type activeSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + // A Server is an instance of an MCP server. // // Servers expose server-side MCP features, which can serve one or more MCP @@ -53,9 +56,9 @@ type Server struct { sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - toolsSubscriptions map[*ServerSession]map[jsonrpc.ID]bool - promptsSubscriptions map[*ServerSession]map[jsonrpc.ID]bool - resourcesSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + toolsSubscriptions activeSubscriptions + promptsSubscriptions activeSubscriptions + resourcesSubscriptions activeSubscriptions pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } @@ -207,9 +210,9 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), - toolsSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), - promptsSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), - resourcesSubscriptions: make(map[*ServerSession]map[jsonrpc.ID]bool), + toolsSubscriptions: make(activeSubscriptions), + promptsSubscriptions: make(activeSubscriptions), + resourcesSubscriptions: make(activeSubscriptions), pendingNotifications: make(map[string]*time.Timer), } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) @@ -702,7 +705,15 @@ func (s *Server) notifySessions(n string) { } notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) - var activeSubscriptions map[*ServerSession]map[jsonrpc.ID]bool + // Notify sessions that subscribed to the notification with 'subscriptions/listen' method. + type subscription struct { + session *ServerSession + reqID jsonrpc.ID + } + var subs []subscription + + s.mu.Lock() + var activeSubscriptions activeSubscriptions switch n { case notificationToolListChanged: activeSubscriptions = s.toolsSubscriptions @@ -711,22 +722,27 @@ func (s *Server) notifySessions(n string) { case notificationResourceListChanged: activeSubscriptions = s.resourcesSubscriptions } + for session, reqIDs := range activeSubscriptions { + for reqID := range reqIDs { + subs = append(subs, subscription{session: session, reqID: reqID}) + } + } + s.mu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - for session := range activeSubscriptions { - for reqID := range activeSubscriptions[session] { - params := changeNotificationParams[n]() - setSubscriptionID(params, reqID) - req := newRequest(session, params) - ctx = context.WithValue(ctx, idContextKey{}, reqID) - if err := handleNotify(ctx, n, req); err != nil { - s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) - } + for _, sub := range subs { + params := changeNotificationParams[n]() + injectMetaSubscriptionID(params, sub.reqID) + req := newRequest(sub.session, params) + reqCtx := context.WithValue(ctx, idContextKey{}, sub.reqID) + if err := handleNotify(reqCtx, n, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) } } } -func setSubscriptionID(params Params, reqID jsonrpc.ID) { +func injectMetaSubscriptionID(params Params, reqID jsonrpc.ID) { m := params.GetMeta() if m == nil { m = map[string]any{} @@ -1072,46 +1088,62 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp return &emptyResult{}, nil } -func (s *Server) addSubscription(index map[*ServerSession]map[jsonrpc.ID]bool, ss *ServerSession, id jsonrpc.ID) { - s.mu.Lock() - defer s.mu.Unlock() - ids := index[ss] +func (m activeSubscriptions) add(ss *ServerSession, id jsonrpc.ID) { + ids := m[ss] if ids == nil { ids = make(map[jsonrpc.ID]bool) - index[ss] = ids + m[ss] = ids } ids[id] = true } -func (s *Server) removeSubscription(index map[*ServerSession]map[jsonrpc.ID]bool, ss *ServerSession, id jsonrpc.ID) { - s.mu.Lock() - defer s.mu.Unlock() - ids, ok := index[ss] +func (m activeSubscriptions) remove(ss *ServerSession, id jsonrpc.ID) { + ids, ok := m[ss] if !ok { return } delete(ids, id) if len(ids) == 0 { - delete(index, ss) + delete(m, ss) } } -func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { - requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) - if !ok || !requestID.IsValid() { - return nil, fmt.Errorf("%w: subscriptions/listen requires a request ID", jsonrpc2.ErrInvalidRequest) +func (s *Server) addSubscriptions(allowed NotificationSubscriptions, ss *ServerSession, id jsonrpc.ID) { + s.mu.Lock() + defer s.mu.Unlock() + if allowed.ToolsListChanged { + s.toolsSubscriptions.add(ss, id) + } + if allowed.PromptsListChanged { + s.promptsSubscriptions.add(ss, id) + } + if allowed.ResourcesListChanged { + s.resourcesSubscriptions.add(ss, id) } +} - allowed := s.allowedSubscriptions(req.Params.Notifications) +func (s *Server) removeSubscriptions(allowed NotificationSubscriptions, ss *ServerSession, id jsonrpc.ID) { + s.mu.Lock() + defer s.mu.Unlock() if allowed.ToolsListChanged { - s.addSubscription(s.toolsSubscriptions, req.Session, requestID) + s.toolsSubscriptions.remove(ss, id) } if allowed.PromptsListChanged { - s.addSubscription(s.promptsSubscriptions, req.Session, requestID) + s.promptsSubscriptions.remove(ss, id) } if allowed.ResourcesListChanged { - s.addSubscription(s.resourcesSubscriptions, req.Session, requestID) + s.resourcesSubscriptions.remove(ss, id) } +} + +func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { + requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) + if !ok || !requestID.IsValid() { + return nil, fmt.Errorf("%w: subscriptions/listen requires a request ID", jsonrpc2.ErrInvalidRequest) + } + + allowed := s.allowedSubscriptions(req.Params.Notifications) + s.addSubscriptions(allowed, req.Session, requestID) ackParams := &SubscriptionsAcknowledgedParams{ Notifications: allowed, @@ -1122,9 +1154,7 @@ func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsList } defer func() { - s.removeSubscription(s.toolsSubscriptions, req.Session, requestID) - s.removeSubscription(s.promptsSubscriptions, req.Session, requestID) - s.removeSubscription(s.resourcesSubscriptions, req.Session, requestID) + s.removeSubscriptions(allowed, req.Session, requestID) }() <-ctx.Done() From 3ca0f0c75698645722f734f6fe09230c1dd3ce11 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 15 Jun 2026 09:19:06 +0000 Subject: [PATCH 05/18] refactor: move subscription management from Server to ServerSession to simplify notification handling --- mcp/server.go | 188 +++++++++++++++++++++----------------------------- 1 file changed, 80 insertions(+), 108 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index ab7b6878..fc19d5fb 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -35,9 +35,6 @@ import ( // DefaultPageSize is the default for [ServerOptions.PageSize]. const DefaultPageSize = 1000 -// A map from sessions to the set of subscription IDs they have active for a given feature. -type activeSubscriptions map[*ServerSession]map[jsonrpc.ID]bool - // A Server is an instance of an MCP server. // // Servers expose server-side MCP features, which can serve one or more MCP @@ -56,10 +53,7 @@ type Server struct { sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - toolsSubscriptions activeSubscriptions - promptsSubscriptions activeSubscriptions - resourcesSubscriptions activeSubscriptions - pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -210,9 +204,6 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), - toolsSubscriptions: make(activeSubscriptions), - promptsSubscriptions: make(activeSubscriptions), - resourcesSubscriptions: make(activeSubscriptions), pendingNotifications: make(map[string]*time.Timer), } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) @@ -690,6 +681,11 @@ func (s *Server) changeAndNotify(notification string, change func() bool) { // notifySessions sends the notification n to all existing sessions. // It is called asynchronously by changeAndNotify. +// +// Legacy (pre-SEP-2575) sessions receive the notification on the shared +// session channel. Sessions speaking the new protocol receive it only if they +// have an active subscriptions/listen stream that opted in to this +// notification type. func (s *Server) notifySessions(n string) { s.mu.Lock() sessions := slices.Clone(s.sessions) @@ -698,46 +694,24 @@ func (s *Server) notifySessions(n string) { // Only add legacy sessions for the notification, new ones use the new notification mechanism. var legacySessions []*ServerSession - for _, s := range sessions { - if s.InitializeParams().isNil() || s.InitializeParams().ProtocolVersion < protocolVersion20260630 { - legacySessions = append(legacySessions, s) + for _, session := range sessions { + if session.InitializeParams().isNil() || session.InitializeParams().ProtocolVersion < protocolVersion20260630 { + legacySessions = append(legacySessions, session) } } notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) - // Notify sessions that subscribed to the notification with 'subscriptions/listen' method. - type subscription struct { - session *ServerSession - reqID jsonrpc.ID - } - var subs []subscription - - s.mu.Lock() - var activeSubscriptions activeSubscriptions - switch n { - case notificationToolListChanged: - activeSubscriptions = s.toolsSubscriptions - case notificationPromptListChanged: - activeSubscriptions = s.promptsSubscriptions - case notificationResourceListChanged: - activeSubscriptions = s.resourcesSubscriptions - } - for session, reqIDs := range activeSubscriptions { - for reqID := range reqIDs { - subs = append(subs, subscription{session: session, reqID: reqID}) - } - } - s.mu.Unlock() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - for _, sub := range subs { - params := changeNotificationParams[n]() - injectMetaSubscriptionID(params, sub.reqID) - req := newRequest(sub.session, params) - reqCtx := context.WithValue(ctx, idContextKey{}, sub.reqID) - if err := handleNotify(reqCtx, n, req); err != nil { - s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) + for _, session := range sessions { + for _, reqID := range session.subscribedListenIDs(n) { + params := changeNotificationParams[n]() + injectMetaSubscriptionID(params, reqID) + req := newRequest(session, params) + reqCtx := context.WithValue(ctx, idContextKey{}, reqID) + if err := handleNotify(reqCtx, n, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) + } } } } @@ -1088,54 +1062,6 @@ func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emp return &emptyResult{}, nil } -func (m activeSubscriptions) add(ss *ServerSession, id jsonrpc.ID) { - ids := m[ss] - if ids == nil { - ids = make(map[jsonrpc.ID]bool) - m[ss] = ids - } - ids[id] = true -} - -func (m activeSubscriptions) remove(ss *ServerSession, id jsonrpc.ID) { - ids, ok := m[ss] - if !ok { - return - } - delete(ids, id) - if len(ids) == 0 { - delete(m, ss) - } -} - -func (s *Server) addSubscriptions(allowed NotificationSubscriptions, ss *ServerSession, id jsonrpc.ID) { - s.mu.Lock() - defer s.mu.Unlock() - if allowed.ToolsListChanged { - s.toolsSubscriptions.add(ss, id) - } - if allowed.PromptsListChanged { - s.promptsSubscriptions.add(ss, id) - } - if allowed.ResourcesListChanged { - s.resourcesSubscriptions.add(ss, id) - } -} - -func (s *Server) removeSubscriptions(allowed NotificationSubscriptions, ss *ServerSession, id jsonrpc.ID) { - s.mu.Lock() - defer s.mu.Unlock() - if allowed.ToolsListChanged { - s.toolsSubscriptions.remove(ss, id) - } - if allowed.PromptsListChanged { - s.promptsSubscriptions.remove(ss, id) - } - if allowed.ResourcesListChanged { - s.resourcesSubscriptions.remove(ss, id) - } -} - func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsListenRequest) (*emptyResult, error) { requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) if !ok || !requestID.IsValid() { @@ -1143,7 +1069,8 @@ func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsList } allowed := s.allowedSubscriptions(req.Params.Notifications) - s.addSubscriptions(allowed, req.Session, requestID) + req.Session.addSubscription(requestID, allowed) + defer req.Session.removeSubscription(requestID) ackParams := &SubscriptionsAcknowledgedParams{ Notifications: allowed, @@ -1153,10 +1080,6 @@ func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsList return nil, fmt.Errorf("sending subscriptions/acknowledged: %w", err) } - defer func() { - s.removeSubscriptions(allowed, req.Session, requestID) - }() - <-ctx.Done() return &emptyResult{}, nil } @@ -1176,6 +1099,55 @@ func (s *Server) allowedSubscriptions(want NotificationSubscriptions) Notificati return agreed } +// addSubscription registers a subscriptions/listen stream on this session. +func (ss *ServerSession) addSubscription(id jsonrpc.ID, allowed NotificationSubscriptions) { + if !allowed.ToolsListChanged && !allowed.PromptsListChanged && !allowed.ResourcesListChanged { + return + } + ss.mu.Lock() + defer ss.mu.Unlock() + if ss.subscriptions == nil { + ss.subscriptions = make(map[jsonrpc.ID]*listenSubscription) + } + ss.subscriptions[id] = &listenSubscription{ + toolsListChanged: allowed.ToolsListChanged, + promptsListChanged: allowed.PromptsListChanged, + resourcesListChanged: allowed.ResourcesListChanged, + } +} + +func (ss *ServerSession) removeSubscription(id jsonrpc.ID) { + ss.mu.Lock() + defer ss.mu.Unlock() + delete(ss.subscriptions, id) +} + +// subscribedListenIDs returns the listen-request IDs that opted in to the +// given list-changed notification. +func (ss *ServerSession) subscribedListenIDs(notification string) []jsonrpc.ID { + ss.mu.Lock() + defer ss.mu.Unlock() + if len(ss.subscriptions) == 0 { + return nil + } + var ids []jsonrpc.ID + for id, sub := range ss.subscriptions { + var match bool + switch notification { + case notificationToolListChanged: + match = sub.toolsListChanged + case notificationPromptListChanged: + match = sub.promptsListChanged + case notificationResourceListChanged: + match = sub.resourcesListChanged + } + if match { + ids = append(ids, id) + } + } + return ids +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection or the provided @@ -1245,9 +1217,6 @@ func (s *Server) disconnect(cc *ServerSession) { for _, subscribedSessions := range s.resourceSubscriptions { delete(subscribedSessions, cc) } - delete(s.toolsSubscriptions, cc) - delete(s.promptsSubscriptions, cc) - delete(s.resourcesSubscriptions, cc) s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) } @@ -1385,6 +1354,17 @@ type ServerSession struct { mu sync.Mutex state ServerSessionState + // subscriptions tracks the SEP-2575 subscriptions/listen streams opened by + // this session, keyed by the originating listen request ID. + subscriptions map[jsonrpc.ID]*listenSubscription +} + +// listenSubscription records the notification types a single +// subscriptions/listen stream has opted in to. +type listenSubscription struct { + toolsListChanged bool + promptsListChanged bool + resourcesListChanged bool } func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { @@ -1856,14 +1836,6 @@ func (ss *ServerSession) Close() error { ss.onClose() } - // Clean up session subscriptions - server := ss.server - server.mu.Lock() - delete(server.toolsSubscriptions, ss) - delete(server.promptsSubscriptions, ss) - delete(server.resourcesSubscriptions, ss) - server.mu.Unlock() - return err } From c729ae9118480831e01e337f58e9cfffd98810a1 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 15 Jun 2026 09:41:54 +0000 Subject: [PATCH 06/18] refactor: remove redundant SSE priming and simplify server session cancellation signature and request matching logic --- mcp/server.go | 2 +- mcp/streamable.go | 13 +------------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index fc19d5fb..535e0a6e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1803,7 +1803,7 @@ func (ss *ServerSession) ping(context.Context, *PingParams) (*emptyResult, error // It should never be invoked in practice because cancellation is preempted, // but having its signature here facilitates the construction of methodInfo // that can be used to validate incoming cancellation notifications. -func (ss *ServerSession) cancel(_ context.Context, _ *CancelledParams) (Result, error) { +func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, error) { return nil, nil } diff --git a/mcp/streamable.go b/mcp/streamable.go index 2c851dd3..c01e1eb2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1582,17 +1582,6 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques c.logger.Warn(fmt.Sprintf("Writing priming event: %v", err)) } } - // For subscriptions/listen, flush headers eagerly so the client sees - // the open SSE stream even when an HTTP/2-aware reverse proxy is - // buffering the HEADERS frame waiting for a DATA frame to coalesce - // with. The acknowledgment notification will follow shortly, but the - // client can begin reading event-stream framing immediately. See the - // equivalent comment in [streamableServerConn.acquireStream]. - if isSubscriptionsListen { - w.WriteHeader(http.StatusOK) - fmt.Fprint(w, ": ok\n\n") - _ = http.NewResponseController(w).Flush() - } } // TODO(rfindley): if we have no event store, we should really cancel all @@ -2483,7 +2472,7 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. // The subscriptions/listen is now returning a response in the SSE stream. - if jsonResp.ID == forCall.ID && forCall.Method != methodSubscriptionsListen { + if jsonResp.ID == forCall.ID { return "", 0, true } } From c9557ad958c901aa22afe8eccaacbad05f5fd132 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 15 Jun 2026 09:42:57 +0000 Subject: [PATCH 07/18] refactor: remove stale comment regarding SSE stream response handling --- mcp/streamable.go | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index c01e1eb2..f4d81c15 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -2471,7 +2471,6 @@ func (c *streamableClientConn) processStream(ctx context.Context, requestSummary if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil { // TODO: we should never get a response when forReq is nil (the standalone SSE request). // We should detect this case. - // The subscriptions/listen is now returning a response in the SSE stream. if jsonResp.ID == forCall.ID { return "", 0, true } From 5d5b8c8593c10206f8b7fdb3f51d37b00742b181 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 18 Jun 2026 07:13:55 +0000 Subject: [PATCH 08/18] feat: implement stream identification for subscriptions/listen and simplify notification handling --- mcp/server.go | 15 +++++++-------- mcp/streamable.go | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 0b866a4b..8a9bc35b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -694,22 +694,21 @@ func (s *Server) notifySessions(n string) { // Only add legacy sessions for the notification, new ones use the new notification mechanism. var legacySessions []*ServerSession - for _, session := range sessions { - if session.InitializeParams().isNil() || session.InitializeParams().ProtocolVersion < protocolVersion20260630 { - legacySessions = append(legacySessions, session) + for _, sess := range sessions { + if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260630 { + legacySessions = append(legacySessions, sess) } } notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - for _, session := range sessions { - for _, reqID := range session.subscribedListenIDs(n) { + for _, sess := range sessions { + for _, reqID := range sess.subscribedListenIDs(n) { params := changeNotificationParams[n]() injectMetaSubscriptionID(params, reqID) - req := newRequest(session, params) - reqCtx := context.WithValue(ctx, idContextKey{}, reqID) - if err := handleNotify(reqCtx, n, req); err != nil { + req := newRequest(sess, params) + if err := handleNotify(ctx, n, req); err != nil { s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) } } diff --git a/mcp/streamable.go b/mcp/streamable.go index 02c9a4fd..a589fd4d 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -908,6 +908,12 @@ type stream struct { // the spec and earlier there was a concept of batching, in which POST // payloads could hold multiple requests or responses. requests map[jsonrpc.ID]struct{} + + // isListen reports whether this stream was opened by a + // subscriptions/listen request. Listen streams are always SSE, live for + // the duration of the subscription, and act as the target for + // out-of-band notifications routed through this connection. + isListen bool } // close sends a 'close' event to the client (if protocolVersion >= 2025-11-25 @@ -1539,6 +1545,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques http.Error(w, fmt.Sprintf("storing stream: %v", err), http.StatusInternalServerError) return } + stream.isListen = isSubscriptionsListen // subscriptions/listen is inherently a long-lived SSE endpoint (SEP-2575): // it has no synchronous result, the response stream stays open until the @@ -1718,6 +1725,15 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e s = c.streams[streamID] } } else { + // In stateless mode there will always be only one stream per connection. + for _, stream := range c.streams { + if stream.isListen { + s = stream + break + } + } + } + if s == nil { s = c.streams[""] // standalone SSE stream } if responseTo.IsValid() { From 6eb46c7aec1dd1e32cfe3adff1d3542f28d2e504 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 18 Jun 2026 10:22:18 +0000 Subject: [PATCH 09/18] refactor: partition sessions by protocol version to restrict new notifications to modern sessions --- mcp/server.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mcp/server.go b/mcp/server.go index 8a9bc35b..e8a9bbc1 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -694,16 +694,19 @@ func (s *Server) notifySessions(n string) { // Only add legacy sessions for the notification, new ones use the new notification mechanism. var legacySessions []*ServerSession + var newSessions []*ServerSession for _, sess := range sessions { if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260630 { legacySessions = append(legacySessions, sess) + } else { + newSessions = append(newSessions, sess) } } notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - for _, sess := range sessions { + for _, sess := range newSessions { for _, reqID := range sess.subscribedListenIDs(n) { params := changeNotificationParams[n]() injectMetaSubscriptionID(params, reqID) From ccb0a6e02261a904b19a770dcf0309d3af426934 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 18 Jun 2026 17:04:35 +0000 Subject: [PATCH 10/18] feat: implement SEP-2575 resource subscription management in Client and update Server state tracking --- mcp/client.go | 80 +++++++++++++++++-- mcp/mcp_test.go | 107 +++++++++++++++++++++++++ mcp/server.go | 208 +++++++++++++++++++++++++++--------------------- 3 files changed, 299 insertions(+), 96 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index fff96859..bae40b1d 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -455,6 +455,15 @@ type ClientSession struct { // It is used to look up x-mcp-header annotations when // constructing Mcp-Param-* headers for tools/call requests. toolCache map[string]*Tool + + // resourceSubsMu guards resourceSubs. + resourceSubsMu sync.Mutex + // resourceSubs maps a subscribed resource URI to the cancel func of the + // goroutine running its dedicated subscriptions/listen stream. Populated + // only under SEP-2575; the legacy protocol routes Subscribe and + // Unsubscribe straight to the resources/subscribe and resources/unsubscribe + // RPCs and leaves this map untouched. + resourceSubs map[string]context.CancelFunc } type clientSessionState struct { @@ -524,6 +533,7 @@ func (cs *ClientSession) Close() error { if cs.listenCancel != nil { cs.listenCancel() } + cs.cancelAllResourceSubscriptions() err := cs.conn.Close() if cs.onClose != nil && cs.calledOnClose.CompareAndSwap(false, true) { @@ -1250,15 +1260,73 @@ func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) ( // Subscribe sends a "resources/subscribe" request to the server, asking for // notifications when the specified resource changes. func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { - _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) - return err + if !cs.usesNewProtocol() { + _, err := handleSend[*emptyResult](ctx, methodSubscribe, newClientRequest(cs, orZero[Params](params))) + return err + } + if params == nil || params.URI == "" { + return fmt.Errorf("Subscribe: missing URI") + } + uri := params.URI + + cs.resourceSubsMu.Lock() + if _, exists := cs.resourceSubs[uri]; exists { + cs.resourceSubsMu.Unlock() + return nil + } + listenCtx, cancel := context.WithCancel(context.Background()) + if cs.resourceSubs == nil { + cs.resourceSubs = make(map[string]context.CancelFunc) + } + cs.resourceSubs[uri] = cancel + cs.resourceSubsMu.Unlock() + + go func() { + _ = cs.subscriptionsListen(listenCtx, &SubscriptionsListenParams{ + Notifications: NotificationSubscriptions{ + ResourceSubscriptions: []string{uri}, + }, + }) + }() + return nil } -// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling -// a previous subscription. +// Unsubscribe cancels a previous [ClientSession.Subscribe] for params.URI. +// +// Under the legacy protocol it sends a "resources/unsubscribe" request. +// +// Under SEP-2575 it cancels the background "subscriptions/listen" stream +// opened by Subscribe for the URI. Unsubscribe is idempotent: calling it for +// a URI that is not currently subscribed is a no-op. func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { - _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) - return err + if !cs.usesNewProtocol() { + _, err := handleSend[*emptyResult](ctx, methodUnsubscribe, newClientRequest(cs, orZero[Params](params))) + return err + } + if params == nil || params.URI == "" { + return fmt.Errorf("Unsubscribe: missing URI") + } + cs.resourceSubsMu.Lock() + cancel, ok := cs.resourceSubs[params.URI] + delete(cs.resourceSubs, params.URI) + cs.resourceSubsMu.Unlock() + if ok { + cancel() + } + return nil +} + +// cancelAllResourceSubscriptions cancels every active SEP-2575 resource +// subscription opened via Subscribe. The listen goroutines exit +// asynchronously as their contexts unwind. Called from Close. +func (cs *ClientSession) cancelAllResourceSubscriptions() { + cs.resourceSubsMu.Lock() + subs := cs.resourceSubs + cs.resourceSubs = nil + cs.resourceSubsMu.Unlock() + for _, cancel := range subs { + cancel() + } } // SubscriptionsListen opens a SEP-2575 "subscriptions/listen" stream and diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 02f05618..4c89569e 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -2650,3 +2650,110 @@ func TestSubscriptionsListen_NoHandlersNoListen(t *testing.T) { case <-time.After(notificationDelay * 10): } } + +// resourceSubServer builds a server that advertises resource subscriptions +// and records every Subscribe/Unsubscribe handler invocation through chans. +func resourceSubServer(t *testing.T, subCh, unsubCh chan string) *Server { + t.Helper() + s := NewServer(testImpl, &ServerOptions{ + SubscribeHandler: func(_ context.Context, r *SubscribeRequest) error { + subCh <- r.Params.URI + return nil + }, + UnsubscribeHandler: func(_ context.Context, r *UnsubscribeRequest) error { + unsubCh <- r.Params.URI + return nil + }, + }) + s.AddResource(&Resource{Name: "r1", URI: "file:///r1"}, nil) + return s +} + +// resourceSubEvent is one delivered notifications/resources/updated. +type resourceSubEvent struct { + uri string + id string // _meta subscription ID, stringified +} + +// TestResourceSubscriptionsSEP2575_Streamable verifies the Subscribe -> +// ResourceUpdated path on a stateless Streamable HTTP server. +// +// Caveat: per-subscription Unsubscribe is intentionally NOT verified here. +// In stateless Streamable HTTP mode the subscriptions/listen handler blocks +// on its request context, and neither the HTTP POST disconnect nor the +// separate notifications/cancelled POST currently propagates to that +// handler's context. The handler only unwinds when the server next attempts +// a write to the (now-dead) SSE stream and the writeErr branch in the +// jsonrpc2 layer cancels the in-flight request. To keep the test +// hermetic we therefore trigger a write at the end by adding a resource, +// which fires notifications/resources/list_changed on the auto-listen path +// (if any) and on the per-URI listen, causing the listen handler to unwind. +// The spec-correct fix is to plumb the POST's request context down to the +// subscriptionsListen handler so HTTP disconnect is observed directly; this +// is tracked separately. +func TestResourceSubscriptions_Streamable(t *testing.T) { + enableNewProtocol(t) + + subCh := make(chan string, 8) + unsubCh := make(chan string, 8) + events := make(chan resourceSubEvent, 16) + + server := resourceSubServer(t, subCh, unsubCh) + handler := NewStreamableHTTPHandler( + func(*http.Request) *Server { return server }, + &StreamableHTTPOptions{Stateless: true}, + ) + httpServer := httptest.NewServer(mustNotPanic(t, handler)) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + c := NewClient(testImpl, &ClientOptions{ + ResourceUpdatedHandler: func(_ context.Context, req *ResourceUpdatedNotificationRequest) { + id := "" + if req.Params != nil && req.Params.Meta != nil { + id = fmt.Sprint(req.Params.Meta[MetaKeySubscriptionID]) + } + events <- resourceSubEvent{uri: req.Params.URI, id: id} + }, + }) + cs, err := c.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, + &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("subscribe r1: %v", err) + } + select { + case got := <-subCh: + if got != "file:///r1" { + t.Fatalf("got URI %q, want %q", got, "file:///r1") + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for SubscribeHandler") + } + + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + select { + case e := <-events: + if e.uri != "file:///r1" { + t.Fatalf("got URI %q, want %q", e.uri, "file:///r1") + } + if e.id == "" || e.id == "" { + t.Fatalf("missing subscription ID on update (got %q)", e.id) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for resource update") + } + + // See test header comment for the explanation of this teardown ritual: + // close the client, then drop server-side TCP, then drive a write that + // will fail (any extra ResourceUpdated for our URI), to unblock the + // in-flight listen handler so httpServer.Close can return. + _ = cs.Close() + httpServer.CloseClientConnections() + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + httpServer.Close() +} diff --git a/mcp/server.go b/mcp/server.go index e8a9bbc1..338eedf7 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -44,16 +44,19 @@ type Server struct { impl *Implementation opts ServerOptions - mu sync.Mutex - prompts *featureSet[*serverPrompt] - tools *featureSet[*serverTool] - resources *featureSet[*serverResource] - resourceTemplates *featureSet[*serverResourceTemplate] - sessions []*ServerSession - sendingMethodHandler_ MethodHandler - receivingMethodHandler_ MethodHandler - resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool - pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send + mu sync.Mutex + prompts *featureSet[*serverPrompt] + tools *featureSet[*serverTool] + resources *featureSet[*serverResource] + resourceTemplates *featureSet[*serverResourceTemplate] + sessions []*ServerSession + sendingMethodHandler_ MethodHandler + receivingMethodHandler_ MethodHandler + toolChangeSubscriptions map[*ServerSession]jsonrpc.ID // session -> requestID for "tools/changed" + promptChangeSubscriptions map[*ServerSession]jsonrpc.ID // session -> requestID for "prompts/changed" + resourceChangeSubscriptions map[*ServerSession]jsonrpc.ID // session -> requestID for "resources/changed" + resourceSubscriptions map[string]map[*ServerSession]jsonrpc.ID // uri -> session -> requestID + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -195,16 +198,19 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { } s := &Server{ - impl: impl, - opts: opts, - prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), - tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), - resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), - resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), - sendingMethodHandler_: defaultSendingMethodHandler, - receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], - resourceSubscriptions: make(map[string]map[*ServerSession]bool), - pendingNotifications: make(map[string]*time.Timer), + impl: impl, + opts: opts, + prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), + tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), + resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), + resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), + sendingMethodHandler_: defaultSendingMethodHandler, + receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + toolChangeSubscriptions: make(map[*ServerSession]jsonrpc.ID), + promptChangeSubscriptions: make(map[*ServerSession]jsonrpc.ID), + resourceChangeSubscriptions: make(map[*ServerSession]jsonrpc.ID), + resourceSubscriptions: make(map[string]map[*ServerSession]jsonrpc.ID), + pendingNotifications: make(map[string]*time.Timer), } s.AddReceivingMiddleware(serverMultiRoundTripMiddleware()) return s @@ -690,30 +696,36 @@ func (s *Server) notifySessions(n string) { s.mu.Lock() sessions := slices.Clone(s.sessions) s.pendingNotifications[n] = nil + var subscribers map[*ServerSession]jsonrpc.ID + switch n { + case notificationToolListChanged: + subscribers = maps.Clone(s.toolChangeSubscriptions) + case notificationPromptListChanged: + subscribers = maps.Clone(s.promptChangeSubscriptions) + case notificationResourceListChanged: + subscribers = maps.Clone(s.resourceChangeSubscriptions) + } s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. - // Only add legacy sessions for the notification, new ones use the new notification mechanism. + // Legacy sessions receive list-changed notifications on the shared session channel without opt-in. var legacySessions []*ServerSession - var newSessions []*ServerSession for _, sess := range sessions { if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260630 { legacySessions = append(legacySessions, sess) - } else { - newSessions = append(newSessions, sess) } } notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) + // Sessions receive the notification only if they opened a + // subscriptions/listen stream that opted in to this notification type. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - for _, sess := range newSessions { - for _, reqID := range sess.subscribedListenIDs(n) { - params := changeNotificationParams[n]() - injectMetaSubscriptionID(params, reqID) - req := newRequest(sess, params) - if err := handleNotify(ctx, n, req); err != nil { - s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) - } + for sess, reqID := range subscribers { + params := changeNotificationParams[n]() + injectMetaSubscriptionID(params, reqID) + req := newRequest(sess, params) + if err := handleNotify(ctx, n, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", n, err)) } } } @@ -1018,12 +1030,38 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot subscribedSessions := s.resourceSubscriptions[params.URI] sessions := slices.Collect(maps.Keys(subscribedSessions)) s.mu.Unlock() - notifySessions(sessions, notificationResourceUpdated, params, s.opts.Logger) + // Only add legacy sessions for the notification, new ones use the new notification mechanism. + var legacySessions []*ServerSession + var newSessions []*ServerSession + for _, sess := range sessions { + if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260630 { + legacySessions = append(legacySessions, sess) + } else { + newSessions = append(newSessions, sess) + } + } + notifySessions(legacySessions, notificationResourceUpdated, params, s.opts.Logger) s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for _, sess := range newSessions { + reqID := subscribedSessions[sess] + p := *params + injectMetaSubscriptionID(&p, reqID) + req := newRequest(sess, &p) + if err := handleNotify(ctx, notificationResourceUpdated, req); err != nil { + s.opts.Logger.Warn(fmt.Sprintf("calling %s: %v", notificationResourceUpdated, err)) + } + } return nil } func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyResult, error) { + requestID, ok := ctx.Value(idContextKey{}).(jsonrpc.ID) + if !ok || !requestID.IsValid() { + return nil, fmt.Errorf("%w: subscribe requires a request ID", jsonrpc2.ErrInvalidRequest) + } if s.opts.SubscribeHandler == nil { return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) } @@ -1034,10 +1072,10 @@ func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyRe s.mu.Lock() defer s.mu.Unlock() if s.resourceSubscriptions[req.Params.URI] == nil { - s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]bool) + s.resourceSubscriptions[req.Params.URI] = make(map[*ServerSession]jsonrpc.ID) } - s.resourceSubscriptions[req.Params.URI][req.Session] = true - s.opts.Logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID()) + s.resourceSubscriptions[req.Params.URI][req.Session] = requestID + s.opts.Logger.Info("resource subscribed", "uri", req.Params.URI, "session_id", req.Session.ID(), "request_id", requestID) return &emptyResult{}, nil } @@ -1071,8 +1109,44 @@ func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsList } allowed := s.allowedSubscriptions(req.Params.Notifications) - req.Session.addSubscription(requestID, allowed) - defer req.Session.removeSubscription(requestID) + s.mu.Lock() + if allowed.ToolsListChanged { + s.toolChangeSubscriptions[req.Session] = requestID + } + if allowed.PromptsListChanged { + s.promptChangeSubscriptions[req.Session] = requestID + } + if allowed.ResourcesListChanged { + s.resourceChangeSubscriptions[req.Session] = requestID + } + s.mu.Unlock() + defer func() { + s.mu.Lock() + delete(s.toolChangeSubscriptions, req.Session) + delete(s.promptChangeSubscriptions, req.Session) + delete(s.resourceChangeSubscriptions, req.Session) + s.mu.Unlock() + }() + + for _, uri := range allowed.ResourceSubscriptions { + _, err := s.subscribe(ctx, &SubscribeRequest{ + Session: req.Session, + Params: &SubscribeParams{ + URI: uri, + Meta: req.Params.GetMeta(), + }, + }) + if err != nil { + return nil, err + } + defer s.unsubscribe(ctx, &UnsubscribeRequest{ + Session: req.Session, + Params: &UnsubscribeParams{ + URI: uri, + Meta: req.Params.GetMeta(), + }, + }) + } ackParams := &SubscriptionsAcknowledgedParams{ Notifications: allowed, @@ -1098,56 +1172,10 @@ func (s *Server) allowedSubscriptions(want NotificationSubscriptions) Notificati if want.ResourcesListChanged && caps.Resources != nil && caps.Resources.ListChanged { agreed.ResourcesListChanged = true } - return agreed -} - -// addSubscription registers a subscriptions/listen stream on this session. -func (ss *ServerSession) addSubscription(id jsonrpc.ID, allowed NotificationSubscriptions) { - if !allowed.ToolsListChanged && !allowed.PromptsListChanged && !allowed.ResourcesListChanged { - return - } - ss.mu.Lock() - defer ss.mu.Unlock() - if ss.subscriptions == nil { - ss.subscriptions = make(map[jsonrpc.ID]*listenSubscription) + if len(want.ResourceSubscriptions) > 0 && caps.Resources != nil && caps.Resources.Subscribe { + agreed.ResourceSubscriptions = slices.Clone(want.ResourceSubscriptions) } - ss.subscriptions[id] = &listenSubscription{ - toolsListChanged: allowed.ToolsListChanged, - promptsListChanged: allowed.PromptsListChanged, - resourcesListChanged: allowed.ResourcesListChanged, - } -} - -func (ss *ServerSession) removeSubscription(id jsonrpc.ID) { - ss.mu.Lock() - defer ss.mu.Unlock() - delete(ss.subscriptions, id) -} - -// subscribedListenIDs returns the listen-request IDs that opted in to the -// given list-changed notification. -func (ss *ServerSession) subscribedListenIDs(notification string) []jsonrpc.ID { - ss.mu.Lock() - defer ss.mu.Unlock() - if len(ss.subscriptions) == 0 { - return nil - } - var ids []jsonrpc.ID - for id, sub := range ss.subscriptions { - var match bool - switch notification { - case notificationToolListChanged: - match = sub.toolsListChanged - case notificationPromptListChanged: - match = sub.promptsListChanged - case notificationResourceListChanged: - match = sub.resourcesListChanged - } - if match { - ids = append(ids, id) - } - } - return ids + return agreed } // Run runs the server over the given transport, which must be persistent. @@ -1219,6 +1247,9 @@ func (s *Server) disconnect(cc *ServerSession) { for _, subscribedSessions := range s.resourceSubscriptions { delete(subscribedSessions, cc) } + delete(s.toolChangeSubscriptions, cc) + delete(s.promptChangeSubscriptions, cc) + delete(s.resourceChangeSubscriptions, cc) s.opts.Logger.Info("server session disconnected", "session_id", cc.ID()) } @@ -1356,9 +1387,6 @@ type ServerSession struct { mu sync.Mutex state ServerSessionState - // subscriptions tracks the SEP-2575 subscriptions/listen streams opened by - // this session, keyed by the originating listen request ID. - subscriptions map[jsonrpc.ID]*listenSubscription } // listenSubscription records the notification types a single From 894078e02e6b0aadd45f78ce97a424af37351d25 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 19 Jun 2026 15:18:51 +0000 Subject: [PATCH 11/18] feat: implement subscriptions/listen stream support for MCP protocol handling --- internal/jsonrpc2/conn.go | 6 +++++- mcp/client.go | 29 +++++++++++------------------ mcp/shared.go | 8 ++++++-- mcp/streamable.go | 26 +++++++++++++++++--------- mcp/streamable_test.go | 2 +- mcp/transport.go | 27 +++++++++++++++++++++++++++ 6 files changed, 67 insertions(+), 31 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 3b4bc57a..f7ea16bd 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -210,10 +210,14 @@ type ConnectionConfig struct { OnInternalError func(error) // optional } +type IsSubscriptionsListenContextKey struct{} + // NewConnection creates a new [Connection] object and starts processing // incoming messages. func NewConnection(ctx context.Context, cfg ConnectionConfig) *Connection { - ctx = notDone{ctx} + if val, _ := ctx.Value(IsSubscriptionsListenContextKey{}).(bool); !val { + ctx = notDone{ctx} + } c := &Connection{ state: inFlightState{closer: cfg.Closer}, diff --git a/mcp/client.go b/mcp/client.go index bae40b1d..c1b88cd2 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -313,14 +313,10 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp if subscribeParams.Notifications.ToolsListChanged || subscribeParams.Notifications.PromptsListChanged || subscribeParams.Notifications.ResourcesListChanged { - // The listen blocks until the server cancels it. Run it in - // a goroutine so Connect can return; ClientSession.Close - // cancels its context to send notifications/cancelled. + // ClientSession.Close cancels the listenCtx context to send notifications/cancelled. listenCtx, cancelListen := context.WithCancel(context.Background()) cs.listenCancel = cancelListen - go func() { - _ = cs.subscriptionsListen(listenCtx, subscribeParams) - }() + cs.subscriptionsListen(listenCtx, subscribeParams) } return cs, nil } @@ -1281,14 +1277,11 @@ func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) cs.resourceSubs[uri] = cancel cs.resourceSubsMu.Unlock() - go func() { - _ = cs.subscriptionsListen(listenCtx, &SubscriptionsListenParams{ - Notifications: NotificationSubscriptions{ - ResourceSubscriptions: []string{uri}, - }, - }) - }() - return nil + return cs.subscriptionsListen(listenCtx, &SubscriptionsListenParams{ + Notifications: NotificationSubscriptions{ + ResourceSubscriptions: []string{uri}, + }, + }) } // Unsubscribe cancels a previous [ClientSession.Subscribe] for params.URI. @@ -1329,10 +1322,10 @@ func (cs *ClientSession) cancelAllResourceSubscriptions() { } } -// SubscriptionsListen opens a SEP-2575 "subscriptions/listen" stream and -// blocks for the lifetime of the subscription. The server's first message on -// the stream is "notifications/subscriptions/acknowledged"; subsequent -// opted-in notifications (e.g. tools/list_changed) are delivered through the +// SubscriptionsListen opens a SEP-2575 "subscriptions/listen" stream. +// +// The server's first message on the stream is "notifications/subscriptions/acknowledged"; +// subsequent opted-in notifications (e.g. tools/list_changed) are delivered through the // usual handlers registered in [ClientOptions]. func (cs *ClientSession) subscriptionsListen(ctx context.Context, params *SubscriptionsListenParams) error { params = injectRequestMeta(cs, params) diff --git a/mcp/shared.go b/mcp/shared.go index ae1ee3ca..67168491 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -123,8 +123,12 @@ func defaultSendingMethodHandler(ctx context.Context, method string, req Request // Create the result to unmarshal into. // The concrete type of the result is the return type of the receiving function. res := info.newResult() - if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { - return nil, err + if method == methodSubscriptionsListen { + callSubscriptionsListen(ctx, req.GetSession().getConn(), method, params) + } else { + if err := call(ctx, req.GetSession().getConn(), method, params, res); err != nil { + return nil, err + } } return res, nil } diff --git a/mcp/streamable.go b/mcp/streamable.go index a589fd4d..087b3c0c 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -360,7 +360,7 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. return } - connectOpts, usesNewProtocol, err := h.ephemeralConnectOpts(req) + connectOpts, usesNewProtocol, hasSubscriptionsListen, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -382,7 +382,11 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. logger: h.opts.Logger, } - session, err := connectStreamable(req.Context(), server, transport, connectOpts) + ctx := req.Context() + if hasSubscriptionsListen && usesNewProtocol { + ctx = context.WithValue(ctx, jsonrpc2.IsSubscriptionsListenContextKey{}, true) + } + session, err := connectStreamable(ctx, server, transport, connectOpts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) http.Error(w, "failed connection", http.StatusInternalServerError) @@ -417,7 +421,7 @@ func (h *StreamableHTTPHandler) serveStatelessLegacyDELETE(w http.ResponseWriter // // The returned usesNewProtocol bool reports whether the protocol version // header indicates a protocol version >= 2026-06-30 (SEP-2575). -func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *ServerSessionOptions, usesNewProtocol bool, err error) { +func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *ServerSessionOptions, usesNewProtocol bool, isSubscriptionsListen bool, err error) { protocolVersion := protocolVersionFromContext(req.Context()) if protocolVersion == "" { protocolVersion = protocolVersion20250326 @@ -426,7 +430,7 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *S var hasInitialize, hasInitialized bool body, err := io.ReadAll(req.Body) if err != nil { - return nil, false, fmt.Errorf("failed to read body") + return nil, false, false, fmt.Errorf("failed to read body") } req.Body.Close() req.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -439,6 +443,8 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *S hasInitialize = true case notificationInitialized: hasInitialized = true + case methodSubscriptionsListen: + isSubscriptionsListen = true } if protocolVersion >= protocolVersion20260630 { usesNewProtocol = true @@ -463,7 +469,7 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *S } return &ServerSessionOptions{ State: state, - }, usesNewProtocol, nil + }, usesNewProtocol, isSubscriptionsListen, nil } func connectStreamable(ctx context.Context, server *Server, transport *StreamableServerTransport, opts *ServerSessionOptions) (*ServerSession, error) { @@ -608,7 +614,7 @@ func (h *StreamableHTTPHandler) serveStatefulPOST(w http.ResponseWriter, req *ht // that arrives before a session exists (e.g. initialize or ping) on a // server configured this way. if sessionID == "" { - connectOpts, _, err := h.ephemeralConnectOpts(req) + connectOpts, _, _, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -1726,15 +1732,17 @@ func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) e } } else { // In stateless mode there will always be only one stream per connection. + // If that stream was open to listen for subscription notifications, + // automatically select as the one to write the notification to. for _, stream := range c.streams { if stream.isListen { s = stream break } } - } - if s == nil { - s = c.streams[""] // standalone SSE stream + if s == nil { + s = c.streams[""] // standalone SSE stream + } } if responseTo.IsValid() { // Once we've responded to a request, disallow related messages by removing diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 2efee20e..5bf15de7 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3305,7 +3305,7 @@ func TestEphemeralConnectOpts(t *testing.T) { } req.Header.Set(protocolVersionHeader, pver) req = req.WithContext(context.WithValue(req.Context(), protocolVersionContextKey{}, pver)) - opts, usesNew, err := h.ephemeralConnectOpts(req) + opts, usesNew, _, err := h.ephemeralConnectOpts(req) if err != nil { t.Fatal(err) } diff --git a/mcp/transport.go b/mcp/transport.go index f67c43d0..1469858e 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -230,6 +230,33 @@ func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result a return nil, jsonrpc2.ErrNotHandled } +func callSubscriptionsListen(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params) { + // The "%w"s in this function expose jsonrpc.Error as part of the API. + call := conn.Call(ctx, method, params) + + go func() { + <-ctx.Done() + notifyCtx, cancelNotify := context.WithTimeout(context.WithoutCancel(ctx), notifyCancellationTimeout) + defer cancelNotify() + _ = conn.Notify(notifyCtx, notificationCancelled, &CancelledParams{ + Reason: ctx.Err().Error(), + RequestID: call.ID().Raw(), + }) + // By default, the jsonrpc2 library waits for graceful shutdown when the + // connection is closed, meaning it expects all outgoing and incoming + // requests to complete. However, for MCP this expectation is unrealistic, + // and can lead to hanging shutdown. For example, if a streamable client is + // killed, the server will not be able to detect this event, except via + // keepalive pings (if they are configured), and so outgoing calls may hang + // indefinitely. + // + // Therefore, we choose to eagerly retire calls, removing them from the + // outgoingCalls map, when the caller context is cancelled: if the caller + // will never receive the response, there's no need to track it. + conn.Retire(call, ctx.Err()) + }() +} + // call executes and awaits a jsonrpc2 call on the given connection, // translating errors into the mcp domain. func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params, result Result) error { From c11e27e5e0d720d21d1084d5730f5588a7a2f1e7 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Fri, 19 Jun 2026 15:34:33 +0000 Subject: [PATCH 12/18] refactor: remove unused toolCache field from MCP client struct --- mcp/client.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 6a799a63..4148bfae 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -447,13 +447,6 @@ type ClientSession struct { pendingElicitationsMu sync.Mutex pendingElicitations map[string]chan struct{} - // toolCacheMu guards toolCache. - toolCacheMu sync.RWMutex - // toolCache stores tool definitions keyed by name. - // It is used to look up x-mcp-header annotations when - // constructing Mcp-Param-* headers for tools/call requests. - toolCache map[string]*Tool - // resourceSubsMu guards resourceSubs. resourceSubsMu sync.Mutex // resourceSubs maps a subscribed resource URI to the cancel func of the From d380bad608c31788dd0a708a25f47502c7d8c6f2 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 21 Jun 2026 10:11:00 +0000 Subject: [PATCH 13/18] fix: allow outgoing cancellation notifications during shutdown and ensure proper request cleanup on connection loss. --- internal/jsonrpc2/conn.go | 17 +- mcp/mcp_test.go | 482 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 496 insertions(+), 3 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index e0525b21..1a91af21 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -534,6 +534,14 @@ func (c *Connection) readIncoming(ctx context.Context, reader Reader, preempter ac.retire(&Response{ID: id, Error: err}) } s.outgoingCalls = nil + + // Cancel any incoming requests still in flight: with the reader gone we + // cannot receive cancellation notifications, and likely cannot write a + // response either, so parked handlers have nothing useful left to do. + // Mirrors the equivalent cleanup on write failure. + for _, r := range s.incomingByID { + r.cancel() + } }) } @@ -728,10 +736,13 @@ func (c *Connection) write(ctx context.Context, msg Message) error { var err error // Fail writes immediately if the connection is shutting down. // - // TODO(rfindley): should we allow cancellation notifications through? It - // could be the case that writes can still succeed. + // Allow outgoing "notifications/cancelled" to allow the remote end + // to cancel in-flight calls. c.updateInFlight(func(s *inFlightState) { - err = s.shuttingDown(ErrServerClosing) + req, ok := msg.(*Request) + if !ok || req.Method != "notifications/cancelled" { + err = s.shuttingDown(ErrServerClosing) + } }) if err == nil { err = c.writer.Write(ctx, msg) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 4c89569e..f9aabcd3 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -2757,3 +2757,485 @@ func TestResourceSubscriptions_Streamable(t *testing.T) { server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) httpServer.Close() } + +// TestResourceSubscriptions_InMemory mirrors TestResourceSubscriptions_Streamable +// over an in-memory (stdio-equivalent) transport: the per-URI Subscribe path +// uses notifications/cancelled rather than HTTP disconnect for teardown. +func TestResourceSubscriptions_InMemory(t *testing.T) { + enableNewProtocol(t) + + subCh := make(chan string, 8) + unsubCh := make(chan string, 8) + events := make(chan resourceSubEvent, 16) + + server := resourceSubServer(t, subCh, unsubCh) + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + c := NewClient(testImpl, &ClientOptions{ + ResourceUpdatedHandler: func(_ context.Context, req *ResourceUpdatedNotificationRequest) { + id := "" + if req.Params != nil && req.Params.Meta != nil { + id = fmt.Sprint(req.Params.Meta[MetaKeySubscriptionID]) + } + events <- resourceSubEvent{uri: req.Params.URI, id: id} + }, + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("subscribe r1: %v", err) + } + waitURI := func(ch chan string, want string) { + t.Helper() + select { + case got := <-ch: + if got != want { + t.Fatalf("got URI %q, want %q", got, want) + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for URI %q", want) + } + } + waitURI(subCh, "file:///r1") + + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + select { + case e := <-events: + if e.uri != "file:///r1" { + t.Fatalf("got URI %q, want %q", e.uri, "file:///r1") + } + if e.id == "" || e.id == "" { + t.Fatalf("missing subscription ID on update (got %q)", e.id) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for resource update") + } + + // On stdio/in-memory, Unsubscribe sends notifications/cancelled which + // reaches the server preempter, cancels the listen handler ctx, and + // triggers the UnsubscribeHandler. + if err := cs.Unsubscribe(ctx, &UnsubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("unsubscribe r1: %v", err) + } + waitURI(unsubCh, "file:///r1") + + // Updates for an unsubscribed URI MUST NOT be delivered. Give the server + // a moment to process the cancellation first. + time.Sleep(50 * time.Millisecond) + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + select { + case e := <-events: + t.Fatalf("unexpected resource update after unsubscribe: %q (id=%s)", e.uri, e.id) + case <-time.After(notificationDelay * 10): + } +} + +// TestResourceSubscriptions_Subscribe_Idempotent verifies that calling +// Subscribe twice for the same URI in the same session is a no-op for the +// second call: it returns nil without invoking SubscribeHandler again and +// without opening a second listen stream. +func TestResourceSubscriptions_Subscribe_Idempotent(t *testing.T) { + enableNewProtocol(t) + + subCh := make(chan string, 8) + unsubCh := make(chan string, 8) + + server := resourceSubServer(t, subCh, unsubCh) + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + c := NewClient(testImpl, &ClientOptions{ + ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) {}, + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("first subscribe: %v", err) + } + select { + case got := <-subCh: + if got != "file:///r1" { + t.Fatalf("got URI %q, want %q", got, "file:///r1") + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for first SubscribeHandler") + } + + // Second Subscribe for the same URI returns nil and does NOT fire + // SubscribeHandler again. + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("second subscribe should be no-op, got error: %v", err) + } + select { + case got := <-subCh: + t.Fatalf("duplicate Subscribe should not re-invoke SubscribeHandler (got %q)", got) + case <-time.After(notificationDelay * 10): + } + + // Subsequent Unsubscribe still works (verifies the URI is tracked + // correctly even though the second Subscribe was a no-op). + if err := cs.Unsubscribe(ctx, &UnsubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("unsubscribe: %v", err) + } + select { + case <-unsubCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for UnsubscribeHandler") + } +} + +// TestResourceSubscriptions_MultipleURIs verifies that two concurrent +// Subscribe calls on the same session each open their own independent listen +// stream with a distinct subscription ID. Unsubscribing one does not affect +// the other. +func TestResourceSubscriptions_MultipleURIs(t *testing.T) { + enableNewProtocol(t) + + subCh := make(chan string, 8) + unsubCh := make(chan string, 8) + events := make(chan resourceSubEvent, 16) + + server := resourceSubServer(t, subCh, unsubCh) + server.AddResource(&Resource{Name: "r2", URI: "file:///r2"}, nil) + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + c := NewClient(testImpl, &ClientOptions{ + ResourceUpdatedHandler: func(_ context.Context, req *ResourceUpdatedNotificationRequest) { + id := "" + if req.Params != nil && req.Params.Meta != nil { + id = fmt.Sprint(req.Params.Meta[MetaKeySubscriptionID]) + } + events <- resourceSubEvent{uri: req.Params.URI, id: id} + }, + }) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer cs.Close() + + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("subscribe r1: %v", err) + } + if err := cs.Subscribe(ctx, &SubscribeParams{URI: "file:///r2"}); err != nil { + t.Fatalf("subscribe r2: %v", err) + } + + // Each Subscribe MUST fire its own SubscribeHandler invocation. + gotURIs := map[string]bool{} + for range 2 { + select { + case got := <-subCh: + gotURIs[got] = true + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for SubscribeHandler") + } + } + if !gotURIs["file:///r1"] || !gotURIs["file:///r2"] { + t.Fatalf("missing SubscribeHandler invocation; got %v", gotURIs) + } + + // Each update delivers exactly one event tagged with that URI's distinct + // subscription ID. + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + ev1 := <-events + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r2"}) + ev2 := <-events + if ev1.uri != "file:///r1" || ev2.uri != "file:///r2" { + t.Fatalf("got URIs %q and %q, want r1 and r2", ev1.uri, ev2.uri) + } + if ev1.id == "" || ev2.id == "" { + t.Fatalf("missing subscription IDs: r1=%q r2=%q", ev1.id, ev2.id) + } + if ev1.id == ev2.id { + t.Fatalf("r1 and r2 should have distinct subscription IDs, both = %q", ev1.id) + } + + // Unsubscribe r1 only. r2 keeps working. + if err := cs.Unsubscribe(ctx, &UnsubscribeParams{URI: "file:///r1"}); err != nil { + t.Fatalf("unsubscribe r1: %v", err) + } + select { + case got := <-unsubCh: + if got != "file:///r1" { + t.Fatalf("got unsubscribe URI %q, want %q", got, "file:///r1") + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for UnsubscribeHandler") + } + time.Sleep(50 * time.Millisecond) // let server-side cancellation settle + + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r1"}) + server.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{URI: "file:///r2"}) + + // Only the r2 update should arrive. + select { + case e := <-events: + if e.uri != "file:///r2" || e.id != ev2.id { + t.Fatalf("post-unsubscribe got %q (id=%s), want r2 (id=%s)", e.uri, e.id, ev2.id) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for r2 update after r1 unsubscribe") + } + select { + case e := <-events: + t.Fatalf("unexpected event after r1 unsubscribe: %q (id=%s)", e.uri, e.id) + case <-time.After(notificationDelay * 10): + } + + // Explicit Unsubscribe r2 to verify it still works independently. + if err := cs.Unsubscribe(ctx, &UnsubscribeParams{URI: "file:///r2"}); err != nil { + t.Fatalf("unsubscribe r2: %v", err) + } + select { + case <-unsubCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for r2 UnsubscribeHandler") + } +} + +// TestSubscriptionsListen_MultipleSessions verifies that two concurrent +// client sessions on the same server are isolated: a list-changed +// notification is fanned out to BOTH sessions, each with its own subscription +// ID; closing one session does not affect deliveries to the other. +func TestSubscriptionsListen_MultipleSessions(t *testing.T) { + enableNewProtocol(t) + + server := newSubListenServer() + + // Open two clients on the same server. + open := func(t *testing.T) (*ClientSession, chan subListenEvent, *ServerSession) { + t.Helper() + events := make(chan subListenEvent, 16) + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + c := newSubListenClient(events) + cs, err := c.Connect(context.Background(), ct, + &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + return cs, events, ss + } + csA, evA, ssA := open(t) + defer ssA.Close() + csB, evB, ssB := open(t) + defer ssB.Close() + + waitFor := func(ch chan subListenEvent, kind string) subListenEvent { + t.Helper() + select { + case e := <-ch: + if e.kind != kind { + t.Fatalf("got event %q, want %q", e.kind, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q event", kind) + return subListenEvent{} + } + } + ackA := waitFor(evA, "ack") + ackB := waitFor(evB, "ack") + // Request IDs are per-connection, so both sessions may legitimately use + // the same ID number; we only require that each session's own + // notifications carry its own ack ID. + + // A single change fans out to BOTH sessions, each tagged with that + // session's own ack ID. + server.AddTool(&Tool{Name: "t2", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + gotA := waitFor(evA, "tool") + gotB := waitFor(evB, "tool") + if gotA.id != ackA.id { + t.Errorf("session A: tool notif id=%s, want %s", gotA.id, ackA.id) + } + if gotB.id != ackB.id { + t.Errorf("session B: tool notif id=%s, want %s", gotB.id, ackB.id) + } + + // Close A; B's subscription must keep delivering. + csA.Close() + time.Sleep(50 * time.Millisecond) + server.AddTool(&Tool{Name: "t3", InputSchema: &jsonschema.Schema{Type: "object"}}, nil) + gotB2 := waitFor(evB, "tool") + if gotB2.id != ackB.id { + t.Errorf("session B after A closed: tool notif id=%s, want %s", gotB2.id, ackB.id) + } + select { + case e := <-evA: + t.Fatalf("session A unexpected event after Close: %q", e.kind) + case <-time.After(notificationDelay * 5): + } + + csB.Close() +} + +// TestSubscriptionsListen_ResourceListChanged covers the resources/list_changed +// auto-listen branch (the other two list-changed types are covered by +// runSubscriptionsListenTest, but resources is not). +func TestSubscriptionsListen_ResourceListChanged(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 16) + + server := NewServer(testImpl, nil) + server.AddResource(&Resource{Name: "r1", URI: "file:///r1"}, nil) + + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer ss.Close() + + asEvent := func(kind string, raw any) subListenEvent { + return subListenEvent{kind, fmt.Sprint(raw)} + } + c := NewClient(testImpl, &ClientOptions{ + ResourceListChangedHandler: func(_ context.Context, req *ResourceListChangedRequest) { + id := any(nil) + if req.Params != nil && req.Params.Meta != nil { + id = req.Params.Meta[MetaKeySubscriptionID] + } + events <- asEvent("resource", id) + }, + }) + c.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == notificationSubscriptionsAck { + if cr, ok := req.(*ClientRequest[*SubscriptionsAcknowledgedParams]); ok && cr.Params != nil { + events <- asEvent("ack", cr.Params.Meta[MetaKeySubscriptionID]) + } + } + return next(ctx, method, req) + } + }) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + waitFor := func(kind string) subListenEvent { + t.Helper() + select { + case e := <-events: + if e.kind != kind { + t.Fatalf("got event %q, want %q", e.kind, kind) + } + return e + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for %q", kind) + return subListenEvent{} + } + } + ack := waitFor("ack") + if ack.id == "" { + t.Fatal("acknowledgment missing subscription ID") + } + server.AddResource(&Resource{Name: "r2", URI: "file:///r2"}, nil) + got := waitFor("resource") + if got.id != ack.id { + t.Errorf("resource notif id=%s, want %s", got.id, ack.id) + } + cs.Close() +} + +// TestSubscriptionsListen_DisconnectScrubsMaps verifies that closing a +// session removes its entries from the server's three per-type subscription +// maps via Server.disconnect. +func TestSubscriptionsListen_DisconnectScrubsMaps(t *testing.T) { + enableNewProtocol(t) + events := make(chan subListenEvent, 16) + server := newSubListenServer() + + ct, st := NewInMemoryTransports() + ss, err := server.Connect(context.Background(), st, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + c := newSubListenClient(events) + cs, err := c.Connect(context.Background(), ct, + &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + if err != nil { + t.Fatalf("client connect: %v", err) + } + + // Wait for the auto-listen to actually register on the server side. + select { + case e := <-events: + if e.kind != "ack" { + t.Fatalf("got %q, want ack", e.kind) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for ack") + } + + // Server should now have this session in tools+prompts maps (the fixture + // client opts in to both via newSubListenClient). + server.mu.Lock() + if _, ok := server.toolChangeSubscriptions[ss]; !ok { + server.mu.Unlock() + t.Fatal("session missing from toolChangeSubscriptions before close") + } + if _, ok := server.promptChangeSubscriptions[ss]; !ok { + server.mu.Unlock() + t.Fatal("session missing from promptChangeSubscriptions before close") + } + server.mu.Unlock() + + cs.Close() + ss.Close() + // Closures complete asynchronously on the server side. + deadline := time.Now().Add(5 * time.Second) + for { + server.mu.Lock() + _, inTool := server.toolChangeSubscriptions[ss] + _, inPrompt := server.promptChangeSubscriptions[ss] + _, inResource := server.resourceChangeSubscriptions[ss] + server.mu.Unlock() + if !inTool && !inPrompt && !inResource { + return + } + if time.Now().After(deadline) { + t.Fatalf("subscription maps not scrubbed after Close: tool=%v prompt=%v resource=%v", + inTool, inPrompt, inResource) + } + time.Sleep(10 * time.Millisecond) + } +} From 558182690fbd1385929a9ddd945fddbd7b1cc054 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 21 Jun 2026 10:48:22 +0000 Subject: [PATCH 14/18] refactor: centralize call cancellation logic and add subscription notification filtering for legacy clients --- mcp/client.go | 5 +++- mcp/server.go | 26 +++++++++++------ mcp/transport.go | 72 +++++++++++++++++++++++------------------------- 3 files changed, 56 insertions(+), 47 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 4148bfae..3f447b6b 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -316,7 +316,10 @@ func (c *Client) Connect(ctx context.Context, t Transport, opts *ClientSessionOp // ClientSession.Close cancels the listenCtx context to send notifications/cancelled. listenCtx, cancelListen := context.WithCancel(context.Background()) cs.listenCancel = cancelListen - cs.subscriptionsListen(listenCtx, subscribeParams) + if err := cs.subscriptionsListen(listenCtx, subscribeParams); err != nil { + cancelListen() + return nil, fmt.Errorf("opening subscriptions/listen: %w", err) + } } return cs, nil } diff --git a/mcp/server.go b/mcp/server.go index 8f4202b7..ec87ee35 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -653,7 +653,9 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR return s.opts.CompletionHandler(ctx, req) } -// Map from notification name to its corresponding params. +// Map from notification name to a function creating its corresponding Params. +// We need to create a fresh one each time to add the jsonrpc ID. See +// [injectMetaSubscriptionID]. var changeNotificationParams = map[string]func() Params{ notificationToolListChanged: func() Params { return &ToolListChangedParams{} }, notificationPromptListChanged: func() Params { return &PromptListChangedParams{} }, @@ -698,8 +700,15 @@ func (s *Server) changeAndNotify(notification string, change func() bool) { // notification type. func (s *Server) notifySessions(n string) { s.mu.Lock() - sessions := slices.Clone(s.sessions) s.pendingNotifications[n] = nil + // Legacy (pre-SEP-2575) sessions receive list-changed notifications on the + // shared session channel without opt-in; collect them while we hold the lock. + var legacySessions []*ServerSession + for _, sess := range s.sessions { + if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260630 { + legacySessions = append(legacySessions, sess) + } + } var subscribers map[*ServerSession]jsonrpc.ID switch n { case notificationToolListChanged: @@ -711,13 +720,6 @@ func (s *Server) notifySessions(n string) { } s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. - // Legacy sessions receive list-changed notifications on the shared session channel without opt-in. - var legacySessions []*ServerSession - for _, sess := range sessions { - if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260630 { - legacySessions = append(legacySessions, sess) - } - } notifySessions(legacySessions, n, changeNotificationParams[n](), s.opts.Logger) // Sessions receive the notification only if they opened a @@ -734,6 +736,12 @@ func (s *Server) notifySessions(n string) { } } +// injectMetaSubscriptionID stamps the listen request's JSON-RPC ID into the +// params' _meta under [MetaKeySubscriptionID], so the receiving client can +// correlate the notification with the [subscriptions/listen] stream it +// originated. +// +// [subscriptions/listen]: https://modelcontextprotocol.io/seps/2575-stateless-mcp#multiple-concurrent-subscriptions func injectMetaSubscriptionID(params Params, reqID jsonrpc.ID) { m := params.GetMeta() if m == nil { diff --git a/mcp/transport.go b/mcp/transport.go index 1469858e..b8370fa5 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -230,30 +230,21 @@ func (c *canceller) Preempt(ctx context.Context, req *jsonrpc.Request) (result a return nil, jsonrpc2.ErrNotHandled } +// callSubscriptionsListen issues a "subscriptions/listen" call (SEP-2575) +// without awaiting its JSON-RPC response. The call's logical lifetime is the +// stream of notifications that follow on the same channel — the empty +// response, if ever delivered, only marks subscription teardown — so the +// caller has nothing useful to block on. +// +// Cancellation is driven by ctx: when it is cancelled, a background goroutine +// sends a "notifications/cancelled" notification referencing the listen's +// request ID and retires the call from the connection's outgoing-calls map. func callSubscriptionsListen(ctx context.Context, conn *jsonrpc2.Connection, method string, params Params) { - // The "%w"s in this function expose jsonrpc.Error as part of the API. call := conn.Call(ctx, method, params) go func() { <-ctx.Done() - notifyCtx, cancelNotify := context.WithTimeout(context.WithoutCancel(ctx), notifyCancellationTimeout) - defer cancelNotify() - _ = conn.Notify(notifyCtx, notificationCancelled, &CancelledParams{ - Reason: ctx.Err().Error(), - RequestID: call.ID().Raw(), - }) - // By default, the jsonrpc2 library waits for graceful shutdown when the - // connection is closed, meaning it expects all outgoing and incoming - // requests to complete. However, for MCP this expectation is unrealistic, - // and can lead to hanging shutdown. For example, if a streamable client is - // killed, the server will not be able to detect this event, except via - // keepalive pings (if they are configured), and so outgoing calls may hang - // indefinitely. - // - // Therefore, we choose to eagerly retire calls, removing them from the - // outgoingCalls map, when the caller context is cancelled: if the caller - // will never receive the response, there's no need to track it. - conn.Retire(call, ctx.Err()) + _ = cancelCall(ctx, conn, call) }() } @@ -267,24 +258,7 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params case errors.Is(err, jsonrpc2.ErrClientClosing), errors.Is(err, jsonrpc2.ErrServerClosing): return fmt.Errorf("%w: calling %q: %v", ErrConnectionClosed, method, err) case ctx.Err() != nil: - notifyCtx, cancelNotify := context.WithTimeout(context.WithoutCancel(ctx), notifyCancellationTimeout) - defer cancelNotify() - err := conn.Notify(notifyCtx, notificationCancelled, &CancelledParams{ - Reason: ctx.Err().Error(), - RequestID: call.ID().Raw(), - }) - // By default, the jsonrpc2 library waits for graceful shutdown when the - // connection is closed, meaning it expects all outgoing and incoming - // requests to complete. However, for MCP this expectation is unrealistic, - // and can lead to hanging shutdown. For example, if a streamable client is - // killed, the server will not be able to detect this event, except via - // keepalive pings (if they are configured), and so outgoing calls may hang - // indefinitely. - // - // Therefore, we choose to eagerly retire calls, removing them from the - // outgoingCalls map, when the caller context is cancelled: if the caller - // will never receive the response, there's no need to track it. - conn.Retire(call, ctx.Err()) + err := cancelCall(ctx, conn, call) return errors.Join(ctx.Err(), err) case err != nil: return fmt.Errorf("calling %q: %w", method, err) @@ -292,6 +266,30 @@ func call(ctx context.Context, conn *jsonrpc2.Connection, method string, params return nil } +// cancelCall sends a "notifications/cancelled" notification for call and eagerly +// retires it from conn. +// +// By default, the jsonrpc2 library waits for graceful shutdown when the +// connection is closed, meaning it expects all outgoing and incoming requests +// to complete. However, for MCP this expectation is unrealistic, and can lead +// to hanging shutdown. For example, if a streamable client is killed, the +// server will not be able to detect this event, except via keepalive pings (if +// they are configured), and so outgoing calls may hang indefinitely. +// +// Therefore, we choose to eagerly retire calls, removing them from the +// outgoingCalls map, when the caller context is cancelled: if the caller will +// never receive the response, there's no need to track it. +func cancelCall(ctx context.Context, conn *jsonrpc2.Connection, call *jsonrpc2.AsyncCall) error { + notifyCtx, cancelNotify := context.WithTimeout(context.WithoutCancel(ctx), notifyCancellationTimeout) + defer cancelNotify() + err := conn.Notify(notifyCtx, notificationCancelled, &CancelledParams{ + Reason: ctx.Err().Error(), + RequestID: call.ID().Raw(), + }) + conn.Retire(call, ctx.Err()) + return err +} + // A LoggingTransport is a [Transport] that delegates to another transport, // writing RPC logs to an io.Writer. type LoggingTransport struct { From f1845d378c2527cf547ee08b489bb59f5d0f4269 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 21 Jun 2026 12:23:22 +0000 Subject: [PATCH 15/18] refactor: propagate request context to notification timeout and remove unused listenSubscription struct --- mcp/server.go | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index ec87ee35..514e7555 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1076,7 +1076,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot notifySessions(legacySessions, notificationResourceUpdated, params, s.opts.Logger) s.opts.Logger.Info("resource updated notification sent", "uri", params.URI, "subscriber_count", len(sessions)) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() for _, sess := range newSessions { reqID := subscribedSessions[sess] @@ -1422,14 +1422,6 @@ type ServerSession struct { state ServerSessionState } -// listenSubscription records the notification types a single -// subscriptions/listen stream has opted in to. -type listenSubscription struct { - toolsListChanged bool - promptsListChanged bool - resourcesListChanged bool -} - func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { ss.mu.Lock() mut(&ss.state) From 809e288f5e865a63f6138216eb906abb4c275fd1 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Sun, 21 Jun 2026 17:02:03 +0000 Subject: [PATCH 16/18] refactor: encapsulate ephemeral connection parameters into ephemeralConnectInfo struct --- mcp/streamable.go | 42 ++++++++++++++++++++++++------------------ mcp/streamable_test.go | 14 +++++++------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index 8c752487..64d075a8 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -366,14 +366,14 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. return } - connectOpts, usesNewProtocol, hasSubscriptionsListen, err := h.ephemeralConnectOpts(req) + info, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } var sessionID string - if legacySessions && !usesNewProtocol { + if legacySessions && !info.usesNewProtocol { sessionID = req.Header.Get(sessionIDHeader) if sessionID == "" { sessionID = server.opts.GetSessionID() @@ -389,10 +389,10 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. } ctx := req.Context() - if hasSubscriptionsListen && usesNewProtocol { + if info.isSubscriptionsListen && info.usesNewProtocol { ctx = context.WithValue(ctx, jsonrpc2.IsSubscriptionsListenContextKey{}, true) } - session, err := connectStreamable(ctx, server, transport, connectOpts) + session, err := connectStreamable(ctx, server, transport, info.opts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) http.Error(w, "failed connection", http.StatusInternalServerError) @@ -416,27 +416,29 @@ func (h *StreamableHTTPHandler) serveStatelessLegacyDELETE(w http.ResponseWriter w.WriteHeader(http.StatusNoContent) } -// ephemeralConnectOpts peeks at the request body to determine whether it -// contains an initialize or initialized message or whether the protocol version -// header indicates a protocol version >= 2026-06-30 (SEP-2575). +type ephemeralConnectInfo struct { + opts *ServerSessionOptions + usesNewProtocol bool + isSubscriptionsListen bool +} + +// ephemeralConnectOpts peeks at the request body to determine connection +// parameters and whether protocol version >= 2026-06-30 (SEP-2575). // // For old-protocol requests, default session state is synthesized so that // the session's init gate doesn't reject the request. // // It is used for both stateless servers and stateful servers with no session ID. -// -// The returned usesNewProtocol bool reports whether the protocol version -// header indicates a protocol version >= 2026-06-30 (SEP-2575). -func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *ServerSessionOptions, usesNewProtocol bool, isSubscriptionsListen bool, err error) { +func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (*ephemeralConnectInfo, error) { protocolVersion := protocolVersionFromContext(req.Context()) if protocolVersion == "" { protocolVersion = protocolVersion20250326 } - var hasInitialize, hasInitialized bool + var hasInitialize, hasInitialized, usesNewProtocol, isSubscriptionsListen bool body, err := io.ReadAll(req.Body) if err != nil { - return nil, false, false, fmt.Errorf("failed to read body") + return nil, fmt.Errorf("failed to read body") } req.Body.Close() req.Body = io.NopCloser(bytes.NewBuffer(body)) @@ -473,9 +475,13 @@ func (h *StreamableHTTPHandler) ephemeralConnectOpts(req *http.Request) (opts *S if !usesNewProtocol { state.LogLevel = "info" } - return &ServerSessionOptions{ - State: state, - }, usesNewProtocol, isSubscriptionsListen, nil + return &ephemeralConnectInfo{ + opts: &ServerSessionOptions{ + State: state, + }, + usesNewProtocol: usesNewProtocol, + isSubscriptionsListen: isSubscriptionsListen, + }, nil } func connectStreamable(ctx context.Context, server *Server, transport *StreamableServerTransport, opts *ServerSessionOptions) (*ServerSession, error) { @@ -620,12 +626,12 @@ func (h *StreamableHTTPHandler) serveStatefulPOST(w http.ResponseWriter, req *ht // that arrives before a session exists (e.g. initialize or ping) on a // server configured this way. if sessionID == "" { - connectOpts, _, _, err := h.ephemeralConnectOpts(req) + info, err := h.ephemeralConnectOpts(req) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - session, err := connectStreamable(req.Context(), server, transport, connectOpts) + session, err := connectStreamable(req.Context(), server, transport, info.opts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) http.Error(w, "failed connection", http.StatusInternalServerError) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index ba83c87a..e437046c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -3319,20 +3319,20 @@ func TestEphemeralConnectOpts(t *testing.T) { } req.Header.Set(protocolVersionHeader, pver) req = req.WithContext(context.WithValue(req.Context(), protocolVersionContextKey{}, pver)) - opts, usesNew, _, err := h.ephemeralConnectOpts(req) + info, err := h.ephemeralConnectOpts(req) if err != nil { t.Fatal(err) } - if usesNew != tt.wantUsesNew { - t.Errorf("usesNewProtocol = %v, want %v", usesNew, tt.wantUsesNew) + if info.usesNewProtocol != tt.wantUsesNew { + t.Errorf("usesNewProtocol = %v, want %v", info.usesNewProtocol, tt.wantUsesNew) } - if got := opts.State.InitializeParams != nil; got != tt.wantInitializeParams { + if got := info.opts.State.InitializeParams != nil; got != tt.wantInitializeParams { t.Errorf("InitializeParams non-nil = %v, want %v (value = %+v)", - got, tt.wantInitializeParams, opts.State.InitializeParams) + got, tt.wantInitializeParams, info.opts.State.InitializeParams) } - if got := opts.State.InitializedParams != nil; got != tt.wantInitializedParams { + if got := info.opts.State.InitializedParams != nil; got != tt.wantInitializedParams { t.Errorf("InitializedParams non-nil = %v, want %v (value = %+v)", - got, tt.wantInitializedParams, opts.State.InitializedParams) + got, tt.wantInitializedParams, info.opts.State.InitializedParams) } }) } From 34fc70249ae4d997a304e92a5cf797e77215fda7 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 22 Jun 2026 15:06:41 +0000 Subject: [PATCH 17/18] feat: introduce PropagateCancellation to jsonrpc2 and MCP to support handler cancellation on transport disconnect --- internal/jsonrpc2/conn.go | 38 +++++++++++++++----------- mcp/client.go | 20 ++++++++------ mcp/mcp_test.go | 20 +++++++------- mcp/server.go | 13 ++++----- mcp/streamable.go | 56 ++++++++++++++++++++++++++------------- mcp/transport.go | 18 +++++++++++++ 6 files changed, 107 insertions(+), 58 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index 1a91af21..b9a78726 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -201,21 +201,20 @@ type Writer interface { // A ConnectionConfig configures a bidirectional jsonrpc2 connection. type ConnectionConfig struct { - Reader Reader // required - Writer Writer // required - Closer io.Closer // required - Preempter Preempter // optional - Bind func(*Connection) Handler // required - OnDone func() // optional - OnInternalError func(error) // optional + Reader Reader // required + Writer Writer // required + Closer io.Closer // required + Preempter Preempter // optional + Bind func(*Connection) Handler // required + OnDone func() // optional + OnInternalError func(error) // optional + PropagateCancellation bool // optional } -type IsSubscriptionsListenContextKey struct{} - // NewConnection creates a new [Connection] object and starts processing // incoming messages. func NewConnection(ctx context.Context, cfg ConnectionConfig) *Connection { - if val, _ := ctx.Value(IsSubscriptionsListenContextKey{}).(bool); !val { + if !cfg.PropagateCancellation { ctx = notDone{ctx} } @@ -736,13 +735,13 @@ func (c *Connection) write(ctx context.Context, msg Message) error { var err error // Fail writes immediately if the connection is shutting down. // - // Allow outgoing "notifications/cancelled" to allow the remote end - // to cancel in-flight calls. + // Allow outgoing "notifications" forwarded by the Notify method. + // This will allow to send the cancelled notification when the client is shutting down. c.updateInFlight(func(s *inFlightState) { - req, ok := msg.(*Request) - if !ok || req.Method != "notifications/cancelled" { - err = s.shuttingDown(ErrServerClosing) + if req, ok := msg.(*Request); ok && !req.IsCall() && s.outgoingNotifications > 0 { + return } + err = s.shuttingDown(ErrServerClosing) }) if err == nil { err = c.writer.Write(ctx, msg) @@ -786,6 +785,15 @@ func (c *Connection) internalErrorf(format string, args ...any) error { } // notDone is a context.Context wrapper that returns a nil Done channel. +// +// Request handlers' contexts are derived from the connection's root context, +// which by default is wrapped in notDone so a transport-level cancellation +// does not implicitly cancel every in-flight handler. Cancellation of an +// in-flight handler is instead expected to flow only through the jsonrpc2 +// layer's explicit channels: the [Preempter] calling [Connection.Cancel] in +// response to the peer's cancel notification, or the transport itself +// failing (the read loop exits on EOF or a write fails) — both of which +// cancel every in-flight incoming request in turn. type notDone struct{ ctx context.Context } func (ic notDone) Value(key any) any { diff --git a/mcp/client.go b/mcp/client.go index fee07c4e..d8fd2c11 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -1324,17 +1324,21 @@ func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) } uri := params.URI + var listenCtx context.Context cs.resourceSubsMu.Lock() - if _, exists := cs.resourceSubs[uri]; exists { - cs.resourceSubsMu.Unlock() - return nil - } - listenCtx, cancel := context.WithCancel(context.Background()) - if cs.resourceSubs == nil { - cs.resourceSubs = make(map[string]context.CancelFunc) + if _, exists := cs.resourceSubs[uri]; !exists { + var cancel context.CancelFunc + listenCtx, cancel = context.WithCancel(context.Background()) + if cs.resourceSubs == nil { + cs.resourceSubs = make(map[string]context.CancelFunc) + } + cs.resourceSubs[uri] = cancel } - cs.resourceSubs[uri] = cancel cs.resourceSubsMu.Unlock() + if listenCtx == nil { + // Already subscribed to this URI + return nil + } return cs.subscriptionsListen(listenCtx, &SubscriptionsListenParams{ Notifications: NotificationSubscriptions{ diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index f9aabcd3..91cc57f9 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -2468,7 +2468,7 @@ func runSubscriptionsListenTest(t *testing.T, client *Client, server *Server, ct ctx, topCancel := context.WithTimeout(context.Background(), 30*time.Second) defer topCancel() - cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + cs, err := client.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } @@ -2571,7 +2571,7 @@ func newSubListenServer() *Server { func enableNewProtocol(t *testing.T) { t.Helper() orig := supportedProtocolVersions - supportedProtocolVersions = append([]string{protocolVersion20260630}, slices.Clone(orig)...) + supportedProtocolVersions = append([]string{protocolVersion20260728}, slices.Clone(orig)...) t.Cleanup(func() { supportedProtocolVersions = orig }) } @@ -2636,7 +2636,7 @@ func TestSubscriptionsListen_NoHandlersNoListen(t *testing.T) { return next(ctx, method, req) } }) - cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } @@ -2718,7 +2718,7 @@ func TestResourceSubscriptions_Streamable(t *testing.T) { }, }) cs, err := c.Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, - &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } @@ -2788,7 +2788,7 @@ func TestResourceSubscriptions_InMemory(t *testing.T) { events <- resourceSubEvent{uri: req.Params.URI, id: id} }, }) - cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } @@ -2866,7 +2866,7 @@ func TestResourceSubscriptions_Subscribe_Idempotent(t *testing.T) { c := NewClient(testImpl, &ClientOptions{ ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) {}, }) - cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } @@ -2939,7 +2939,7 @@ func TestResourceSubscriptions_MultipleURIs(t *testing.T) { events <- resourceSubEvent{uri: req.Params.URI, id: id} }, }) - cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } @@ -3045,7 +3045,7 @@ func TestSubscriptionsListen_MultipleSessions(t *testing.T) { } c := newSubListenClient(events) cs, err := c.Connect(context.Background(), ct, - &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } @@ -3146,7 +3146,7 @@ func TestSubscriptionsListen_ResourceListChanged(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + cs, err := c.Connect(ctx, ct, &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } @@ -3191,7 +3191,7 @@ func TestSubscriptionsListen_DisconnectScrubsMaps(t *testing.T) { } c := newSubListenClient(events) cs, err := c.Connect(context.Background(), ct, - &ClientSessionOptions{protocolVersion: protocolVersion20260630}) + &ClientSessionOptions{protocolVersion: protocolVersion20260728}) if err != nil { t.Fatalf("client connect: %v", err) } diff --git a/mcp/server.go b/mcp/server.go index 9c93bcfe..011ea56e 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -705,7 +705,7 @@ func (s *Server) notifySessions(n string) { // shared session channel without opt-in; collect them while we hold the lock. var legacySessions []*ServerSession for _, sess := range s.sessions { - if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260630 { + if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260728 { legacySessions = append(legacySessions, sess) } } @@ -1067,7 +1067,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot var legacySessions []*ServerSession var newSessions []*ServerSession for _, sess := range sessions { - if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260630 { + if sess.InitializeParams().isNil() || sess.InitializeParams().ProtocolVersion < protocolVersion20260728 { legacySessions = append(legacySessions, sess) } else { newSessions = append(newSessions, sess) @@ -1185,7 +1185,7 @@ func (s *Server) subscriptionsListen(ctx context.Context, req *SubscriptionsList Notifications: allowed, Meta: Meta{MetaKeySubscriptionID: requestID.Raw()}, } - if err := req.Session.NotifySubscriptionAcked(ctx, ackParams); err != nil { + if err := req.Session.notifySubscriptionAcked(ctx, ackParams); err != nil { return nil, fmt.Errorf("sending subscriptions/acknowledged: %w", err) } @@ -1384,9 +1384,10 @@ func (ss *ServerSession) NotifyProgress(ctx context.Context, params *ProgressNot return handleNotify(ctx, notificationProgress, newServerRequest(ss, orZero[Params](params))) } -// NotifySubscriptionAcked sends a subscription acknowledged notification from the server to the client -// associated with this session. -func (ss *ServerSession) NotifySubscriptionAcked(ctx context.Context, params *SubscriptionsAcknowledgedParams) error { +// notifySubscriptionAcked sends a "notifications/subscriptions/acknowledged" +// notification on the listen stream represented by this session, indicating +// the subscription filter the server accepted (SEP-2575). +func (ss *ServerSession) notifySubscriptionAcked(ctx context.Context, params *SubscriptionsAcknowledgedParams) error { return handleNotify(ctx, notificationSubscriptionsAck, newServerRequest(ss, orZero[Params](params))) } diff --git a/mcp/streamable.go b/mcp/streamable.go index 8d315a33..b71bc774 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -381,18 +381,15 @@ func (h *StreamableHTTPHandler) serveStateless(w http.ResponseWriter, req *http. } transport := &StreamableServerTransport{ - SessionID: sessionID, - Stateless: true, - EventStore: h.opts.EventStore, - jsonResponse: h.opts.JSONResponse, - logger: h.opts.Logger, + SessionID: sessionID, + Stateless: true, + EventStore: h.opts.EventStore, + jsonResponse: h.opts.JSONResponse, + logger: h.opts.Logger, + shouldPropagateCancellation: info.isSubscriptionsListen && info.usesNewProtocol, } - ctx := req.Context() - if info.isSubscriptionsListen && info.usesNewProtocol { - ctx = context.WithValue(ctx, jsonrpc2.IsSubscriptionsListenContextKey{}, true) - } - session, err := connectStreamable(ctx, server, transport, info.opts) + session, err := connectStreamable(req.Context(), server, transport, info.opts) if err != nil { h.opts.Logger.Error(fmt.Sprintf("failed to connect: %v", err)) http.Error(w, "failed connection", http.StatusInternalServerError) @@ -787,6 +784,10 @@ type StreamableServerTransport struct { // to write their own streamable HTTP handler. logger *slog.Logger + // shouldPropagateCancellation is forwarded to the underlying + // [streamableServerConn]. See its docstring. + shouldPropagateCancellation bool + // connection is non-nil if and only if the transport has been connected. connection *streamableServerConn } @@ -797,15 +798,16 @@ func (t *StreamableServerTransport) Connect(ctx context.Context) (Connection, er return nil, fmt.Errorf("transport already connected") } t.connection = &streamableServerConn{ - sessionID: t.SessionID, - stateless: t.Stateless, - eventStore: t.EventStore, - jsonResponse: t.jsonResponse, - logger: ensureLogger(t.logger), // see #556: must be non-nil - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - streams: make(map[string]*stream), - requestStreams: make(map[jsonrpc.ID]string), + sessionID: t.SessionID, + stateless: t.Stateless, + eventStore: t.EventStore, + jsonResponse: t.jsonResponse, + logger: ensureLogger(t.logger), // see #556: must be non-nil + shouldPropagateCancellation: t.shouldPropagateCancellation, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + streams: make(map[string]*stream), + requestStreams: make(map[jsonrpc.ID]string), } // Stream 0 corresponds to the standalone SSE stream. // @@ -835,6 +837,13 @@ type streamableServerConn struct { jsonResponse bool eventStore EventStore + // shouldPropagateCancellation is true when the underlying HTTP request's + // lifetime IS the connection's cancellation signal (e.g., a stateless + // POST that owns a long-lived subscriptions/listen stream). It is read + // by the [cancellationPropagator] interface so the jsonrpc2 layer wires + // handler contexts to observe the carrier's cancellation. + shouldPropagateCancellation bool + logger *slog.Logger toolLookup func(name string) (*serverTool, bool) @@ -872,6 +881,15 @@ func (c *streamableServerConn) SessionID() string { return c.sessionID } +// propagateCancellation implements [cancellationPropagator]. It returns true +// when this connection is bound to a single HTTP request whose lifetime +// should drive request-handler cancellation — for example, a stateless POST +// carrying a long-lived subscriptions/listen stream that must unwind when +// the client TCP-disconnects. +func (c *streamableServerConn) propagateCancellation() bool { + return c.shouldPropagateCancellation +} + // A stream is a single logical stream of SSE events within a server session. // A stream begins with a client request, or with a client GET that has // no Last-Event-ID header. diff --git a/mcp/transport.go b/mcp/transport.go index b8370fa5..55c73d74 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -191,6 +191,13 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, preempter.conn = conn return jsonrpc2.HandlerFunc(h.handle) } + // Transports may opt in to propagating cancellation of ctx into request + // handler contexts when their own lifecycle IS the cancellation signal + // (e.g., a connection bound to a single HTTP request). + var propagateCancellation bool + if cp, ok := mcpConn.(cancellationPropagator); ok { + propagateCancellation = cp.propagateCancellation() + } _ = jsonrpc2.NewConnection(ctx, jsonrpc2.ConnectionConfig{ Reader: reader, Writer: writer, @@ -203,11 +210,22 @@ func connect[H handler, State any](ctx context.Context, t Transport, b binder[H, OnInternalError: func(err error) { logger.Error("jsonrpc2 internal error", "error", err) }, + PropagateCancellation: propagateCancellation, }) assert(preempter.conn != nil, "unbound preempter") return h, nil } +// cancellationPropagator is an optional interface implemented by a +// [Connection] whose own lifecycle should propagate cancellation into request +// handler contexts. The default jsonrpc2 behavior is to suppress propagation; +// transports that bind a connection to a single short-lived carrier (such as +// a one-shot HTTP request) should return true here so that handlers unwind +// when the carrier observes the peer going away. +type cancellationPropagator interface { + propagateCancellation() bool +} + // A canceller is a jsonrpc2.Preempter that cancels in-flight requests on MCP // cancelled notifications. type canceller struct { From 85a027705ead513184ec3c8ddb940a86e9db6a71 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Mon, 22 Jun 2026 15:30:05 +0000 Subject: [PATCH 18/18] docs: expand documentation for ConnectionConfig.PropagateCancellation field --- internal/jsonrpc2/conn.go | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/internal/jsonrpc2/conn.go b/internal/jsonrpc2/conn.go index b9a78726..4994c63b 100644 --- a/internal/jsonrpc2/conn.go +++ b/internal/jsonrpc2/conn.go @@ -201,14 +201,29 @@ type Writer interface { // A ConnectionConfig configures a bidirectional jsonrpc2 connection. type ConnectionConfig struct { - Reader Reader // required - Writer Writer // required - Closer io.Closer // required - Preempter Preempter // optional - Bind func(*Connection) Handler // required - OnDone func() // optional - OnInternalError func(error) // optional - PropagateCancellation bool // optional + Reader Reader // required + Writer Writer // required + Closer io.Closer // required + Preempter Preempter // optional + Bind func(*Connection) Handler // required + OnDone func() // optional + OnInternalError func(error) // optional + + // PropagateCancellation controls whether cancellation of the context + // passed to [NewConnection] is observable by request handlers. + // + // By default (false), the connection wraps that context (see [notDone]) + // so handlers' Done channels do not fire when the connection's root + // context is cancelled. Cancellation of an in-flight handler is then + // expected to flow only through the jsonrpc2 layer's explicit channels + // (the [Preempter] reacting to the peer's cancel notification, or a + // transport read/write failure cancelling every in-flight request). + // + // Set this true when cancellation of the connection's context is itself + // a meaningful signal that handlers should react to (for example, when + // the connection is tied to a carrier that owns it and whose end means + // the request is cancelled). + PropagateCancellation bool // optional } // NewConnection creates a new [Connection] object and starts processing