diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 668d766..9ec4a5a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,7 +40,9 @@ jobs: config: ./typos.toml - name: Lint - run: cargo clippy + run: | + cargo clippy + cargo clippy --all-targets --all-features - name: Build run: cargo build --all-targets --all-features diff --git a/md/SUMMARY.md b/md/SUMMARY.md index 7ee0349..6750364 100644 --- a/md/SUMMARY.md +++ b/md/SUMMARY.md @@ -6,6 +6,7 @@ - [Design Overview](./design.md) - [Protocol Reference](./protocol.md) +- [Protocol V2](./protocol-v2.md) # Conductor (agent-client-protocol-conductor) diff --git a/md/protocol-v2.md b/md/protocol-v2.md new file mode 100644 index 0000000..76387ec --- /dev/null +++ b/md/protocol-v2.md @@ -0,0 +1,75 @@ +# Protocol V2 + +The core SDK can opt into the draft ACP protocol v2 surface with the +`unstable_protocol_v2` crate feature: + +```toml +agent-client-protocol = { version = "...", features = ["unstable_protocol_v2"] } +``` + +This feature is separate from the broad `unstable` feature because protocol v2 +is a versioning experiment, not just an unstable method family. + +By default, `Client.builder()` and `Agent.builder()` continue to expose the +stable v1 API and advertise protocol v1. To use the v2 API for a connection, +construct the builder with `Client.v2()` or `Agent.v2()`: + +```rust +use agent_client_protocol::schema::{ProtocolVersion, v2}; +use agent_client_protocol::{Agent, Client}; + +# async fn run(agent_transport: impl agent_client_protocol::ConnectTo) -> agent_client_protocol::Result<()> { +Client + .v2() + .connect_with(agent_transport, async |cx| { + let initialize = cx + .send_request(v2::InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + + assert_eq!(initialize.protocol_version, ProtocolVersion::V2); + Ok(()) + }) + .await?; +# Ok(()) +# } + +# async fn serve(client_transport: impl agent_client_protocol::ConnectTo) -> agent_client_protocol::Result<()> { +Agent + .v2() + .on_receive_request( + async |initialize: v2::InitializeRequest, responder, _cx| { + responder.respond(v2::InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ) + .connect_to(client_transport) + .await?; +# Ok(()) +# } +``` + +When v2 mode is enabled, application code should use types from +`agent_client_protocol::schema::v2`. The flat `agent_client_protocol::schema::*` +exports remain the stable v1 schema. This will likely change as v2 gets closer +to release. + +The SDK handles the `initialize` negotiation at the JSON-RPC boundary: + +- A v2 client advertises protocol v2 as its latest supported version. +- A v2 client requires a v2 agent. If the agent responds with v1, the + `initialize` request resolves with an error and the caller must explicitly + fall back to a v1 client implementation if that is acceptable. +- A v2 agent responds with v2 when the client supports it, or v1 when the client + only supports v1. Agent handlers still receive v2 schema types; the SDK tracks + the negotiated wire version separately and adapts supported behavior at the + transport boundary. +- If the agent responds with any other unsupported version, the request resolves + with an error so the client can close the connection. +- After initialization, the SDK converts supported messages and responses between + the local API version and the negotiated wire version. + +That means an agent can be implemented against v2 request and response types +while still serving v1 clients. The goal is for agent-side v1 compatibility to +live in the SDK wherever it can be represented as protocol adaptation. Clients +should opt into v2 separately and should not assume v2 behavior from v1 agents. diff --git a/src/agent-client-protocol/Cargo.toml b/src/agent-client-protocol/Cargo.toml index 65e6a98..7946942 100644 --- a/src/agent-client-protocol/Cargo.toml +++ b/src/agent-client-protocol/Cargo.toml @@ -37,6 +37,7 @@ unstable_session_delete = ["agent-client-protocol-schema/unstable_session_delete unstable_session_fork = ["agent-client-protocol-schema/unstable_session_fork"] unstable_session_model = ["agent-client-protocol-schema/unstable_session_model"] unstable_session_usage = ["agent-client-protocol-schema/unstable_session_usage"] +unstable_protocol_v2 = ["agent-client-protocol-schema/unstable_protocol_v2"] [dependencies] agent-client-protocol-schema.workspace = true diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 588ef6f..73c553d 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -6,6 +6,7 @@ pub use jsonrpcmsg; // Types re-exported from crate root use serde::{Deserialize, Serialize}; +use std::any::TypeId; use std::fmt::Debug; use std::panic::Location; use std::pin::pin; @@ -19,6 +20,7 @@ mod dynamic_handler; pub(crate) mod handlers; mod incoming_actor; mod outgoing_actor; +mod protocol_compat; pub(crate) mod run; mod task_actor; mod transport_actor; @@ -28,6 +30,7 @@ pub use crate::jsonrpc::handlers::NullHandler; use crate::jsonrpc::handlers::{ChainedHandler, NamedHandler}; use crate::jsonrpc::handlers::{MessageHandler, NotificationHandler, RequestHandler}; use crate::jsonrpc::outgoing_actor::{OutgoingMessageTx, send_raw_message}; +use crate::jsonrpc::protocol_compat::{ProtocolCompat, ProtocolMode}; use crate::jsonrpc::run::SpawnedRun; use crate::jsonrpc::run::{ChainRun, NullRun, RunWithConnectionTo}; use crate::jsonrpc::task_actor::{Task, TaskTx}; @@ -554,6 +557,21 @@ where /// Responder for background tasks. responder: Runner, + + /// Protocol version mode for the public API and wire compatibility layer. + protocol_mode: ProtocolMode, +} + +fn default_protocol_mode() -> ProtocolMode { + let role = TypeId::of::(); + + if role == TypeId::of::() { + ProtocolMode::v1_agent() + } else if role == TypeId::of::() { + ProtocolMode::v1_client() + } else { + ProtocolMode::disabled() + } } impl Builder { @@ -566,6 +584,7 @@ impl Builder { name: None, handler: NullHandler, responder: NullRun, + protocol_mode: default_protocol_mode::(), } } } @@ -581,6 +600,7 @@ where name: None, handler, responder: NullRun, + protocol_mode: default_protocol_mode::(), } } } @@ -597,6 +617,28 @@ impl< self } + pub(crate) fn v1_agent(mut self) -> Self { + self.protocol_mode = ProtocolMode::v1_agent(); + self + } + + pub(crate) fn v1_client(mut self) -> Self { + self.protocol_mode = ProtocolMode::v1_client(); + self + } + + #[cfg(feature = "unstable_protocol_v2")] + pub(crate) fn v2_agent(mut self) -> Self { + self.protocol_mode = ProtocolMode::v2_agent(); + self + } + + #[cfg(feature = "unstable_protocol_v2")] + pub(crate) fn v2_client(mut self) -> Self { + self.protocol_mode = ProtocolMode::v2_client(); + self + } + /// Merge another [`Builder`] into this one. /// /// Prefer [`Self::on_receive_request`] or [`Self::on_receive_notification`]. @@ -613,14 +655,22 @@ impl< impl HandleDispatchFrom, impl RunWithConnectionTo, > { + let Builder { + name: other_name, + handler: other_handler, + responder: other_responder, + protocol_mode: other_protocol_mode, + host: _, + } = other; Builder { host: self.host, name: self.name, handler: ChainedHandler::new( self.handler, - NamedHandler::new(other.name, other.handler), + NamedHandler::new(other_name, other_handler), ), - responder: ChainRun::new(self.responder, other.responder), + responder: ChainRun::new(self.responder, other_responder), + protocol_mode: self.protocol_mode.merge(other_protocol_mode), } } @@ -637,6 +687,7 @@ impl< name: self.name, handler: ChainedHandler::new(self.handler, handler), responder: self.responder, + protocol_mode: self.protocol_mode, } } @@ -653,6 +704,7 @@ impl< name: self.name, handler: self.handler, responder: ChainRun::new(self.responder, responder), + protocol_mode: self.protocol_mode, } } @@ -1173,6 +1225,7 @@ impl< handler, responder, host: me, + protocol_mode, } = self; let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); @@ -1198,6 +1251,7 @@ impl< } = transport_channel; let (reply_tx, reply_rx) = mpsc::unbounded(); + let protocol_compat = ProtocolCompat::new(protocol_mode); let future = crate::util::instrument_with_connection_name(name, { let connection = connection.clone(); @@ -1211,6 +1265,7 @@ impl< outgoing_rx, reply_tx.clone(), transport_outgoing_tx, + protocol_compat.clone(), ), // Protocol layer: jsonrpcmsg::Message → handler/reply routing incoming_actor::incoming_protocol_actor( @@ -1220,6 +1275,7 @@ impl< dynamic_handler_rx, reply_rx, handler, + protocol_compat, ), task_actor::task_actor(new_task_rx, &connection), responder.run_with_connection_to(connection.clone()), @@ -1341,6 +1397,9 @@ enum OutgoingMessage { Response { id: jsonrpcmsg::Id, + /// Method of the incoming request this response completes. + method: String, + response: Result, }, @@ -1907,6 +1966,7 @@ impl Responder { /// The response will be serialized to JSON and sent over the wire. fn new(message_tx: OutgoingMessageTx, method: String, id: jsonrpcmsg::Id) -> Self { let id_clone = id.clone(); + let method_clone = method.clone(); Self { method, id, @@ -1915,6 +1975,7 @@ impl Responder { &message_tx, OutgoingMessage::Response { id: id_clone, + method: method_clone, response, }, ) diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index 0aa2a7a..302554f 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -20,6 +20,7 @@ use crate::jsonrpc::ResponseRouter; use crate::jsonrpc::dynamic_handler::DynHandleDispatchFrom; use crate::jsonrpc::dynamic_handler::DynamicHandlerMessage; use crate::jsonrpc::outgoing_actor::send_raw_message; +use crate::jsonrpc::protocol_compat::ProtocolCompat; use crate::role::Role; @@ -50,6 +51,7 @@ pub(super) async fn incoming_protocol_actor( dynamic_handler_rx: mpsc::UnboundedReceiver>, reply_rx: mpsc::UnboundedReceiver, mut handler: impl HandleDispatchFrom, + protocol_compat: ProtocolCompat, ) -> Result<(), crate::Error> { let mut my_rx = transport_rx .map(IncomingProtocolMsg::Transport) @@ -96,6 +98,7 @@ pub(super) async fn incoming_protocol_actor( for pending_message in pending_messages { tracing::trace!(method = pending_message.method(), handler = ?handler.dyn_describe_chain(), "Retrying message"); let id = pending_message.id(); + let method = pending_message.method().to_string(); match handler .dyn_handle_dispatch_from(pending_message, connection.clone()) .await @@ -112,7 +115,7 @@ pub(super) async fn incoming_protocol_actor( } Err(err) => { tracing::warn!(?err, handler = ?handler.dyn_describe_chain(), "Dynamic handler errored on pending message, reporting back"); - report_handler_error(connection, id, err)?; + report_handler_error(connection, id, method, err)?; } } } @@ -130,16 +133,29 @@ pub(super) async fn incoming_protocol_actor( Ok(message) => match message { jsonrpcmsg::Message::Request(request) => { tracing::trace!(method = %request.method, id = ?request.id, "Handling request"); - let dispatch = dispatch_from_request(connection, request); - dispatch_dispatch( - counterpart.clone(), - connection, - dispatch, - &mut dynamic_handlers, - &mut handler, - &mut pending_messages, - ) - .await?; + let request_method = request.method.clone(); + let request_id = request.id.clone(); + match dispatch_from_request(connection, request, &protocol_compat) { + Ok(dispatch) => { + dispatch_dispatch( + counterpart.clone(), + connection, + dispatch, + &mut dynamic_handlers, + &mut handler, + &mut pending_messages, + ) + .await?; + } + Err(error) => { + report_handler_error( + connection, + request_id.map(|id| serde_json::to_value(&id).unwrap()), + request_method, + error, + )?; + } + } } jsonrpcmsg::Message::Response(response) => { tracing::trace!(id = ?response.id, has_result = response.result.is_some(), has_error = response.error.is_some(), "Handling response"); @@ -156,6 +172,8 @@ pub(super) async fn incoming_protocol_actor( let id_json = serde_json::to_value(&id).unwrap(); if let Some(pending_reply) = pending_replies.remove(&id_json) { + let result = protocol_compat + .incoming_response(&pending_reply.method, result); // Route the response through the handler chain let dispatch = dispatch_from_response(id, pending_reply, result); dispatch_dispatch( @@ -199,19 +217,21 @@ enum IncomingProtocolMsg { fn dispatch_from_request( connection: &ConnectionTo, request: jsonrpcmsg::Request, -) -> Dispatch { + protocol_compat: &ProtocolCompat, +) -> Result { let message = UntypedMessage::new(&request.method, &request.params).expect("well-formed JSON"); + let message = protocol_compat.incoming_message(message)?; match &request.id { - Some(id) => Dispatch::Request( + Some(id) => Ok(Dispatch::Request( message, Responder::new( connection.message_tx.clone(), request.method.clone(), id.clone(), ), - ), - None => Dispatch::Notification(message), + )), + None => Ok(Dispatch::Notification(message)), } } @@ -275,7 +295,7 @@ async fn dispatch_dispatch( Err(err) => { tracing::warn!(?method, ?id, ?err, handler = ?handler.describe_chain(), "Handler errored, reporting back to remote"); - return report_handler_error(connection, id, err); + return report_handler_error(connection, id, method, err); } } @@ -299,7 +319,7 @@ async fn dispatch_dispatch( Err(err) => { tracing::warn!(?method, ?id, ?err, handler = ?dynamic_handler.dyn_describe_chain(), "Dynamic handler errored, reporting back to remote"); - return report_handler_error(connection, id, err); + return report_handler_error(connection, id, method, err); } } } @@ -327,7 +347,7 @@ async fn dispatch_dispatch( handler = "default", "Default handler errored, reporting back to remote" ); - return report_handler_error(connection, id, err); + return report_handler_error(connection, id, method, err); } } @@ -367,6 +387,7 @@ async fn dispatch_dispatch( fn report_handler_error( connection: &ConnectionTo, id: Option, + method: String, error: crate::Error, ) -> Result<(), crate::Error> { match id { @@ -377,6 +398,7 @@ fn report_handler_error( &connection.message_tx, OutgoingMessage::Response { id: jsonrpc_id, + method, response: Err(error), }, ) diff --git a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs index 7eee713..0b54ff7 100644 --- a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs @@ -4,6 +4,7 @@ use futures::channel::mpsc; use crate::jsonrpc::OutgoingMessage; use crate::jsonrpc::ReplyMessage; +use crate::jsonrpc::protocol_compat::ProtocolCompat; pub type OutgoingMessageTx = mpsc::UnboundedSender; @@ -27,6 +28,7 @@ pub(super) async fn outgoing_protocol_actor( mut outgoing_rx: mpsc::UnboundedReceiver, reply_tx: mpsc::UnboundedSender, transport_tx: mpsc::UnboundedSender>, + protocol_compat: ProtocolCompat, ) -> Result<(), crate::Error> { while let Some(message) = outgoing_rx.next().await { tracing::debug!(?message, "outgoing_protocol_actor"); @@ -40,6 +42,18 @@ pub(super) async fn outgoing_protocol_actor( untyped, response_tx, } => { + let request = match protocol_compat + .outgoing_message(untyped) + .and_then(|untyped| untyped.into_jsonrpc_msg(Some(id.clone()))) + { + Ok(request) => request, + Err(error) => { + tracing::warn!(?id, %method, ?error, "Failed to convert outgoing request"); + complete_request_with_error(response_tx, error); + continue; + } + }; + // Record where the reply should be sent once it arrives. reply_tx .unbounded_send(ReplyMessage::Subscribe { @@ -50,35 +64,47 @@ pub(super) async fn outgoing_protocol_actor( }) .map_err(crate::Error::into_internal_error)?; - jsonrpcmsg::Message::Request(untyped.into_jsonrpc_msg(Some(id))?) + jsonrpcmsg::Message::Request(request) } OutgoingMessage::Notification { untyped } => { - let msg = untyped.into_jsonrpc_msg(None)?; + let msg = match protocol_compat + .outgoing_message(untyped) + .and_then(|untyped| untyped.into_jsonrpc_msg(None)) + { + Ok(msg) => msg, + Err(error) => { + tracing::warn!( + ?error, + "Dropping outgoing notification after conversion failed" + ); + continue; + } + }; jsonrpcmsg::Message::Request(msg) } OutgoingMessage::Response { id, - response: Ok(value), - } => { - tracing::debug!(?id, "Sending success response"); - jsonrpcmsg::Message::Response(jsonrpcmsg::Response::success_v2(value, Some(id))) - } - OutgoingMessage::Response { - id, - response: Err(error), - } => { - tracing::warn!(?id, ?error, "Sending error response"); - // Convert crate::Error to jsonrpcmsg::Error - let jsonrpc_error = jsonrpcmsg::Error { - code: error.code.into(), - message: error.message, - data: error.data, - }; - jsonrpcmsg::Message::Response(jsonrpcmsg::Response::error_v2( - jsonrpc_error, - Some(id), - )) - } + method, + response, + } => match protocol_compat.outgoing_response(&method, response) { + Ok(value) => { + tracing::debug!(?id, "Sending success response"); + jsonrpcmsg::Message::Response(jsonrpcmsg::Response::success_v2(value, Some(id))) + } + Err(error) => { + tracing::warn!(?id, %method, ?error, "Sending error response"); + // Convert crate::Error to jsonrpcmsg::Error + let jsonrpc_error = jsonrpcmsg::Error { + code: error.code.into(), + message: error.message, + data: error.data, + }; + jsonrpcmsg::Message::Response(jsonrpcmsg::Response::error_v2( + jsonrpc_error, + Some(id), + )) + } + }, OutgoingMessage::Error { error } => { // Convert crate::Error to jsonrpcmsg::Error let jsonrpc_error = jsonrpcmsg::Error { @@ -99,3 +125,113 @@ pub(super) async fn outgoing_protocol_actor( } Ok(()) } + +fn complete_request_with_error( + response_tx: futures::channel::oneshot::Sender, + error: crate::Error, +) { + if response_tx + .send(crate::jsonrpc::ResponsePayload { + result: Err(error), + ack_tx: None, + }) + .is_err() + { + tracing::debug!("Dropped failed outgoing request because receiver was gone"); + } +} + +#[cfg(all(test, feature = "unstable_protocol_v2"))] +mod tests { + use futures::StreamExt as _; + use futures::channel::{mpsc, oneshot}; + + use super::*; + use crate::Role as _; + + fn malformed_v2_known_method() -> Result { + crate::UntypedMessage::new("session/new", serde_json::json!({})) + } + + #[tokio::test(flavor = "current_thread")] + async fn failed_request_conversion_completes_request_locally() -> Result<(), crate::Error> { + let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + let (reply_tx, mut reply_rx) = mpsc::unbounded(); + let (transport_tx, mut transport_rx) = mpsc::unbounded(); + let (response_tx, response_rx) = oneshot::channel(); + + outgoing_tx + .unbounded_send(OutgoingMessage::Request { + id: jsonrpcmsg::Id::Number(1), + role_id: crate::Agent.role_id(), + method: "session/new".into(), + untyped: malformed_v2_known_method()?, + response_tx, + }) + .map_err(crate::Error::into_internal_error)?; + drop(outgoing_tx); + + outgoing_protocol_actor( + outgoing_rx, + reply_tx, + transport_tx, + ProtocolCompat::new(crate::jsonrpc::protocol_compat::ProtocolMode::v2_agent()), + ) + .await?; + + let response = response_rx + .await + .map_err(crate::Error::into_internal_error)?; + assert!( + response.result.is_err(), + "conversion failure should complete the local request" + ); + assert!(response.ack_tx.is_none()); + assert!(reply_rx.next().await.is_none()); + assert!(transport_rx.next().await.is_none()); + + Ok(()) + } + + #[tokio::test(flavor = "current_thread")] + async fn failed_notification_conversion_does_not_stop_actor() -> Result<(), crate::Error> { + let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + let (reply_tx, _reply_rx) = mpsc::unbounded(); + let (transport_tx, mut transport_rx) = mpsc::unbounded(); + + outgoing_tx + .unbounded_send(OutgoingMessage::Notification { + untyped: malformed_v2_known_method()?, + }) + .map_err(crate::Error::into_internal_error)?; + outgoing_tx + .unbounded_send(OutgoingMessage::Notification { + untyped: crate::UntypedMessage::new( + "_local/notify", + serde_json::json!({ "ok": true }), + )?, + }) + .map_err(crate::Error::into_internal_error)?; + drop(outgoing_tx); + + outgoing_protocol_actor( + outgoing_rx, + reply_tx, + transport_tx, + ProtocolCompat::new(crate::jsonrpc::protocol_compat::ProtocolMode::v2_agent()), + ) + .await?; + + let message = transport_rx + .next() + .await + .expect("valid notification should still be sent")?; + let jsonrpcmsg::Message::Request(request) = message else { + panic!("expected outgoing notification request, got {message:?}"); + }; + assert_eq!(request.method, "_local/notify"); + assert!(transport_rx.next().await.is_none()); + + Ok(()) + } +} diff --git a/src/agent-client-protocol/src/jsonrpc/protocol_compat.rs b/src/agent-client-protocol/src/jsonrpc/protocol_compat.rs new file mode 100644 index 0000000..1d13142 --- /dev/null +++ b/src/agent-client-protocol/src/jsonrpc/protocol_compat.rs @@ -0,0 +1,749 @@ +#[cfg(not(feature = "unstable_protocol_v2"))] +mod imp { + #![allow(clippy::unused_self, clippy::unnecessary_wraps)] + use crate::UntypedMessage; + + #[derive(Clone, Copy, Debug, Default)] + pub(crate) struct ProtocolMode; + + impl ProtocolMode { + pub(crate) fn disabled() -> Self { + Self + } + + pub(crate) fn v1_agent() -> Self { + Self + } + + pub(crate) fn v1_client() -> Self { + Self + } + + pub(crate) fn merge(self, _other: Self) -> Self { + self + } + } + + #[derive(Clone, Debug, Default)] + pub(crate) struct ProtocolCompat; + + impl ProtocolCompat { + pub(crate) fn new(_mode: ProtocolMode) -> Self { + Self + } + + pub(crate) fn incoming_message( + &self, + message: UntypedMessage, + ) -> Result { + Ok(message) + } + + pub(crate) fn outgoing_message( + &self, + message: UntypedMessage, + ) -> Result { + Ok(message) + } + + pub(crate) fn incoming_response( + &self, + _method: &str, + result: Result, + ) -> Result { + result + } + + pub(crate) fn outgoing_response( + &self, + _method: &str, + result: Result, + ) -> Result { + result + } + } +} + +#[cfg(feature = "unstable_protocol_v2")] +mod imp { + use std::sync::{Arc, Mutex}; + + use agent_client_protocol_schema::v2::{ + self, + conversion::{IntoV1, IntoV2, v1_to_v2, v2_to_v1}, + }; + + use crate::schema::{ + AgentNotification, AgentRequest, AgentResponse, ClientNotification, ClientRequest, + ClientResponse, ErrorCode, ProtocolVersion, + }; + use crate::{JsonRpcMessage, JsonRpcResponse, UntypedMessage}; + + #[derive(Clone, Copy, Debug)] + pub(crate) enum ProtocolMode { + Disabled, + Acp(AcpProtocolMode), + } + + #[derive(Clone, Copy, Debug)] + pub(crate) struct AcpProtocolMode { + api: ProtocolVersionKind, + latest_supported: ProtocolVersionKind, + require_latest_response: bool, + } + + impl ProtocolMode { + pub(crate) fn disabled() -> Self { + Self::Disabled + } + + pub(crate) fn v1_agent() -> Self { + Self::Acp(AcpProtocolMode { + api: ProtocolVersionKind::V1, + latest_supported: ProtocolVersionKind::V1, + require_latest_response: false, + }) + } + + pub(crate) fn v1_client() -> Self { + Self::Acp(AcpProtocolMode { + api: ProtocolVersionKind::V1, + latest_supported: ProtocolVersionKind::V1, + require_latest_response: true, + }) + } + + pub(crate) fn v2_agent() -> Self { + Self::Acp(AcpProtocolMode { + api: ProtocolVersionKind::V2, + latest_supported: ProtocolVersionKind::V2, + require_latest_response: false, + }) + } + + pub(crate) fn v2_client() -> Self { + Self::Acp(AcpProtocolMode { + api: ProtocolVersionKind::V2, + latest_supported: ProtocolVersionKind::V2, + require_latest_response: true, + }) + } + + pub(crate) fn merge(self, other: Self) -> Self { + match (self, other) { + (Self::Disabled, other) => other, + (this, Self::Disabled) => this, + (Self::Acp(this), Self::Acp(other)) => { + assert_eq!( + this.api, other.api, + "cannot merge ACP builders with different API protocol versions; \ + handler chains share a single API surface", + ); + if this.latest_supported >= other.latest_supported { + Self::Acp(this) + } else { + Self::Acp(other) + } + } + } + } + } + + #[derive(Clone, Debug)] + pub(crate) struct ProtocolCompat { + mode: Option, + state: Arc>, + } + + #[derive(Debug)] + struct ProtocolState { + negotiated: ProtocolVersionKind, + pending_initialize: Option, + } + + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] + enum ProtocolVersionKind { + V1, + V2, + } + + impl ProtocolVersionKind { + fn as_protocol_version(self) -> ProtocolVersion { + match self { + Self::V1 => ProtocolVersion::V1, + Self::V2 => ProtocolVersion::V2, + } + } + + fn from_protocol_version(version: ProtocolVersion) -> Option { + if version == ProtocolVersion::V1 { + Some(Self::V1) + } else if version == ProtocolVersion::V2 { + Some(Self::V2) + } else { + None + } + } + } + + impl ProtocolCompat { + pub(crate) fn new(mode: ProtocolMode) -> Self { + Self { + mode: match mode { + ProtocolMode::Disabled => None, + ProtocolMode::Acp(mode) => Some(mode), + }, + state: Arc::new(Mutex::new(ProtocolState { + negotiated: ProtocolVersionKind::V1, + pending_initialize: None, + })), + } + } + + pub(crate) fn incoming_message( + &self, + message: UntypedMessage, + ) -> Result { + let Some(mode) = self.mode else { + return Ok(message); + }; + + if message.method() == "initialize" { + return self.incoming_initialize_request(mode, message); + } + + convert_message(message, self.active_wire_version(), mode.api) + } + + pub(crate) fn outgoing_message( + &self, + mut message: UntypedMessage, + ) -> Result { + let Some(mode) = self.mode else { + return Ok(message); + }; + + let wire_version = if message.method() == "initialize" { + set_protocol_version(&mut message.params, mode.latest_supported)?; + self.set_pending_initialize(mode.latest_supported); + mode.latest_supported + } else { + self.active_wire_version() + }; + + convert_message(message, mode.api, wire_version) + } + + pub(crate) fn incoming_response( + &self, + method: &str, + result: Result, + ) -> Result { + let Some(mode) = self.mode else { + return result; + }; + + if method == "initialize" { + return self.incoming_initialize_response(mode, result); + } + + let value = result?; + convert_response(method, value, self.active_wire_version(), mode.api) + } + + pub(crate) fn outgoing_response( + &self, + method: &str, + result: Result, + ) -> Result { + let Some(mode) = self.mode else { + return result; + }; + + // Always drain any pending initialize state so a failed initialize + // doesn't leak negotiation state to a subsequent request. + let pending_initialize = if method == "initialize" { + self.take_pending_initialize() + } else { + None + }; + + let mut value = result?; + + let wire_version = if method == "initialize" { + let negotiated = pending_initialize.unwrap_or_else(|| { + protocol_version_from_value(&value) + .and_then(ProtocolVersionKind::from_protocol_version) + .unwrap_or(mode.latest_supported) + }); + set_protocol_version(&mut value, negotiated)?; + self.set_negotiated(negotiated); + negotiated + } else { + self.active_wire_version() + }; + + convert_response(method, value, mode.api, wire_version) + } + + fn incoming_initialize_request( + &self, + mode: AcpProtocolMode, + mut message: UntypedMessage, + ) -> Result { + let requested = required_protocol_version_from_value(message.params())?; + let requested_kind = ProtocolVersionKind::from_protocol_version(requested); + let wire_version = requested_kind.unwrap_or(mode.latest_supported); + let negotiated = self.negotiate(requested); + self.set_pending_initialize(negotiated); + + message = convert_message(message, wire_version, mode.api)?; + set_protocol_version(&mut message.params, mode.api)?; + Ok(message) + } + + fn incoming_initialize_response( + &self, + mode: AcpProtocolMode, + result: Result, + ) -> Result { + let _pending_initialize = self.take_pending_initialize(); + let mut value = result?; + let response_version = required_protocol_version_from_value(&value)?; + if !self.supports(response_version) { + return Err(unsupported_protocol_version(response_version)); + } + + let wire_version = ProtocolVersionKind::from_protocol_version(response_version) + .ok_or_else(|| unsupported_protocol_version(response_version))?; + if mode.require_latest_response && wire_version != mode.latest_supported { + return Err(required_protocol_version( + mode.latest_supported, + wire_version, + )); + } + self.set_negotiated(wire_version); + + value = convert_response("initialize", value, wire_version, mode.api)?; + set_protocol_version(&mut value, wire_version)?; + Ok(value) + } + + fn supports(&self, version: ProtocolVersion) -> bool { + let Some(mode) = self.mode else { + return false; + }; + + version == ProtocolVersion::V1 + || (mode.latest_supported == ProtocolVersionKind::V2 + && version == ProtocolVersion::V2) + } + + fn negotiate(&self, requested: ProtocolVersion) -> ProtocolVersionKind { + let mode = self + .mode + .expect("ACP protocol mode should be enabled while negotiating"); + + if self.supports(requested) { + ProtocolVersionKind::from_protocol_version(requested) + .unwrap_or(mode.latest_supported) + } else { + mode.latest_supported + } + } + + fn active_wire_version(&self) -> ProtocolVersionKind { + let state = self + .state + .lock() + .expect("protocol compatibility state mutex poisoned"); + state.pending_initialize.unwrap_or(state.negotiated) + } + + fn set_negotiated(&self, negotiated: ProtocolVersionKind) { + self.state + .lock() + .expect("protocol compatibility state mutex poisoned") + .negotiated = negotiated; + } + + fn set_pending_initialize(&self, negotiated: ProtocolVersionKind) { + self.state + .lock() + .expect("protocol compatibility state mutex poisoned") + .pending_initialize = Some(negotiated); + } + + fn take_pending_initialize(&self) -> Option { + self.state + .lock() + .expect("protocol compatibility state mutex poisoned") + .pending_initialize + .take() + } + } + + fn protocol_version_from_value(value: &serde_json::Value) -> Option { + serde_json::from_value(value.get("protocolVersion")?.clone()).ok() + } + + fn required_protocol_version_from_value( + value: &serde_json::Value, + ) -> Result { + let Some(version) = value.get("protocolVersion") else { + return Err(invalid_initialize_protocol_version()); + }; + + serde_json::from_value(version.clone()).map_err(|_| invalid_initialize_protocol_version()) + } + + fn invalid_initialize_protocol_version() -> crate::Error { + crate::Error::invalid_params() + .data("initialize.protocolVersion must be a valid ACP protocol version") + } + + fn set_protocol_version( + value: &mut serde_json::Value, + version: ProtocolVersionKind, + ) -> Result<(), crate::Error> { + if let serde_json::Value::Object(object) = value { + object.insert( + "protocolVersion".into(), + serde_json::to_value(version.as_protocol_version()) + .map_err(crate::Error::into_internal_error)?, + ); + } + Ok(()) + } + + fn convert_message( + message: UntypedMessage, + from: ProtocolVersionKind, + to: ProtocolVersionKind, + ) -> Result { + if message.method().starts_with('_') || from == to { + return Ok(message); + } + + match (from, to) { + (ProtocolVersionKind::V1, ProtocolVersionKind::V2) => public_to_v2_message(message), + (ProtocolVersionKind::V2, ProtocolVersionKind::V1) => v2_to_public_message(message), + _ => Ok(message), + } + } + + fn convert_response( + method: &str, + value: serde_json::Value, + from: ProtocolVersionKind, + to: ProtocolVersionKind, + ) -> Result { + if method.starts_with('_') || from == to { + return Ok(value); + } + + match (from, to) { + (ProtocolVersionKind::V1, ProtocolVersionKind::V2) => { + public_to_v2_response(method, value) + } + (ProtocolVersionKind::V2, ProtocolVersionKind::V1) => { + v2_to_public_response(method, value) + } + _ => Ok(value), + } + } + + fn public_to_v2_message(message: UntypedMessage) -> Result { + let UntypedMessage { method, params } = message; + + if let Some(message) = try_convert_message_to_v2::(&method, ¶ms)? { + return Ok(message); + } + if let Some(message) = try_convert_message_to_v2::(&method, ¶ms)? { + return Ok(message); + } + if let Some(message) = try_convert_message_to_v2::(&method, ¶ms)? { + return Ok(message); + } + if let Some(message) = try_convert_message_to_v2::(&method, ¶ms)? { + return Ok(message); + } + + Ok(UntypedMessage { method, params }) + } + + fn v2_to_public_message(message: UntypedMessage) -> Result { + let UntypedMessage { method, params } = message; + + if let Some(message) = try_convert_message_to_v1::(&method, ¶ms)? { + return Ok(message); + } + if let Some(message) = try_convert_message_to_v1::(&method, ¶ms)? { + return Ok(message); + } + if let Some(message) = + try_convert_message_to_v1::(&method, ¶ms)? + { + return Ok(message); + } + if let Some(message) = try_convert_message_to_v1::(&method, ¶ms)? + { + return Ok(message); + } + + Ok(UntypedMessage { method, params }) + } + + fn public_to_v2_response( + method: &str, + value: serde_json::Value, + ) -> Result { + if let Some(value) = try_convert_response_to_v2::(method, &value)? { + return Ok(value); + } + if let Some(value) = try_convert_response_to_v2::(method, &value)? { + return Ok(value); + } + + Ok(value) + } + + fn v2_to_public_response( + method: &str, + value: serde_json::Value, + ) -> Result { + if let Some(value) = try_convert_response_to_v1::(method, &value)? { + return Ok(value); + } + if let Some(value) = try_convert_response_to_v1::(method, &value)? { + return Ok(value); + } + + Ok(value) + } + + fn try_convert_message_to_v2( + method: &str, + params: &serde_json::Value, + ) -> Result, crate::Error> + where + T: JsonRpcMessage + IntoV2, + T::Output: JsonRpcMessage, + { + let Some(message) = try_parse_message::(method, params)? else { + return Ok(None); + }; + let wire_message = v1_to_v2(message)?; + wire_message.to_untyped_message().map(Some) + } + + fn try_convert_message_to_v1( + method: &str, + params: &serde_json::Value, + ) -> Result, crate::Error> + where + T: JsonRpcMessage + IntoV1, + T::Output: JsonRpcMessage, + { + let Some(message) = try_parse_message::(method, params)? else { + return Ok(None); + }; + let public_message = v2_to_v1(message)?; + public_message.to_untyped_message().map(Some) + } + + fn try_parse_message( + method: &str, + params: &serde_json::Value, + ) -> Result, crate::Error> { + match T::parse_message(method, params) { + Ok(message) => Ok(Some(message)), + Err(error) if error.code == ErrorCode::MethodNotFound => Ok(None), + Err(error) => Err(error), + } + } + + fn try_convert_response_to_v2( + method: &str, + value: &serde_json::Value, + ) -> Result, crate::Error> + where + T: JsonRpcResponse + IntoV2, + T::Output: JsonRpcResponse, + { + let Some(response) = try_parse_response::(method, value)? else { + return Ok(None); + }; + let wire_response = v1_to_v2(response)?; + wire_response.into_json(method).map(Some) + } + + fn try_convert_response_to_v1( + method: &str, + value: &serde_json::Value, + ) -> Result, crate::Error> + where + T: JsonRpcResponse + IntoV1, + T::Output: JsonRpcResponse, + { + let Some(response) = try_parse_response::(method, value)? else { + return Ok(None); + }; + let public_response = v2_to_v1(response)?; + public_response.into_json(method).map(Some) + } + + fn try_parse_response( + method: &str, + value: &serde_json::Value, + ) -> Result, crate::Error> { + match T::from_value(method, value.clone()) { + Ok(response) => Ok(Some(response)), + Err(error) if error.code == ErrorCode::MethodNotFound => Ok(None), + Err(error) => Err(error), + } + } + + fn unsupported_protocol_version(version: ProtocolVersion) -> crate::Error { + crate::Error::invalid_request().data(format!( + "unsupported ACP protocol version {version}; this endpoint does not support that version" + )) + } + + fn required_protocol_version( + required: ProtocolVersionKind, + negotiated: ProtocolVersionKind, + ) -> crate::Error { + crate::Error::invalid_request().data(format!( + "required ACP protocol version {} but peer negotiated {}; use a v1 client implementation for v1 agents", + required.as_protocol_version(), + negotiated.as_protocol_version(), + )) + } + + #[cfg(test)] + mod tests { + use super::*; + + fn negotiated(compat: &ProtocolCompat) -> ProtocolVersionKind { + compat + .state + .lock() + .expect("protocol compatibility state mutex poisoned") + .negotiated + } + + #[test] + fn initialize_request_sets_active_wire_version_before_response() -> Result<(), crate::Error> + { + let compat = ProtocolCompat::new(ProtocolMode::v2_agent()); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V1); + + compat.incoming_message(UntypedMessage::new( + "initialize", + v2::InitializeRequest::new(ProtocolVersion::V2), + )?)?; + + assert_eq!(negotiated(&compat), ProtocolVersionKind::V1); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2); + + compat.outgoing_response( + "initialize", + Ok(serde_json::to_value(v2::InitializeResponse::new( + ProtocolVersion::V2, + ))?), + )?; + + assert_eq!(negotiated(&compat), ProtocolVersionKind::V2); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2); + Ok(()) + } + + #[test] + fn outgoing_initialize_sets_active_wire_version_before_response() -> Result<(), crate::Error> + { + let compat = ProtocolCompat::new(ProtocolMode::v2_client()); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V1); + + compat.outgoing_message(UntypedMessage::new( + "initialize", + v2::InitializeRequest::new(ProtocolVersion::V1), + )?)?; + + assert_eq!(negotiated(&compat), ProtocolVersionKind::V1); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2); + + compat.incoming_response( + "initialize", + Ok(serde_json::to_value(v2::InitializeResponse::new( + ProtocolVersion::V2, + ))?), + )?; + + assert_eq!(negotiated(&compat), ProtocolVersionKind::V2); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2); + Ok(()) + } + + #[test] + fn failed_incoming_initialize_response_clears_pending_wire_version() + -> Result<(), crate::Error> { + let compat = ProtocolCompat::new(ProtocolMode::v2_client()); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V1); + + compat.outgoing_message(UntypedMessage::new( + "initialize", + v2::InitializeRequest::new(ProtocolVersion::V1), + )?)?; + + assert_eq!(negotiated(&compat), ProtocolVersionKind::V1); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V2); + + let result = compat.incoming_response( + "initialize", + Err(crate::Error::invalid_request().data("initialize failed")), + ); + + assert!(result.is_err()); + assert_eq!(negotiated(&compat), ProtocolVersionKind::V1); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V1); + Ok(()) + } + + #[test] + fn incoming_initialize_response_requires_protocol_version() -> Result<(), crate::Error> { + for value in [ + serde_json::json!({}), + serde_json::json!({ "protocolVersion": 100_000 }), + ] { + let compat = ProtocolCompat::new(ProtocolMode::v2_client()); + compat.outgoing_message(UntypedMessage::new( + "initialize", + v2::InitializeRequest::new(ProtocolVersion::V1), + )?)?; + + let error = compat + .incoming_response("initialize", Ok(value)) + .expect_err("initialize responses must declare an ACP protocol version"); + let data = error + .data + .as_ref() + .and_then(|data| data.as_str()) + .unwrap_or_default(); + assert!(data.contains("protocolVersion"), "{error:?}"); + assert_eq!(negotiated(&compat), ProtocolVersionKind::V1); + assert_eq!(compat.active_wire_version(), ProtocolVersionKind::V1); + } + + Ok(()) + } + + #[test] + #[should_panic(expected = "cannot merge ACP builders with different API protocol versions")] + fn merging_different_api_protocol_modes_panics() { + let _ = ProtocolMode::v1_agent().merge(ProtocolMode::v2_agent()); + } + } +} + +pub(crate) use imp::{ProtocolCompat, ProtocolMode}; diff --git a/src/agent-client-protocol/src/role/acp.rs b/src/agent-client-protocol/src/role/acp.rs index 3ffb334..b103890 100644 --- a/src/agent-client-protocol/src/role/acp.rs +++ b/src/agent-client-protocol/src/role/acp.rs @@ -17,6 +17,10 @@ pub struct Client; impl Role for Client { type Counterpart = Agent; + fn builder(self) -> Builder { + Builder::new(self).v1_client() + } + async fn default_handle_dispatch_from( &self, message: Dispatch, @@ -40,7 +44,19 @@ impl Role for Client { impl Client { /// Create a connection builder for a client. pub fn builder(self) -> Builder { - Builder::new(self) + ::builder(self) + } + + /// Create a client builder that requires an ACP protocol v2 agent. + /// + /// If the agent negotiates v1 during initialization, the initialize + /// request resolves with an error so callers can choose an explicit v1 + /// fallback path. + /// + /// Requires the `unstable_protocol_v2` crate feature. + #[cfg(feature = "unstable_protocol_v2")] + pub fn v2(self) -> Builder { + self.builder().v2_client() } /// Connect to `agent` and run `main_fn` with the [`ConnectionTo`]. @@ -72,6 +88,10 @@ pub struct Agent; impl Role for Agent { type Counterpart = Client; + fn builder(self) -> Builder { + Builder::new(self).v1_agent() + } + fn role_id(&self) -> RoleId { RoleId::from_singleton(self) } @@ -105,7 +125,19 @@ impl Role for Agent { impl Agent { /// Create a connection builder for an agent. pub fn builder(self) -> Builder { - Builder::new(self) + ::builder(self) + } + + /// Create an agent builder that uses the ACP protocol v2 API. + /// + /// The SDK will negotiate v1 or v2 during initialization and convert + /// supported messages at the transport boundary, so handlers can be written + /// against v2 types while still serving v1 clients. + /// + /// Requires the `unstable_protocol_v2` crate feature. + #[cfg(feature = "unstable_protocol_v2")] + pub fn v2(self) -> Builder { + self.builder().v2_agent() } } diff --git a/src/agent-client-protocol/src/schema/enum_impls.rs b/src/agent-client-protocol/src/schema/enum_impls.rs index 854ef2e..f0f9674 100644 --- a/src/agent-client-protocol/src/schema/enum_impls.rs +++ b/src/agent-client-protocol/src/schema/enum_impls.rs @@ -1,7 +1,10 @@ //! JsonRpcMessage and JsonRpcNotification/JsonRpcRequest implementations for //! the ACP enum types from agent-client-protocol-schema. -use crate::schema::{AgentNotification, AgentRequest, ClientNotification, ClientRequest}; +use crate::schema::{ + AgentNotification, AgentRequest, AgentResponse, ClientNotification, ClientRequest, + ClientResponse, +}; // ============================================================================ // Agent side (messages that agents receive) @@ -26,11 +29,39 @@ impl_jsonrpc_request_enum!(ClientRequest { PromptRequest => "session/prompt", #[cfg(feature = "unstable_session_model")] SetSessionModelRequest => "session/set_model", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpRequest => "mcp/message", [ext] ExtMethodRequest, }); +impl_jsonrpc_response_enum!(AgentResponse { + InitializeResponse => "initialize", + AuthenticateResponse => "authenticate", + #[cfg(feature = "unstable_logout")] + LogoutResponse => "logout", + NewSessionResponse => "session/new", + LoadSessionResponse => "session/load", + ListSessionsResponse => "session/list", + #[cfg(feature = "unstable_session_delete")] + DeleteSessionResponse => "session/delete", + #[cfg(feature = "unstable_session_fork")] + ForkSessionResponse => "session/fork", + ResumeSessionResponse => "session/resume", + CloseSessionResponse => "session/close", + SetSessionModeResponse => "session/set_mode", + SetSessionConfigOptionResponse => "session/set_config_option", + PromptResponse => "session/prompt", + #[cfg(feature = "unstable_session_model")] + SetSessionModelResponse => "session/set_model", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpResponse => "mcp/message", + [ext] ExtMethodResponse, +}); + impl_jsonrpc_notification_enum!(ClientNotification { CancelNotification => "session/cancel", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpNotification => "mcp/message", [ext] ExtNotification, }); @@ -47,10 +78,36 @@ impl_jsonrpc_request_enum!(AgentRequest { ReleaseTerminalRequest => "terminal/release", WaitForTerminalExitRequest => "terminal/wait_for_exit", KillTerminalRequest => "terminal/kill", + #[cfg(feature = "unstable_mcp_over_acp")] + ConnectMcpRequest => "mcp/connect", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpRequest => "mcp/message", + #[cfg(feature = "unstable_mcp_over_acp")] + DisconnectMcpRequest => "mcp/disconnect", [ext] ExtMethodRequest, }); +impl_jsonrpc_response_enum!(ClientResponse { + WriteTextFileResponse => "fs/write_text_file", + ReadTextFileResponse => "fs/read_text_file", + RequestPermissionResponse => "session/request_permission", + CreateTerminalResponse => "terminal/create", + TerminalOutputResponse => "terminal/output", + ReleaseTerminalResponse => "terminal/release", + WaitForTerminalExitResponse => "terminal/wait_for_exit", + KillTerminalResponse => "terminal/kill", + #[cfg(feature = "unstable_mcp_over_acp")] + ConnectMcpResponse => "mcp/connect", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpResponse => "mcp/message", + #[cfg(feature = "unstable_mcp_over_acp")] + DisconnectMcpResponse => "mcp/disconnect", + [ext] ExtMethodResponse, +}); + impl_jsonrpc_notification_enum!(AgentNotification { SessionNotification => "session/update", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpNotification => "mcp/message", [ext] ExtNotification, }); diff --git a/src/agent-client-protocol/src/schema/mod.rs b/src/agent-client-protocol/src/schema/mod.rs index 8367031..6279701 100644 --- a/src/agent-client-protocol/src/schema/mod.rs +++ b/src/agent-client-protocol/src/schema/mod.rs @@ -219,11 +219,47 @@ macro_rules! impl_jsonrpc_notification_enum { }; } +/// Implement `JsonRpcResponse` for an enum that dispatches across multiple +/// response types, with an extension method fallback. +macro_rules! impl_jsonrpc_response_enum { + ($enum:ty { + $( $(#[$meta:meta])* $variant:ident => $method:literal, )* + [ext] $ext_variant:ident, + }) => { + impl $crate::JsonRpcResponse for $enum { + fn into_json( + self, + _method: &str, + ) -> Result { + serde_json::to_value(self).map_err($crate::Error::into_internal_error) + } + + fn from_value( + method: &str, + value: serde_json::Value, + ) -> Result { + match method { + $( $(#[$meta])* $method => $crate::util::json_cast(value).map(Self::$variant), )* + _ => { + if method.starts_with('_') { + $crate::util::json_cast(value).map(Self::$ext_variant) + } else { + Err($crate::Error::method_not_found()) + } + } + } + } + } + }; +} + // Internal organization mod agent_to_client; mod client_to_agent; mod enum_impls; mod proxy_protocol; +#[cfg(feature = "unstable_protocol_v2")] +mod v2_impls; // Re-export everything from agent_client_protocol_schema pub use agent_client_protocol_schema::*; diff --git a/src/agent-client-protocol/src/schema/v2_impls.rs b/src/agent-client-protocol/src/schema/v2_impls.rs new file mode 100644 index 0000000..5b87c8a --- /dev/null +++ b/src/agent-client-protocol/src/schema/v2_impls.rs @@ -0,0 +1,424 @@ +//! JSON-RPC trait implementations for the experimental schema v2 namespace. + +use crate::schema::v2; +use crate::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, UntypedMessage}; + +macro_rules! impl_v2_jsonrpc_request { + ($req:ty, $resp:ty, $method:literal) => { + impl JsonRpcMessage for $req { + fn matches_method(method: &str) -> bool { + method == $method + } + + fn method(&self) -> &str { + $method + } + + fn to_untyped_message(&self) -> Result { + UntypedMessage::new($method, self) + } + + fn parse_message( + method: &str, + params: &impl serde::Serialize, + ) -> Result { + if method != $method { + return Err(crate::Error::method_not_found()); + } + crate::util::json_cast_params(params) + } + } + + impl JsonRpcRequest for $req { + type Response = $resp; + } + + impl JsonRpcResponse for $resp { + fn into_json(self, _method: &str) -> Result { + serde_json::to_value(self).map_err(crate::Error::into_internal_error) + } + + fn from_value(_method: &str, value: serde_json::Value) -> Result { + crate::util::json_cast(value) + } + } + }; +} + +macro_rules! impl_v2_jsonrpc_notification { + ($notif:ty, $method:literal) => { + impl JsonRpcMessage for $notif { + fn matches_method(method: &str) -> bool { + method == $method + } + + fn method(&self) -> &str { + $method + } + + fn to_untyped_message(&self) -> Result { + UntypedMessage::new($method, self) + } + + fn parse_message( + method: &str, + params: &impl serde::Serialize, + ) -> Result { + if method != $method { + return Err(crate::Error::method_not_found()); + } + crate::util::json_cast_params(params) + } + } + + impl JsonRpcNotification for $notif {} + }; +} + +macro_rules! impl_v2_jsonrpc_request_enum { + ($enum:ty { + $( $(#[$meta:meta])* $variant:ident => $method:literal, )* + [ext] $ext_variant:ident, + }) => { + impl JsonRpcMessage for $enum { + fn matches_method(_method: &str) -> bool { + true + } + + fn method(&self) -> &str { + match self { + $( $(#[$meta])* Self::$variant(_) => $method, )* + Self::$ext_variant(ext) => &ext.method, + _ => "_unknown", + } + } + + fn to_untyped_message(&self) -> Result { + UntypedMessage::new(self.method(), self) + } + + fn parse_message( + method: &str, + params: &impl serde::Serialize, + ) -> Result { + match method { + $( $(#[$meta])* $method => crate::util::json_cast_params(params).map(Self::$variant), )* + _ => { + if method.starts_with('_') { + crate::util::json_cast_params(params).map( + |ext_req: v2::ExtRequest| { + Self::$ext_variant(v2::ExtRequest::new( + method.to_string(), + ext_req.params, + )) + }, + ) + } else { + Err(crate::Error::method_not_found()) + } + } + } + } + } + + impl JsonRpcRequest for $enum { + type Response = serde_json::Value; + } + }; +} + +macro_rules! impl_v2_jsonrpc_notification_enum { + ($enum:ty { + $( $(#[$meta:meta])* $variant:ident => $method:literal, )* + [ext] $ext_variant:ident, + }) => { + impl JsonRpcMessage for $enum { + fn matches_method(_method: &str) -> bool { + true + } + + fn method(&self) -> &str { + match self { + $( $(#[$meta])* Self::$variant(_) => $method, )* + Self::$ext_variant(ext) => &ext.method, + _ => "_unknown", + } + } + + fn to_untyped_message(&self) -> Result { + UntypedMessage::new(self.method(), self) + } + + fn parse_message( + method: &str, + params: &impl serde::Serialize, + ) -> Result { + match method { + $( $(#[$meta])* $method => crate::util::json_cast_params(params).map(Self::$variant), )* + _ => { + if method.starts_with('_') { + crate::util::json_cast_params(params).map( + |ext_notif: v2::ExtNotification| { + Self::$ext_variant(v2::ExtNotification::new( + method.to_string(), + ext_notif.params, + )) + }, + ) + } else { + Err(crate::Error::method_not_found()) + } + } + } + } + } + + impl JsonRpcNotification for $enum {} + }; +} + +macro_rules! impl_v2_jsonrpc_response_enum { + ($enum:ty { + $( $(#[$meta:meta])* $variant:ident => $method:literal, )* + [ext] $ext_variant:ident, + }) => { + impl JsonRpcResponse for $enum { + fn into_json( + self, + _method: &str, + ) -> Result { + serde_json::to_value(self).map_err(crate::Error::into_internal_error) + } + + fn from_value( + method: &str, + value: serde_json::Value, + ) -> Result { + match method { + $( $(#[$meta])* $method => crate::util::json_cast(value).map(Self::$variant), )* + _ => { + if method.starts_with('_') { + crate::util::json_cast(value).map(Self::$ext_variant) + } else { + Err(crate::Error::method_not_found()) + } + } + } + } + } + }; +} + +impl_v2_jsonrpc_request!(v2::InitializeRequest, v2::InitializeResponse, "initialize"); +impl_v2_jsonrpc_request!( + v2::AuthenticateRequest, + v2::AuthenticateResponse, + "authenticate" +); +#[cfg(feature = "unstable_logout")] +impl_v2_jsonrpc_request!(v2::LogoutRequest, v2::LogoutResponse, "logout"); +impl_v2_jsonrpc_request!(v2::NewSessionRequest, v2::NewSessionResponse, "session/new"); +impl_v2_jsonrpc_request!( + v2::LoadSessionRequest, + v2::LoadSessionResponse, + "session/load" +); +impl_v2_jsonrpc_request!( + v2::ListSessionsRequest, + v2::ListSessionsResponse, + "session/list" +); +#[cfg(feature = "unstable_session_delete")] +impl_v2_jsonrpc_request!( + v2::DeleteSessionRequest, + v2::DeleteSessionResponse, + "session/delete" +); +#[cfg(feature = "unstable_session_fork")] +impl_v2_jsonrpc_request!( + v2::ForkSessionRequest, + v2::ForkSessionResponse, + "session/fork" +); +impl_v2_jsonrpc_request!( + v2::ResumeSessionRequest, + v2::ResumeSessionResponse, + "session/resume" +); +impl_v2_jsonrpc_request!( + v2::CloseSessionRequest, + v2::CloseSessionResponse, + "session/close" +); +impl_v2_jsonrpc_request!( + v2::SetSessionModeRequest, + v2::SetSessionModeResponse, + "session/set_mode" +); +impl_v2_jsonrpc_request!( + v2::SetSessionConfigOptionRequest, + v2::SetSessionConfigOptionResponse, + "session/set_config_option" +); +impl_v2_jsonrpc_request!(v2::PromptRequest, v2::PromptResponse, "session/prompt"); +#[cfg(feature = "unstable_session_model")] +impl_v2_jsonrpc_request!( + v2::SetSessionModelRequest, + v2::SetSessionModelResponse, + "session/set_model" +); +#[cfg(feature = "unstable_mcp_over_acp")] +impl_v2_jsonrpc_request!(v2::MessageMcpRequest, v2::MessageMcpResponse, "mcp/message"); + +impl_v2_jsonrpc_notification!(v2::CancelNotification, "session/cancel"); +#[cfg(feature = "unstable_mcp_over_acp")] +impl_v2_jsonrpc_notification!(v2::MessageMcpNotification, "mcp/message"); + +impl_v2_jsonrpc_request!( + v2::WriteTextFileRequest, + v2::WriteTextFileResponse, + "fs/write_text_file" +); +impl_v2_jsonrpc_request!( + v2::ReadTextFileRequest, + v2::ReadTextFileResponse, + "fs/read_text_file" +); +impl_v2_jsonrpc_request!( + v2::RequestPermissionRequest, + v2::RequestPermissionResponse, + "session/request_permission" +); +impl_v2_jsonrpc_request!( + v2::CreateTerminalRequest, + v2::CreateTerminalResponse, + "terminal/create" +); +impl_v2_jsonrpc_request!( + v2::TerminalOutputRequest, + v2::TerminalOutputResponse, + "terminal/output" +); +impl_v2_jsonrpc_request!( + v2::ReleaseTerminalRequest, + v2::ReleaseTerminalResponse, + "terminal/release" +); +impl_v2_jsonrpc_request!( + v2::WaitForTerminalExitRequest, + v2::WaitForTerminalExitResponse, + "terminal/wait_for_exit" +); +impl_v2_jsonrpc_request!( + v2::KillTerminalRequest, + v2::KillTerminalResponse, + "terminal/kill" +); +#[cfg(feature = "unstable_mcp_over_acp")] +impl_v2_jsonrpc_request!(v2::ConnectMcpRequest, v2::ConnectMcpResponse, "mcp/connect"); +#[cfg(feature = "unstable_mcp_over_acp")] +impl_v2_jsonrpc_request!( + v2::DisconnectMcpRequest, + v2::DisconnectMcpResponse, + "mcp/disconnect" +); + +impl_v2_jsonrpc_notification!(v2::SessionNotification, "session/update"); + +impl_v2_jsonrpc_request_enum!(v2::ClientRequest { + InitializeRequest => "initialize", + AuthenticateRequest => "authenticate", + #[cfg(feature = "unstable_logout")] + LogoutRequest => "logout", + NewSessionRequest => "session/new", + LoadSessionRequest => "session/load", + ListSessionsRequest => "session/list", + #[cfg(feature = "unstable_session_delete")] + DeleteSessionRequest => "session/delete", + #[cfg(feature = "unstable_session_fork")] + ForkSessionRequest => "session/fork", + ResumeSessionRequest => "session/resume", + CloseSessionRequest => "session/close", + SetSessionModeRequest => "session/set_mode", + SetSessionConfigOptionRequest => "session/set_config_option", + PromptRequest => "session/prompt", + #[cfg(feature = "unstable_session_model")] + SetSessionModelRequest => "session/set_model", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpRequest => "mcp/message", + [ext] ExtMethodRequest, +}); + +impl_v2_jsonrpc_response_enum!(v2::AgentResponse { + InitializeResponse => "initialize", + AuthenticateResponse => "authenticate", + #[cfg(feature = "unstable_logout")] + LogoutResponse => "logout", + NewSessionResponse => "session/new", + LoadSessionResponse => "session/load", + ListSessionsResponse => "session/list", + #[cfg(feature = "unstable_session_delete")] + DeleteSessionResponse => "session/delete", + #[cfg(feature = "unstable_session_fork")] + ForkSessionResponse => "session/fork", + ResumeSessionResponse => "session/resume", + CloseSessionResponse => "session/close", + SetSessionModeResponse => "session/set_mode", + SetSessionConfigOptionResponse => "session/set_config_option", + PromptResponse => "session/prompt", + #[cfg(feature = "unstable_session_model")] + SetSessionModelResponse => "session/set_model", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpResponse => "mcp/message", + [ext] ExtMethodResponse, +}); + +impl_v2_jsonrpc_notification_enum!(v2::ClientNotification { + CancelNotification => "session/cancel", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpNotification => "mcp/message", + [ext] ExtNotification, +}); + +impl_v2_jsonrpc_request_enum!(v2::AgentRequest { + WriteTextFileRequest => "fs/write_text_file", + ReadTextFileRequest => "fs/read_text_file", + RequestPermissionRequest => "session/request_permission", + CreateTerminalRequest => "terminal/create", + TerminalOutputRequest => "terminal/output", + ReleaseTerminalRequest => "terminal/release", + WaitForTerminalExitRequest => "terminal/wait_for_exit", + KillTerminalRequest => "terminal/kill", + #[cfg(feature = "unstable_mcp_over_acp")] + ConnectMcpRequest => "mcp/connect", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpRequest => "mcp/message", + #[cfg(feature = "unstable_mcp_over_acp")] + DisconnectMcpRequest => "mcp/disconnect", + [ext] ExtMethodRequest, +}); + +impl_v2_jsonrpc_response_enum!(v2::ClientResponse { + WriteTextFileResponse => "fs/write_text_file", + ReadTextFileResponse => "fs/read_text_file", + RequestPermissionResponse => "session/request_permission", + CreateTerminalResponse => "terminal/create", + TerminalOutputResponse => "terminal/output", + ReleaseTerminalResponse => "terminal/release", + WaitForTerminalExitResponse => "terminal/wait_for_exit", + KillTerminalResponse => "terminal/kill", + #[cfg(feature = "unstable_mcp_over_acp")] + ConnectMcpResponse => "mcp/connect", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpResponse => "mcp/message", + #[cfg(feature = "unstable_mcp_over_acp")] + DisconnectMcpResponse => "mcp/disconnect", + [ext] ExtMethodResponse, +}); + +impl_v2_jsonrpc_notification_enum!(v2::AgentNotification { + SessionNotification => "session/update", + #[cfg(feature = "unstable_mcp_over_acp")] + MessageMcpNotification => "mcp/message", + [ext] ExtNotification, +}); diff --git a/src/agent-client-protocol/tests/protocol_v2.rs b/src/agent-client-protocol/tests/protocol_v2.rs new file mode 100644 index 0000000..d6692ec --- /dev/null +++ b/src/agent-client-protocol/tests/protocol_v2.rs @@ -0,0 +1,549 @@ +#![cfg(feature = "unstable_protocol_v2")] + +use std::path::PathBuf; + +use agent_client_protocol::schema::{self, ProtocolVersion, v2}; +use agent_client_protocol::{ + Agent, Builder, Client, ConnectTo, Error, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, + NullHandler, Role, UntypedRole, jsonrpcmsg, +}; +use agent_client_protocol_test::testy::Testy; +use futures::StreamExt as _; +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; + +#[derive(Debug, Clone, Serialize, Deserialize, JsonRpcRequest)] +#[request(method = "initialize", response = ForeignInitializeResponse)] +struct ForeignInitializeRequest { + #[serde(rename = "protocolVersion")] + protocol_version: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonRpcResponse)] +struct ForeignInitializeResponse { + #[serde(rename = "protocolVersion")] + protocol_version: String, +} + +struct ForeignPeer; + +impl ConnectTo for ForeignPeer { + async fn connect_to(self, client: impl ConnectTo) -> Result<(), Error> { + UntypedRole + .builder() + .on_receive_request( + async |request: ForeignInitializeRequest, responder, _cx| { + assert_eq!(request.protocol_version, "2025-06-18"); + responder.respond(ForeignInitializeResponse { + protocol_version: request.protocol_version, + }) + }, + agent_client_protocol::on_receive_request!(), + ) + .connect_to(client) + .await + } +} + +fn cwd() -> Result { + std::env::current_dir().map_err(Error::into_internal_error) +} + +#[cfg(feature = "unstable_mcp_over_acp")] +fn json_value(value: impl Serialize) -> Result { + serde_json::to_value(value).map_err(Error::into_internal_error) +} + +async fn assert_malformed_initialize_rejected(params: Map) -> Result<(), Error> { + let agent = Agent.v2().on_receive_request( + async |_initialize: v2::InitializeRequest, responder, _cx| { + responder.respond_with_internal_error("handler should not run") + }, + agent_client_protocol::on_receive_request!(), + ); + let (mut channel, agent_future) = ConnectTo::::into_channel_and_future(agent); + let agent_task = tokio::spawn(agent_future); + + channel + .tx + .unbounded_send(Ok(jsonrpcmsg::Message::Request( + jsonrpcmsg::Request::new_v2( + "initialize".into(), + Some(jsonrpcmsg::Params::Object(params)), + Some(jsonrpcmsg::Id::Number(1)), + ), + ))) + .map_err(Error::into_internal_error)?; + + while let Some(message) = channel.rx.next().await { + let message = message?; + let jsonrpcmsg::Message::Response(response) = message else { + continue; + }; + let error = response.error.expect("malformed initialize should fail"); + assert_eq!(error.code, -32602); + let data = error + .data + .as_ref() + .and_then(|data| data.as_str()) + .unwrap_or_default(); + assert!(data.contains("protocolVersion"), "{error:?}"); + agent_task.abort(); + return Ok(()); + } + + agent_task.abort(); + Err(agent_client_protocol::util::internal_error( + "agent did not respond to malformed initialize", + )) +} + +async fn assert_v2_client_rejected_by_v1_agent(agent: impl ConnectTo) -> Result<(), Error> { + Client + .v2() + .connect_with(agent, async |cx| { + let error = cx + .send_request(v2::InitializeRequest::new(ProtocolVersion::V2)) + .block_task() + .await + .expect_err("v1 agent protocol mode should reject v2 clients"); + let data = error + .data + .as_ref() + .and_then(|data| data.as_str()) + .unwrap_or_default(); + assert!( + data.contains("required ACP protocol version 2"), + "{error:?}" + ); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn non_acp_initialize_is_not_rewritten() -> Result<(), Error> { + UntypedRole + .builder() + .connect_with(ForeignPeer, async |cx| { + let response = cx + .send_request(ForeignInitializeRequest { + protocol_version: "2025-06-18".into(), + }) + .block_task() + .await?; + + assert_eq!(response.protocol_version, "2025-06-18"); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn v2_agent_rejects_initialize_without_protocol_version() -> Result<(), Error> { + assert_malformed_initialize_rejected(Map::new()).await +} + +#[tokio::test(flavor = "current_thread")] +async fn v2_agent_rejects_initialize_with_malformed_protocol_version() -> Result<(), Error> { + let mut params = Map::new(); + params.insert("protocolVersion".into(), serde_json::json!(100_000)); + + assert_malformed_initialize_rejected(params).await +} + +#[tokio::test(flavor = "current_thread")] +async fn role_builder_v1_agent_rejects_v2_client_negotiation() -> Result<(), Error> { + let agent = ::builder(Agent).on_receive_request( + async |initialize: schema::InitializeRequest, responder, _cx| { + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + responder.respond(schema::InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ); + + Client + .v2() + .connect_with(agent, async |cx| { + let error = cx + .send_request(v2::InitializeRequest::new(ProtocolVersion::V2)) + .block_task() + .await + .expect_err("Role::builder should preserve v1 agent protocol mode"); + let data = error + .data + .as_ref() + .and_then(|data| data.as_str()) + .unwrap_or_default(); + assert!( + data.contains("required ACP protocol version 2"), + "{error:?}" + ); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn builder_new_v1_agent_rejects_v2_client_negotiation() -> Result<(), Error> { + let agent = Builder::new(Agent).on_receive_request( + async |initialize: schema::InitializeRequest, responder, _cx| { + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + responder.respond(schema::InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ); + + assert_v2_client_rejected_by_v1_agent(agent).await +} + +#[tokio::test(flavor = "current_thread")] +async fn builder_new_with_v1_agent_rejects_v2_client_negotiation() -> Result<(), Error> { + let agent = Builder::new_with(Agent, NullHandler).on_receive_request( + async |initialize: schema::InitializeRequest, responder, _cx| { + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + responder.respond(schema::InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ); + + assert_v2_client_rejected_by_v1_agent(agent).await +} + +#[tokio::test(flavor = "current_thread")] +async fn role_builder_v1_client_downgrades_initialize_for_v2_agent() -> Result<(), Error> { + let agent = Agent.v2().on_receive_request( + async |initialize: v2::InitializeRequest, responder, _cx| { + assert_eq!(initialize.protocol_version, ProtocolVersion::V2); + responder.respond(v2::InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ); + + ::builder(Client) + .connect_with(agent, async |cx| { + let initialize = cx + .send_request(schema::InitializeRequest::new(ProtocolVersion::V2)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + Ok(()) + }) + .await +} + +#[test] +fn v2_extension_enum_parsing_preserves_method_prefix() -> Result<(), Error> { + let params = serde_json::json!({ "payload": true }); + + let request = v2::ClientRequest::parse_message("_vendor/request", ¶ms)?; + assert_eq!(request.method(), "_vendor/request"); + let untyped_request = request.to_untyped_message()?; + assert_eq!(untyped_request.method(), "_vendor/request"); + assert_eq!(untyped_request.params(), ¶ms); + + let notification = v2::AgentNotification::parse_message("_vendor/notify", ¶ms)?; + assert_eq!(notification.method(), "_vendor/notify"); + let untyped_notification = notification.to_untyped_message()?; + assert_eq!(untyped_notification.method(), "_vendor/notify"); + assert_eq!(untyped_notification.params(), ¶ms); + + Ok(()) +} + +#[cfg(feature = "unstable_mcp_over_acp")] +#[test] +fn mcp_over_acp_variants_are_jsonrpc_mapped() -> Result<(), Error> { + fn assert_request() {} + fn assert_notification() {} + + macro_rules! assert_message_mapping { + ($ty:ty, $method:literal, $params:expr, $pattern:pat) => {{ + let message = <$ty as JsonRpcMessage>::parse_message($method, &$params)?; + assert_eq!(message.method(), $method); + assert_eq!(message.to_untyped_message()?.method(), $method); + assert!(matches!(message, $pattern)); + }}; + } + + macro_rules! assert_response_mapping { + ($ty:ty, $method:literal, $value:expr, $pattern:pat) => {{ + let response = <$ty as JsonRpcResponse>::from_value($method, $value)?; + assert!(matches!(response, $pattern)); + }}; + } + + assert_request::(); + assert_request::(); + assert_request::(); + assert_notification::(); + + assert_message_mapping!( + schema::ClientRequest, + "mcp/message", + json_value(schema::MessageMcpRequest::new("conn-1", "tools/list"))?, + schema::ClientRequest::MessageMcpRequest(_) + ); + assert_response_mapping!( + schema::AgentResponse, + "mcp/message", + serde_json::json!({ "tools": [] }), + schema::AgentResponse::MessageMcpResponse(_) + ); + assert_message_mapping!( + schema::ClientNotification, + "mcp/message", + json_value(schema::MessageMcpNotification::new( + "conn-1", + "notifications/tools/list" + ))?, + schema::ClientNotification::MessageMcpNotification(_) + ); + assert_message_mapping!( + schema::AgentRequest, + "mcp/connect", + json_value(schema::ConnectMcpRequest::new("server-1"))?, + schema::AgentRequest::ConnectMcpRequest(_) + ); + assert_message_mapping!( + schema::AgentRequest, + "mcp/message", + json_value(schema::MessageMcpRequest::new("conn-1", "tools/list"))?, + schema::AgentRequest::MessageMcpRequest(_) + ); + assert_message_mapping!( + schema::AgentRequest, + "mcp/disconnect", + json_value(schema::DisconnectMcpRequest::new("conn-1"))?, + schema::AgentRequest::DisconnectMcpRequest(_) + ); + assert_response_mapping!( + schema::ClientResponse, + "mcp/connect", + json_value(schema::ConnectMcpResponse::new("conn-1"))?, + schema::ClientResponse::ConnectMcpResponse(_) + ); + assert_response_mapping!( + schema::ClientResponse, + "mcp/message", + serde_json::json!({ "tools": [] }), + schema::ClientResponse::MessageMcpResponse(_) + ); + assert_response_mapping!( + schema::ClientResponse, + "mcp/disconnect", + serde_json::json!({}), + schema::ClientResponse::DisconnectMcpResponse(_) + ); + assert_message_mapping!( + schema::AgentNotification, + "mcp/message", + json_value(schema::MessageMcpNotification::new( + "conn-1", + "notifications/tools/list" + ))?, + schema::AgentNotification::MessageMcpNotification(_) + ); + + assert_message_mapping!( + v2::MessageMcpRequest, + "mcp/message", + json_value(v2::MessageMcpRequest::new("conn-1", "tools/list"))?, + v2::MessageMcpRequest { .. } + ); + assert_message_mapping!( + v2::MessageMcpNotification, + "mcp/message", + json_value(v2::MessageMcpNotification::new( + "conn-1", + "notifications/tools/list" + ))?, + v2::MessageMcpNotification { .. } + ); + assert_message_mapping!( + v2::ConnectMcpRequest, + "mcp/connect", + json_value(v2::ConnectMcpRequest::new("server-1"))?, + v2::ConnectMcpRequest { .. } + ); + assert_message_mapping!( + v2::DisconnectMcpRequest, + "mcp/disconnect", + json_value(v2::DisconnectMcpRequest::new("conn-1"))?, + v2::DisconnectMcpRequest { .. } + ); + + assert_message_mapping!( + v2::ClientRequest, + "mcp/message", + json_value(v2::MessageMcpRequest::new("conn-1", "tools/list"))?, + v2::ClientRequest::MessageMcpRequest(_) + ); + assert_response_mapping!( + v2::AgentResponse, + "mcp/message", + serde_json::json!({ "tools": [] }), + v2::AgentResponse::MessageMcpResponse(_) + ); + assert_message_mapping!( + v2::ClientNotification, + "mcp/message", + json_value(v2::MessageMcpNotification::new( + "conn-1", + "notifications/tools/list" + ))?, + v2::ClientNotification::MessageMcpNotification(_) + ); + assert_message_mapping!( + v2::AgentRequest, + "mcp/connect", + json_value(v2::ConnectMcpRequest::new("server-1"))?, + v2::AgentRequest::ConnectMcpRequest(_) + ); + assert_message_mapping!( + v2::AgentRequest, + "mcp/message", + json_value(v2::MessageMcpRequest::new("conn-1", "tools/list"))?, + v2::AgentRequest::MessageMcpRequest(_) + ); + assert_message_mapping!( + v2::AgentRequest, + "mcp/disconnect", + json_value(v2::DisconnectMcpRequest::new("conn-1"))?, + v2::AgentRequest::DisconnectMcpRequest(_) + ); + assert_response_mapping!( + v2::ClientResponse, + "mcp/connect", + json_value(v2::ConnectMcpResponse::new("conn-1"))?, + v2::ClientResponse::ConnectMcpResponse(_) + ); + assert_response_mapping!( + v2::ClientResponse, + "mcp/message", + serde_json::json!({ "tools": [] }), + v2::ClientResponse::MessageMcpResponse(_) + ); + assert_response_mapping!( + v2::ClientResponse, + "mcp/disconnect", + serde_json::json!({}), + v2::ClientResponse::DisconnectMcpResponse(_) + ); + assert_message_mapping!( + v2::AgentNotification, + "mcp/message", + json_value(v2::MessageMcpNotification::new( + "conn-1", + "notifications/tools/list" + ))?, + v2::AgentNotification::MessageMcpNotification(_) + ); + + Ok(()) +} + +#[tokio::test(flavor = "current_thread")] +async fn v2_agent_serves_v1_client_with_v2_handlers() -> Result<(), Error> { + let agent = Agent + .v2() + .on_receive_request( + async |initialize: v2::InitializeRequest, responder, _cx| { + assert_eq!(initialize.protocol_version, ProtocolVersion::V2); + // The compatibility layer should force this back to the negotiated v1 wire version. + responder.respond(v2::InitializeResponse::new(ProtocolVersion::V2)) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async |request: v2::NewSessionRequest, responder, _cx| { + assert!(request.cwd.is_absolute()); + responder.respond(v2::NewSessionResponse::new(v2::SessionId::new( + "v2-session", + ))) + }, + agent_client_protocol::on_receive_request!(), + ); + + Client + .builder() + .connect_with(agent, async |cx| { + let initialize = cx + .send_request(schema::InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V1); + + let session = cx + .send_request(schema::NewSessionRequest::new(cwd()?)) + .block_task() + .await?; + assert_eq!(session.session_id.0.as_ref(), "v2-session"); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn v2_client_rejects_v1_agent() -> Result<(), Error> { + Client + .v2() + .connect_with(Testy::new(), async |cx| { + let error = cx + .send_request(v2::InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await + .expect_err("v2 clients require a v2 agent"); + let data = error + .data + .as_ref() + .and_then(|data| data.as_str()) + .unwrap_or_default(); + assert!( + data.contains("required ACP protocol version 2"), + "{error:?}" + ); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn v2_client_and_agent_negotiate_v2() -> Result<(), Error> { + let agent = Agent + .v2() + .on_receive_request( + async |initialize: v2::InitializeRequest, responder, _cx| { + assert_eq!(initialize.protocol_version, ProtocolVersion::V2); + responder.respond(v2::InitializeResponse::new(initialize.protocol_version)) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async |request: v2::NewSessionRequest, responder, _cx| { + assert!(request.cwd.is_absolute()); + responder.respond(v2::NewSessionResponse::new(v2::SessionId::new( + "v2-native-session", + ))) + }, + agent_client_protocol::on_receive_request!(), + ); + + Client + .v2() + .connect_with(agent, async |cx| { + let initialize = cx + .send_request(v2::InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + assert_eq!(initialize.protocol_version, ProtocolVersion::V2); + + let session = cx + .send_request(v2::NewSessionRequest::new(cwd()?)) + .block_task() + .await?; + assert_eq!(session.session_id.0.as_ref(), "v2-native-session"); + Ok(()) + }) + .await +}