From 104b99503afbe954562d9de8d0257a40ea3360f1 Mon Sep 17 00:00:00 2001 From: Dale Seo Date: Sat, 27 Dec 2025 22:27:37 -0500 Subject: [PATCH 1/2] feat: implement SEP-1699 SSE polling via server-side disconnect --- crates/rmcp/Cargo.toml | 8 +- .../src/transport/common/server_side_http.rs | 23 ++- .../streamable_http_server/session/local.rs | 141 ++++++++++++++- .../transport/streamable_http_server/tower.rs | 77 ++++++-- .../tests/test_streamable_http_priming.rs | 168 ++++++++++++++++++ 5 files changed, 395 insertions(+), 22 deletions(-) create mode 100644 crates/rmcp/tests/test_streamable_http_priming.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 2b63f66f..b86f2abe 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -97,6 +97,7 @@ server-side-http = [ "dep:http-body-util", "dep:bytes", "dep:sse-stream", + "dep:axum", "tower", ] @@ -201,4 +202,9 @@ path = "tests/test_elicitation.rs" [[test]] name = "test_task" required-features = ["server", "client", "macros"] -path = "tests/test_task.rs" \ No newline at end of file +path = "tests/test_task.rs" + +[[test]] +name = "test_streamable_http_priming" +required-features = ["server", "client", "transport-streamable-http-server", "reqwest"] +path = "tests/test_streamable_http_priming.rs" \ No newline at end of file diff --git a/crates/rmcp/src/transport/common/server_side_http.rs b/crates/rmcp/src/transport/common/server_side_http.rs index 51cd51f5..8b5aa8ad 100644 --- a/crates/rmcp/src/transport/common/server_side_http.rs +++ b/crates/rmcp/src/transport/common/server_side_http.rs @@ -59,8 +59,14 @@ impl sse_stream::Timer for TokioTimer { #[derive(Debug, Clone)] pub struct ServerSseMessage { + /// The event ID for this message. When set, clients can use this ID + /// with the `Last-Event-ID` header to resume the stream from this point. pub event_id: Option, - pub message: Arc, + /// The JSON-RPC message content. For priming events, set this to `None`. + pub message: Option>, + /// The retry interval hint for clients. Clients should wait this duration + /// before attempting to reconnect. This maps to the SSE `retry:` field. + pub retry: Option, } pub(crate) fn sse_stream_response( @@ -71,9 +77,20 @@ pub(crate) fn sse_stream_response( use futures::StreamExt; let stream = stream .map(|message| { - let data = serde_json::to_string(&message.message).expect("valid message"); - let mut sse = Sse::default().data(data); + let mut sse = if let Some(ref msg) = message.message { + let data = serde_json::to_string(msg.as_ref()).expect("valid message"); + Sse::default().data(data) + } else { + // Priming event: empty data per SSE spec (just "data:\n") + Sse::default().data("") + }; + sse.id = message.event_id; + + if let Some(retry) = message.retry { + sse.retry = Some(retry.as_millis() as u64); + } + Result::::Ok(sse) }) .take_until(async move { ct.cancelled().await }); diff --git a/crates/rmcp/src/transport/streamable_http_server/session/local.rs b/crates/rmcp/src/transport/streamable_http_server/session/local.rs index 1dca3faf..d68d63e1 100644 --- a/crates/rmcp/src/transport/streamable_http_server/session/local.rs +++ b/crates/rmcp/src/transport/streamable_http_server/session/local.rs @@ -201,7 +201,7 @@ impl CachedTx { Self::new(tx, None) } - async fn send(&mut self, message: ServerJsonRpcMessage) { + fn next_event_id(&self) -> EventId { let index = self.cache.back().map_or(0, |m| { m.event_id .as_deref() @@ -211,14 +211,33 @@ impl CachedTx { .index + 1 }); - let event_id = EventId { + EventId { http_request_id: self.http_request_id, index, + } + } + + async fn send(&mut self, message: ServerJsonRpcMessage) { + let event_id = self.next_event_id(); + let message = ServerSseMessage { + event_id: Some(event_id.to_string()), + message: Some(Arc::new(message)), + retry: None, }; + self.cache_and_send(message).await; + } + + async fn send_priming(&mut self, retry: Duration) { + let event_id = self.next_event_id(); let message = ServerSseMessage { event_id: Some(event_id.to_string()), - message: Arc::new(message), + message: None, + retry: Some(retry), }; + self.cache_and_send(message).await; + } + + async fn cache_and_send(&mut self, message: ServerSseMessage) { if self.cache.len() >= self.capacity { self.cache.pop_front(); self.cache.push_back(message.clone()); @@ -525,7 +544,53 @@ impl LocalSessionWorker { } } } + + async fn close_sse_stream( + &mut self, + http_request_id: Option, + retry_interval: Option, + ) -> Result<(), SessionError> { + match http_request_id { + // Close a request-wise stream + Some(id) => { + let request_wise = self + .tx_router + .get_mut(&id) + .ok_or(SessionError::ChannelClosed(Some(id)))?; + + // Send priming event if retry interval is specified + if let Some(interval) = retry_interval { + request_wise.tx.send_priming(interval).await; + } + + // Close the stream by dropping the sender + let (tx, _rx) = tokio::sync::mpsc::channel(1); + request_wise.tx.tx = tx; + + tracing::debug!( + http_request_id = id, + "closed SSE stream for server-initiated disconnection" + ); + Ok(()) + } + // Close the standalone (common) stream + None => { + // Send priming event if retry interval is specified + if let Some(interval) = retry_interval { + self.common.send_priming(interval).await; + } + + // Close the stream by dropping the sender + let (tx, _rx) = tokio::sync::mpsc::channel(1); + self.common.tx = tx; + + tracing::debug!("closed standalone SSE stream for server-initiated disconnection"); + Ok(()) + } + } + } } + #[derive(Debug)] pub enum SessionEvent { ClientMessage { @@ -548,6 +613,13 @@ pub enum SessionEvent { responder: oneshot::Sender>, }, Close, + CloseSseStream { + /// The HTTP request ID to close. If `None`, closes the standalone (common) stream. + http_request_id: Option, + /// Optional retry interval. If provided, a priming event is sent before closing. + retry_interval: Option, + responder: oneshot::Sender>, + }, } #[derive(Debug, Clone)] @@ -683,6 +755,60 @@ impl LocalSessionHandle { rx.await .map_err(|_| SessionError::SessionServiceTerminated)? } + + /// Close an SSE stream for a specific request. + /// + /// This closes the SSE connection for a POST request stream, but keeps the session + /// and message cache active. Clients can reconnect using the `Last-Event-ID` header + /// via a GET request to resume receiving messages. + /// + /// # Arguments + /// + /// * `http_request_id` - The HTTP request ID of the stream to close + /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent + pub async fn close_sse_stream( + &self, + http_request_id: HttpRequestId, + retry_interval: Option, + ) -> Result<(), SessionError> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::CloseSseStream { + http_request_id: Some(http_request_id), + retry_interval, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } + + /// Close the standalone SSE stream. + /// + /// This closes the standalone SSE connection (established via GET request), + /// but keeps the session and message cache active. Clients can reconnect using + /// the `Last-Event-ID` header via a GET request to resume receiving messages. + /// + /// # Arguments + /// + /// * `retry_interval` - Optional retry interval. If provided, a priming event is sent + pub async fn close_standalone_sse_stream( + &self, + retry_interval: Option, + ) -> Result<(), SessionError> { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.event_tx + .send(SessionEvent::CloseSseStream { + http_request_id: None, + retry_interval, + responder: tx, + }) + .await + .map_err(|_| SessionError::SessionServiceTerminated)?; + rx.await + .map_err(|_| SessionError::SessionServiceTerminated)? + } } pub type SessionTransport = WorkerTransport; @@ -848,6 +974,15 @@ impl Worker for LocalSessionWorker { InnerEvent::FromHttpService(SessionEvent::Close) => { return Err(WorkerQuitReason::TransportClosed); } + InnerEvent::FromHttpService(SessionEvent::CloseSseStream { + http_request_id, + retry_interval, + responder, + }) => { + let handle_result = + self.close_sse_stream(http_request_id, retry_interval).await; + let _ = responder.send(handle_result); + } _ => { // ignore } diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 08789566..37d4a008 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -32,7 +32,10 @@ use crate::{ pub struct StreamableHttpServerConfig { /// The ping message duration for SSE connections. pub sse_keep_alive: Option, + /// The retry interval for SSE priming events. + pub sse_retry: Option, /// If true, the server will create a session for each request and keep it alive. + /// When enabled, SSE priming events are sent to enable client reconnection. pub stateful_mode: bool, /// Cancellation token for the Streamable HTTP server. /// @@ -45,6 +48,7 @@ impl Default for StreamableHttpServerConfig { fn default() -> Self { Self { sse_keep_alive: Some(Duration::from_secs(15)), + sse_retry: Some(Duration::from_secs(3)), stateful_mode: true, cancellation_token: CancellationToken::new(), } @@ -216,6 +220,7 @@ where .resume(&session_id, last_event_id) .await .map_err(internal_error_response("resume session"))?; + // Resume doesn't need priming - client already has the event ID Ok(sse_stream_response( stream, self.config.sse_keep_alive, @@ -228,6 +233,19 @@ where .create_standalone_stream(&session_id) .await .map_err(internal_error_response("create standalone stream"))?; + // Prepend priming event if sse_retry configured + let stream = if let Some(retry) = self.config.sse_retry { + let priming = ServerSseMessage { + event_id: Some("0".into()), + message: None, + retry: Some(retry), + }; + futures::stream::once(async move { priming }) + .chain(stream) + .left_stream() + } else { + stream.right_stream() + }; Ok(sse_stream_response( stream, self.config.sse_keep_alive, @@ -322,6 +340,19 @@ where .create_stream(&session_id, message) .await .map_err(internal_error_response("get session"))?; + // Prepend priming event if sse_retry configured + let stream = if let Some(retry) = self.config.sse_retry { + let priming = ServerSseMessage { + event_id: Some("0".into()), + message: None, + retry: Some(retry), + }; + futures::stream::once(async move { priming }) + .chain(stream) + .left_stream() + } else { + stream.right_stream() + }; Ok(sse_stream_response( stream, self.config.sse_keep_alive, @@ -389,15 +420,28 @@ where .initialize_session(&session_id, message) .await .map_err(internal_error_response("create stream"))?; + let stream = futures::stream::once(async move { + ServerSseMessage { + event_id: None, + message: Some(Arc::new(response)), + retry: None, + } + }); + // Prepend priming event if sse_retry configured + let stream = if let Some(retry) = self.config.sse_retry { + let priming = ServerSseMessage { + event_id: Some("0".into()), + message: None, + retry: Some(retry), + }; + futures::stream::once(async move { priming }) + .chain(stream) + .left_stream() + } else { + stream.right_stream() + }; let mut response = sse_stream_response( - futures::stream::once({ - async move { - ServerSseMessage { - event_id: None, - message: response.into(), - } - } - }), + stream, self.config.sse_keep_alive, self.config.cancellation_token.child_token(), ); @@ -424,14 +468,17 @@ where // on service created let _ = service.waiting().await; }); + // Stateless mode: no priming (no session to resume) + let stream = ReceiverStream::new(receiver).map(|message| { + tracing::info!(?message); + ServerSseMessage { + event_id: None, + message: Some(Arc::new(message)), + retry: None, + } + }); Ok(sse_stream_response( - ReceiverStream::new(receiver).map(|message| { - tracing::info!(?message); - ServerSseMessage { - event_id: None, - message: message.into(), - } - }), + stream, self.config.sse_keep_alive, self.config.cancellation_token.child_token(), )) diff --git a/crates/rmcp/tests/test_streamable_http_priming.rs b/crates/rmcp/tests/test_streamable_http_priming.rs new file mode 100644 index 00000000..778dfedf --- /dev/null +++ b/crates/rmcp/tests/test_streamable_http_priming.rs @@ -0,0 +1,168 @@ +use std::time::Duration; + +use rmcp::transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, +}; +use tokio_util::sync::CancellationToken; + +mod common; +use common::calculator::Calculator; + +#[tokio::test] +async fn test_priming_on_stream_start() -> anyhow::Result<()> { + let ct = CancellationToken::new(); + + // stateful_mode: true automatically enables priming with DEFAULT_RETRY_INTERVAL (3 seconds) + let service: StreamableHttpService = + StreamableHttpService::new( + || Ok(Calculator::new()), + Default::default(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + cancellation_token: ct.child_token(), + ..Default::default() + }, + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = tcp_listener.local_addr()?; + + let handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + // Send initialize request + let client = reqwest::Client::new(); + let response = client + .post(format!("http://{addr}/mcp")) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#) + .send() + .await?; + + assert_eq!(response.status(), 200); + + let body = response.text().await?; + + // Split SSE events by double newline + let events: Vec<&str> = body.split("\n\n").filter(|e| !e.is_empty()).collect(); + assert!(events.len() >= 2); + + // Verify priming event (first event) + let priming_event = events[0]; + assert!(priming_event.contains("id: 0")); + assert!(priming_event.contains("retry: 3000")); + assert!(priming_event.contains("data:")); + + // Verify initialize response (second event) + let response_event = events[1]; + assert!(response_event.contains(r#""jsonrpc":"2.0""#)); + assert!(response_event.contains(r#""id":1"#)); + + ct.cancel(); + handle.await?; + + Ok(()) +} + +#[tokio::test] +async fn test_priming_on_stream_close() -> anyhow::Result<()> { + use std::sync::Arc; + + use rmcp::transport::streamable_http_server::session::SessionId; + + let ct = CancellationToken::new(); + let session_manager = Arc::new(LocalSessionManager::default()); + + // stateful_mode: true automatically enables priming with DEFAULT_RETRY_INTERVAL (3 seconds) + let service = StreamableHttpService::new( + || Ok(Calculator::new()), + session_manager.clone(), + StreamableHttpServerConfig { + stateful_mode: true, + sse_keep_alive: None, + cancellation_token: ct.child_token(), + ..Default::default() + }, + ); + + let router = axum::Router::new().nest_service("/mcp", service); + let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let addr = tcp_listener.local_addr()?; + + let handle = tokio::spawn({ + let ct = ct.clone(); + async move { + let _ = axum::serve(tcp_listener, router) + .with_graceful_shutdown(async move { ct.cancelled_owned().await }) + .await; + } + }); + + // Send initialize request to create a session + let client = reqwest::Client::new(); + let response = client + .post(format!("http://{addr}/mcp")) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#) + .send() + .await?; + + let session_id: SessionId = response.headers()["mcp-session-id"].to_str()?.into(); + + // Open a standalone GET stream (send() returns when headers are received) + let response = client + .get(format!("http://{addr}/mcp")) + .header("Accept", "text/event-stream") + .header("mcp-session-id", session_id.to_string()) + .send() + .await?; + + assert_eq!(response.status(), 200); + + // Spawn a task to read the response body (blocks until stream closes) + let read_task = tokio::spawn(async move { response.text().await.unwrap() }); + + // Close the standalone stream with a 5-second retry hint + let sessions = session_manager.sessions.read().await; + let session = sessions.get(&session_id).unwrap(); + session + .close_standalone_sse_stream(Some(Duration::from_secs(5))) + .await?; + drop(sessions); + + // Wait for the read task to complete and verify the response + let body = read_task.await?; + + // Verify the stream received two priming events: + // 1. At stream start (retry: 3000) + // 2. Before close (retry: 5000) + let events: Vec<&str> = body.split("\n\n").filter(|e| !e.is_empty()).collect(); + assert_eq!(events.len(), 2); + + // First event: priming at stream start + let start_priming = events[0]; + assert!(start_priming.contains("id:")); + assert!(start_priming.contains("retry: 3000")); + assert!(start_priming.contains("data:")); + + // Second event: priming before close + let close_priming = events[1]; + assert!(close_priming.contains("id:")); + assert!(close_priming.contains("retry: 5000")); + assert!(close_priming.contains("data:")); + + ct.cancel(); + handle.await?; + + Ok(()) +} From 9b6f017e9d5b48f0b9a272dd6b86d02344214d9d Mon Sep 17 00:00:00 2001 From: Dale Seo Date: Sat, 27 Dec 2025 23:59:02 -0500 Subject: [PATCH 2/2] test: add tests for priming behavior on stream start and close --- crates/rmcp/tests/test_with_js.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index ef0516bc..c1e5d81a 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -72,6 +72,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { stateful_mode: true, sse_keep_alive: None, cancellation_token: ct.child_token(), + ..Default::default() }, ); let router = axum::Router::new().nest_service("/mcp", service);