diff --git a/md/SUMMARY.md b/md/SUMMARY.md index 7ee0349..3e67924 100644 --- a/md/SUMMARY.md +++ b/md/SUMMARY.md @@ -6,6 +6,7 @@ - [Design Overview](./design.md) - [Protocol Reference](./protocol.md) +- [Protocol Versioning](./protocol-versions.md) # Conductor (agent-client-protocol-conductor) diff --git a/md/protocol-versions.md b/md/protocol-versions.md new file mode 100644 index 0000000..63c4bd5 --- /dev/null +++ b/md/protocol-versions.md @@ -0,0 +1,30 @@ +# Protocol Versioning + +The SDK normally exposes the stable ACP v1 schema types through `agent_client_protocol::schema::*`. + +For experiments against the draft ACP v2 schema, enable the `unstable_protocol_v2` feature on +`agent-client-protocol`. With that feature enabled, `schema::*` resolves to the schema crate's +`v2` types by default, while the connection layer still speaks either v1 or v2 on the wire. + +## Negotiation + +Protocol version negotiation is driven by the normal `initialize` request and response: + +- Before any `initialize` message is observed, non-initialize ACP messages are treated as v2 when the v2 feature is enabled. +- While an `initialize` request is in flight, the requested `protocolVersion` is used provisionally for wire conversion. +- The `initialize` request is encoded according to the requested `protocolVersion`. +- The `initialize` response records the negotiated wire version on the connection. +- Later known ACP requests, responses, and notifications are downgraded to v1 or left as v2 based on that negotiated version. + +The conversion is internal to the SDK. Agent and client handlers continue to use the feature-selected +`schema::*` types. With `unstable_protocol_v2` enabled, that means user code handles v2 types even +when the remote side negotiated v1. + +## Scope + +The adapter converts known ACP payloads at the untyped JSON-RPC boundary using the schema crate's v2 +conversion module. Custom JSON-RPC methods and extension methods are passed through unchanged. + +The v2 feature is intentionally separate from the existing `unstable` umbrella because it changes the +SDK's default Rust type namespace. It should be enabled explicitly by experiments that are ready to +compile against the draft v2 types. diff --git a/src/agent-client-protocol-conductor/src/conductor.rs b/src/agent-client-protocol-conductor/src/conductor.rs index 25ae367..f871cf6 100644 --- a/src/agent-client-protocol-conductor/src/conductor.rs +++ b/src/agent-client-protocol-conductor/src/conductor.rs @@ -834,7 +834,13 @@ where conductor_tx .send(ConductorMessage::LeftToRight { target_component_index: component_index + 1, - message: dispatch.map(|r, cx| (r.message, cx), |n| n.message), + message: dispatch.map( + |r, responder| { + let method = r.message.method().to_string(); + (r.message, responder.wrap_method(method)) + }, + |n| n.message, + ), }) .await .map_err(agent_client_protocol::util::internal_error) diff --git a/src/agent-client-protocol/CHANGELOG.md b/src/agent-client-protocol/CHANGELOG.md index 6c0b977..25f5f5f 100644 --- a/src/agent-client-protocol/CHANGELOG.md +++ b/src/agent-client-protocol/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [Unreleased] + +### Added + +- Add `unstable_protocol_v2` support that uses draft v2 schema types by default while negotiating v1 or v2 wire payloads internally. + ## [0.11.1](https://github.com/agentclientprotocol/rust-sdk/compare/v0.11.0...v0.11.1) - 2026-04-21 ### Fixed diff --git a/src/agent-client-protocol/Cargo.toml b/src/agent-client-protocol/Cargo.toml index 14731dc..96727ea 100644 --- a/src/agent-client-protocol/Cargo.toml +++ b/src/agent-client-protocol/Cargo.toml @@ -33,6 +33,7 @@ unstable_session_additional_directories = ["agent-client-protocol-schema/unstabl unstable_session_fork = ["agent-client-protocol-schema/unstable_session_fork"] unstable_session_model = ["agent-client-protocol-schema/unstable_session_model"] unstable_session_usage = ["agent-client-protocol-schema/unstable_session_usage"] +unstable_protocol_v2 = ["agent-client-protocol-schema/unstable_protocol_v2"] [dependencies] agent-client-protocol-schema.workspace = true diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 588ef6f..dd95155 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1,6 +1,5 @@ //! Core JSON-RPC server support. -use agent_client_protocol_schema::SessionId; // Re-export jsonrpcmsg for use in public API pub use jsonrpcmsg; @@ -34,6 +33,7 @@ use crate::jsonrpc::task_actor::{Task, TaskTx}; use crate::mcp_server::McpServer; use crate::role::HasPeer; use crate::role::Role; +use crate::schema::{METHOD_SUCCESSOR_MESSAGE, SessionId}; use crate::util::json_cast; use crate::{Agent, Client, ConnectTo, RoleId}; @@ -1178,11 +1178,15 @@ impl< let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); let (new_task_tx, new_task_rx) = mpsc::unbounded(); let (dynamic_handler_tx, dynamic_handler_rx) = mpsc::unbounded(); + #[cfg(feature = "unstable_protocol_v2")] + let protocol_state = crate::schema::v2_compat::ProtocolState::default(); let connection = ConnectionTo::new( me.counterpart(), outgoing_tx, new_task_tx, dynamic_handler_tx, + #[cfg(feature = "unstable_protocol_v2")] + protocol_state.clone(), ); // Convert transport into server - this returns a channel for us to use @@ -1211,6 +1215,8 @@ impl< outgoing_rx, reply_tx.clone(), transport_outgoing_tx, + #[cfg(feature = "unstable_protocol_v2")] + protocol_state.clone(), ), // Protocol layer: jsonrpcmsg::Message → handler/reply routing incoming_actor::incoming_protocol_actor( @@ -1220,6 +1226,8 @@ impl< dynamic_handler_rx, reply_rx, handler, + #[cfg(feature = "unstable_protocol_v2")] + protocol_state.clone(), ), task_actor::task_actor(new_task_rx, &connection), responder.run_with_connection_to(connection.clone()), @@ -1341,6 +1349,8 @@ enum OutgoingMessage { Response { id: jsonrpcmsg::Id, + method: String, + response: Result, }, @@ -1424,6 +1434,8 @@ pub struct ConnectionTo { message_tx: OutgoingMessageTx, task_tx: TaskTx, dynamic_handler_tx: mpsc::UnboundedSender>, + #[cfg(feature = "unstable_protocol_v2")] + protocol_state: crate::schema::v2_compat::ProtocolState, } impl ConnectionTo { @@ -1432,12 +1444,16 @@ impl ConnectionTo { message_tx: mpsc::UnboundedSender, task_tx: mpsc::UnboundedSender, dynamic_handler_tx: mpsc::UnboundedSender>, + #[cfg(feature = "unstable_protocol_v2")] + protocol_state: crate::schema::v2_compat::ProtocolState, ) -> Self { Self { counterpart, message_tx, task_tx, dynamic_handler_tx, + #[cfg(feature = "unstable_protocol_v2")] + protocol_state, } } @@ -1446,6 +1462,15 @@ impl ConnectionTo { self.counterpart.clone() } + /// Return the protocol version negotiated by the initialize handshake. + /// + /// This is `None` until an `initialize` response is sent or received. + #[cfg(feature = "unstable_protocol_v2")] + #[must_use] + pub fn negotiated_protocol_version(&self) -> Option { + self.protocol_state.negotiated_protocol_version() + } + /// Spawns a task that will run so long as the JSON-RPC connection is being served. /// /// This is the primary mechanism for offloading expensive work from handler callbacks @@ -1662,12 +1687,14 @@ impl ConnectionTo { let (response_tx, response_rx) = oneshot::channel(); let role_id = peer.role_id(); let remote_style = self.counterpart.remote_style(peer); + let mut response_method = method.clone(); match remote_style.transform_outgoing_message(request) { Ok(untyped) => { + response_method = response_method_for_outgoing_request(&method, &untyped); // Transform the message for the target role let message = OutgoingMessage::Request { id: id.clone(), - method: method.clone(), + method: response_method.clone(), role_id, untyped, response_tx, @@ -1709,8 +1736,13 @@ impl ConnectionTo { } } - SentRequest::new(id, method.clone(), self.task_tx.clone(), response_rx) - .map(move |json| ::from_value(&method, json)) + SentRequest::new( + id, + response_method.clone(), + self.task_tx.clone(), + response_rx, + ) + .map(move |json| ::from_value(&response_method, json)) } /// Send an outgoing notification to the default counterpart peer (no reply expected). @@ -1811,6 +1843,16 @@ impl ConnectionTo { } } +fn response_method_for_outgoing_request(method: &str, untyped: &UntypedMessage) -> String { + if untyped.method == METHOD_SUCCESSOR_MESSAGE + && let Some(serde_json::Value::String(inner_method)) = untyped.params.get("method") + { + return inner_method.clone(); + } + + method.to_string() +} + #[derive(Clone, Debug)] pub struct DynamicHandlerRegistration { uuid: Uuid, @@ -1888,7 +1930,7 @@ pub struct Responder { /// /// For incoming requests: serializes to JSON and sends over the wire. /// For incoming responses: sends to the waiting oneshot channel. - send_fn: Box) -> Result<(), crate::Error> + Send>, + send_fn: Box) -> Result<(), crate::Error> + Send>, } impl std::fmt::Debug for Responder { @@ -1910,15 +1952,18 @@ impl Responder { Self { method, id, - send_fn: Box::new(move |response: Result| { - send_raw_message( - &message_tx, - OutgoingMessage::Response { - id: id_clone, - response, - }, - ) - }), + send_fn: Box::new( + move |method, response: Result| { + send_raw_message( + &message_tx, + OutgoingMessage::Response { + id: id_clone, + method, + response, + }, + ) + }, + ), } } @@ -1970,13 +2015,17 @@ impl Responder { self, wrap_fn: impl FnOnce(&str, Result) -> Result + Send + 'static, ) -> Responder { - let method = self.method.clone(); + let Self { + method, + id, + send_fn, + } = self; Responder { - method: self.method, - id: self.id, - send_fn: Box::new(move |input: Result| { + method, + id, + send_fn: Box::new(move |method, input: Result| { let t_value = wrap_fn(&method, input); - (self.send_fn)(t_value) + send_fn(method, t_value) }), } } @@ -1986,8 +2035,13 @@ impl Responder { self, response: Result, ) -> Result<(), crate::Error> { - tracing::debug!(id = ?self.id, "respond called"); - (self.send_fn)(response) + let Self { + method, + id, + send_fn, + } = self; + tracing::debug!(id = ?id, "respond called"); + send_fn(method, response) } /// Respond to the JSON-RPC request with a value. diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index 0aa2a7a..eb47f2c 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -20,6 +20,8 @@ use crate::jsonrpc::ResponseRouter; use crate::jsonrpc::dynamic_handler::DynHandleDispatchFrom; use crate::jsonrpc::dynamic_handler::DynamicHandlerMessage; use crate::jsonrpc::outgoing_actor::send_raw_message; +#[cfg(feature = "unstable_protocol_v2")] +use crate::schema::v2_compat::ProtocolState; use crate::role::Role; @@ -50,6 +52,7 @@ pub(super) async fn incoming_protocol_actor( dynamic_handler_rx: mpsc::UnboundedReceiver>, reply_rx: mpsc::UnboundedReceiver, mut handler: impl HandleDispatchFrom, + #[cfg(feature = "unstable_protocol_v2")] protocol_state: ProtocolState, ) -> Result<(), crate::Error> { let mut my_rx = transport_rx .map(IncomingProtocolMsg::Transport) @@ -96,6 +99,7 @@ pub(super) async fn incoming_protocol_actor( for pending_message in pending_messages { tracing::trace!(method = pending_message.method(), handler = ?handler.dyn_describe_chain(), "Retrying message"); let id = pending_message.id(); + let method = pending_message.method().to_string(); match handler .dyn_handle_dispatch_from(pending_message, connection.clone()) .await @@ -112,7 +116,7 @@ pub(super) async fn incoming_protocol_actor( } Err(err) => { tracing::warn!(?err, handler = ?handler.dyn_describe_chain(), "Dynamic handler errored on pending message, reporting back"); - report_handler_error(connection, id, err)?; + report_handler_error(connection, id, method, err)?; } } } @@ -131,6 +135,16 @@ pub(super) async fn incoming_protocol_actor( jsonrpcmsg::Message::Request(request) => { tracing::trace!(method = %request.method, id = ?request.id, "Handling request"); let dispatch = dispatch_from_request(connection, request); + #[cfg(feature = "unstable_protocol_v2")] + let dispatch = match protocol_state.convert_incoming_dispatch(dispatch) { + Ok(dispatch) => dispatch, + Err(error) => { + error + .dispatch + .respond_with_error(error.error, connection.clone())?; + continue; + } + }; dispatch_dispatch( counterpart.clone(), connection, @@ -158,6 +172,18 @@ pub(super) async fn incoming_protocol_actor( if let Some(pending_reply) = pending_replies.remove(&id_json) { // Route the response through the handler chain let dispatch = dispatch_from_response(id, pending_reply, result); + #[cfg(feature = "unstable_protocol_v2")] + let dispatch = match protocol_state + .convert_incoming_dispatch(dispatch) + { + Ok(dispatch) => dispatch, + Err(error) => { + error + .dispatch + .respond_with_error(error.error, connection.clone())?; + continue; + } + }; dispatch_dispatch( counterpart.clone(), connection, @@ -275,7 +301,7 @@ async fn dispatch_dispatch( Err(err) => { tracing::warn!(?method, ?id, ?err, handler = ?handler.describe_chain(), "Handler errored, reporting back to remote"); - return report_handler_error(connection, id, err); + return report_handler_error(connection, id, method, err); } } @@ -299,7 +325,7 @@ async fn dispatch_dispatch( Err(err) => { tracing::warn!(?method, ?id, ?err, handler = ?dynamic_handler.dyn_describe_chain(), "Dynamic handler errored, reporting back to remote"); - return report_handler_error(connection, id, err); + return report_handler_error(connection, id, method, err); } } } @@ -327,7 +353,7 @@ async fn dispatch_dispatch( handler = "default", "Default handler errored, reporting back to remote" ); - return report_handler_error(connection, id, err); + return report_handler_error(connection, id, method, err); } } @@ -367,6 +393,7 @@ async fn dispatch_dispatch( fn report_handler_error( connection: &ConnectionTo, id: Option, + method: String, error: crate::Error, ) -> Result<(), crate::Error> { match id { @@ -377,6 +404,7 @@ fn report_handler_error( &connection.message_tx, OutgoingMessage::Response { id: jsonrpc_id, + method, response: Err(error), }, ) diff --git a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs index 7eee713..748c398 100644 --- a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs @@ -4,6 +4,10 @@ use futures::channel::mpsc; use crate::jsonrpc::OutgoingMessage; use crate::jsonrpc::ReplyMessage; +#[cfg(feature = "unstable_protocol_v2")] +use crate::jsonrpc::ResponsePayload; +#[cfg(feature = "unstable_protocol_v2")] +use crate::schema::v2_compat::ProtocolState; pub type OutgoingMessageTx = mpsc::UnboundedSender; @@ -27,6 +31,7 @@ pub(super) async fn outgoing_protocol_actor( mut outgoing_rx: mpsc::UnboundedReceiver, reply_tx: mpsc::UnboundedSender, transport_tx: mpsc::UnboundedSender>, + #[cfg(feature = "unstable_protocol_v2")] protocol_state: ProtocolState, ) -> Result<(), crate::Error> { while let Some(message) = outgoing_rx.next().await { tracing::debug!(?message, "outgoing_protocol_actor"); @@ -40,6 +45,18 @@ pub(super) async fn outgoing_protocol_actor( untyped, response_tx, } => { + #[cfg(feature = "unstable_protocol_v2")] + let untyped = match protocol_state.convert_outgoing_request(untyped) { + Ok(untyped) => untyped, + Err(error) => { + drop(response_tx.send(ResponsePayload { + result: Err(error), + ack_tx: None, + })); + continue; + } + }; + // Record where the reply should be sent once it arrives. reply_tx .unbounded_send(ReplyMessage::Subscribe { @@ -53,31 +70,55 @@ pub(super) async fn outgoing_protocol_actor( jsonrpcmsg::Message::Request(untyped.into_jsonrpc_msg(Some(id))?) } OutgoingMessage::Notification { untyped } => { + #[cfg(feature = "unstable_protocol_v2")] + let untyped = match protocol_state.convert_outgoing_notification(untyped) { + Ok(untyped) => untyped, + Err(error) => { + tracing::warn!( + ?error, + "dropping outgoing notification that cannot be converted" + ); + continue; + } + }; + let msg = untyped.into_jsonrpc_msg(None)?; jsonrpcmsg::Message::Request(msg) } OutgoingMessage::Response { id, - response: Ok(value), - } => { - tracing::debug!(?id, "Sending success response"); - jsonrpcmsg::Message::Response(jsonrpcmsg::Response::success_v2(value, Some(id))) - } - OutgoingMessage::Response { - id, - response: Err(error), + method, + response, } => { - tracing::warn!(?id, ?error, "Sending error response"); - // Convert crate::Error to jsonrpcmsg::Error - let jsonrpc_error = jsonrpcmsg::Error { - code: error.code.into(), - message: error.message, - data: error.data, - }; - jsonrpcmsg::Message::Response(jsonrpcmsg::Response::error_v2( - jsonrpc_error, - Some(id), - )) + #[cfg(not(feature = "unstable_protocol_v2"))] + drop(method); + + #[cfg(feature = "unstable_protocol_v2")] + let response = protocol_state + .convert_outgoing_response(&method, response) + .unwrap_or_else(Err); + + match response { + Ok(value) => { + tracing::debug!(?id, "Sending success response"); + jsonrpcmsg::Message::Response(jsonrpcmsg::Response::success_v2( + value, + Some(id), + )) + } + Err(error) => { + tracing::warn!(?id, ?error, "Sending error response"); + let jsonrpc_error = jsonrpcmsg::Error { + code: error.code.into(), + message: error.message, + data: error.data, + }; + jsonrpcmsg::Message::Response(jsonrpcmsg::Response::error_v2( + jsonrpc_error, + Some(id), + )) + } + } } OutgoingMessage::Error { error } => { // Convert crate::Error to jsonrpcmsg::Error diff --git a/src/agent-client-protocol/src/mcp_server/server.rs b/src/agent-client-protocol/src/mcp_server/server.rs index 5d5fd3f..53c6994 100644 --- a/src/agent-client-protocol/src/mcp_server/server.rs +++ b/src/agent-client-protocol/src/mcp_server/server.rs @@ -2,7 +2,6 @@ use std::{marker::PhantomData, sync::Arc}; -use agent_client_protocol_schema::NewSessionRequest; use futures::{StreamExt, channel::mpsc}; use uuid::Uuid; @@ -18,6 +17,7 @@ use crate::{ builder::McpServerBuilder, }, role::{self, HasPeer}, + schema::NewSessionRequest, util::MatchDispatchFrom, }; diff --git a/src/agent-client-protocol/src/role.rs b/src/agent-client-protocol/src/role.rs index e1c28aa..da12cd5 100644 --- a/src/agent-client-protocol/src/role.rs +++ b/src/agent-client-protocol/src/role.rs @@ -230,8 +230,16 @@ where "Response variant cannot be unwrapped as SuccessorMessage", ) })?; - let SuccessorMessage { message, meta } = json_cast(untyped_message.params())?; - let successor_dispatch = dispatch.try_map_message(|_| Ok(message))?; + let SuccessorMessage { message, meta }: SuccessorMessage = + json_cast(untyped_message.params())?; + let method = message.method().to_string(); + let successor_dispatch = match dispatch { + Dispatch::Request(_, responder) => { + Dispatch::Request(message, responder.wrap_method(method)) + } + Dispatch::Notification(_) => Dispatch::Notification(message), + Dispatch::Response(_, _) => unreachable!("response dispatches were rejected above"), + }; tracing::trace!( unwrapped_method = %successor_dispatch.method(), "handle_incoming_dispatch: unwrapped to inner message" diff --git a/src/agent-client-protocol/src/role/acp.rs b/src/agent-client-protocol/src/role/acp.rs index 3ffb334..46056cc 100644 --- a/src/agent-client-protocol/src/role/acp.rs +++ b/src/agent-client-protocol/src/role/acp.rs @@ -1,10 +1,11 @@ use std::{fmt::Debug, hash::Hash}; -use agent_client_protocol_schema::{NewSessionRequest, NewSessionResponse, SessionId}; - use crate::jsonrpc::{Builder, handlers::NullHandler, run::NullRun}; use crate::role::{HasPeer, RemoteStyle}; -use crate::schema::{InitializeProxyRequest, InitializeRequest, METHOD_INITIALIZE_PROXY}; +use crate::schema::{ + InitializeProxyRequest, InitializeRequest, METHOD_INITIALIZE_PROXY, NewSessionRequest, + NewSessionResponse, SessionId, +}; use crate::util::MatchDispatchFrom; use crate::{ConnectTo, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, Role, RoleId}; diff --git a/src/agent-client-protocol/src/schema/mod.rs b/src/agent-client-protocol/src/schema/mod.rs index 8367031..29b57f6 100644 --- a/src/agent-client-protocol/src/schema/mod.rs +++ b/src/agent-client-protocol/src/schema/mod.rs @@ -224,9 +224,18 @@ mod agent_to_client; mod client_to_agent; mod enum_impls; mod proxy_protocol; +#[cfg(feature = "unstable_protocol_v2")] +pub(crate) mod v2_compat; // Re-export everything from agent_client_protocol_schema +#[cfg(feature = "unstable_protocol_v2")] +pub use agent_client_protocol_schema::v2::*; +#[cfg(not(feature = "unstable_protocol_v2"))] pub use agent_client_protocol_schema::*; +#[cfg(feature = "unstable_protocol_v2")] +pub use agent_client_protocol_schema::{ + IntoMaybeUndefined, IntoOption, MaybeUndefined, ProtocolVersion, SkipListener, +}; // Re-export proxy/MCP protocol types pub use proxy_protocol::*; diff --git a/src/agent-client-protocol/src/schema/proxy_protocol.rs b/src/agent-client-protocol/src/schema/proxy_protocol.rs index aa35dd1..3852840 100644 --- a/src/agent-client-protocol/src/schema/proxy_protocol.rs +++ b/src/agent-client-protocol/src/schema/proxy_protocol.rs @@ -2,8 +2,8 @@ //! //! These types are intended to become part of the ACP protocol specification. +use crate::schema::{InitializeRequest, InitializeResponse}; use crate::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, UntypedMessage}; -use agent_client_protocol_schema::InitializeResponse; use serde::{Deserialize, Serialize}; // ============================================================================= @@ -197,11 +197,11 @@ pub const METHOD_INITIALIZE_PROXY: &str = "_proxy/initialize"; pub struct InitializeProxyRequest { /// The underlying initialize request data. #[serde(flatten)] - pub initialize: agent_client_protocol_schema::InitializeRequest, + pub initialize: InitializeRequest, } -impl From for InitializeProxyRequest { - fn from(initialize: agent_client_protocol_schema::InitializeRequest) -> Self { +impl From for InitializeProxyRequest { + fn from(initialize: InitializeRequest) -> Self { Self { initialize } } } diff --git a/src/agent-client-protocol/src/schema/v2_compat.rs b/src/agent-client-protocol/src/schema/v2_compat.rs new file mode 100644 index 0000000..161e23f --- /dev/null +++ b/src/agent-client-protocol/src/schema/v2_compat.rs @@ -0,0 +1,507 @@ +//! Runtime conversion between v1 wire payloads and the SDK's v2 default types. + +use std::sync::{ + Arc, + atomic::{AtomicU8, Ordering}, +}; + +use agent_client_protocol_schema::v2; +use agent_client_protocol_schema::v2::conversion::{IntoV1, IntoV2}; +use agent_client_protocol_schema::{self as v1, ProtocolVersion}; +use serde::{Serialize, de::DeserializeOwned}; +use serde_json::Value; + +use crate::{Dispatch, UntypedMessage}; + +const STATE_UNKNOWN: u8 = 0; +const STATE_V1: u8 = 1; +const STATE_V2: u8 = 2; + +const METHOD_INITIALIZE: &str = "initialize"; +const METHOD_INITIALIZE_PROXY: &str = "_proxy/initialize"; +const METHOD_SUCCESSOR_MESSAGE: &str = "_proxy/successor"; + +#[derive(Clone, Debug, Default)] +pub(crate) struct ProtocolState { + active_wire_version: Arc, + negotiated_wire_version: Arc, +} + +impl ProtocolState { + pub(crate) fn negotiated_protocol_version(&self) -> Option { + match self.negotiated_wire_version.load(Ordering::SeqCst) { + STATE_V1 => Some(ProtocolVersion::V1), + STATE_V2 => Some(ProtocolVersion::V2), + STATE_UNKNOWN => None, + _ => unreachable!("invalid protocol state"), + } + } + + pub(crate) fn convert_incoming_dispatch( + &self, + dispatch: Dispatch, + ) -> Result { + match dispatch { + Dispatch::Request(message, responder) => { + let original = message.clone(); + match self.convert_incoming_message(MessageKind::Request, message) { + Ok(message) => Ok(Dispatch::Request(message, responder)), + Err(error) => Err(IncomingConversionError { + dispatch: Dispatch::Request(original, responder), + error, + }), + } + } + Dispatch::Notification(message) => { + let original = message.clone(); + match self.convert_incoming_message(MessageKind::Notification, message) { + Ok(message) => Ok(Dispatch::Notification(message)), + Err(error) => Err(IncomingConversionError { + dispatch: Dispatch::Notification(original), + error, + }), + } + } + Dispatch::Response(result, router) => { + let result = self + .convert_incoming_response(router.method(), result) + .unwrap_or_else(Err); + Ok(Dispatch::Response(result, router)) + } + } + } + + pub(crate) fn convert_outgoing_request( + &self, + message: UntypedMessage, + ) -> Result { + self.convert_outgoing_message(MessageKind::Request, message) + } + + pub(crate) fn convert_outgoing_notification( + &self, + message: UntypedMessage, + ) -> Result { + self.convert_outgoing_message(MessageKind::Notification, message) + } + + pub(crate) fn convert_outgoing_response( + &self, + method: &str, + response: Result, + ) -> Result, crate::Error> { + let Ok(value) = response else { + return Ok(response); + }; + + let wire_version = self.outgoing_response_wire_version(method, &value); + let value = match wire_version { + WireVersion::V1 => response_v2_to_v1(method, value)?, + WireVersion::V2 => value, + }; + Ok(Ok(value)) + } + + fn convert_incoming_message( + &self, + kind: MessageKind, + message: UntypedMessage, + ) -> Result { + let wire_version = self.incoming_message_wire_version(&message); + let params = match wire_version { + WireVersion::V1 => match kind { + MessageKind::Request => request_v1_to_v2(&message.method, message.params)?, + MessageKind::Notification => { + notification_v1_to_v2(&message.method, message.params)? + } + }, + WireVersion::V2 => message.params, + }; + Ok(UntypedMessage { + method: message.method, + params, + }) + } + + fn convert_outgoing_message( + &self, + kind: MessageKind, + message: UntypedMessage, + ) -> Result { + let wire_version = self.outgoing_message_wire_version(&message); + let params = match wire_version { + WireVersion::V1 => match kind { + MessageKind::Request => request_v2_to_v1(&message.method, message.params)?, + MessageKind::Notification => { + notification_v2_to_v1(&message.method, message.params)? + } + }, + WireVersion::V2 => message.params, + }; + Ok(UntypedMessage { + method: message.method, + params, + }) + } + + fn convert_incoming_response( + &self, + method: &str, + response: Result, + ) -> Result, crate::Error> { + let Ok(value) = response else { + return Ok(response); + }; + + let wire_version = self.incoming_response_wire_version(method, &value); + let value = match wire_version { + WireVersion::V1 => response_v1_to_v2(method, value)?, + WireVersion::V2 => value, + }; + Ok(Ok(value)) + } + + fn incoming_message_wire_version(&self, message: &UntypedMessage) -> WireVersion { + if is_initialize_method(&message.method) { + let wire_version = wire_version_from_params(&message.params); + self.set_provisional_wire_version(wire_version); + return wire_version; + } + + self.current_wire_version().unwrap_or(WireVersion::V2) + } + + fn outgoing_message_wire_version(&self, message: &UntypedMessage) -> WireVersion { + if is_initialize_method(&message.method) { + let wire_version = wire_version_from_params(&message.params); + self.set_provisional_wire_version(wire_version); + return wire_version; + } + + self.current_wire_version().unwrap_or(WireVersion::V2) + } + + fn incoming_response_wire_version(&self, method: &str, value: &Value) -> WireVersion { + if is_initialize_method(method) { + let wire_version = wire_version_from_params(value); + self.set_wire_version(wire_version); + return wire_version; + } + + self.current_wire_version().unwrap_or(WireVersion::V2) + } + + fn outgoing_response_wire_version(&self, method: &str, value: &Value) -> WireVersion { + if is_initialize_method(method) { + let wire_version = wire_version_from_params(value); + self.set_wire_version(wire_version); + return wire_version; + } + + self.current_wire_version().unwrap_or(WireVersion::V2) + } + + fn current_wire_version(&self) -> Option { + match self.active_wire_version.load(Ordering::SeqCst) { + STATE_UNKNOWN => None, + STATE_V1 => Some(WireVersion::V1), + STATE_V2 => Some(WireVersion::V2), + _ => unreachable!("invalid protocol state"), + } + } + + fn set_provisional_wire_version(&self, version: WireVersion) { + self.active_wire_version + .store(version as u8, Ordering::SeqCst); + } + + fn set_wire_version(&self, version: WireVersion) { + self.set_provisional_wire_version(version); + self.negotiated_wire_version + .store(version as u8, Ordering::SeqCst); + } +} + +#[derive(Debug)] +pub(crate) struct IncomingConversionError { + pub(crate) dispatch: Dispatch, + pub(crate) error: crate::Error, +} + +#[derive(Clone, Copy, Debug)] +enum MessageKind { + Request, + Notification, +} + +#[derive(Clone, Copy, Debug)] +enum WireVersion { + V1 = STATE_V1 as isize, + V2 = STATE_V2 as isize, +} + +fn is_initialize_method(method: &str) -> bool { + matches!(method, METHOD_INITIALIZE | METHOD_INITIALIZE_PROXY) +} + +fn wire_version_for_protocol_version(version: ProtocolVersion) -> WireVersion { + if version >= ProtocolVersion::V2 { + WireVersion::V2 + } else { + WireVersion::V1 + } +} + +fn wire_version_from_params(params: &Value) -> WireVersion { + protocol_version_from_params(params).map_or(WireVersion::V1, wire_version_for_protocol_version) +} + +fn protocol_version_from_params(params: &Value) -> Option { + params + .get("protocolVersion") + .cloned() + .and_then(|value| serde_json::from_value(value).ok()) +} + +fn conversion_error(error: v2::conversion::ProtocolConversionError) -> crate::Error { + crate::Error::internal_error().data(error.to_string()) +} + +fn into_v1_value(value: Value) -> Result +where + T: DeserializeOwned + IntoV1, + T::Output: Serialize, +{ + let typed = serde_json::from_value::(value)?; + let converted = typed.into_v1().map_err(conversion_error)?; + serde_json::to_value(converted).map_err(crate::Error::into_internal_error) +} + +fn into_v2_value(value: Value) -> Result +where + T: DeserializeOwned + IntoV2, + T::Output: Serialize, +{ + let typed = serde_json::from_value::(value)?; + let converted = typed.into_v2().map_err(conversion_error)?; + serde_json::to_value(converted).map_err(crate::Error::into_internal_error) +} + +macro_rules! convert_request_to_v1 { + ($method:expr, $params:expr, {$($(#[$meta:meta])* $name:pat => $ty:ty,)*}) => { + match $method { + $($(#[$meta])* $name => into_v1_value::<$ty>($params),)* + _ if $method == METHOD_SUCCESSOR_MESSAGE => { + convert_successor_message($params, request_v2_to_v1) + } + _ => Ok($params), + } + }; +} + +macro_rules! convert_request_to_v2 { + ($method:expr, $params:expr, {$($(#[$meta:meta])* $name:pat => $ty:ty,)*}) => { + match $method { + $($(#[$meta])* $name => into_v2_value::<$ty>($params),)* + _ if $method == METHOD_SUCCESSOR_MESSAGE => { + convert_successor_message($params, request_v1_to_v2) + } + _ => Ok($params), + } + }; +} + +macro_rules! convert_notification_to_v1 { + ($method:expr, $params:expr, {$($(#[$meta:meta])* $name:pat => $ty:ty,)*}) => { + match $method { + $($(#[$meta])* $name => into_v1_value::<$ty>($params),)* + _ if $method == METHOD_SUCCESSOR_MESSAGE => { + convert_successor_message($params, notification_v2_to_v1) + } + _ => Ok($params), + } + }; +} + +macro_rules! convert_notification_to_v2 { + ($method:expr, $params:expr, {$($(#[$meta:meta])* $name:pat => $ty:ty,)*}) => { + match $method { + $($(#[$meta])* $name => into_v2_value::<$ty>($params),)* + _ if $method == METHOD_SUCCESSOR_MESSAGE => { + convert_successor_message($params, notification_v1_to_v2) + } + _ => Ok($params), + } + }; +} + +macro_rules! convert_response_to_v1 { + ($method:expr, $params:expr, {$($(#[$meta:meta])* $name:pat => $ty:ty,)*}) => { + match $method { + $($(#[$meta])* $name => into_v1_value::<$ty>($params),)* + _ => Ok($params), + } + }; +} + +macro_rules! convert_response_to_v2 { + ($method:expr, $params:expr, {$($(#[$meta:meta])* $name:pat => $ty:ty,)*}) => { + match $method { + $($(#[$meta])* $name => into_v2_value::<$ty>($params),)* + _ => Ok($params), + } + }; +} + +fn request_v2_to_v1(method: &str, params: Value) -> Result { + convert_request_to_v1!(method, params, { + METHOD_INITIALIZE => v2::InitializeRequest, + METHOD_INITIALIZE_PROXY => v2::InitializeRequest, + "authenticate" => v2::AuthenticateRequest, + #[cfg(feature = "unstable_logout")] + "logout" => v2::LogoutRequest, + "session/new" => v2::NewSessionRequest, + "session/load" => v2::LoadSessionRequest, + "session/list" => v2::ListSessionsRequest, + #[cfg(feature = "unstable_session_fork")] + "session/fork" => v2::ForkSessionRequest, + "session/resume" => v2::ResumeSessionRequest, + "session/close" => v2::CloseSessionRequest, + "session/set_mode" => v2::SetSessionModeRequest, + "session/set_config_option" => v2::SetSessionConfigOptionRequest, + "session/prompt" => v2::PromptRequest, + #[cfg(feature = "unstable_session_model")] + "session/set_model" => v2::SetSessionModelRequest, + "fs/write_text_file" => v2::WriteTextFileRequest, + "fs/read_text_file" => v2::ReadTextFileRequest, + "session/request_permission" => v2::RequestPermissionRequest, + "terminal/create" => v2::CreateTerminalRequest, + "terminal/output" => v2::TerminalOutputRequest, + "terminal/release" => v2::ReleaseTerminalRequest, + "terminal/wait_for_exit" => v2::WaitForTerminalExitRequest, + "terminal/kill" => v2::KillTerminalRequest, + }) +} + +fn request_v1_to_v2(method: &str, params: Value) -> Result { + convert_request_to_v2!(method, params, { + METHOD_INITIALIZE => v1::InitializeRequest, + METHOD_INITIALIZE_PROXY => v1::InitializeRequest, + "authenticate" => v1::AuthenticateRequest, + #[cfg(feature = "unstable_logout")] + "logout" => v1::LogoutRequest, + "session/new" => v1::NewSessionRequest, + "session/load" => v1::LoadSessionRequest, + "session/list" => v1::ListSessionsRequest, + #[cfg(feature = "unstable_session_fork")] + "session/fork" => v1::ForkSessionRequest, + "session/resume" => v1::ResumeSessionRequest, + "session/close" => v1::CloseSessionRequest, + "session/set_mode" => v1::SetSessionModeRequest, + "session/set_config_option" => v1::SetSessionConfigOptionRequest, + "session/prompt" => v1::PromptRequest, + #[cfg(feature = "unstable_session_model")] + "session/set_model" => v1::SetSessionModelRequest, + "fs/write_text_file" => v1::WriteTextFileRequest, + "fs/read_text_file" => v1::ReadTextFileRequest, + "session/request_permission" => v1::RequestPermissionRequest, + "terminal/create" => v1::CreateTerminalRequest, + "terminal/output" => v1::TerminalOutputRequest, + "terminal/release" => v1::ReleaseTerminalRequest, + "terminal/wait_for_exit" => v1::WaitForTerminalExitRequest, + "terminal/kill" => v1::KillTerminalRequest, + }) +} + +fn notification_v2_to_v1(method: &str, params: Value) -> Result { + convert_notification_to_v1!(method, params, { + "session/cancel" => v2::CancelNotification, + "session/update" => v2::SessionNotification, + }) +} + +fn notification_v1_to_v2(method: &str, params: Value) -> Result { + convert_notification_to_v2!(method, params, { + "session/cancel" => v1::CancelNotification, + "session/update" => v1::SessionNotification, + }) +} + +fn response_v2_to_v1(method: &str, params: Value) -> Result { + convert_response_to_v1!(method, params, { + METHOD_INITIALIZE => v2::InitializeResponse, + METHOD_INITIALIZE_PROXY => v2::InitializeResponse, + "authenticate" => v2::AuthenticateResponse, + #[cfg(feature = "unstable_logout")] + "logout" => v2::LogoutResponse, + "session/new" => v2::NewSessionResponse, + "session/load" => v2::LoadSessionResponse, + "session/list" => v2::ListSessionsResponse, + #[cfg(feature = "unstable_session_fork")] + "session/fork" => v2::ForkSessionResponse, + "session/resume" => v2::ResumeSessionResponse, + "session/close" => v2::CloseSessionResponse, + "session/set_mode" => v2::SetSessionModeResponse, + "session/set_config_option" => v2::SetSessionConfigOptionResponse, + "session/prompt" => v2::PromptResponse, + #[cfg(feature = "unstable_session_model")] + "session/set_model" => v2::SetSessionModelResponse, + "fs/write_text_file" => v2::WriteTextFileResponse, + "fs/read_text_file" => v2::ReadTextFileResponse, + "session/request_permission" => v2::RequestPermissionResponse, + "terminal/create" => v2::CreateTerminalResponse, + "terminal/output" => v2::TerminalOutputResponse, + "terminal/release" => v2::ReleaseTerminalResponse, + "terminal/wait_for_exit" => v2::WaitForTerminalExitResponse, + "terminal/kill" => v2::KillTerminalResponse, + }) +} + +fn response_v1_to_v2(method: &str, params: Value) -> Result { + convert_response_to_v2!(method, params, { + METHOD_INITIALIZE => v1::InitializeResponse, + METHOD_INITIALIZE_PROXY => v1::InitializeResponse, + "authenticate" => v1::AuthenticateResponse, + #[cfg(feature = "unstable_logout")] + "logout" => v1::LogoutResponse, + "session/new" => v1::NewSessionResponse, + "session/load" => v1::LoadSessionResponse, + "session/list" => v1::ListSessionsResponse, + #[cfg(feature = "unstable_session_fork")] + "session/fork" => v1::ForkSessionResponse, + "session/resume" => v1::ResumeSessionResponse, + "session/close" => v1::CloseSessionResponse, + "session/set_mode" => v1::SetSessionModeResponse, + "session/set_config_option" => v1::SetSessionConfigOptionResponse, + "session/prompt" => v1::PromptResponse, + #[cfg(feature = "unstable_session_model")] + "session/set_model" => v1::SetSessionModelResponse, + "fs/write_text_file" => v1::WriteTextFileResponse, + "fs/read_text_file" => v1::ReadTextFileResponse, + "session/request_permission" => v1::RequestPermissionResponse, + "terminal/create" => v1::CreateTerminalResponse, + "terminal/output" => v1::TerminalOutputResponse, + "terminal/release" => v1::ReleaseTerminalResponse, + "terminal/wait_for_exit" => v1::WaitForTerminalExitResponse, + "terminal/kill" => v1::KillTerminalResponse, + }) +} + +fn convert_successor_message( + params: Value, + convert_inner: fn(&str, Value) -> Result, +) -> Result { + let Value::Object(mut message) = params else { + return Ok(params); + }; + let Some(Value::String(method)) = message.get("method").cloned() else { + return Ok(Value::Object(message)); + }; + let Some(params) = message.remove("params") else { + return Ok(Value::Object(message)); + }; + + let params = convert_inner(&method, params)?; + message.insert("params".to_string(), params); + Ok(Value::Object(message)) +} diff --git a/src/agent-client-protocol/src/session.rs b/src/agent-client-protocol/src/session.rs index 144d403..dca6c23 100644 --- a/src/agent-client-protocol/src/session.rs +++ b/src/agent-client-protocol/src/session.rs @@ -1,9 +1,5 @@ use std::{future::Future, marker::PhantomData, path::Path}; -use agent_client_protocol_schema::{ - ContentBlock, ContentChunk, NewSessionRequest, NewSessionResponse, PromptRequest, - PromptResponse, SessionModeState, SessionNotification, SessionUpdate, StopReason, -}; use futures::channel::{mpsc, oneshot}; use crate::{ @@ -14,7 +10,11 @@ use crate::{ }, mcp_server::McpServer, role::{HasPeer, acp::ProxySessionMessages}, - schema::SessionId, + schema::{ + ContentBlock, ContentChunk, NewSessionRequest, NewSessionResponse, PromptRequest, + PromptResponse, SessionId, SessionModeState, SessionNotification, SessionUpdate, + StopReason, + }, util::{MatchDispatch, MatchDispatchFrom, run_until}, }; diff --git a/src/agent-client-protocol/tests/protocol_v2.rs b/src/agent-client-protocol/tests/protocol_v2.rs new file mode 100644 index 0000000..8ae053b --- /dev/null +++ b/src/agent-client-protocol/tests/protocol_v2.rs @@ -0,0 +1,443 @@ +#![cfg(feature = "unstable_protocol_v2")] + +use agent_client_protocol::schema::{ + AgentCapabilities, InitializeProxyRequest, InitializeRequest, InitializeResponse, + ListSessionsRequest, ListSessionsResponse, NewSessionRequest, NewSessionResponse, + ProtocolVersion, SessionId, SuccessorMessage, +}; +use agent_client_protocol::{ + Agent, Channel, Client, Conductor, ConnectionTo, Handled, Proxy, Responder, UntypedMessage, +}; +use std::sync::{Arc, Mutex}; + +async fn run_initialize_test( + protocol_version: ProtocolVersion, +) -> agent_client_protocol::Result<()> { + let local = tokio::task::LocalSet::new(); + + local + .run_until(async move { + assert!( + std::any::type_name::().contains("::v2::"), + "unstable_protocol_v2 should make schema::* resolve to v2 types" + ); + + let (client_channel, agent_channel) = Channel::duplex(); + let expected_version = protocol_version; + + let agent = Agent + .builder() + .on_receive_request( + async move |initialize: InitializeRequest, + responder: Responder, + cx: ConnectionTo| { + assert_eq!(cx.negotiated_protocol_version(), None); + responder.respond( + InitializeResponse::new(initialize.protocol_version) + .agent_capabilities(AgentCapabilities::new()), + ) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async move |_request: NewSessionRequest, + responder: Responder, + cx: ConnectionTo| { + assert_eq!(cx.negotiated_protocol_version(), Some(expected_version)); + responder.respond(NewSessionResponse::new(SessionId::new("session-1"))) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + agent.connect_to(agent_channel).await.ok(); + }); + + Client + .builder() + .connect_with(client_channel, async move |cx| { + let initialize = cx + .send_request(InitializeRequest::new(protocol_version)) + .block_task() + .await?; + + assert_eq!(initialize.protocol_version, protocol_version); + assert_eq!(cx.negotiated_protocol_version(), Some(protocol_version)); + + let new_session = cx + .send_request(NewSessionRequest::new( + std::env::current_dir() + .map_err(agent_client_protocol::Error::into_internal_error)?, + )) + .block_task() + .await?; + + assert_eq!(new_session.session_id, SessionId::new("session-1")); + Ok(()) + }) + .await + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn v2_schema_can_negotiate_v1_wire_version() -> agent_client_protocol::Result<()> { + run_initialize_test(ProtocolVersion::V1).await +} + +#[tokio::test(flavor = "current_thread")] +async fn v2_schema_can_negotiate_v2_wire_version() -> agent_client_protocol::Result<()> { + run_initialize_test(ProtocolVersion::V2).await +} + +#[tokio::test(flavor = "current_thread")] +async fn successor_request_responses_use_inner_method_for_conversion() +-> agent_client_protocol::Result<()> { + let local = tokio::task::LocalSet::new(); + + local + .run_until(async move { + let (conductor_channel, proxy_channel) = Channel::duplex(); + + let proxy = Proxy + .builder() + .on_receive_request_from( + Client, + async move |initialize: InitializeProxyRequest, + responder: Responder, + _cx: ConnectionTo| { + responder.respond( + InitializeResponse::new(initialize.initialize.protocol_version) + .agent_capabilities(AgentCapabilities::new()), + ) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request_from( + Agent, + async move |_request: NewSessionRequest, + responder: Responder, + _cx: ConnectionTo| { + let method = responder.method().to_string(); + if method != "session/new" { + return responder.respond_with_error( + agent_client_protocol::Error::internal_error().data(method), + ); + } + + responder.respond(NewSessionResponse::new(SessionId::new("session-1"))) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + proxy.connect_to(proxy_channel).await.ok(); + }); + + Conductor + .builder() + .connect_with(conductor_channel, async move |cx| { + cx.send_request(InitializeProxyRequest::from(InitializeRequest::new( + ProtocolVersion::V1, + ))) + .block_task() + .await?; + + let request = SuccessorMessage { + message: NewSessionRequest::new( + std::env::current_dir() + .map_err(agent_client_protocol::Error::into_internal_error)?, + ), + meta: None, + }; + + let response = cx.send_request(request); + assert_eq!(response.method(), "session/new"); + + let response = response.block_task().await?; + assert_eq!(response.session_id, SessionId::new("session-1")); + Ok(()) + }) + .await + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn successor_response_conversion_errors_use_inner_method() -> agent_client_protocol::Result<()> +{ + let local = tokio::task::LocalSet::new(); + + local + .run_until(async move { + let (conductor_channel, proxy_channel) = Channel::duplex(); + + let proxy = Proxy + .builder() + .on_receive_request_from( + Client, + async move |initialize: InitializeProxyRequest, + responder: Responder, + _cx: ConnectionTo| { + responder.respond( + InitializeResponse::new(initialize.initialize.protocol_version) + .agent_capabilities(AgentCapabilities::new()), + ) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request_from( + Agent, + async move |request: UntypedMessage, + responder: Responder, + _cx: ConnectionTo| { + if request.method() != "session/new" { + return Ok(Handled::No { + message: (request, responder), + retry: false, + }); + } + + responder.respond(serde_json::json!({"invalid": true}))?; + Ok(Handled::Yes) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + proxy.connect_to(proxy_channel).await.ok(); + }); + + Conductor + .builder() + .connect_with(conductor_channel, async move |cx| { + cx.send_request(InitializeProxyRequest::from(InitializeRequest::new( + ProtocolVersion::V1, + ))) + .block_task() + .await?; + + let request = SuccessorMessage { + message: NewSessionRequest::new( + std::env::current_dir() + .map_err(agent_client_protocol::Error::into_internal_error)?, + ), + meta: None, + }; + + let response = cx.send_request(request); + assert_eq!(response.method(), "session/new"); + + let error = response.block_task().await.unwrap_err(); + assert_eq!(error.code, agent_client_protocol::ErrorCode::InvalidParams); + Ok(()) + }) + .await + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn initialize_request_sets_provisional_wire_version() -> agent_client_protocol::Result<()> { + let local = tokio::task::LocalSet::new(); + + local + .run_until(async move { + let (client_channel, agent_channel) = Channel::duplex(); + let (error_tx, error_rx) = tokio::sync::oneshot::channel(); + let error_tx = Arc::new(Mutex::new(Some(error_tx))); + + let agent = Agent.builder().on_receive_request( + { + let error_tx = error_tx.clone(); + async move |initialize: InitializeRequest, + responder: Responder, + cx: ConnectionTo| { + assert_eq!(cx.negotiated_protocol_version(), None); + + let bad_request = UntypedMessage::new( + "session/new", + serde_json::json!({"invalid": true}), + )?; + + if let Some(error_tx) = error_tx.lock().unwrap().take() { + cx.send_request(bad_request).on_receiving_result( + async move |result| { + let error = result.unwrap_err(); + error_tx + .send(error.code) + .map_err(|_| agent_client_protocol::Error::internal_error()) + }, + )?; + } + + responder.respond( + InitializeResponse::new(initialize.protocol_version) + .agent_capabilities(AgentCapabilities::new()), + ) + } + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + agent.connect_to(agent_channel).await.ok(); + }); + + Client + .builder() + .connect_with(client_channel, async move |cx| { + cx.send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + + assert_eq!(cx.negotiated_protocol_version(), Some(ProtocolVersion::V1)); + let error_code = error_rx + .await + .map_err(agent_client_protocol::Error::into_internal_error)?; + assert_eq!(error_code, agent_client_protocol::ErrorCode::InvalidParams); + Ok(()) + }) + .await + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn outgoing_request_conversion_error_is_returned_to_sender() +-> agent_client_protocol::Result<()> { + let local = tokio::task::LocalSet::new(); + + local + .run_until(async move { + let (client_channel, agent_channel) = Channel::duplex(); + + let agent = Agent + .builder() + .on_receive_request( + async move |initialize: InitializeRequest, + responder: Responder, + _cx: ConnectionTo| { + responder.respond( + InitializeResponse::new(initialize.protocol_version) + .agent_capabilities(AgentCapabilities::new()), + ) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async move |_request: ListSessionsRequest, + responder: Responder, + _cx: ConnectionTo| { + responder.respond(ListSessionsResponse::new(vec![])) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + agent.connect_to(agent_channel).await.ok(); + }); + + Client + .builder() + .connect_with(client_channel, async move |cx| { + cx.send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + + let bad_request = + UntypedMessage::new("session/new", serde_json::json!({"invalid": true}))?; + let error = cx.send_request(bad_request).block_task().await.unwrap_err(); + assert_eq!(error.code, agent_client_protocol::ErrorCode::InvalidParams); + + let sessions = cx + .send_request(ListSessionsRequest::new()) + .block_task() + .await?; + assert!(sessions.sessions.is_empty()); + Ok(()) + }) + .await + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn outgoing_response_conversion_error_is_sent_as_json_rpc_error() +-> agent_client_protocol::Result<()> { + let local = tokio::task::LocalSet::new(); + + local + .run_until(async move { + let (client_channel, agent_channel) = Channel::duplex(); + + let agent = Agent + .builder() + .on_receive_request( + async move |initialize: InitializeRequest, + responder: Responder, + _cx: ConnectionTo| { + responder.respond( + InitializeResponse::new(initialize.protocol_version) + .agent_capabilities(AgentCapabilities::new()), + ) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async move |_request: ListSessionsRequest, + responder: Responder, + _cx: ConnectionTo| { + responder.respond(ListSessionsResponse::new(vec![])) + }, + agent_client_protocol::on_receive_request!(), + ) + .on_receive_request( + async move |request: UntypedMessage, + responder: Responder, + _cx: ConnectionTo| { + if request.method() != "session/new" { + return Ok(Handled::No { + message: (request, responder), + retry: false, + }); + } + + responder.respond(serde_json::json!({"invalid": true}))?; + Ok(Handled::Yes) + }, + agent_client_protocol::on_receive_request!(), + ); + + tokio::task::spawn_local(async move { + agent.connect_to(agent_channel).await.ok(); + }); + + Client + .builder() + .connect_with(client_channel, async move |cx| { + cx.send_request(InitializeRequest::new(ProtocolVersion::V1)) + .block_task() + .await?; + + let error = cx + .send_request(NewSessionRequest::new( + std::env::current_dir() + .map_err(agent_client_protocol::Error::into_internal_error)?, + )) + .block_task() + .await + .unwrap_err(); + assert_eq!(error.code, agent_client_protocol::ErrorCode::InvalidParams); + + let sessions = cx + .send_request(ListSessionsRequest::new()) + .block_task() + .await?; + assert!(sessions.sessions.is_empty()); + Ok(()) + }) + .await + }) + .await +}