11use std:: {
22 collections:: { HashMap , HashSet , VecDeque } ,
33 num:: ParseIntError ,
4- time:: Duration ,
4+ time:: { Duration , Instant } ,
55} ;
66
7- use futures:: Stream ;
7+ use futures:: { Stream , StreamExt } ;
88use thiserror:: Error ;
99use tokio:: sync:: {
1010 mpsc:: { Receiver , Sender } ,
@@ -86,10 +86,17 @@ impl SessionManager for LocalSessionManager {
8686 . get ( id)
8787 . ok_or ( LocalSessionManagerError :: SessionNotFound ( id. clone ( ) ) ) ?;
8888 let receiver = handle. establish_request_wise_channel ( ) . await ?;
89- handle
90- . push_message ( message, receiver. http_request_id )
91- . await ?;
92- Ok ( ReceiverStream :: new ( receiver. inner ) )
89+ let http_request_id = receiver. http_request_id ;
90+ handle. push_message ( message, http_request_id) . await ?;
91+
92+ let priming = self . session_config . sse_retry . map ( |retry| {
93+ let event_id = match http_request_id {
94+ Some ( id) => format ! ( "0/{id}" ) ,
95+ None => "0" . into ( ) ,
96+ } ;
97+ ServerSseMessage :: priming ( event_id, retry)
98+ } ) ;
99+ Ok ( futures:: stream:: iter ( priming) . chain ( ReceiverStream :: new ( receiver. inner ) ) )
93100 }
94101
95102 async fn create_standalone_stream (
@@ -188,23 +195,29 @@ struct CachedTx {
188195 cache : VecDeque < ServerSseMessage > ,
189196 http_request_id : Option < HttpRequestId > ,
190197 capacity : usize ,
198+ starting_index : usize ,
191199}
192200
193201impl CachedTx {
194- fn new ( tx : Sender < ServerSseMessage > , http_request_id : Option < HttpRequestId > ) -> Self {
202+ fn new (
203+ tx : Sender < ServerSseMessage > ,
204+ http_request_id : Option < HttpRequestId > ,
205+ starting_index : usize ,
206+ ) -> Self {
195207 Self {
196208 cache : VecDeque :: with_capacity ( tx. capacity ( ) ) ,
197209 capacity : tx. capacity ( ) ,
198210 tx,
199211 http_request_id,
212+ starting_index,
200213 }
201214 }
202215 fn new_common ( tx : Sender < ServerSseMessage > ) -> Self {
203- Self :: new ( tx, None )
216+ Self :: new ( tx, None , 0 )
204217 }
205218
206219 fn next_event_id ( & self ) -> EventId {
207- let index = self . cache . back ( ) . map_or ( 0 , |m| {
220+ let index = self . cache . back ( ) . map_or ( self . starting_index , |m| {
208221 m. event_id
209222 . as_deref ( )
210223 . unwrap_or_default ( )
@@ -272,6 +285,7 @@ impl CachedTx {
272285struct HttpRequestWise {
273286 resources : HashSet < ResourceKey > ,
274287 tx : CachedTx ,
288+ completed_at : Option < Instant > ,
275289}
276290
277291type HttpRequestId = u64 ;
@@ -342,23 +356,27 @@ pub struct StreamableHttpMessageReceiver {
342356
343357impl LocalSessionWorker {
344358 fn unregister_resource ( & mut self , resource : & ResourceKey ) {
345- if let Some ( http_request_id) = self . resource_router . remove ( resource) {
346- tracing:: trace!( ?resource, http_request_id, "unregister resource" ) ;
347- if let Some ( channel) = self . tx_router . get_mut ( & http_request_id) {
348- // It's okey to do so, since we don't handle batch json rpc request anymore
349- // and this can be refactored after the batch request is removed in the coming version.
350- if channel. resources . is_empty ( ) || matches ! ( resource, ResourceKey :: McpRequestId ( _) )
351- {
352- tracing:: debug!( http_request_id, "close http request wise channel" ) ;
353- if let Some ( channel) = self . tx_router . remove ( & http_request_id) {
354- for resource in channel. resources {
355- self . resource_router . remove ( & resource) ;
356- }
357- }
358- }
359- } else {
360- tracing:: warn!( http_request_id, "http request wise channel not found" ) ;
361- }
359+ let Some ( http_request_id) = self . resource_router . remove ( resource) else {
360+ return ;
361+ } ;
362+ tracing:: trace!( ?resource, http_request_id, "unregister resource" ) ;
363+ let Some ( channel) = self . tx_router . get_mut ( & http_request_id) else {
364+ tracing:: warn!( http_request_id, "http request wise channel not found" ) ;
365+ return ;
366+ } ;
367+ if !channel. resources . is_empty ( ) && !matches ! ( resource, ResourceKey :: McpRequestId ( _) ) {
368+ return ;
369+ }
370+ tracing:: debug!( http_request_id, "close http request wise channel" ) ;
371+ let resources: Vec < _ > = channel. resources . drain ( ) . collect ( ) ;
372+ channel. completed_at = Some ( Instant :: now ( ) ) ;
373+ // Close the sender so the client's SSE stream ends,
374+ // but keep the entry so the cache is available for
375+ // late resume requests.
376+ let ( closed_tx, _) = tokio:: sync:: mpsc:: channel ( 1 ) ;
377+ channel. tx . tx = closed_tx;
378+ for resource in resources {
379+ self . resource_router . remove ( & resource) ;
362380 }
363381 }
364382 fn register_resource ( & mut self , resource : ResourceKey , http_request_id : HttpRequestId ) {
@@ -395,6 +413,11 @@ impl LocalSessionWorker {
395413 self . unregister_resource ( & resource) ;
396414 }
397415 }
416+ fn evict_expired_channels ( & mut self ) {
417+ let ttl = self . session_config . completed_cache_ttl ;
418+ self . tx_router
419+ . retain ( |_, rw| rw. completed_at . is_none_or ( |at| at. elapsed ( ) < ttl) ) ;
420+ }
398421 fn next_http_request_id ( & mut self ) -> HttpRequestId {
399422 let id = self . next_http_request_id ;
400423 self . next_http_request_id = self . next_http_request_id . wrapping_add ( 1 ) ;
@@ -405,11 +428,13 @@ impl LocalSessionWorker {
405428 ) -> Result < StreamableHttpMessageReceiver , SessionError > {
406429 let http_request_id = self . next_http_request_id ( ) ;
407430 let ( tx, rx) = tokio:: sync:: mpsc:: channel ( self . session_config . channel_capacity ) ;
431+ let starting_index = usize:: from ( self . session_config . sse_retry . is_some ( ) ) ;
408432 self . tx_router . insert (
409433 http_request_id,
410434 HttpRequestWise {
411435 resources : Default :: default ( ) ,
412- tx : CachedTx :: new ( tx, Some ( http_request_id) ) ,
436+ tx : CachedTx :: new ( tx, Some ( http_request_id) , starting_index) ,
437+ completed_at : None ,
413438 } ,
414439 ) ;
415440 tracing:: debug!( http_request_id, "establish new request wise channel" ) ;
@@ -524,28 +549,25 @@ impl LocalSessionWorker {
524549
525550 match last_event_id. http_request_id {
526551 Some ( http_request_id) => {
527- if let Some ( request_wise) = self . tx_router . get_mut ( & http_request_id) {
528- // Resume existing request-wise channel
529- let channel = tokio:: sync:: mpsc:: channel ( self . session_config . channel_capacity ) ;
530- let ( tx, rx) = channel;
531- request_wise. tx . tx = tx;
532- let index = last_event_id. index ;
533- // sync messages after index
534- request_wise. tx . sync ( index) . await ?;
535- Ok ( StreamableHttpMessageReceiver {
536- http_request_id : Some ( http_request_id) ,
537- inner : rx,
538- } )
539- } else {
540- // Request-wise channel completed (POST response already delivered).
541- // The client's EventSource is reconnecting after the POST SSE stream
542- // ended. Fall through to common channel handling below.
543- tracing:: debug!(
544- http_request_id,
545- "Request-wise channel completed, falling back to common channel"
546- ) ;
547- self . resume_or_shadow_common ( last_event_id. index ) . await
552+ let request_wise = self
553+ . tx_router
554+ . get_mut ( & http_request_id)
555+ . ok_or ( SessionError :: ChannelClosed ( Some ( http_request_id) ) ) ?;
556+ let is_completed = request_wise. completed_at . is_some ( ) ;
557+ let ( tx, rx) = tokio:: sync:: mpsc:: channel ( self . session_config . channel_capacity ) ;
558+ request_wise. tx . tx = tx;
559+ let index = last_event_id. index ;
560+ request_wise. tx . sync ( index) . await ?;
561+ if is_completed {
562+ // Drop the sender after replaying so the stream ends
563+ // instead of hanging indefinitely.
564+ let ( closed_tx, _) = tokio:: sync:: mpsc:: channel ( 1 ) ;
565+ request_wise. tx . tx = closed_tx;
548566 }
567+ Ok ( StreamableHttpMessageReceiver {
568+ http_request_id : Some ( http_request_id) ,
569+ inner : rx,
570+ } )
549571 }
550572 None => self . resume_or_shadow_common ( last_event_id. index ) . await ,
551573 }
@@ -955,6 +977,7 @@ impl Worker for LocalSessionWorker {
955977 let ct = context. cancellation_token . clone ( ) ;
956978 let keep_alive = self . session_config . keep_alive . unwrap_or ( Duration :: MAX ) ;
957979 loop {
980+ self . evict_expired_channels ( ) ;
958981 let keep_alive_timeout = tokio:: time:: sleep ( keep_alive) ;
959982 let event = tokio:: select! {
960983 event = self . event_rx. recv( ) => {
@@ -1076,18 +1099,31 @@ pub struct SessionConfig {
10761099 /// Defaults to 5 minutes. Set to `None` to disable (not recommended
10771100 /// for long-running servers behind proxies).
10781101 pub keep_alive : Option < Duration > ,
1102+ /// SSE retry interval for priming events on request-wise streams.
1103+ /// When set, the session layer prepends a priming event with the correct
1104+ /// stream-identifying event ID to each request-wise SSE stream.
1105+ /// Default is 3 seconds, matching `StreamableHttpServerConfig::default()`.
1106+ pub sse_retry : Option < Duration > ,
1107+ /// How long to retain completed request-wise channel caches for late
1108+ /// resume requests. After this duration, completed entries are evicted
1109+ /// and resume will return an error. Default is 60 seconds.
1110+ pub completed_cache_ttl : Duration ,
10791111}
10801112
10811113impl SessionConfig {
10821114 pub const DEFAULT_CHANNEL_CAPACITY : usize = 16 ;
10831115 pub const DEFAULT_KEEP_ALIVE : Duration = Duration :: from_secs ( 300 ) ;
1116+ pub const DEFAULT_SSE_RETRY : Duration = Duration :: from_secs ( 3 ) ;
1117+ pub const DEFAULT_COMPLETED_CACHE_TTL : Duration = Duration :: from_secs ( 60 ) ;
10841118}
10851119
10861120impl Default for SessionConfig {
10871121 fn default ( ) -> Self {
10881122 Self {
10891123 channel_capacity : Self :: DEFAULT_CHANNEL_CAPACITY ,
10901124 keep_alive : Some ( Self :: DEFAULT_KEEP_ALIVE ) ,
1125+ sse_retry : Some ( Self :: DEFAULT_SSE_RETRY ) ,
1126+ completed_cache_ttl : Self :: DEFAULT_COMPLETED_CACHE_TTL ,
10911127 }
10921128 }
10931129}
0 commit comments