From 43ff95d2ca3252a55b8d65662649df7342db080f Mon Sep 17 00:00:00 2001 From: Guy Lichtman <1395797+glicht@users.noreply.github.com> Date: Wed, 25 Mar 2026 17:14:27 +0200 Subject: [PATCH 1/2] feat: optional session store --- crates/rmcp/Cargo.toml | 9 + .../src/transport/streamable_http_server.rs | 2 +- .../streamable_http_server/session.rs | 51 +++ .../streamable_http_server/session/local.rs | 16 +- .../streamable_http_server/session/store.rs | 62 +++ .../transport/streamable_http_server/tower.rs | 406 +++++++++++++++-- .../test_streamable_http_session_store.rs | 407 ++++++++++++++++++ 7 files changed, 909 insertions(+), 44 deletions(-) create mode 100644 crates/rmcp/src/transport/streamable_http_server/session/store.rs create mode 100644 crates/rmcp/tests/test_streamable_http_session_store.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index cbf02ea48..a3c27be4c 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -303,3 +303,12 @@ required-features = [ ] path = "tests/test_streamable_http_stale_session.rs" +[[test]] +name = "test_streamable_http_session_store" +required-features = [ + "client", + "server", + "transport-streamable-http-client-reqwest", + "transport-streamable-http-server", +] +path = "tests/test_streamable_http_session_store.rs" diff --git a/crates/rmcp/src/transport/streamable_http_server.rs b/crates/rmcp/src/transport/streamable_http_server.rs index 9cbb63cc0..df1945ab2 100644 --- a/crates/rmcp/src/transport/streamable_http_server.rs +++ b/crates/rmcp/src/transport/streamable_http_server.rs @@ -1,6 +1,6 @@ pub mod session; #[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))] pub mod tower; -pub use session::{SessionId, SessionManager}; +pub use session::{RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker}; #[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))] pub use tower::{StreamableHttpServerConfig, StreamableHttpService}; diff --git a/crates/rmcp/src/transport/streamable_http_server/session.rs b/crates/rmcp/src/transport/streamable_http_server/session.rs index dcdb25c86..48ad8b001 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session.rs @@ -30,6 +30,39 @@ use crate::{ pub mod local; pub mod never; +pub mod store; + +pub use store::{SessionState, SessionStore, SessionStoreError}; + +/// Extension marker inserted into the `initialize` request extensions during a +/// session restore replay. Handlers can check for its presence to distinguish a +/// cross-instance restore from a genuine client-initiated `initialize` request. +/// +/// ```rust,ignore +/// if req.extensions().get::().is_some() { +/// // this is a restore replay, not a fresh client connection +/// } +/// ``` +#[derive(Debug, Clone)] +pub struct SessionRestoreMarker { + pub id: SessionId, +} + +/// The outcome of a [`SessionManager::restore_session`] call. +#[derive(Debug)] +pub enum RestoreOutcome { + /// The session was just re-created from external state; the caller must + /// spawn an MCP handler against the returned transport and replay the + /// `initialize` handshake. + Restored(T), + /// The session was already present in memory (e.g. a concurrent request + /// already restored it). The caller should proceed as if `has_session` + /// had returned `true` — no further action is required. + AlreadyPresent, + /// This session manager does not support external-store restore. + /// The caller should fall through to the normal 404 response. + NotSupported, +} /// Controls how MCP sessions are created, validated, and closed. /// @@ -98,4 +131,22 @@ pub trait SessionManager: Send + Sync + 'static { ) -> impl Future< Output = Result + Send + Sync + 'static, Self::Error>, > + Send; + + /// Attempt to restore a previously-known session from external state, + /// creating a fresh in-memory session worker with the given `id`. + /// + /// See [`RestoreOutcome`] for the three possible results: + /// - [`RestoreOutcome::Restored`] — session re-created; caller must spawn + /// an MCP handler and replay the `initialize` handshake. + /// - [`RestoreOutcome::AlreadyPresent`] — session is already in memory + /// (e.g. a concurrent request restored it first); caller proceeds + /// normally. + /// - [`RestoreOutcome::NotSupported`] (default) — this session manager + /// does not support external-store restore; caller returns 404. + fn restore_session( + &self, + _id: SessionId, + ) -> impl Future, Self::Error>> + Send { + futures::future::ready(Ok(RestoreOutcome::NotSupported)) + } } diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index cad533802..3dc1b6339 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -128,6 +128,20 @@ impl SessionManager for LocalSessionManager { handle.push_message(message, None).await?; Ok(()) } + + async fn restore_session( + &self, + id: SessionId, + ) -> Result, Self::Error> { + let mut sessions = self.sessions.write().await; + if sessions.contains_key(&id) { + // A concurrent request already restored this session. + return Ok(RestoreOutcome::AlreadyPresent); + } + let (handle, worker) = create_local_session(id.clone(), self.session_config.clone()); + sessions.insert(id, handle); + Ok(RestoreOutcome::Restored(WorkerTransport::spawn(worker))) + } } /// `/request_id>` @@ -179,7 +193,7 @@ impl std::str::FromStr for EventId { } } -use super::{ServerSseMessage, SessionManager}; +use super::{RestoreOutcome, ServerSseMessage, SessionManager}; struct CachedTx { tx: Sender, diff --git a/crates/rmcp/src/transport/streamable_http_server/session/store.rs b/crates/rmcp/src/transport/streamable_http_server/session/store.rs new file mode 100644 index 000000000..d89a9c162 --- /dev/null +++ b/crates/rmcp/src/transport/streamable_http_server/session/store.rs @@ -0,0 +1,62 @@ +use crate::model::InitializeRequestParams; + +/// State persisted to an external store for cross-instance session recovery. +/// +/// When a client reconnects to a different server instance, the new instance +/// loads this state to transparently replay the `initialize` handshake without +/// the client needing to re-initialize. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct SessionState { + /// Parameters from the client's original `initialize` request. + pub initialize_params: InitializeRequestParams, +} + +/// Type alias for boxed session store errors. +pub type SessionStoreError = Box; + +/// Pluggable external session store for cross-instance recovery. +/// +/// Implement this trait to back sessions with Redis, a database, or any +/// key-value store. The simplest usage is to set +/// [`StreamableHttpServerConfig::session_store`] to an `Arc`. +/// +/// # Example (in-memory, for testing) +/// +/// ```rust,ignore +/// use std::{collections::HashMap, sync::Arc}; +/// use tokio::sync::RwLock; +/// use rmcp::transport::streamable_http_server::session::store::{ +/// SessionState, SessionStore, SessionStoreError, +/// }; +/// +/// #[derive(Default)] +/// struct InMemoryStore(Arc>>); +/// +/// #[async_trait::async_trait] +/// impl SessionStore for InMemoryStore { +/// async fn load(&self, id: &str) -> Result, SessionStoreError> { +/// Ok(self.0.read().await.get(id).cloned()) +/// } +/// async fn store(&self, id: &str, state: &SessionState) -> Result<(), SessionStoreError> { +/// self.0.write().await.insert(id.to_owned(), state.clone()); +/// Ok(()) +/// } +/// async fn delete(&self, id: &str) -> Result<(), SessionStoreError> { +/// self.0.write().await.remove(id); +/// Ok(()) +/// } +/// } +/// ``` +#[async_trait::async_trait] +pub trait SessionStore: Send + Sync + 'static { + /// Load session state for the given `session_id`. + /// + /// Returns `Ok(None)` when no entry exists (i.e. session is unknown to the store). + async fn load(&self, session_id: &str) -> Result, SessionStoreError>; + + /// Persist session state for the given `session_id`. + async fn store(&self, session_id: &str, state: &SessionState) -> Result<(), SessionStoreError>; + + /// Remove session state for the given `session_id`. + async fn delete(&self, session_id: &str) -> Result<(), SessionStoreError>; +} diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 0130467df..ec4617a71 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; +use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration}; use bytes::Bytes; use futures::{StreamExt, future::BoxFuture}; @@ -8,10 +8,15 @@ use http_body_util::{BodyExt, Full, combinators::BoxBody}; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; -use super::session::SessionManager; +use super::session::{ + RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker, SessionState, SessionStore, +}; use crate::{ RoleServer, - model::{ClientJsonRpcMessage, ClientRequest, GetExtensions, ProtocolVersion}, + model::{ + ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest, + InitializedNotification, ProtocolVersion, + }, serve_server, service::serve_directly, transport::{ @@ -29,7 +34,7 @@ use crate::{ }, }; -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct StreamableHttpServerConfig { /// The ping message duration for SSE connections. pub sse_keep_alive: Option, @@ -48,6 +53,44 @@ pub struct StreamableHttpServerConfig { /// When this token is cancelled, all active sessions are terminated and /// the server stops accepting new requests. pub cancellation_token: CancellationToken, + /// Optional external session store for cross-instance recovery. + /// + /// When set, [`SessionState`] (the client's `initialize` parameters) is + /// persisted after a successful handshake and deleted when the session + /// closes. On any subsequent request that arrives at an instance with no + /// in-memory session, the store is consulted: if an entry is found the + /// session is transparently restored so the client does not need to + /// re-initialize. + /// + /// # Example + /// ```rust,ignore + /// use std::sync::Arc; + /// use rmcp::transport::streamable_http_server::{ + /// StreamableHttpServerConfig, session::SessionStore, + /// }; + /// + /// let config = StreamableHttpServerConfig { + /// session_store: Some(Arc::new(MyRedisStore::new())), + /// ..Default::default() + /// }; + /// ``` + pub session_store: Option>, +} + +impl std::fmt::Debug for StreamableHttpServerConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamableHttpServerConfig") + .field("sse_keep_alive", &self.sse_keep_alive) + .field("sse_retry", &self.sse_retry) + .field("stateful_mode", &self.stateful_mode) + .field("json_response", &self.json_response) + .field("cancellation_token", &self.cancellation_token) + .field( + "session_store", + &self.session_store.as_ref().map(|_| ""), + ) + .finish() + } } impl Default for StreamableHttpServerConfig { @@ -58,6 +101,7 @@ impl Default for StreamableHttpServerConfig { stateful_mode: true, json_response: false, cancellation_token: CancellationToken::new(), + session_store: None, } } } @@ -189,6 +233,13 @@ pub struct StreamableHttpService { pub config: StreamableHttpServerConfig, session_manager: Arc, service_factory: Arc Result + Send + Sync>, + /// Tracks in-progress session restores so that concurrent requests for the + /// same unknown session ID wait for the first restore to complete rather + /// than racing to replay the initialize handshake. `None` when no external + /// session store is configured (avoids allocating the map). + pending_restores: Option< + Arc>>>>, + >, } impl Clone for StreamableHttpService { @@ -197,6 +248,7 @@ impl Clone for StreamableHttpService { config: self.config.clone(), session_manager: self.session_manager.clone(), service_factory: self.service_factory.clone(), + pending_restores: self.pending_restores.clone(), } } } @@ -237,15 +289,249 @@ where session_manager: Arc, config: StreamableHttpServerConfig, ) -> Self { + let pending_restores = config.session_store.is_some().then(|| { + Arc::new(tokio::sync::RwLock::new(HashMap::< + SessionId, + tokio::sync::watch::Sender>, + >::new())) + }); Self { config, session_manager, service_factory: Arc::new(service_factory), + pending_restores, } } fn get_service(&self) -> Result { (self.service_factory)() } + + /// Spawn a task that runs `serve_server` for the given session, waits for + /// it to finish, and then calls `close_session`. + /// + /// `init_done_tx`: when `Some`, the sender is fired after `serve_server` + /// returns successfully, signalling to the caller that the MCP handshake + /// is complete. Used by `try_restore_from_store` to synchronise with the + /// restore `initialize` replay; `handle_post` passes `None`. + fn spawn_session_worker( + session_manager: Arc, + session_id: SessionId, + service: S, + transport: M::Transport, + init_done_tx: Option>, + ) where + S: crate::Service + Send + 'static, + M: SessionManager, + { + tokio::spawn(async move { + let svc = + serve_server::(service, transport) + .await; + match svc { + Ok(svc) => { + if let Some(tx) = init_done_tx { + let _ = tx.send(()); + } + let _ = svc.waiting().await; + } + Err(e) => { + tracing::error!("Failed to serve session: {e}"); + // Dropping init_done_tx (if Some) signals failure to the caller. + } + } + let _ = session_manager + .close_session(&session_id) + .await + .inspect_err(|e| { + tracing::error!("Failed to close session {session_id}: {e}"); + }); + }); + } + + /// Attempt to restore a session from the external store. + /// + /// Returns `true` when the session is available and ready to serve the + /// current request (either just restored or already in memory). Returns + /// `false` when no store is configured or the session ID is unknown. + /// + /// Concurrent requests for the same unknown session ID are serialized: the + /// first caller performs the full restore and handshake replay while others + /// subscribe to a `watch` channel and wait, avoiding duplicate handshakes. + async fn try_restore_from_store( + &self, + session_id: &SessionId, + parts: &http::request::Parts, + ) -> Result + where + S: crate::Service + Send + 'static, + M: SessionManager, + { + // Both fields are Some iff a session store is configured. + let (Some(pending_restores), Some(store)) = + (&self.pending_restores, &self.config.session_store) + else { + return Ok(false); + }; + + // Serialize concurrent restores for the same session ID. + // Write-lock once: if another task is already restoring, subscribe and wait; + // otherwise, register ourselves as the restoring task. + // Channel value: None = in progress, Some(true) = restored, Some(false) = not found/failed. + let (watch_tx, _watch_rx) = tokio::sync::watch::channel(None::); + { + let mut pending = pending_restores.write().await; + if let Some(tx) = pending.get(session_id) { + let mut rx = tx.subscribe(); + drop(pending); + // Wait for the restore to finish, then propagate the outcome. + let result = rx + .wait_for(|r| r.is_some()) + .await + .map(|r| r.unwrap_or(false)) + .unwrap_or(false); + return Ok(result); + } + pending.insert(session_id.clone(), watch_tx.clone()); + } + + // Helper: signal waiters with the outcome and remove from the pending map. + let finish = { + let pending_restores = pending_restores.clone(); + let session_id = session_id.clone(); + move |result: bool| { + let pending_restores = pending_restores.clone(); + let session_id = session_id.clone(); + tokio::spawn(async move { + if let Some(tx) = pending_restores.write().await.remove(&session_id) { + let _ = tx.send(Some(result)); + } + }); + } + }; + + // --- Step 3: load from external store --- + let state = match store.load(session_id.as_ref()).await { + Ok(Some(s)) => s, + Ok(None) => { + finish(false); + return Ok(false); + } + Err(e) => { + tracing::error!( + session_id = session_id.as_ref(), + error = %e, + "session store load failed during restore" + ); + finish(false); + return Err(std::io::Error::other(e)); + } + }; + + // --- Step 4: ask the session manager to allocate an in-memory worker --- + let transport = match self + .session_manager + .restore_session(session_id.clone()) + .await + .map_err(|e| std::io::Error::other(e.to_string())) + { + Ok(RestoreOutcome::Restored(t)) => t, + Ok(RestoreOutcome::AlreadyPresent) => { + // Invariant violation: pending_restores ensures only one task can call + // restore_session per session ID, so AlreadyPresent is impossible here. + finish(false); + return Err(std::io::Error::other( + "restore_session returned AlreadyPresent unexpectedly; session manager might have modified the session store outside of the restore_session API", + )); + } + Ok(RestoreOutcome::NotSupported) => { + finish(false); + return Ok(false); + } + Err(e) => { + finish(false); + return Err(e); + } + }; + + // --- Step 5: replay the MCP initialize handshake --- + let service = match self.get_service() { + Ok(s) => s, + Err(e) => { + finish(false); + return Err(e); + } + }; + + // `serve_server` requires both the `initialize` request and the + // `notifications/initialized` notification before transitioning to + // the running state — we must send both before returning. + let mut restore_init = ClientJsonRpcMessage::request( + ClientRequest::InitializeRequest(InitializeRequest { + params: state.initialize_params, + ..Default::default() + }), + crate::model::NumberOrString::Number(0), + ); + restore_init.insert_extension(parts.clone()); + restore_init.insert_extension(SessionRestoreMarker { + id: session_id.clone(), + }); + let mut restore_initialized = ClientJsonRpcMessage::notification( + ClientNotification::InitializedNotification(InitializedNotification { + ..Default::default() + }), + ); + restore_initialized.insert_extension(parts.clone()); + restore_initialized.insert_extension(SessionRestoreMarker { + id: session_id.clone(), + }); + // Signal from the spawned task once serve_server finishes initialising. + let (init_done_tx, init_done_rx) = tokio::sync::oneshot::channel::<()>(); + + Self::spawn_session_worker( + self.session_manager.clone(), + session_id.clone(), + service, + transport, + Some(init_done_tx), + ); + + if let Err(e) = self + .session_manager + .initialize_session(session_id, restore_init) + .await + .map_err(|e| std::io::Error::other(e.to_string())) + { + finish(false); + return Err(e); + } + + if let Err(e) = self + .session_manager + .accept_message(session_id, restore_initialized) + .await + .map_err(|e| std::io::Error::other(e.to_string())) + { + finish(false); + return Err(e); + } + + if init_done_rx.await.is_err() { + finish(false); + return Err(std::io::Error::other( + "serve_server initialization failed during restore", + )); + } + + // Restore complete — wake any waiting concurrent requests. + finish(true); + + tracing::debug!( + session_id = session_id.as_ref(), + "session restored from external store" + ); + Ok(true) + } pub async fn handle(&self, request: Request) -> Response> where B: Body + Send + 'static, @@ -317,18 +603,26 @@ where .has_session(&session_id) .await .map_err(internal_error_response("check session"))?; + let (parts, _) = request.into_parts(); if !has_session { - // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions - return Ok(Response::builder() - .status(http::StatusCode::NOT_FOUND) - .body(Full::new(Bytes::from("Not Found: Session not found")).boxed()) - .expect("valid response")); + // Attempt transparent cross-instance restore from external store. + let restored = self + .try_restore_from_store(&session_id, &parts) + .await + .map_err(internal_error_response("restore session"))?; + if !restored { + // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions + return Ok(Response::builder() + .status(http::StatusCode::NOT_FOUND) + .body(Full::new(Bytes::from("Not Found: Session not found")).boxed()) + .expect("valid response")); + } } // Validate MCP-Protocol-Version header (per 2025-06-18 spec) - validate_protocol_version_header(request.headers())?; + validate_protocol_version_header(&parts.headers)?; // check if last event id is provided - let last_event_id = request - .headers() + let last_event_id = parts + .headers .get(HEADER_LAST_EVENT_ID) .and_then(|v| v.to_str().ok()) .map(|s| s.to_owned()); @@ -432,11 +726,18 @@ where .await .map_err(internal_error_response("check session"))?; if !has_session { - // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions - return Ok(Response::builder() - .status(http::StatusCode::NOT_FOUND) - .body(Full::new(Bytes::from("Not Found: Session not found")).boxed()) - .expect("valid response")); + // Attempt transparent cross-instance restore from external store. + let restored = self + .try_restore_from_store(&session_id, &part) + .await + .map_err(internal_error_response("restore session"))?; + if !restored { + // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions + return Ok(Response::builder() + .status(http::StatusCode::NOT_FOUND) + .body(Full::new(Bytes::from("Not Found: Session not found")).boxed()) + .expect("valid response")); + } } // Validate MCP-Protocol-Version header (per 2025-06-18 spec) @@ -498,6 +799,21 @@ where .create_session() .await .map_err(internal_error_response("create session"))?; + // Capture init params for external store persistence before + // extensions are injected (which would require Clone). + let stored_init_params = if self.config.session_store.is_some() { + if let ClientJsonRpcMessage::Request(req) = &message { + if let ClientRequest::InitializeRequest(init_req) = &req.request { + Some(init_req.params.clone()) + } else { + None + } + } else { + None + } + } else { + None + }; if let ClientJsonRpcMessage::Request(req) = &mut message { if !matches!(req.request, ClientRequest::InitializeRequest(_)) { return Err(unexpected_message_response("initialize request")); @@ -511,37 +827,36 @@ where .get_service() .map_err(internal_error_response("get service"))?; // spawn a task to serve the session - tokio::spawn({ - let session_manager = self.session_manager.clone(); - let session_id = session_id.clone(); - async move { - let service = serve_server::( - service, transport, - ) - .await; - match service { - Ok(service) => { - // on service created - let _ = service.waiting().await; - } - Err(e) => { - tracing::error!("Failed to create service: {e}"); - } - } - let _ = session_manager - .close_session(&session_id) - .await - .inspect_err(|e| { - tracing::error!("Failed to close session {session_id}: {e}"); - }); - } - }); + Self::spawn_session_worker( + self.session_manager.clone(), + session_id.clone(), + service, + transport, + None, + ); // get initialize response let response = self .session_manager .initialize_session(&session_id, message) .await .map_err(internal_error_response("create stream"))?; + // Persist session state to external store after a successful handshake. + if let (Some(store), Some(params)) = + (&self.config.session_store, stored_init_params) + { + let state = SessionState { + initialize_params: params, + }; + let _ = store + .store(session_id.as_ref(), &state) + .await + .inspect_err(|e| { + tracing::warn!( + "Failed to persist session {} to store: {e}", + session_id + ); + }); + } let stream = futures::stream::once(async move { ServerSseMessage { event_id: None, @@ -677,6 +992,13 @@ where .close_session(&session_id) .await .map_err(internal_error_response("close session"))?; + // Remove from external store: a DELETE means the client intentionally + // ends the session, so the store entry is no longer needed. + if let Some(store) = &self.config.session_store { + let _ = store.delete(session_id.as_ref()).await.inspect_err(|e| { + tracing::warn!("Failed to delete session {} from store: {e}", session_id); + }); + } Ok(accepted_response()) } } diff --git a/crates/rmcp/tests/test_streamable_http_session_store.rs b/crates/rmcp/tests/test_streamable_http_session_store.rs new file mode 100644 index 000000000..0fdcce0c0 --- /dev/null +++ b/crates/rmcp/tests/test_streamable_http_session_store.rs @@ -0,0 +1,407 @@ +#![cfg(all( + feature = "client", + feature = "server", + feature = "transport-streamable-http-client-reqwest", + feature = "transport-streamable-http-server", + not(feature = "local") +))] + +use std::{collections::HashMap, sync::Arc}; + +use rmcp::{ + ServiceExt, + transport::{ + StreamableHttpClientTransport, + streamable_http_client::StreamableHttpClientTransportConfig, + streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, + session::{SessionState, SessionStore, SessionStoreError, local::LocalSessionManager}, + }, + }, +}; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; + +mod common; +use common::calculator::Calculator; + +// --------------------------------------------------------------------------- +// Shared in-memory store used across tests +// --------------------------------------------------------------------------- + +#[derive(Default, Clone)] +struct InMemorySessionStore(Arc>>); + +impl InMemorySessionStore { + fn new() -> Self { + Self::default() + } + + async fn len(&self) -> usize { + self.0.read().await.len() + } +} + +#[async_trait::async_trait] +impl SessionStore for InMemorySessionStore { + async fn load(&self, session_id: &str) -> Result, SessionStoreError> { + Ok(self.0.read().await.get(session_id).cloned()) + } + + async fn store(&self, session_id: &str, state: &SessionState) -> Result<(), SessionStoreError> { + self.0 + .write() + .await + .insert(session_id.to_owned(), state.clone()); + Ok(()) + } + + async fn delete(&self, session_id: &str) -> Result<(), SessionStoreError> { + self.0.write().await.remove(session_id); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Helper: spin up a StreamableHttpService backed by the given store and +// return the bound address together with the cancellation token. +// --------------------------------------------------------------------------- + +fn make_service( + session_store: Arc, + ct: &CancellationToken, +) -> StreamableHttpService { + StreamableHttpService::new( + || Ok(Calculator::new()), + Default::default(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + cancellation_token: ct.child_token(), + session_store: Some(session_store), + ..Default::default() + }, + ) +} + +// --------------------------------------------------------------------------- +// Test 1 — state is persisted to the store after a successful handshake +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_session_state_persisted_to_store() -> anyhow::Result<()> { + let store = Arc::new(InMemorySessionStore::new()); + let ct = CancellationToken::new(); + let service = make_service(store.clone(), &ct); + + let router = axum::Router::new().nest_service("/mcp", service); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + // Connect a full client — this performs the initialize + initialized handshake. + let transport = StreamableHttpClientTransport::from_config( + StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), + ); + let client = ().serve(transport).await?; + + // Make a real request so the session is fully active. + let _resources = client.list_all_resources().await?; + + // The store should now contain exactly one session entry. + assert_eq!( + store.len().await, + 1, + "session state should be persisted to the store after initialization" + ); + + // Verify the stored state contains the expected client info. + let entries = store.0.read().await; + let state = entries.values().next().expect("store entry should exist"); + assert_eq!( + state.initialize_params.client_info.name, "rmcp", + "stored client_info.name should match the rmcp client" + ); + + let _ = client.cancel().await; + ct.cancel(); + handle.await?; + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Test 2 — store entry is removed when the client sends HTTP DELETE +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_session_state_deleted_from_store_on_delete() -> anyhow::Result<()> { + let store = Arc::new(InMemorySessionStore::new()); + let session_manager = Arc::new(LocalSessionManager::default()); + let ct = CancellationToken::new(); + + let service = StreamableHttpService::new( + || Ok(Calculator::new()), + session_manager.clone(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + cancellation_token: ct.child_token(), + session_store: Some(store.clone()), + ..Default::default() + }, + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + + let handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + let transport = StreamableHttpClientTransport::from_config( + StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), + ); + let client = ().serve(transport).await?; + let _resources = client.list_all_resources().await?; + + assert_eq!(store.len().await, 1, "store should have one entry"); + + // Get the session ID from the server's in-memory map. + let session_id = { + let sessions = session_manager.sessions.read().await; + sessions + .keys() + .next() + .cloned() + .expect("session should exist") + }; + + // Send an explicit HTTP DELETE — this is the signal to remove from store. + let http_client = reqwest::Client::new(); + let response = http_client + .delete(format!("http://{addr}/mcp")) + .header("mcp-session-id", session_id.as_ref()) + .send() + .await?; + assert_eq!(response.status(), 202); + + assert_eq!( + store.len().await, + 0, + "store entry should be removed after explicit DELETE" + ); + + let _ = client.cancel().await; + ct.cancel(); + handle.await?; + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Helper: spin up a server on an ephemeral port and return its address and +// the join handle. The server shuts down when `ct` is cancelled. +// --------------------------------------------------------------------------- + +fn spawn_server( + session_store: Option>, + session_manager: Arc, + ct: &CancellationToken, +) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) { + let svc = StreamableHttpService::new( + || Ok(Calculator::new()), + session_manager, + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + cancellation_token: ct.child_token(), + session_store, + ..Default::default() + }, + ); + // Use std::net::TcpListener so the port is bound synchronously before + // we return — avoids a race between returning the addr and the server + // actually starting to accept connections. + let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + std_listener.set_nonblocking(true).unwrap(); + let addr = std_listener.local_addr().unwrap(); + let listener = tokio::net::TcpListener::from_std(std_listener).unwrap(); + let router = axum::Router::new().nest_service("/mcp", svc); + let handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + (addr, handle) +} + +// --------------------------------------------------------------------------- +// Test 3 — cross-instance session restore +// +// Both halves follow the same structure: +// +// Instance A initializes the session (session state may be saved to store) +// Instance A is fully shut down +// Instance B (fresh, no in-memory state) receives a request for the old ID +// +// Without a store → 404. With a shared store → transparent restore. +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_cross_instance_session_restore() -> anyhow::Result<()> { + let http = reqwest::Client::new(); + + // ----------------------------------------------------------------------- + // Negative check: no session store → instance B returns 404. + // ----------------------------------------------------------------------- + { + // --- Instance A (no store): initialize --- + let ct_a = CancellationToken::new(); + let (addr_a, srv_a) = spawn_server(None, Arc::new(LocalSessionManager::default()), &ct_a); + + let init_resp = http + .post(format!("http://{addr_a}/mcp")) + .header("accept", "application/json, text/event-stream") + .header("content-type", "application/json") + .body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}"#) + .send() + .await?; + assert_eq!( + init_resp.status(), + 200, + "instance A: initialize should succeed" + ); + let session_id = init_resp + .headers() + .get("mcp-session-id") + .expect("session ID header must be present") + .to_str()? + .to_owned(); + + // Shut down instance A completely. + ct_a.cancel(); + srv_a.await?; + + // --- Instance B (no store, fresh state): send request --- + let ct_b = CancellationToken::new(); + let (addr_b, srv_b) = spawn_server(None, Arc::new(LocalSessionManager::default()), &ct_b); + + let resp = http + .post(format!("http://{addr_b}/mcp")) + .header("accept", "application/json, text/event-stream") + .header("content-type", "application/json") + .header("mcp-session-id", &session_id) + .body(r#"{"jsonrpc":"2.0","id":2,"method":"ping","params":{}}"#) + .send() + .await?; + assert_eq!( + resp.status(), + reqwest::StatusCode::NOT_FOUND, + "without a session store, instance B must return 404 for an unknown session ID" + ); + + ct_b.cancel(); + srv_b.await?; + } + + // ----------------------------------------------------------------------- + // Positive check: shared session store → instance B restores transparently. + // ----------------------------------------------------------------------- + { + let store: Arc = Arc::new(InMemorySessionStore::new()); + + // --- Instance A (with store): initialize --- + let ct_a = CancellationToken::new(); + let sm_a = Arc::new(LocalSessionManager::default()); + let (addr_a, srv_a) = spawn_server(Some(store.clone()), sm_a.clone(), &ct_a); + + let init_resp = http + .post(format!("http://{addr_a}/mcp")) + .header("accept", "application/json, text/event-stream") + .header("content-type", "application/json") + .body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}"#) + .send() + .await?; + assert_eq!( + init_resp.status(), + 200, + "instance A: initialize should succeed" + ); + let original_session_id = init_resp + .headers() + .get("mcp-session-id") + .expect("session ID header must be present") + .to_str()? + .to_owned(); + + // Confirm the session was persisted. + let store_ref = store + .load(&original_session_id) + .await + .expect("store load should not error"); + assert!( + store_ref.is_some(), + "store should hold the session after initialization" + ); + + // Shut down instance A completely — session lives only in the store now. + ct_a.cancel(); + srv_a.await?; + + // --- Instance B (same store, fresh in-memory state): send request --- + let ct_b = CancellationToken::new(); + let sm_b = Arc::new(LocalSessionManager::default()); + let (addr_b, srv_b) = spawn_server(Some(store.clone()), sm_b.clone(), &ct_b); + + let resp = http + .post(format!("http://{addr_b}/mcp")) + .header("accept", "application/json, text/event-stream") + .header("content-type", "application/json") + .header("mcp-session-id", &original_session_id) + .body(r#"{"jsonrpc":"2.0","id":2,"method":"ping","params":{}}"#) + .send() + .await?; + assert_eq!( + resp.status(), + 200, + "instance B: request must succeed after transparent restore" + ); + + // The session must be in instance B's memory under the ORIGINAL ID. + { + let sessions = sm_b.sessions.read().await; + let restored_id = sessions + .keys() + .next() + .expect("session should exist in instance B after restore"); + assert_eq!( + restored_id.as_ref(), + original_session_id.as_str(), + "restored session must keep the original session ID" + ); + } + + ct_b.cancel(); + srv_b.await?; + } + + Ok(()) +} From b241cd3078e3f8a6a49d9b93dc676d957276e371 Mon Sep 17 00:00:00 2001 From: Guy Lichtman <1395797+glicht@users.noreply.github.com> Date: Wed, 25 Mar 2026 18:07:05 +0200 Subject: [PATCH 2/2] fix: docs --- .../rmcp/src/transport/streamable_http_server/session/store.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rmcp/src/transport/streamable_http_server/session/store.rs b/crates/rmcp/src/transport/streamable_http_server/session/store.rs index d89a9c162..382bdfded 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/store.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/store.rs @@ -18,7 +18,7 @@ pub type SessionStoreError = Box; /// /// Implement this trait to back sessions with Redis, a database, or any /// key-value store. The simplest usage is to set -/// [`StreamableHttpServerConfig::session_store`] to an `Arc`. +/// `StreamableHttpServerConfig::session_store` to an `Arc`. /// /// # Example (in-memory, for testing) ///