Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ macro_rules! server_handler_methods {
request: InitializeRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
if context.peer.peer_info().is_none() {
context.peer.set_peer_info(request);
}
context.peer.set_peer_info(request);
std::future::ready(Ok(self.get_info()))
}
fn complete(
Expand Down
16 changes: 7 additions & 9 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ pub struct Peer<R: ServiceRole> {
tx: mpsc::Sender<PeerSinkMessage<R>>,
request_id_provider: Arc<dyn RequestIdProvider>,
progress_token_provider: Arc<dyn ProgressTokenProvider>,
info: Arc<tokio::sync::OnceCell<R::PeerInfo>>,
info: Arc<std::sync::RwLock<Option<Arc<R::PeerInfo>>>>,
}

impl<R: ServiceRole> std::fmt::Debug for Peer<R> {
Expand Down Expand Up @@ -423,7 +423,7 @@ impl<R: ServiceRole> Peer<R> {
tx,
request_id_provider,
progress_token_provider: Arc::new(AtomicU32ProgressTokenProvider::default()),
info: Arc::new(tokio::sync::OnceCell::new_with(peer_info)),
info: Arc::new(std::sync::RwLock::new(peer_info.map(Arc::new))),
},
rx,
)
Expand Down Expand Up @@ -484,16 +484,14 @@ impl<R: ServiceRole> Peer<R> {
peer: self.clone(),
})
}
pub fn peer_info(&self) -> Option<&R::PeerInfo> {
self.info.get()
/// Snapshot of the peer's handshake info.
pub fn peer_info(&self) -> Option<Arc<R::PeerInfo>> {
self.info.read().expect("peer info lock poisoned").clone()
}

/// Stores the peer's handshake info, overwriting any previous value.
pub fn set_peer_info(&self, info: R::PeerInfo) {
if self.info.initialized() {
tracing::warn!("trying to set peer info, which is already initialized");
} else {
let _ = self.info.set(info);
}
*self.info.write().expect("peer info lock poisoned") = Some(Arc::new(info));
}

pub fn is_transport_closed(&self) -> bool {
Expand Down
61 changes: 61 additions & 0 deletions crates/rmcp/tests/test_server_initialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,67 @@ async fn server_pinned_version_used_as_fallback_for_unknown_client_request() {
assert_eq!(negotiated, ProtocolVersion::V_2025_06_18);
}

fn duplicate_init_request(id: u64, version: &str) -> ClientJsonRpcMessage {
msg(&format!(
r#"{{
"jsonrpc": "2.0",
"id": {id},
"method": "initialize",
"params": {{
"protocolVersion": "{version}",
"capabilities": {{ "sampling": {{}} }},
"clientInfo": {{ "name": "renegotiated-client", "version": "9.9.9" }}
}}
}}"#
))
}

#[tokio::test]
async fn server_accepts_duplicate_initialize() {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await });
let mut client = IntoTransport::<rmcp::RoleClient, _, _>::into_transport(client_transport);

do_initialize(&mut client).await;
client.send(initialized_notification()).await.unwrap();

client
.send(duplicate_init_request(2, "2025-11-25"))
.await
.unwrap();
let response = client.receive().await.unwrap();
assert!(
matches!(response, ServerJsonRpcMessage::Response(_)),
"expected successful InitializeResult, got: {response:?}"
);
}

#[tokio::test]
async fn server_session_remains_usable_after_renegotiation() {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await });
let mut client = IntoTransport::<rmcp::RoleClient, _, _>::into_transport(client_transport);

do_initialize(&mut client).await;
client.send(initialized_notification()).await.unwrap();
client
.send(duplicate_init_request(2, "2025-11-25"))
.await
.unwrap();
let _renegotiated = client.receive().await.unwrap();

client.send(ping_request(3)).await.unwrap();
let pong = client.receive().await.unwrap();
assert!(
matches!(
pong,
ServerJsonRpcMessage::Response(ref r)
if matches!(r.result, ServerResult::EmptyResult(_))
),
"expected EmptyResult ping after renegotiation, got: {pong:?}"
);
}

// Server buffers multiple requests before initialized and processes them in order.
#[tokio::test]
async fn server_init_buffers_multiple_requests_before_initialized() {
Expand Down