Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,8 @@ required-features = [
"transport-streamable-http-server",
]
path = "tests/test_custom_headers.rs"

[[test]]
name = "test_sse_channel_replacement_bug"
required-features = ["server", "client", "transport-streamable-http-server", "transport-streamable-http-client", "reqwest"]
path = "tests/test_sse_channel_replacement_bug.rs"
97 changes: 71 additions & 26 deletions crates/rmcp/src/transport/streamable_http_server/session/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ pub struct LocalSessionWorker {
tx_router: HashMap<HttpRequestId, HttpRequestWise>,
resource_router: HashMap<ResourceKey, HttpRequestId>,
common: CachedTx,
/// Shadow senders for secondary SSE streams (e.g. from POST EventSource
/// reconnections). These keep the HTTP connections alive via SSE keep-alive
/// without receiving notifications, preventing clients like Cursor from
/// entering infinite reconnect loops when multiple EventSource connections
/// compete to replace the common channel.
shadow_txs: Vec<Sender<ServerSseMessage>>,
event_rx: Receiver<SessionEvent>,
session_config: SessionConfig,
}
Expand All @@ -315,6 +321,8 @@ pub enum SessionError {
SessionServiceTerminated,
#[error("Invalid event id")]
InvalidEventId,
#[error("Conflict: Only one standalone SSE stream is allowed per session")]
Conflict,
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
Expand Down Expand Up @@ -513,36 +521,69 @@ impl LocalSessionWorker {
&mut self,
last_event_id: EventId,
) -> Result<StreamableHttpMessageReceiver, SessionError> {
// Clean up closed shadow senders before processing
self.shadow_txs.retain(|tx| !tx.is_closed());

match last_event_id.http_request_id {
Some(http_request_id) => {
let request_wise = self
.tx_router
.get_mut(&http_request_id)
.ok_or(SessionError::ChannelClosed(Some(http_request_id)))?;
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let (tx, rx) = channel;
request_wise.tx.tx = tx;
let index = last_event_id.index;
// sync messages after index
request_wise.tx.sync(index).await?;
Ok(StreamableHttpMessageReceiver {
http_request_id: Some(http_request_id),
inner: rx,
})
}
None => {
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let (tx, rx) = channel;
self.common.tx = tx;
let index = last_event_id.index;
// sync messages after index
self.common.sync(index).await?;
Ok(StreamableHttpMessageReceiver {
http_request_id: None,
inner: rx,
})
if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
// Resume existing request-wise channel
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let (tx, rx) = channel;
request_wise.tx.tx = tx;
let index = last_event_id.index;
// sync messages after index
request_wise.tx.sync(index).await?;
Ok(StreamableHttpMessageReceiver {
http_request_id: Some(http_request_id),
inner: rx,
})
} else {
// Request-wise channel completed (POST response already delivered).
// The client's EventSource is reconnecting after the POST SSE stream
// ended. Fall through to common channel handling below.
tracing::debug!(
http_request_id,
"Request-wise channel completed, falling back to common channel"
);
self.resume_or_shadow_common()
}
}
None => self.resume_or_shadow_common(),
}
}

/// Resume the common channel, or create a shadow stream if the primary is
/// still active.
///
/// When the primary common channel is dead (receiver dropped), replace it
/// so this stream becomes the new primary notification channel.
///
/// When the primary is still active, create a "shadow" stream — an idle SSE
/// connection kept alive by keep-alive pings. This prevents multiple
/// EventSource connections (e.g. from POST response reconnections) from
/// killing each other by repeatedly replacing the common channel sender.
fn resume_or_shadow_common(&mut self) -> Result<StreamableHttpMessageReceiver, SessionError> {
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
if self.common.tx.is_closed() {
// Primary common channel is dead — replace it.
tracing::debug!("Replacing dead common channel with new primary");
self.common.tx = tx;
} else {
// Primary common channel is still active. Create a shadow stream
// that stays alive via SSE keep-alive but doesn't receive
// notifications. This prevents competing EventSource connections
// from killing each other's channels.
tracing::debug!(
shadow_count = self.shadow_txs.len(),
"Common channel active, creating shadow stream"
);
self.shadow_txs.push(tx);
}
Ok(StreamableHttpMessageReceiver {
http_request_id: None,
inner: rx,
})
}

async fn close_sse_stream(
Expand Down Expand Up @@ -584,6 +625,9 @@ impl LocalSessionWorker {
let (tx, _rx) = tokio::sync::mpsc::channel(1);
self.common.tx = tx;

// Also close all shadow streams
self.shadow_txs.clear();

tracing::debug!("closed standalone SSE stream for server-initiated disconnection");
Ok(())
}
Expand Down Expand Up @@ -1036,6 +1080,7 @@ pub fn create_local_session(
tx_router: HashMap::new(),
resource_router: HashMap::new(),
common,
shadow_txs: Vec::new(),
event_rx,
session_config: config.clone(),
};
Expand Down
50 changes: 34 additions & 16 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ where
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned().into());
let Some(session_id) = session_id else {
// unauthorized
// MCP spec: servers that require a session ID SHOULD respond with 400 Bad Request
return Ok(Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed())
.status(http::StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
.expect("valid response"));
};
// check if session exists
Expand All @@ -201,10 +201,10 @@ where
.await
.map_err(internal_error_response("check session"))?;
if !has_session {
// unauthorized
// MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
return Ok(Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed())
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
}
// check if last event id is provided
Expand All @@ -215,11 +215,20 @@ where
.map(|s| s.to_owned());
if let Some(last_event_id) = last_event_id {
// check if session has this event id
let stream = self
let stream = match self
.session_manager
.resume(&session_id, last_event_id)
.await
.map_err(internal_error_response("resume session"))?;
{
Ok(stream) => stream,
Err(e) if e.to_string().contains("Conflict:") => {
return Ok(Response::builder()
.status(http::StatusCode::CONFLICT)
.body(Full::new(Bytes::from(e.to_string())).boxed())
.expect("valid response"));
}
Err(e) => return Err(internal_error_response("resume session")(e)),
};
// Resume doesn't need priming - client already has the event ID
Ok(sse_stream_response(
stream,
Expand All @@ -228,11 +237,20 @@ where
))
} else {
// create standalone stream
let stream = self
let stream = match self
.session_manager
.create_standalone_stream(&session_id)
.await
.map_err(internal_error_response("create standalone stream"))?;
{
Ok(stream) => stream,
Err(e) if e.to_string().contains("Conflict:") => {
return Ok(Response::builder()
.status(http::StatusCode::CONFLICT)
.body(Full::new(Bytes::from(e.to_string())).boxed())
.expect("valid response"));
}
Err(e) => return Err(internal_error_response("create standalone stream")(e)),
};
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
Expand Down Expand Up @@ -313,10 +331,10 @@ where
.await
.map_err(internal_error_response("check session"))?;
if !has_session {
// unauthorized
// MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
return Ok(Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed())
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
}

Expand Down Expand Up @@ -505,10 +523,10 @@ where
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned().into());
let Some(session_id) = session_id else {
// unauthorized
// MCP spec: servers that require a session ID SHOULD respond with 400 Bad Request
return Ok(Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed())
.status(http::StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from("Bad Request: Session ID is required")).boxed())
.expect("valid response"));
};
// close session
Expand Down
Loading