From f22f011d77559ff73c7c2e8dbed19bbe2cfa3cc7 Mon Sep 17 00:00:00 2001 From: anaslimem Date: Thu, 2 Apr 2026 15:16:07 +0100 Subject: [PATCH] Fix pending request cleanup and improve IO fairness --- src/agent-client-protocol/src/rpc.rs | 71 ++++++++++++++++++---- src/agent-client-protocol/src/rpc_tests.rs | 51 ++++++++++++++++ 2 files changed, 110 insertions(+), 12 deletions(-) diff --git a/src/agent-client-protocol/src/rpc.rs b/src/agent-client-protocol/src/rpc.rs index 42acdcc..34fdc7f 100644 --- a/src/agent-client-protocol/src/rpc.rs +++ b/src/agent-client-protocol/src/rpc.rs @@ -2,11 +2,15 @@ use std::{ any::Any, borrow::Cow, collections::HashMap, + future::Future, + marker::PhantomData, + pin::Pin, rc::Rc, sync::{ Arc, Mutex, atomic::{AtomicI64, Ordering}, }, + task::{Context, Poll}, }; use agent_client_protocol_schema::{ @@ -22,7 +26,6 @@ use futures::{ }, future::LocalBoxFuture, io::BufReader, - select_biased, }; use serde::{Deserialize, de::DeserializeOwned}; use serde_json::value::RawValue; @@ -43,6 +46,43 @@ struct PendingResponse { respond: oneshot::Sender>>, } +pub(crate) struct PendingRequest { + id: RequestId, + pending_responses: Arc>>, + rx: oneshot::Receiver>>, + _marker: PhantomData, +} + +impl Future for PendingRequest +where + Out: Send + 'static, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + match Pin::new(&mut this.rx).poll(cx) { + Poll::Ready(result) => { + let result = result + .map_err(|_| Error::internal_error().data("server shut down unexpectedly"))??; + let result = result + .downcast::() + .map_err(|_| Error::internal_error().data("failed to deserialize response"))?; + Poll::Ready(Ok(*result)) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl Unpin for PendingRequest {} + +impl Drop for PendingRequest { + fn drop(&mut self) { + drop(self.pending_responses.lock().unwrap().remove(&self.id)); + } +} + impl RpcConnection where Local: Side + 'static, @@ -113,7 +153,7 @@ where &self, method: impl Into>, params: Option, - ) -> Result>> { + ) -> Result> { let (tx, rx) = oneshot::channel(); let id = self.next_id.fetch_add(1, Ordering::SeqCst); let id = RequestId::Number(id); @@ -143,14 +183,11 @@ where Error::internal_error().data("connection closed before request could be sent") ); } - Ok(async move { - let result = rx - .await - .map_err(|_| Error::internal_error().data("server shut down unexpectedly"))?? - .downcast::() - .map_err(|_| Error::internal_error().data("failed to deserialize response"))?; - - Ok(*result) + Ok(PendingRequest { + id, + pending_responses: self.pending_responses.clone(), + rx, + _marker: PhantomData, }) } @@ -167,7 +204,7 @@ where let mut outgoing_line = Vec::new(); let mut incoming_line = String::new(); loop { - select_biased! { + futures::select! { message = outgoing_rx.next() => { if let Some(message) = message { outgoing_line.clear(); @@ -236,7 +273,9 @@ where pending_response.respond.send(result).ok(); } } else { - log::error!("received response for unknown request id: {id:?}"); + log::debug!( + "received response for unknown request id: {id:?} (possibly cancelled)" + ); } } else if let Some(method) = message.method { // Notification @@ -315,6 +354,14 @@ where } } +#[cfg(test)] +impl RpcConnection { + // Test-only visibility into pending request tracking for drop cleanup assertions. + pub(crate) fn pending_response_count(&self) -> usize { + self.pending_responses.lock().unwrap().len() + } +} + #[derive(Debug, Deserialize)] pub struct RawIncomingMessage<'a> { id: Option, diff --git a/src/agent-client-protocol/src/rpc_tests.rs b/src/agent-client-protocol/src/rpc_tests.rs index 288f946..de73ece 100644 --- a/src/agent-client-protocol/src/rpc_tests.rs +++ b/src/agent-client-protocol/src/rpc_tests.rs @@ -982,3 +982,54 @@ async fn test_set_session_config_option() { }) .await; } + +#[tokio::test] +async fn test_pending_response_cleanup_on_drop() { + struct NoopHandler; + + impl MessageHandler for NoopHandler { + fn handle_request( + &self, + _request: AgentRequest, + ) -> impl std::future::Future> { + async { Err(Error::internal_error()) } + } + + fn handle_notification( + &self, + _notification: AgentNotification, + ) -> impl std::future::Future> { + async { Ok(()) } + } + } + + let local_set = tokio::task::LocalSet::new(); + local_set + .run_until(async { + let (_client_to_agent_rx, client_to_agent_tx) = piper::pipe(1024); + let (agent_to_client_rx, _agent_to_client_tx) = piper::pipe(1024); + + let (conn, _io_task) = RpcConnection::::new( + NoopHandler, + client_to_agent_tx, + agent_to_client_rx, + |fut| { + tokio::task::spawn_local(fut); + }, + ); + + let pending = conn + .request::( + AGENT_METHOD_NAMES.initialize, + Some(ClientRequest::InitializeRequest(InitializeRequest::new( + ProtocolVersion::LATEST, + ))), + ) + .expect("request should enqueue pending response"); + + assert_eq!(conn.pending_response_count(), 1); + drop(pending); + assert_eq!(conn.pending_response_count(), 0); + }) + .await; +}