From 01a01e55d98cf9d853f81fb86cb7090ccb457387 Mon Sep 17 00:00:00 2001 From: ContextVM Date: Thu, 21 May 2026 12:30:11 +0200 Subject: [PATCH] feat: add progress-aware request timeouts --- crates/rmcp/Cargo.toml | 5 + crates/rmcp/src/service.rs | 246 ++++++++++++++++-- crates/rmcp/src/service/server.rs | 4 + .../tests/test_request_timeout_progress.rs | 211 +++++++++++++++ 4 files changed, 443 insertions(+), 23 deletions(-) create mode 100644 crates/rmcp/tests/test_request_timeout_progress.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 8a5bd63a..63867981 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -252,6 +252,11 @@ name = "test_progress_subscriber" required-features = ["server", "client", "macros"] path = "tests/test_progress_subscriber.rs" +[[test]] +name = "test_request_timeout_progress" +required-features = ["server", "client", "macros"] +path = "tests/test_request_timeout_progress.rs" + [[test]] name = "test_elicitation" required-features = ["elicitation", "client", "server"] diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index d938cd66..e3e132d1 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -42,8 +42,12 @@ pub(crate) type MaybeBoxFuture<'a, T> = BoxFuture<'a, T>; #[cfg(feature = "local")] pub(crate) type MaybeBoxFuture<'a, T> = LocalBoxFuture<'a, T>; +#[cfg(feature = "server")] +use crate::model::ClientNotification; #[cfg(feature = "server")] use crate::model::ServerJsonRpcMessage; +#[cfg(feature = "client")] +use crate::model::ServerNotification; use crate::{ error::ErrorData as McpError, model::{ @@ -299,7 +303,37 @@ impl ProgressTokenProvider for AtomicU32Provider { } } +#[doc(hidden)] +pub trait ProgressNotificationToken { + fn progress_token(&self) -> Option<&ProgressToken>; +} + +#[cfg(feature = "server")] +impl ProgressNotificationToken for ClientNotification { + fn progress_token(&self) -> Option<&ProgressToken> { + match self { + ClientNotification::ProgressNotification(notification) => { + Some(¬ification.params.progress_token) + } + _ => None, + } + } +} + +#[cfg(feature = "client")] +impl ProgressNotificationToken for ServerNotification { + fn progress_token(&self) -> Option<&ProgressToken> { + match self { + ServerNotification::ProgressNotification(notification) => { + Some(¬ification.params.progress_token) + } + _ => None, + } + } +} + type Responder = tokio::sync::oneshot::Sender; +type ProgressTimeoutWatchers = Arc>>>; /// A handle to a remote request /// @@ -314,40 +348,142 @@ pub struct RequestHandle { pub peer: Peer, pub id: RequestId, pub progress_token: ProgressToken, + progress_timeout_watchers: ProgressTimeoutWatchers, + progress_reset_rx: Option>, } impl RequestHandle { pub const REQUEST_TIMEOUT_REASON: &str = "request timeout"; - pub async fn await_response(self) -> Result { - if let Some(timeout) = self.options.timeout { - let timeout_result = tokio::time::timeout(timeout, async move { - self.rx.await.map_err(|_e| ServiceError::TransportClosed)? - }) - .await; - match timeout_result { - Ok(response) => response, - Err(_) => { - let error = Err(ServiceError::Timeout { timeout }); - // cancel this request - let notification = CancelledNotification { - params: CancelledNotificationParam { - request_id: self.id, - reason: Some(Self::REQUEST_TIMEOUT_REASON.to_owned()), - }, - method: crate::model::CancelledNotificationMethod, - extensions: Default::default(), - }; - let _ = self.peer.send_notification(notification.into()).await; - error + pub const REQUEST_MAX_TOTAL_TIMEOUT_REASON: &str = "maximum total timeout exceeded"; + + async fn send_timeout_cancel_notification(&self, reason: &str) { + let notification = CancelledNotification { + params: CancelledNotificationParam { + request_id: self.id.clone(), + reason: Some(reason.to_owned()), + }, + method: crate::model::CancelledNotificationMethod, + extensions: Default::default(), + }; + let _ = self.peer.send_notification(notification.into()).await; + } + + async fn cleanup_progress_timeout_watcher( + progress_timeout_watchers: &ProgressTimeoutWatchers, + progress_token: &ProgressToken, + has_progress_reset_rx: bool, + ) { + if has_progress_reset_rx { + progress_timeout_watchers + .write() + .await + .remove(progress_token); + } + } + + pub async fn await_response(mut self) -> Result { + let timeout = self.options.timeout; + let max_total_timeout = self.options.max_total_timeout; + let reset_timeout_on_progress = self.options.reset_timeout_on_progress; + + let has_progress_reset_rx = self.progress_reset_rx.is_some(); + let progress_timeout_watchers = self.progress_timeout_watchers.clone(); + let progress_token = self.progress_token.clone(); + + let result = + if timeout.is_some() && !reset_timeout_on_progress && max_total_timeout.is_none() { + let timeout = timeout.expect("timeout is checked above"); + let timeout_result = tokio::time::timeout(timeout, &mut self.rx).await; + match timeout_result { + Ok(response) => response.map_err(|_e| ServiceError::TransportClosed)?, + Err(_) => { + let error = Err(ServiceError::Timeout { timeout }); + // cancel this request + self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON) + .await; + error + } + } + } else if timeout.is_none() && max_total_timeout.is_none() { + (&mut self.rx) + .await + .map_err(|_e| ServiceError::TransportClosed)? + } else { + self.await_response_with_progress_timeout( + timeout, + max_total_timeout, + reset_timeout_on_progress, + ) + .await + }; + + Self::cleanup_progress_timeout_watcher( + &progress_timeout_watchers, + &progress_token, + has_progress_reset_rx, + ) + .await; + result + } + + async fn await_response_with_progress_timeout( + &mut self, + timeout: Option, + max_total_timeout: Option, + reset_timeout_on_progress: bool, + ) -> Result { + let mut idle_sleep = timeout.map(tokio::time::sleep).map(Box::pin); + let mut max_total_sleep = max_total_timeout.map(tokio::time::sleep).map(Box::pin); + + loop { + tokio::select! { + biased; + + response = &mut self.rx => { + return response.map_err(|_e| ServiceError::TransportClosed)?; + } + _ = async { + if let Some(sleep) = idle_sleep.as_mut() { + sleep.as_mut().await; + } + }, if idle_sleep.is_some() => { + let timeout = timeout.expect("idle timeout exists when idle sleep exists"); + self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON).await; + return Err(ServiceError::Timeout { timeout }); + } + _ = async { + if let Some(sleep) = max_total_sleep.as_mut() { + sleep.as_mut().await; + } + }, if max_total_sleep.is_some() => { + let timeout = max_total_timeout.expect("max total timeout exists when max total sleep exists"); + self.send_timeout_cancel_notification(Self::REQUEST_MAX_TOTAL_TIMEOUT_REASON).await; + return Err(ServiceError::Timeout { timeout }); + } + progress = async { + match self.progress_reset_rx.as_mut() { + Some(rx) => rx.recv().await, + None => None, + } + }, if reset_timeout_on_progress && timeout.is_some() && self.progress_reset_rx.is_some() => { + if progress.is_some() { + if let (Some(timeout), Some(sleep)) = (timeout, idle_sleep.as_mut()) { + sleep.as_mut().reset(tokio::time::Instant::now() + timeout); + } + } } } - } else { - self.rx.await.map_err(|_e| ServiceError::TransportClosed)? } } /// Cancel this request pub async fn cancel(self, reason: Option) -> Result<(), ServiceError> { + Self::cleanup_progress_timeout_watcher( + &self.progress_timeout_watchers, + &self.progress_token, + self.progress_reset_rx.is_some(), + ) + .await; let notification = CancelledNotification { params: CancelledNotificationParam { request_id: self.id, @@ -384,6 +520,7 @@ pub struct Peer { tx: mpsc::Sender>, request_id_provider: Arc, progress_token_provider: Arc, + progress_timeout_watchers: ProgressTimeoutWatchers, info: Arc>, } @@ -403,12 +540,33 @@ type ProxyOutbound = mpsc::Receiver>; pub struct PeerRequestOptions { pub timeout: Option, pub meta: Option, + /// Reset the request timeout when a matching progress notification is received. + pub reset_timeout_on_progress: bool, + /// Maximum total time to wait for the request, regardless of progress notifications. + pub max_total_timeout: Option, } impl PeerRequestOptions { pub fn no_options() -> Self { Self::default() } + + pub fn with_timeout(timeout: Duration) -> Self { + Self { + timeout: Some(timeout), + ..Self::default() + } + } + + pub fn reset_timeout_on_progress(mut self) -> Self { + self.reset_timeout_on_progress = true; + self + } + + pub fn with_max_total_timeout(mut self, timeout: Duration) -> Self { + self.max_total_timeout = Some(timeout); + self + } } impl Peer { @@ -423,6 +581,7 @@ impl Peer { tx, request_id_provider, progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()), + progress_timeout_watchers: Default::default(), info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)), }, rx, @@ -468,6 +627,16 @@ impl Peer { request.get_meta_mut().extend(meta); } let (responder, receiver) = tokio::sync::oneshot::channel(); + let progress_reset_rx = if options.reset_timeout_on_progress && options.timeout.is_some() { + let (sender, receiver) = mpsc::channel(1); + self.progress_timeout_watchers + .write() + .await + .insert(progress_token.clone(), sender); + Some(receiver) + } else { + None + }; self.tx .send(PeerSinkMessage::Request { request, @@ -482,8 +651,33 @@ impl Peer { progress_token, options, peer: self.clone(), + progress_timeout_watchers: self.progress_timeout_watchers.clone(), + progress_reset_rx, }) } + + async fn notify_progress_timeout_watcher(&self, progress_token: &ProgressToken) { + let sender = self + .progress_timeout_watchers + .read() + .await + .get(progress_token) + .cloned(); + if let Some(sender) = sender { + match sender.try_send(()) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + tracing::trace!(?progress_token, "progress timeout watcher channel is full"); + } + Err(mpsc::error::TrySendError::Closed(_)) => { + self.progress_timeout_watchers + .write() + .await + .remove(progress_token); + } + } + } + } pub fn peer_info(&self) -> Option<&R::PeerInfo> { self.info.get() } @@ -692,6 +886,7 @@ pub fn serve_directly( ) -> RunningService where R: ServiceRole, + R::PeerNot: ProgressNotificationToken, S: Service, T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -708,6 +903,7 @@ pub fn serve_directly_with_ct( ) -> RunningService where R: ServiceRole, + R::PeerNot: ProgressNotificationToken, S: Service, T: IntoTransport, E: std::error::Error + Send + Sync + 'static, @@ -748,6 +944,7 @@ fn serve_inner( ) -> RunningService where R: ServiceRole, + R::PeerNot: ProgressNotificationToken, S: Service, T: Transport + 'static, { @@ -994,6 +1191,9 @@ where } Err(notification) => notification, }; + if let Some(progress_token) = notification.progress_token() { + peer.notify_progress_timeout_watcher(progress_token).await; + } { let service = shared_service.clone(); let mut extensions = Extensions::new(); diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index c185696e..2676391a 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -356,6 +356,8 @@ macro_rules! method { let options = crate::service::PeerRequestOptions { timeout, meta: None, + reset_timeout_on_progress: false, + max_total_timeout: None, }; let result = self .send_request_with_option(request, options) @@ -383,6 +385,8 @@ macro_rules! method { let options = crate::service::PeerRequestOptions { timeout, meta: None, + reset_timeout_on_progress: false, + max_total_timeout: None, }; let result = self .send_request_with_option(request, options) diff --git a/crates/rmcp/tests/test_request_timeout_progress.rs b/crates/rmcp/tests/test_request_timeout_progress.rs new file mode 100644 index 00000000..95eea34c --- /dev/null +++ b/crates/rmcp/tests/test_request_timeout_progress.rs @@ -0,0 +1,211 @@ +#![cfg(not(feature = "local"))] + +use std::{ + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, +}; + +use rmcp::{ + ClientHandler, Peer, RoleServer, ServerHandler, ServiceError, ServiceExt, + handler::server::tool::ToolRouter, + model::{CallToolRequestParams, ClientRequest, Meta, ProgressNotificationParam, Request}, + service::PeerRequestOptions, + tool, tool_handler, tool_router, +}; + +#[derive(Clone, Default)] +struct ProgressCountingClient { + progress_count: Arc, +} + +impl ClientHandler for ProgressCountingClient { + async fn on_progress( + &self, + _params: ProgressNotificationParam, + _context: rmcp::service::NotificationContext, + ) { + self.progress_count.fetch_add(1, Ordering::SeqCst); + } +} + +struct ProgressTimeoutServer { + tool_router: ToolRouter, +} + +impl ProgressTimeoutServer { + fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl ProgressTimeoutServer { + #[tool] + async fn delayed_without_progress(&self) -> Result<(), rmcp::ErrorData> { + tokio::time::sleep(Duration::from_millis(250)).await; + Ok(()) + } + + #[tool] + async fn delayed_with_progress( + &self, + meta: Meta, + client: Peer, + ) -> Result<(), rmcp::ErrorData> { + let progress_token = meta + .get_progress_token() + .ok_or(rmcp::ErrorData::invalid_params( + "Progress token is required", + None, + ))?; + + for step in 0..4 { + tokio::time::sleep(Duration::from_millis(50)).await; + let _ = client + .notify_progress(ProgressNotificationParam { + progress_token: progress_token.clone(), + progress: step as f64, + total: Some(4.0), + message: Some("working".into()), + }) + .await; + } + + Ok(()) + } + + #[tool] + async fn delayed_with_unrelated_progress( + &self, + client: Peer, + ) -> Result<(), rmcp::ErrorData> { + for step in 0..4 { + tokio::time::sleep(Duration::from_millis(50)).await; + let _ = client + .notify_progress(ProgressNotificationParam { + progress_token: rmcp::model::ProgressToken( + rmcp::model::NumberOrString::Number(999_999), + ), + progress: step as f64, + total: Some(4.0), + message: Some("unrelated".into()), + }) + .await; + } + + Ok(()) + } +} + +#[tool_handler] +impl ServerHandler for ProgressTimeoutServer {} + +async fn start_pair() +-> anyhow::Result> { + let server = ProgressTimeoutServer::new(); + let client = ProgressCountingClient::default(); + let (transport_server, transport_client) = tokio::io::duplex(4096); + + tokio::spawn(async move { + let service = server.serve(transport_server).await?; + service.waiting().await?; + anyhow::Ok(()) + }); + + Ok(client.serve(transport_client).await?) +} + +async fn call_tool_with_options( + client: &rmcp::service::RunningService, + name: &str, + options: PeerRequestOptions, +) -> Result { + client + .send_request_with_option( + ClientRequest::CallToolRequest(Request::new(CallToolRequestParams::new( + name.to_owned(), + ))), + options, + ) + .await? + .await_response() + .await +} + +#[tokio::test] +async fn request_timeout_still_expires_without_progress() -> anyhow::Result<()> { + let client = start_pair().await?; + let result = call_tool_with_options( + &client, + "delayed_without_progress", + PeerRequestOptions::with_timeout(Duration::from_millis(75)), + ) + .await; + + assert!(matches!(result, Err(ServiceError::Timeout { .. }))); + Ok(()) +} + +#[tokio::test] +async fn progress_does_not_reset_timeout_by_default() -> anyhow::Result<()> { + let client = start_pair().await?; + let result = call_tool_with_options( + &client, + "delayed_with_progress", + PeerRequestOptions::with_timeout(Duration::from_millis(75)), + ) + .await; + + assert!(matches!(result, Err(ServiceError::Timeout { .. }))); + Ok(()) +} + +#[tokio::test] +async fn matching_progress_resets_timeout_when_enabled() -> anyhow::Result<()> { + let client = start_pair().await?; + let result = call_tool_with_options( + &client, + "delayed_with_progress", + PeerRequestOptions::with_timeout(Duration::from_millis(75)).reset_timeout_on_progress(), + ) + .await; + + assert!(result.is_ok()); + assert!(client.service().progress_count.load(Ordering::SeqCst) > 0); + Ok(()) +} + +#[tokio::test] +async fn max_total_timeout_wins_over_progress_reset() -> anyhow::Result<()> { + let client = start_pair().await?; + let result = call_tool_with_options( + &client, + "delayed_with_progress", + PeerRequestOptions::with_timeout(Duration::from_millis(75)) + .reset_timeout_on_progress() + .with_max_total_timeout(Duration::from_millis(125)), + ) + .await; + + assert!(matches!(result, Err(ServiceError::Timeout { .. }))); + Ok(()) +} + +#[tokio::test] +async fn unrelated_progress_does_not_reset_timeout() -> anyhow::Result<()> { + let client = start_pair().await?; + let result = call_tool_with_options( + &client, + "delayed_with_unrelated_progress", + PeerRequestOptions::with_timeout(Duration::from_millis(75)).reset_timeout_on_progress(), + ) + .await; + + assert!(matches!(result, Err(ServiceError::Timeout { .. }))); + Ok(()) +}