diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index e602926..351fc41 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use futures::StreamExt as _; use futures::channel::mpsc; use futures::channel::oneshot; +use futures::stream; use futures_concurrency::stream::StreamExt as _; use rustc_hash::FxHashMap; use uuid::Uuid; @@ -58,8 +59,11 @@ pub(super) async fn incoming_protocol_actor( mut handler: impl HandleDispatchFrom, protocol_compat: ProtocolCompat, ) -> Result<(), crate::Error> { - let mut my_rx = transport_rx - .map(IncomingProtocolMsg::Transport) + let transport_with_eof = futures::StreamExt::chain( + transport_rx.map(IncomingProtocolMsg::Transport), + stream::iter([IncomingProtocolMsg::TransportClosed]), + ); + let mut my_rx = transport_with_eof .merge(dynamic_handler_rx.map(IncomingProtocolMsg::DynamicHandler)) .merge(reply_rx.map(IncomingProtocolMsg::Reply)); @@ -76,6 +80,11 @@ pub(super) async fn incoming_protocol_actor( while let Some(message_result) = my_rx.next().await { tracing::trace!(message = ?message_result, actor = "incoming_protocol_actor"); match message_result { + IncomingProtocolMsg::TransportClosed => { + tracing::debug!("Transport closed (EOF), shutting down incoming actor"); + return Err(crate::Error::internal_error().data("transport closed".to_string())); + } + IncomingProtocolMsg::Reply(message) => match message { ReplyMessage::Subscribe { id, @@ -253,6 +262,7 @@ pub(super) async fn incoming_protocol_actor( #[derive(Debug)] enum IncomingProtocolMsg { Transport(Result), + TransportClosed, DynamicHandler(DynamicHandlerMessage), Reply(ReplyMessage), } diff --git a/src/agent-client-protocol/src/util.rs b/src/agent-client-protocol/src/util.rs index 770dcb3..57de80d 100644 --- a/src/agent-client-protocol/src/util.rs +++ b/src/agent-client-protocol/src/util.rs @@ -105,8 +105,9 @@ pub async fn both( /// Run `background` until `foreground` completes. /// /// Returns the result of `foreground`. If `background` errors before -/// `foreground` completes, the error is propagated. If `background` -/// completes with `Ok(())`, we continue waiting for `foreground`. +/// `foreground` completes, the error is propagated and `foreground` is +/// cancelled. If `background` completes with `Ok(())`, we continue +/// waiting for `foreground`. pub async fn run_until( background: impl Future>, foreground: impl Future>, diff --git a/src/agent-client-protocol/tests/jsonrpc_transport_close.rs b/src/agent-client-protocol/tests/jsonrpc_transport_close.rs new file mode 100644 index 0000000..5235e6c --- /dev/null +++ b/src/agent-client-protocol/tests/jsonrpc_transport_close.rs @@ -0,0 +1,178 @@ +//! Tests for transport close (EOF) detection. +//! +//! Verifies that `connect_with` and `connect_to` return when the remote +//! end of the transport closes, rather than hanging forever. + +use agent_client_protocol::{ByteStreams, Dispatch, Handled, role::UntypedRole}; +use futures::{AsyncRead, AsyncWrite}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + +fn setup_test_streams() -> ( + impl AsyncRead, + impl AsyncWrite, + impl AsyncRead, + impl AsyncWrite, +) { + let (client_writer, server_reader) = tokio::io::duplex(1024); + let (server_writer, client_reader) = tokio::io::duplex(1024); + + ( + server_reader.compat(), + server_writer.compat_write(), + client_reader.compat(), + client_writer.compat_write(), + ) +} + +/// When the remote side's `connect_with` returns (dropping the transport), +/// the local `connect_with` with a blocking main_fn should also return +/// rather than hanging forever. +#[tokio::test(flavor = "current_thread")] +async fn connect_with_returns_on_remote_transport_close() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = ByteStreams::new(server_writer, server_reader); + let client_transport = ByteStreams::new(client_writer, client_reader); + + // Server accepts and immediately returns (closing transport). + tokio::task::spawn_local(async move { + UntypedRole + .builder() + .name("server") + .connect_with(server_transport, async move |_cx| Ok(())) + .await + .ok(); + }); + + // Client blocks on an mpsc channel (simulating a proxy pattern). + let (_tx, mut rx) = mpsc::unbounded_channel::(); + let result = tokio::time::timeout( + Duration::from_secs(2), + UntypedRole.builder().name("client").connect_with( + client_transport, + async move |_cx| { + // Block until channel closes — but transport EOF should + // cancel this future before that happens. + while rx.recv().await.is_some() {} + Ok(()) + }, + ), + ) + .await; + + assert!( + result.is_ok(), + "connect_with should return on transport close, not time out" + ); + // The connection closed before main_fn finished, so it returns Err. + assert!(result.unwrap().is_err()); + }) + .await; +} + +/// `connect_to` (server mode) should return when the remote end disconnects. +#[tokio::test(flavor = "current_thread")] +async fn connect_to_returns_on_remote_transport_close() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = ByteStreams::new(server_writer, server_reader); + let client_transport = ByteStreams::new(client_writer, client_reader); + + // Client connects and immediately returns. + tokio::task::spawn_local(async move { + UntypedRole + .builder() + .name("client") + .connect_with(client_transport, async move |_cx| Ok(())) + .await + .ok(); + }); + + // Server uses connect_to (pending foreground) — should detect EOF. + let result = tokio::time::timeout( + Duration::from_secs(2), + UntypedRole + .builder() + .name("server") + .connect_to(server_transport), + ) + .await; + + assert!( + result.is_ok(), + "connect_to should return on transport close, not time out" + ); + }) + .await; +} + +/// When using `on_receive_dispatch` with a forwarding channel pattern, +/// transport close should still be detected and the connection should exit. +#[tokio::test(flavor = "current_thread")] +async fn connect_with_on_receive_dispatch_returns_on_transport_close() { + use tokio::task::LocalSet; + + let local = LocalSet::new(); + + local + .run_until(async { + let (server_reader, server_writer, client_reader, client_writer) = setup_test_streams(); + + let server_transport = ByteStreams::new(server_writer, server_reader); + let client_transport = ByteStreams::new(client_writer, client_reader); + + // Server immediately exits. + tokio::task::spawn_local(async move { + UntypedRole + .builder() + .name("server") + .connect_with(server_transport, async move |_cx| Ok(())) + .await + .ok(); + }); + + // Client with an on_receive_dispatch handler and a blocking main_fn. + let (_outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::(); + let result = tokio::time::timeout( + Duration::from_secs(2), + UntypedRole + .builder() + .name("client") + .on_receive_dispatch( + async move |_dispatch: Dispatch, + _cx: agent_client_protocol::ConnectionTo| { + Ok(Handled::Yes) + }, + agent_client_protocol::on_receive_dispatch!(), + ) + .connect_with(client_transport, async move |cx| { + while let Some(dispatch) = outgoing_rx.recv().await { + cx.send_proxied_message(dispatch)?; + } + Ok(()) + }), + ) + .await; + + assert!( + result.is_ok(), + "connect_with with handler should return on transport close, not time out" + ); + assert!(result.unwrap().is_err()); + }) + .await; +}