Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/agent-client-protocol/src/jsonrpc/incoming_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::collections::HashMap;
use futures::StreamExt as _;
use futures::channel::mpsc;
use futures::channel::oneshot;
use futures::stream;
use futures_concurrency::stream::StreamExt as _;
use rustc_hash::FxHashMap;
use uuid::Uuid;
Expand Down Expand Up @@ -58,8 +59,11 @@ pub(super) async fn incoming_protocol_actor<Counterpart: Role>(
mut handler: impl HandleDispatchFrom<Counterpart>,
protocol_compat: ProtocolCompat,
) -> Result<(), crate::Error> {
let mut my_rx = transport_rx
.map(IncomingProtocolMsg::Transport)
let transport_with_eof = futures::StreamExt::chain(
transport_rx.map(IncomingProtocolMsg::Transport),
stream::iter([IncomingProtocolMsg::TransportClosed]),
);
let mut my_rx = transport_with_eof
.merge(dynamic_handler_rx.map(IncomingProtocolMsg::DynamicHandler))
.merge(reply_rx.map(IncomingProtocolMsg::Reply));

Expand All @@ -76,6 +80,11 @@ pub(super) async fn incoming_protocol_actor<Counterpart: Role>(
while let Some(message_result) = my_rx.next().await {
tracing::trace!(message = ?message_result, actor = "incoming_protocol_actor");
match message_result {
IncomingProtocolMsg::TransportClosed => {
tracing::debug!("Transport closed (EOF), shutting down incoming actor");
return Err(crate::Error::internal_error().data("transport closed".to_string()));
}

IncomingProtocolMsg::Reply(message) => match message {
ReplyMessage::Subscribe {
id,
Expand Down Expand Up @@ -253,6 +262,7 @@ pub(super) async fn incoming_protocol_actor<Counterpart: Role>(
#[derive(Debug)]
enum IncomingProtocolMsg<Counterpart: Role> {
Transport(Result<RawJsonRpcMessage, crate::Error>),
TransportClosed,
DynamicHandler(DynamicHandlerMessage<Counterpart>),
Reply(ReplyMessage),
}
Expand Down
5 changes: 3 additions & 2 deletions src/agent-client-protocol/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ pub async fn both<E>(
/// Run `background` until `foreground` completes.
///
/// Returns the result of `foreground`. If `background` errors before
/// `foreground` completes, the error is propagated. If `background`
/// completes with `Ok(())`, we continue waiting for `foreground`.
/// `foreground` completes, the error is propagated and `foreground` is
/// cancelled. If `background` completes with `Ok(())`, we continue
/// waiting for `foreground`.
pub async fn run_until<T, E>(
background: impl Future<Output = Result<(), E>>,
foreground: impl Future<Output = Result<T, E>>,
Expand Down
178 changes: 178 additions & 0 deletions src/agent-client-protocol/tests/jsonrpc_transport_close.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
//! Tests for transport close (EOF) detection.
//!
//! Verifies that `connect_with` and `connect_to` return when the remote
//! end of the transport closes, rather than hanging forever.

use agent_client_protocol::{ByteStreams, Dispatch, Handled, role::UntypedRole};
use futures::{AsyncRead, AsyncWrite};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};

fn setup_test_streams() -> (
impl AsyncRead,
impl AsyncWrite,
impl AsyncRead,
impl AsyncWrite,
) {
let (client_writer, server_reader) = tokio::io::duplex(1024);
let (server_writer, client_reader) = tokio::io::duplex(1024);

(
server_reader.compat(),
server_writer.compat_write(),
client_reader.compat(),
client_writer.compat_write(),
)
}

/// When the remote side's `connect_with` returns (dropping the transport),
/// the local `connect_with` with a blocking main_fn should also return
/// rather than hanging forever.
#[tokio::test(flavor = "current_thread")]
async fn connect_with_returns_on_remote_transport_close() {
use tokio::task::LocalSet;

let local = LocalSet::new();

local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();

let server_transport = ByteStreams::new(server_writer, server_reader);
let client_transport = ByteStreams::new(client_writer, client_reader);

// Server accepts and immediately returns (closing transport).
tokio::task::spawn_local(async move {
UntypedRole
.builder()
.name("server")
.connect_with(server_transport, async move |_cx| Ok(()))
.await
.ok();
});

// Client blocks on an mpsc channel (simulating a proxy pattern).
let (_tx, mut rx) = mpsc::unbounded_channel::<Dispatch>();
let result = tokio::time::timeout(
Duration::from_secs(2),
UntypedRole.builder().name("client").connect_with(
client_transport,
async move |_cx| {
// Block until channel closes — but transport EOF should
// cancel this future before that happens.
while rx.recv().await.is_some() {}
Ok(())
},
),
)
.await;

assert!(
result.is_ok(),
"connect_with should return on transport close, not time out"
);
// The connection closed before main_fn finished, so it returns Err.
assert!(result.unwrap().is_err());
})
.await;
}

/// `connect_to` (server mode) should return when the remote end disconnects.
#[tokio::test(flavor = "current_thread")]
async fn connect_to_returns_on_remote_transport_close() {
use tokio::task::LocalSet;

let local = LocalSet::new();

local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();

let server_transport = ByteStreams::new(server_writer, server_reader);
let client_transport = ByteStreams::new(client_writer, client_reader);

// Client connects and immediately returns.
tokio::task::spawn_local(async move {
UntypedRole
.builder()
.name("client")
.connect_with(client_transport, async move |_cx| Ok(()))
.await
.ok();
});

// Server uses connect_to (pending foreground) — should detect EOF.
let result = tokio::time::timeout(
Duration::from_secs(2),
UntypedRole
.builder()
.name("server")
.connect_to(server_transport),
)
.await;

assert!(
result.is_ok(),
"connect_to should return on transport close, not time out"
);
})
.await;
}

/// When using `on_receive_dispatch` with a forwarding channel pattern,
/// transport close should still be detected and the connection should exit.
#[tokio::test(flavor = "current_thread")]
async fn connect_with_on_receive_dispatch_returns_on_transport_close() {
use tokio::task::LocalSet;

let local = LocalSet::new();

local
.run_until(async {
let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams();

let server_transport = ByteStreams::new(server_writer, server_reader);
let client_transport = ByteStreams::new(client_writer, client_reader);

// Server immediately exits.
tokio::task::spawn_local(async move {
UntypedRole
.builder()
.name("server")
.connect_with(server_transport, async move |_cx| Ok(()))
.await
.ok();
});

// Client with an on_receive_dispatch handler and a blocking main_fn.
let (_outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<Dispatch>();
let result = tokio::time::timeout(
Duration::from_secs(2),
UntypedRole
.builder()
.name("client")
.on_receive_dispatch(
async move |_dispatch: Dispatch,
_cx: agent_client_protocol::ConnectionTo<UntypedRole>| {
Ok(Handled::Yes)
},
agent_client_protocol::on_receive_dispatch!(),
)
.connect_with(client_transport, async move |cx| {
while let Some(dispatch) = outgoing_rx.recv().await {
cx.send_proxied_message(dispatch)?;
}
Ok(())
}),
)
.await;

assert!(
result.is_ok(),
"connect_with with handler should return on transport close, not time out"
);
assert!(result.unwrap().is_err());
})
.await;
}