Skip to content

Commit 3e56d52

Browse files
authored
fix: include http_request_id in request-wise priming event IDs (#799)
* fix: include http_request_id in request-wise priming event IDs * refactor: use Option::into_iter and usize::from for priming * fix: retain event cache for completed request-wise channels * fix: track completed_at for cache eviction and resume * fix: log resume failures at warn level * test: add completed_cache_ttl eviction test * fix: return empty stream on failed resume * test: add resume after completion test
1 parent 6603c1f commit 3e56d52

3 files changed

Lines changed: 467 additions & 89 deletions

File tree

crates/rmcp/src/transport/streamable_http_server/session/local.rs

Lines changed: 84 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use std::{
22
collections::{HashMap, HashSet, VecDeque},
33
num::ParseIntError,
4-
time::Duration,
4+
time::{Duration, Instant},
55
};
66

7-
use futures::Stream;
7+
use futures::{Stream, StreamExt};
88
use thiserror::Error;
99
use tokio::sync::{
1010
mpsc::{Receiver, Sender},
@@ -86,10 +86,17 @@ impl SessionManager for LocalSessionManager {
8686
.get(id)
8787
.ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
8888
let receiver = handle.establish_request_wise_channel().await?;
89-
handle
90-
.push_message(message, receiver.http_request_id)
91-
.await?;
92-
Ok(ReceiverStream::new(receiver.inner))
89+
let http_request_id = receiver.http_request_id;
90+
handle.push_message(message, http_request_id).await?;
91+
92+
let priming = self.session_config.sse_retry.map(|retry| {
93+
let event_id = match http_request_id {
94+
Some(id) => format!("0/{id}"),
95+
None => "0".into(),
96+
};
97+
ServerSseMessage::priming(event_id, retry)
98+
});
99+
Ok(futures::stream::iter(priming).chain(ReceiverStream::new(receiver.inner)))
93100
}
94101

95102
async fn create_standalone_stream(
@@ -188,23 +195,29 @@ struct CachedTx {
188195
cache: VecDeque<ServerSseMessage>,
189196
http_request_id: Option<HttpRequestId>,
190197
capacity: usize,
198+
starting_index: usize,
191199
}
192200

193201
impl CachedTx {
194-
fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
202+
fn new(
203+
tx: Sender<ServerSseMessage>,
204+
http_request_id: Option<HttpRequestId>,
205+
starting_index: usize,
206+
) -> Self {
195207
Self {
196208
cache: VecDeque::with_capacity(tx.capacity()),
197209
capacity: tx.capacity(),
198210
tx,
199211
http_request_id,
212+
starting_index,
200213
}
201214
}
202215
fn new_common(tx: Sender<ServerSseMessage>) -> Self {
203-
Self::new(tx, None)
216+
Self::new(tx, None, 0)
204217
}
205218

206219
fn next_event_id(&self) -> EventId {
207-
let index = self.cache.back().map_or(0, |m| {
220+
let index = self.cache.back().map_or(self.starting_index, |m| {
208221
m.event_id
209222
.as_deref()
210223
.unwrap_or_default()
@@ -272,6 +285,7 @@ impl CachedTx {
272285
struct HttpRequestWise {
273286
resources: HashSet<ResourceKey>,
274287
tx: CachedTx,
288+
completed_at: Option<Instant>,
275289
}
276290

277291
type HttpRequestId = u64;
@@ -342,23 +356,27 @@ pub struct StreamableHttpMessageReceiver {
342356

343357
impl LocalSessionWorker {
344358
fn unregister_resource(&mut self, resource: &ResourceKey) {
345-
if let Some(http_request_id) = self.resource_router.remove(resource) {
346-
tracing::trace!(?resource, http_request_id, "unregister resource");
347-
if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
348-
// It's okey to do so, since we don't handle batch json rpc request anymore
349-
// and this can be refactored after the batch request is removed in the coming version.
350-
if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
351-
{
352-
tracing::debug!(http_request_id, "close http request wise channel");
353-
if let Some(channel) = self.tx_router.remove(&http_request_id) {
354-
for resource in channel.resources {
355-
self.resource_router.remove(&resource);
356-
}
357-
}
358-
}
359-
} else {
360-
tracing::warn!(http_request_id, "http request wise channel not found");
361-
}
359+
let Some(http_request_id) = self.resource_router.remove(resource) else {
360+
return;
361+
};
362+
tracing::trace!(?resource, http_request_id, "unregister resource");
363+
let Some(channel) = self.tx_router.get_mut(&http_request_id) else {
364+
tracing::warn!(http_request_id, "http request wise channel not found");
365+
return;
366+
};
367+
if !channel.resources.is_empty() && !matches!(resource, ResourceKey::McpRequestId(_)) {
368+
return;
369+
}
370+
tracing::debug!(http_request_id, "close http request wise channel");
371+
let resources: Vec<_> = channel.resources.drain().collect();
372+
channel.completed_at = Some(Instant::now());
373+
// Close the sender so the client's SSE stream ends,
374+
// but keep the entry so the cache is available for
375+
// late resume requests.
376+
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
377+
channel.tx.tx = closed_tx;
378+
for resource in resources {
379+
self.resource_router.remove(&resource);
362380
}
363381
}
364382
fn register_resource(&mut self, resource: ResourceKey, http_request_id: HttpRequestId) {
@@ -395,6 +413,11 @@ impl LocalSessionWorker {
395413
self.unregister_resource(&resource);
396414
}
397415
}
416+
fn evict_expired_channels(&mut self) {
417+
let ttl = self.session_config.completed_cache_ttl;
418+
self.tx_router
419+
.retain(|_, rw| rw.completed_at.is_none_or(|at| at.elapsed() < ttl));
420+
}
398421
fn next_http_request_id(&mut self) -> HttpRequestId {
399422
let id = self.next_http_request_id;
400423
self.next_http_request_id = self.next_http_request_id.wrapping_add(1);
@@ -405,11 +428,13 @@ impl LocalSessionWorker {
405428
) -> Result<StreamableHttpMessageReceiver, SessionError> {
406429
let http_request_id = self.next_http_request_id();
407430
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
431+
let starting_index = usize::from(self.session_config.sse_retry.is_some());
408432
self.tx_router.insert(
409433
http_request_id,
410434
HttpRequestWise {
411435
resources: Default::default(),
412-
tx: CachedTx::new(tx, Some(http_request_id)),
436+
tx: CachedTx::new(tx, Some(http_request_id), starting_index),
437+
completed_at: None,
413438
},
414439
);
415440
tracing::debug!(http_request_id, "establish new request wise channel");
@@ -524,28 +549,25 @@ impl LocalSessionWorker {
524549

525550
match last_event_id.http_request_id {
526551
Some(http_request_id) => {
527-
if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
528-
// Resume existing request-wise channel
529-
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
530-
let (tx, rx) = channel;
531-
request_wise.tx.tx = tx;
532-
let index = last_event_id.index;
533-
// sync messages after index
534-
request_wise.tx.sync(index).await?;
535-
Ok(StreamableHttpMessageReceiver {
536-
http_request_id: Some(http_request_id),
537-
inner: rx,
538-
})
539-
} else {
540-
// Request-wise channel completed (POST response already delivered).
541-
// The client's EventSource is reconnecting after the POST SSE stream
542-
// ended. Fall through to common channel handling below.
543-
tracing::debug!(
544-
http_request_id,
545-
"Request-wise channel completed, falling back to common channel"
546-
);
547-
self.resume_or_shadow_common(last_event_id.index).await
552+
let request_wise = self
553+
.tx_router
554+
.get_mut(&http_request_id)
555+
.ok_or(SessionError::ChannelClosed(Some(http_request_id)))?;
556+
let is_completed = request_wise.completed_at.is_some();
557+
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
558+
request_wise.tx.tx = tx;
559+
let index = last_event_id.index;
560+
request_wise.tx.sync(index).await?;
561+
if is_completed {
562+
// Drop the sender after replaying so the stream ends
563+
// instead of hanging indefinitely.
564+
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
565+
request_wise.tx.tx = closed_tx;
548566
}
567+
Ok(StreamableHttpMessageReceiver {
568+
http_request_id: Some(http_request_id),
569+
inner: rx,
570+
})
549571
}
550572
None => self.resume_or_shadow_common(last_event_id.index).await,
551573
}
@@ -955,6 +977,7 @@ impl Worker for LocalSessionWorker {
955977
let ct = context.cancellation_token.clone();
956978
let keep_alive = self.session_config.keep_alive.unwrap_or(Duration::MAX);
957979
loop {
980+
self.evict_expired_channels();
958981
let keep_alive_timeout = tokio::time::sleep(keep_alive);
959982
let event = tokio::select! {
960983
event = self.event_rx.recv() => {
@@ -1076,18 +1099,31 @@ pub struct SessionConfig {
10761099
/// Defaults to 5 minutes. Set to `None` to disable (not recommended
10771100
/// for long-running servers behind proxies).
10781101
pub keep_alive: Option<Duration>,
1102+
/// SSE retry interval for priming events on request-wise streams.
1103+
/// When set, the session layer prepends a priming event with the correct
1104+
/// stream-identifying event ID to each request-wise SSE stream.
1105+
/// Default is 3 seconds, matching `StreamableHttpServerConfig::default()`.
1106+
pub sse_retry: Option<Duration>,
1107+
/// How long to retain completed request-wise channel caches for late
1108+
/// resume requests. After this duration, completed entries are evicted
1109+
/// and resume will return an error. Default is 60 seconds.
1110+
pub completed_cache_ttl: Duration,
10791111
}
10801112

10811113
impl SessionConfig {
10821114
pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
10831115
pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
1116+
pub const DEFAULT_SSE_RETRY: Duration = Duration::from_secs(3);
1117+
pub const DEFAULT_COMPLETED_CACHE_TTL: Duration = Duration::from_secs(60);
10841118
}
10851119

10861120
impl Default for SessionConfig {
10871121
fn default() -> Self {
10881122
Self {
10891123
channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
10901124
keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
1125+
sse_retry: Some(Self::DEFAULT_SSE_RETRY),
1126+
completed_cache_ttl: Self::DEFAULT_COMPLETED_CACHE_TTL,
10911127
}
10921128
}
10931129
}

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -478,40 +478,52 @@ where
478478
.and_then(|v| v.to_str().ok())
479479
.map(|s| s.to_owned());
480480
if let Some(last_event_id) = last_event_id {
481-
// check if session has this event id
482-
let stream = self
481+
match self
483482
.session_manager
484483
.resume(&session_id, last_event_id)
485484
.await
486-
.map_err(internal_error_response("resume session"))?;
487-
// Resume doesn't need priming - client already has the event ID
488-
Ok(sse_stream_response(
489-
stream,
490-
self.config.sse_keep_alive,
491-
self.config.cancellation_token.child_token(),
492-
))
493-
} else {
494-
// create standalone stream
495-
let stream = self
496-
.session_manager
497-
.create_standalone_stream(&session_id)
498-
.await
499-
.map_err(internal_error_response("create standalone stream"))?;
500-
// Prepend priming event if sse_retry configured
501-
let stream = if let Some(retry) = self.config.sse_retry {
502-
let priming = ServerSseMessage::priming("0", retry);
503-
futures::stream::once(async move { priming })
504-
.chain(stream)
505-
.left_stream()
506-
} else {
507-
stream.right_stream()
508-
};
509-
Ok(sse_stream_response(
510-
stream,
511-
self.config.sse_keep_alive,
512-
self.config.cancellation_token.child_token(),
513-
))
485+
{
486+
Ok(stream) => {
487+
return Ok(sse_stream_response(
488+
stream,
489+
self.config.sse_keep_alive,
490+
self.config.cancellation_token.child_token(),
491+
));
492+
}
493+
Err(e) => {
494+
// Return 200 with an immediately-closed empty stream.
495+
// Returning an HTTP error would cause EventSource to retry
496+
// with the same Last-Event-ID in an infinite loop. An empty
497+
// 200 cleanly terminates the EventSource without delivering
498+
// events from a different stream.
499+
tracing::warn!("Resume failed ({e}), returning empty stream");
500+
return Ok(sse_stream_response(
501+
futures::stream::empty(),
502+
None,
503+
self.config.cancellation_token.child_token(),
504+
));
505+
}
506+
}
514507
}
508+
// No Last-Event-ID — create standalone stream
509+
let stream = self
510+
.session_manager
511+
.create_standalone_stream(&session_id)
512+
.await
513+
.map_err(internal_error_response("create standalone stream"))?;
514+
let stream = if let Some(retry) = self.config.sse_retry {
515+
let priming = ServerSseMessage::priming("0", retry);
516+
futures::stream::once(async move { priming })
517+
.chain(stream)
518+
.left_stream()
519+
} else {
520+
stream.right_stream()
521+
};
522+
Ok(sse_stream_response(
523+
stream,
524+
self.config.sse_keep_alive,
525+
self.config.cancellation_token.child_token(),
526+
))
515527
}
516528

517529
async fn handle_post<B>(&self, request: Request<B>) -> Result<BoxResponse, BoxResponse>
@@ -598,20 +610,14 @@ where
598610

599611
match message {
600612
ClientJsonRpcMessage::Request(_) => {
613+
// Priming for request-wise streams is handled by the
614+
// session layer (SessionManager::create_stream) which
615+
// has access to the http_request_id for correct event IDs.
601616
let stream = self
602617
.session_manager
603618
.create_stream(&session_id, message)
604619
.await
605620
.map_err(internal_error_response("get session"))?;
606-
// Prepend priming event if sse_retry configured
607-
let stream = if let Some(retry) = self.config.sse_retry {
608-
let priming = ServerSseMessage::priming("0", retry);
609-
futures::stream::once(async move { priming })
610-
.chain(stream)
611-
.left_stream()
612-
} else {
613-
stream.right_stream()
614-
};
615621
Ok(sse_stream_response(
616622
stream,
617623
self.config.sse_keep_alive,

0 commit comments

Comments
 (0)