diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 8673a8bf..78c3be2f 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -184,9 +184,7 @@ macro_rules! server_handler_methods { request: InitializeRequestParams, context: RequestContext, ) -> impl Future> + MaybeSendFuture + '_ { - if context.peer.peer_info().is_none() { - context.peer.set_peer_info(request); - } + context.peer.set_peer_info(request); std::future::ready(Ok(self.get_info())) } fn complete( diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index d938cd66..08791e5e 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -384,7 +384,7 @@ pub struct Peer { tx: mpsc::Sender>, request_id_provider: Arc, progress_token_provider: Arc, - info: Arc>, + info: Arc>>>, } impl std::fmt::Debug for Peer { @@ -423,7 +423,7 @@ impl Peer { tx, request_id_provider, progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()), - info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)), + info: Arc::new(std::sync::RwLock::new(peer_info.map(Arc::new))), }, rx, ) @@ -484,16 +484,14 @@ impl Peer { peer: self.clone(), }) } - pub fn peer_info(&self) -> Option<&R::PeerInfo> { - self.info.get() + /// Snapshot of the peer's handshake info. + pub fn peer_info(&self) -> Option> { + self.info.read().expect("peer info lock poisoned").clone() } + /// Stores the peer's handshake info, overwriting any previous value. pub fn set_peer_info(&self, info: R::PeerInfo) { - if self.info.initialized() { - tracing::warn!("trying to set peer info, which is already initialized"); - } else { - let _ = self.info.set(info); - } + *self.info.write().expect("peer info lock poisoned") = Some(Arc::new(info)); } pub fn is_transport_closed(&self) -> bool { diff --git a/crates/rmcp/tests/test_server_initialization.rs b/crates/rmcp/tests/test_server_initialization.rs index e2e04896..6542e295 100644 --- a/crates/rmcp/tests/test_server_initialization.rs +++ b/crates/rmcp/tests/test_server_initialization.rs @@ -299,6 +299,67 @@ async fn server_pinned_version_used_as_fallback_for_unknown_client_request() { assert_eq!(negotiated, ProtocolVersion::V_2025_06_18); } +fn duplicate_init_request(id: u64, version: &str) -> ClientJsonRpcMessage { + msg(&format!( + r#"{{ + "jsonrpc": "2.0", + "id": {id}, + "method": "initialize", + "params": {{ + "protocolVersion": "{version}", + "capabilities": {{ "sampling": {{}} }}, + "clientInfo": {{ "name": "renegotiated-client", "version": "9.9.9" }} + }} + }}"# + )) +} + +#[tokio::test] +async fn server_accepts_duplicate_initialize() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + do_initialize(&mut client).await; + client.send(initialized_notification()).await.unwrap(); + + client + .send(duplicate_init_request(2, "2025-11-25")) + .await + .unwrap(); + let response = client.receive().await.unwrap(); + assert!( + matches!(response, ServerJsonRpcMessage::Response(_)), + "expected successful InitializeResult, got: {response:?}" + ); +} + +#[tokio::test] +async fn server_session_remains_usable_after_renegotiation() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + do_initialize(&mut client).await; + client.send(initialized_notification()).await.unwrap(); + client + .send(duplicate_init_request(2, "2025-11-25")) + .await + .unwrap(); + let _renegotiated = client.receive().await.unwrap(); + + client.send(ping_request(3)).await.unwrap(); + let pong = client.receive().await.unwrap(); + assert!( + matches!( + pong, + ServerJsonRpcMessage::Response(ref r) + if matches!(r.result, ServerResult::EmptyResult(_)) + ), + "expected EmptyResult ping after renegotiation, got: {pong:?}" + ); +} + // Server buffers multiple requests before initialized and processes them in order. #[tokio::test] async fn server_init_buffers_multiple_requests_before_initialized() {