Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ server-side-http = [
"dep:http-body-util",
"dep:bytes",
"dep:sse-stream",
"dep:axum",
"tower",
]

Expand Down Expand Up @@ -201,4 +202,9 @@ path = "tests/test_elicitation.rs"
[[test]]
name = "test_task"
required-features = ["server", "client", "macros"]
path = "tests/test_task.rs"
path = "tests/test_task.rs"

[[test]]
name = "test_streamable_http_priming"
required-features = ["server", "client", "transport-streamable-http-server", "reqwest"]
path = "tests/test_streamable_http_priming.rs"
23 changes: 20 additions & 3 deletions crates/rmcp/src/transport/common/server_side_http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ impl sse_stream::Timer for TokioTimer {

#[derive(Debug, Clone)]
pub struct ServerSseMessage {
/// The event ID for this message. When set, clients can use this ID
/// with the `Last-Event-ID` header to resume the stream from this point.
pub event_id: Option<String>,
pub message: Arc<ServerJsonRpcMessage>,
/// The JSON-RPC message content. For priming events, set this to `None`.
pub message: Option<Arc<ServerJsonRpcMessage>>,
/// The retry interval hint for clients. Clients should wait this duration
/// before attempting to reconnect. This maps to the SSE `retry:` field.
pub retry: Option<Duration>,
}

pub(crate) fn sse_stream_response(
Expand All @@ -71,9 +77,20 @@ pub(crate) fn sse_stream_response(
use futures::StreamExt;
let stream = stream
.map(|message| {
let data = serde_json::to_string(&message.message).expect("valid message");
let mut sse = Sse::default().data(data);
let mut sse = if let Some(ref msg) = message.message {
let data = serde_json::to_string(msg.as_ref()).expect("valid message");
Sse::default().data(data)
} else {
// Priming event: empty data per SSE spec (just "data:\n")
Sse::default().data("")
};

sse.id = message.event_id;

if let Some(retry) = message.retry {
sse.retry = Some(retry.as_millis() as u64);
}

Result::<Sse, Infallible>::Ok(sse)
})
.take_until(async move { ct.cancelled().await });
Expand Down
141 changes: 138 additions & 3 deletions crates/rmcp/src/transport/streamable_http_server/session/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl CachedTx {
Self::new(tx, None)
}

async fn send(&mut self, message: ServerJsonRpcMessage) {
fn next_event_id(&self) -> EventId {
let index = self.cache.back().map_or(0, |m| {
m.event_id
.as_deref()
Expand All @@ -211,14 +211,33 @@ impl CachedTx {
.index
+ 1
});
let event_id = EventId {
EventId {
http_request_id: self.http_request_id,
index,
}
}

async fn send(&mut self, message: ServerJsonRpcMessage) {
let event_id = self.next_event_id();
let message = ServerSseMessage {
event_id: Some(event_id.to_string()),
message: Some(Arc::new(message)),
retry: None,
};
self.cache_and_send(message).await;
}

async fn send_priming(&mut self, retry: Duration) {
let event_id = self.next_event_id();
let message = ServerSseMessage {
event_id: Some(event_id.to_string()),
message: Arc::new(message),
message: None,
retry: Some(retry),
};
self.cache_and_send(message).await;
}

async fn cache_and_send(&mut self, message: ServerSseMessage) {
if self.cache.len() >= self.capacity {
self.cache.pop_front();
self.cache.push_back(message.clone());
Expand Down Expand Up @@ -525,7 +544,53 @@ impl LocalSessionWorker {
}
}
}

async fn close_sse_stream(
&mut self,
http_request_id: Option<HttpRequestId>,
retry_interval: Option<Duration>,
) -> Result<(), SessionError> {
match http_request_id {
// Close a request-wise stream
Some(id) => {
let request_wise = self
.tx_router
.get_mut(&id)
.ok_or(SessionError::ChannelClosed(Some(id)))?;

// Send priming event if retry interval is specified
if let Some(interval) = retry_interval {
request_wise.tx.send_priming(interval).await;
}

// Close the stream by dropping the sender
let (tx, _rx) = tokio::sync::mpsc::channel(1);
request_wise.tx.tx = tx;

tracing::debug!(
http_request_id = id,
"closed SSE stream for server-initiated disconnection"
);
Ok(())
}
// Close the standalone (common) stream
None => {
// Send priming event if retry interval is specified
if let Some(interval) = retry_interval {
self.common.send_priming(interval).await;
}

// Close the stream by dropping the sender
let (tx, _rx) = tokio::sync::mpsc::channel(1);
self.common.tx = tx;

tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
Ok(())
}
}
}
}

#[derive(Debug)]
pub enum SessionEvent {
ClientMessage {
Expand All @@ -548,6 +613,13 @@ pub enum SessionEvent {
responder: oneshot::Sender<Result<ServerJsonRpcMessage, SessionError>>,
},
Close,
CloseSseStream {
/// The HTTP request ID to close. If `None`, closes the standalone (common) stream.
http_request_id: Option<HttpRequestId>,
/// Optional retry interval. If provided, a priming event is sent before closing.
retry_interval: Option<Duration>,
responder: oneshot::Sender<Result<(), SessionError>>,
},
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -683,6 +755,60 @@ impl LocalSessionHandle {
rx.await
.map_err(|_| SessionError::SessionServiceTerminated)?
}

/// Close an SSE stream for a specific request.
///
/// This closes the SSE connection for a POST request stream, but keeps the session
/// and message cache active. Clients can reconnect using the `Last-Event-ID` header
/// via a GET request to resume receiving messages.
///
/// # Arguments
///
/// * `http_request_id` - The HTTP request ID of the stream to close
/// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
pub async fn close_sse_stream(
&self,
http_request_id: HttpRequestId,
retry_interval: Option<Duration>,
) -> Result<(), SessionError> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.event_tx
.send(SessionEvent::CloseSseStream {
http_request_id: Some(http_request_id),
retry_interval,
responder: tx,
})
.await
.map_err(|_| SessionError::SessionServiceTerminated)?;
rx.await
.map_err(|_| SessionError::SessionServiceTerminated)?
}

/// Close the standalone SSE stream.
///
/// This closes the standalone SSE connection (established via GET request),
/// but keeps the session and message cache active. Clients can reconnect using
/// the `Last-Event-ID` header via a GET request to resume receiving messages.
///
/// # Arguments
///
/// * `retry_interval` - Optional retry interval. If provided, a priming event is sent
pub async fn close_standalone_sse_stream(
&self,
retry_interval: Option<Duration>,
) -> Result<(), SessionError> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.event_tx
.send(SessionEvent::CloseSseStream {
http_request_id: None,
retry_interval,
responder: tx,
})
.await
.map_err(|_| SessionError::SessionServiceTerminated)?;
rx.await
.map_err(|_| SessionError::SessionServiceTerminated)?
}
}

pub type SessionTransport = WorkerTransport<LocalSessionWorker>;
Expand Down Expand Up @@ -848,6 +974,15 @@ impl Worker for LocalSessionWorker {
InnerEvent::FromHttpService(SessionEvent::Close) => {
return Err(WorkerQuitReason::TransportClosed);
}
InnerEvent::FromHttpService(SessionEvent::CloseSseStream {
http_request_id,
retry_interval,
responder,
}) => {
let handle_result =
self.close_sse_stream(http_request_id, retry_interval).await;
let _ = responder.send(handle_result);
}
_ => {
// ignore
}
Expand Down
77 changes: 62 additions & 15 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ use crate::{
pub struct StreamableHttpServerConfig {
/// The ping message duration for SSE connections.
pub sse_keep_alive: Option<Duration>,
/// The retry interval for SSE priming events.
pub sse_retry: Option<Duration>,
/// If true, the server will create a session for each request and keep it alive.
/// When enabled, SSE priming events are sent to enable client reconnection.
pub stateful_mode: bool,
/// Cancellation token for the Streamable HTTP server.
///
Expand All @@ -45,6 +48,7 @@ impl Default for StreamableHttpServerConfig {
fn default() -> Self {
Self {
sse_keep_alive: Some(Duration::from_secs(15)),
sse_retry: Some(Duration::from_secs(3)),
stateful_mode: true,
cancellation_token: CancellationToken::new(),
}
Expand Down Expand Up @@ -216,6 +220,7 @@ where
.resume(&session_id, last_event_id)
.await
.map_err(internal_error_response("resume session"))?;
// Resume doesn't need priming - client already has the event ID
Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
Expand All @@ -228,6 +233,19 @@ where
.create_standalone_stream(&session_id)
.await
.map_err(internal_error_response("create standalone stream"))?;
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
} else {
stream.right_stream()
};
Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
Expand Down Expand Up @@ -322,6 +340,19 @@ where
.create_stream(&session_id, message)
.await
.map_err(internal_error_response("get session"))?;
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
} else {
stream.right_stream()
};
Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
Expand Down Expand Up @@ -389,15 +420,28 @@ where
.initialize_session(&session_id, message)
.await
.map_err(internal_error_response("create stream"))?;
let stream = futures::stream::once(async move {
ServerSseMessage {
event_id: None,
message: Some(Arc::new(response)),
retry: None,
}
});
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
} else {
stream.right_stream()
};
let mut response = sse_stream_response(
futures::stream::once({
async move {
ServerSseMessage {
event_id: None,
message: response.into(),
}
}
}),
stream,
self.config.sse_keep_alive,
self.config.cancellation_token.child_token(),
);
Expand All @@ -424,14 +468,17 @@ where
// on service created
let _ = service.waiting().await;
});
// Stateless mode: no priming (no session to resume)
let stream = ReceiverStream::new(receiver).map(|message| {
tracing::info!(?message);
ServerSseMessage {
event_id: None,
message: Some(Arc::new(message)),
retry: None,
}
});
Ok(sse_stream_response(
ReceiverStream::new(receiver).map(|message| {
tracing::info!(?message);
ServerSseMessage {
event_id: None,
message: message.into(),
}
}),
stream,
self.config.sse_keep_alive,
self.config.cancellation_token.child_token(),
))
Expand Down
Loading