Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ name = "test_progress_subscriber"
required-features = ["server", "client", "macros"]
path = "tests/test_progress_subscriber.rs"

[[test]]
name = "test_request_timeout_progress"
required-features = ["server", "client", "macros"]
path = "tests/test_request_timeout_progress.rs"

[[test]]
name = "test_elicitation"
required-features = ["elicitation", "client", "server"]
Expand Down
246 changes: 223 additions & 23 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ pub(crate) type MaybeBoxFuture<'a, T> = BoxFuture<'a, T>;
#[cfg(feature = "local")]
pub(crate) type MaybeBoxFuture<'a, T> = LocalBoxFuture<'a, T>;

#[cfg(feature = "server")]
use crate::model::ClientNotification;
#[cfg(feature = "server")]
use crate::model::ServerJsonRpcMessage;
#[cfg(feature = "client")]
use crate::model::ServerNotification;
use crate::{
error::ErrorData as McpError,
model::{
Expand Down Expand Up @@ -299,7 +303,37 @@ impl ProgressTokenProvider for AtomicU32Provider {
}
}

#[doc(hidden)]
pub trait ProgressNotificationToken {
fn progress_token(&self) -> Option<&ProgressToken>;
}

#[cfg(feature = "server")]
impl ProgressNotificationToken for ClientNotification {
fn progress_token(&self) -> Option<&ProgressToken> {
match self {
ClientNotification::ProgressNotification(notification) => {
Some(&notification.params.progress_token)
}
_ => None,
}
}
}

#[cfg(feature = "client")]
impl ProgressNotificationToken for ServerNotification {
fn progress_token(&self) -> Option<&ProgressToken> {
match self {
ServerNotification::ProgressNotification(notification) => {
Some(&notification.params.progress_token)
}
_ => None,
}
}
}

type Responder<T> = tokio::sync::oneshot::Sender<T>;
type ProgressTimeoutWatchers = Arc<tokio::sync::RwLock<HashMap<ProgressToken, mpsc::Sender<()>>>>;

/// A handle to a remote request
///
Expand All @@ -314,40 +348,142 @@ pub struct RequestHandle<R: ServiceRole> {
pub peer: Peer<R>,
pub id: RequestId,
pub progress_token: ProgressToken,
progress_timeout_watchers: ProgressTimeoutWatchers,
progress_reset_rx: Option<mpsc::Receiver<()>>,
}

impl<R: ServiceRole> RequestHandle<R> {
pub const REQUEST_TIMEOUT_REASON: &str = "request timeout";
pub async fn await_response(self) -> Result<R::PeerResp, ServiceError> {
if let Some(timeout) = self.options.timeout {
let timeout_result = tokio::time::timeout(timeout, async move {
self.rx.await.map_err(|_e| ServiceError::TransportClosed)?
})
.await;
match timeout_result {
Ok(response) => response,
Err(_) => {
let error = Err(ServiceError::Timeout { timeout });
// cancel this request
let notification = CancelledNotification {
params: CancelledNotificationParam {
request_id: self.id,
reason: Some(Self::REQUEST_TIMEOUT_REASON.to_owned()),
},
method: crate::model::CancelledNotificationMethod,
extensions: Default::default(),
};
let _ = self.peer.send_notification(notification.into()).await;
error
pub const REQUEST_MAX_TOTAL_TIMEOUT_REASON: &str = "maximum total timeout exceeded";

async fn send_timeout_cancel_notification(&self, reason: &str) {
let notification = CancelledNotification {
params: CancelledNotificationParam {
request_id: self.id.clone(),
reason: Some(reason.to_owned()),
},
method: crate::model::CancelledNotificationMethod,
extensions: Default::default(),
};
let _ = self.peer.send_notification(notification.into()).await;
}

async fn cleanup_progress_timeout_watcher(
progress_timeout_watchers: &ProgressTimeoutWatchers,
progress_token: &ProgressToken,
has_progress_reset_rx: bool,
) {
if has_progress_reset_rx {
progress_timeout_watchers
.write()
.await
.remove(progress_token);
}
}

pub async fn await_response(mut self) -> Result<R::PeerResp, ServiceError> {
let timeout = self.options.timeout;
let max_total_timeout = self.options.max_total_timeout;
let reset_timeout_on_progress = self.options.reset_timeout_on_progress;

let has_progress_reset_rx = self.progress_reset_rx.is_some();
let progress_timeout_watchers = self.progress_timeout_watchers.clone();
let progress_token = self.progress_token.clone();

let result =
if timeout.is_some() && !reset_timeout_on_progress && max_total_timeout.is_none() {
let timeout = timeout.expect("timeout is checked above");
let timeout_result = tokio::time::timeout(timeout, &mut self.rx).await;
match timeout_result {
Ok(response) => response.map_err(|_e| ServiceError::TransportClosed)?,
Err(_) => {
let error = Err(ServiceError::Timeout { timeout });
// cancel this request
self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON)
.await;
error
}
}
} else if timeout.is_none() && max_total_timeout.is_none() {
(&mut self.rx)
.await
.map_err(|_e| ServiceError::TransportClosed)?
} else {
self.await_response_with_progress_timeout(
timeout,
max_total_timeout,
reset_timeout_on_progress,
)
.await
};

Self::cleanup_progress_timeout_watcher(
&progress_timeout_watchers,
&progress_token,
has_progress_reset_rx,
)
.await;
result
}

async fn await_response_with_progress_timeout(
&mut self,
timeout: Option<Duration>,
max_total_timeout: Option<Duration>,
reset_timeout_on_progress: bool,
) -> Result<R::PeerResp, ServiceError> {
let mut idle_sleep = timeout.map(tokio::time::sleep).map(Box::pin);
let mut max_total_sleep = max_total_timeout.map(tokio::time::sleep).map(Box::pin);

loop {
tokio::select! {
biased;

response = &mut self.rx => {
return response.map_err(|_e| ServiceError::TransportClosed)?;
}
_ = async {
if let Some(sleep) = idle_sleep.as_mut() {
sleep.as_mut().await;
}
}, if idle_sleep.is_some() => {
let timeout = timeout.expect("idle timeout exists when idle sleep exists");
self.send_timeout_cancel_notification(Self::REQUEST_TIMEOUT_REASON).await;
return Err(ServiceError::Timeout { timeout });
}
_ = async {
if let Some(sleep) = max_total_sleep.as_mut() {
sleep.as_mut().await;
}
}, if max_total_sleep.is_some() => {
let timeout = max_total_timeout.expect("max total timeout exists when max total sleep exists");
self.send_timeout_cancel_notification(Self::REQUEST_MAX_TOTAL_TIMEOUT_REASON).await;
return Err(ServiceError::Timeout { timeout });
}
progress = async {
match self.progress_reset_rx.as_mut() {
Some(rx) => rx.recv().await,
None => None,
}
}, if reset_timeout_on_progress && timeout.is_some() && self.progress_reset_rx.is_some() => {
if progress.is_some() {
if let (Some(timeout), Some(sleep)) = (timeout, idle_sleep.as_mut()) {
sleep.as_mut().reset(tokio::time::Instant::now() + timeout);
}
}
}
}
} else {
self.rx.await.map_err(|_e| ServiceError::TransportClosed)?
}
}

/// Cancel this request
pub async fn cancel(self, reason: Option<String>) -> Result<(), ServiceError> {
Self::cleanup_progress_timeout_watcher(
&self.progress_timeout_watchers,
&self.progress_token,
self.progress_reset_rx.is_some(),
)
.await;
let notification = CancelledNotification {
params: CancelledNotificationParam {
request_id: self.id,
Expand Down Expand Up @@ -384,6 +520,7 @@ pub struct Peer<R: ServiceRole> {
tx: mpsc::Sender<PeerSinkMessage<R>>,
request_id_provider: Arc<dyn RequestIdProvider>,
progress_token_provider: Arc<dyn ProgressTokenProvider>,
progress_timeout_watchers: ProgressTimeoutWatchers,
info: Arc<tokio::sync::OnceCell<R::PeerInfo>>,
}

Expand All @@ -403,12 +540,33 @@ type ProxyOutbound<R> = mpsc::Receiver<PeerSinkMessage<R>>;
pub struct PeerRequestOptions {
pub timeout: Option<Duration>,
pub meta: Option<Meta>,
/// Reset the request timeout when a matching progress notification is received.
pub reset_timeout_on_progress: bool,
/// Maximum total time to wait for the request, regardless of progress notifications.
pub max_total_timeout: Option<Duration>,
}

impl PeerRequestOptions {
pub fn no_options() -> Self {
Self::default()
}

pub fn with_timeout(timeout: Duration) -> Self {
Self {
timeout: Some(timeout),
..Self::default()
}
}

pub fn reset_timeout_on_progress(mut self) -> Self {
self.reset_timeout_on_progress = true;
self
}

pub fn with_max_total_timeout(mut self, timeout: Duration) -> Self {
self.max_total_timeout = Some(timeout);
self
}
}

impl<R: ServiceRole> Peer<R> {
Expand All @@ -423,6 +581,7 @@ impl<R: ServiceRole> Peer<R> {
tx,
request_id_provider,
progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
progress_timeout_watchers: Default::default(),
info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)),
},
rx,
Expand Down Expand Up @@ -468,6 +627,16 @@ impl<R: ServiceRole> Peer<R> {
request.get_meta_mut().extend(meta);
}
let (responder, receiver) = tokio::sync::oneshot::channel();
let progress_reset_rx = if options.reset_timeout_on_progress && options.timeout.is_some() {
let (sender, receiver) = mpsc::channel(1);
self.progress_timeout_watchers
.write()
.await
.insert(progress_token.clone(), sender);
Some(receiver)
} else {
None
};
self.tx
.send(PeerSinkMessage::Request {
request,
Expand All @@ -482,8 +651,33 @@ impl<R: ServiceRole> Peer<R> {
progress_token,
options,
peer: self.clone(),
progress_timeout_watchers: self.progress_timeout_watchers.clone(),
progress_reset_rx,
})
}

async fn notify_progress_timeout_watcher(&self, progress_token: &ProgressToken) {
let sender = self
.progress_timeout_watchers
.read()
.await
.get(progress_token)
.cloned();
if let Some(sender) = sender {
match sender.try_send(()) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
tracing::trace!(?progress_token, "progress timeout watcher channel is full");
}
Err(mpsc::error::TrySendError::Closed(_)) => {
self.progress_timeout_watchers
.write()
.await
.remove(progress_token);
}
}
}
}
pub fn peer_info(&self) -> Option<&R::PeerInfo> {
self.info.get()
}
Expand Down Expand Up @@ -692,6 +886,7 @@ pub fn serve_directly<R, S, T, E, A>(
) -> RunningService<R, S>
where
R: ServiceRole,
R::PeerNot: ProgressNotificationToken,
S: Service<R>,
T: IntoTransport<R, E, A>,
E: std::error::Error + Send + Sync + 'static,
Expand All @@ -708,6 +903,7 @@ pub fn serve_directly_with_ct<R, S, T, E, A>(
) -> RunningService<R, S>
where
R: ServiceRole,
R::PeerNot: ProgressNotificationToken,
S: Service<R>,
T: IntoTransport<R, E, A>,
E: std::error::Error + Send + Sync + 'static,
Expand Down Expand Up @@ -748,6 +944,7 @@ fn serve_inner<R, S, T>(
) -> RunningService<R, S>
where
R: ServiceRole,
R::PeerNot: ProgressNotificationToken,
S: Service<R>,
T: Transport<R> + 'static,
{
Expand Down Expand Up @@ -994,6 +1191,9 @@ where
}
Err(notification) => notification,
};
if let Some(progress_token) = notification.progress_token() {
peer.notify_progress_timeout_watcher(progress_token).await;
}
{
let service = shared_service.clone();
let mut extensions = Extensions::new();
Expand Down
4 changes: 4 additions & 0 deletions crates/rmcp/src/service/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ macro_rules! method {
let options = crate::service::PeerRequestOptions {
timeout,
meta: None,
reset_timeout_on_progress: false,
max_total_timeout: None,
};
let result = self
.send_request_with_option(request, options)
Expand Down Expand Up @@ -383,6 +385,8 @@ macro_rules! method {
let options = crate::service::PeerRequestOptions {
timeout,
meta: None,
reset_timeout_on_progress: false,
max_total_timeout: None,
};
let result = self
.send_request_with_option(request, options)
Expand Down
Loading