diff --git a/rs/hang-cli/src/client.rs b/rs/hang-cli/src/client.rs index 8d65d27e2..fa14f569e 100644 --- a/rs/hang-cli/src/client.rs +++ b/rs/hang-cli/src/client.rs @@ -11,7 +11,7 @@ pub async fn run_client(client: moq_native::Client, url: Url, name: String, publ tracing::info!(%url, %name, "connecting"); // Establish the connection, not providing a subscriber. - let session = client.connect_with_fallback(url, origin.consumer, None).await?; + let session = client.with_publish(origin.consumer).connect(url).await?; #[cfg(unix)] // Notify systemd that we're ready. diff --git a/rs/hang-cli/src/main.rs b/rs/hang-cli/src/main.rs index 1c6f321dc..968fb0e9c 100644 --- a/rs/hang-cli/src/main.rs +++ b/rs/hang-cli/src/main.rs @@ -98,10 +98,9 @@ async fn main() -> anyhow::Result<()> { Command::Serve { config, dir, name, .. } => { let web_bind = config.bind.unwrap_or("[::]:443".parse().unwrap()); - #[allow(unused_mut)] - let mut server = config.init()?; + let server = config.init()?; #[cfg(feature = "iroh")] - server.with_iroh(iroh); + let server = server.with_iroh(iroh); let web_tls = server.tls_info(); @@ -112,11 +111,10 @@ async fn main() -> anyhow::Result<()> { } } Command::Publish { config, url, name, .. } => { - #[allow(unused_mut)] - let mut client = config.init()?; + let client = config.init()?; #[cfg(feature = "iroh")] - client.with_iroh(iroh); + let client = client.with_iroh(iroh); run_client(client, url, name, publish).await } diff --git a/rs/hang-cli/src/server.rs b/rs/hang-cli/src/server.rs index 3e3d57ecb..02ce9def7 100644 --- a/rs/hang-cli/src/server.rs +++ b/rs/hang-cli/src/server.rs @@ -43,7 +43,7 @@ async fn run_session( origin.producer.publish_broadcast(&name, consumer); // Blindly accept the session (WebTransport or QUIC), regardless of the URL. - let session = session.accept(origin.consumer, None).await?; + let session = session.with_publish(origin.consumer).accept().await?; tracing::info!(id, "accepted session"); diff --git a/rs/hang/examples/video.rs b/rs/hang/examples/video.rs index a6e554f8d..d2d9e9e2c 100644 --- a/rs/hang/examples/video.rs +++ b/rs/hang/examples/video.rs @@ -27,10 +27,10 @@ async fn run_session(origin: moq_lite::OriginConsumer) -> anyhow::Result<()> { // The "anon" path is usually configured to bypass authentication; be careful! let url = url::Url::parse("https://cdn.moq.dev/anon/video-example").unwrap(); - // Establish a WebTransport/QUIC connection and MoQ handshake. - // None means we're not consuming anything from the session, otherwise we would provide an OriginProducer. - // Optional: Use connect_with_fallback if you also want to support WebSocket. - let session = client.connect(url, origin, None).await?; + // Establish a WebTransport/QUIC connection and MoQ handshake for publishing. + // with_publish() registers an OriginConsumer for outgoing data. + // Use with_consume() if you also want to subscribe/consume from the session. + let session = client.with_publish(origin).connect(url).await?; // Wait until the session is closed. session.closed().await.map_err(Into::into) diff --git a/rs/libmoq/src/session.rs b/rs/libmoq/src/session.rs index 223054fca..4e3ff477b 100644 --- a/rs/libmoq/src/session.rs +++ b/rs/libmoq/src/session.rs @@ -48,10 +48,14 @@ impl Session { let client = moq_native::ClientConfig::default() .init() .map_err(|err| Error::Connect(Arc::new(err)))?; + let session = client - .connect(url, publish, consume) + .with_publish(publish) + .with_consume(consume) + .connect(url) .await .map_err(|err| Error::Connect(Arc::new(err)))?; + callback.call(()); session.closed().await?; diff --git a/rs/moq-clock/src/main.rs b/rs/moq-clock/src/main.rs index f1a2ce50a..a1c14eb61 100644 --- a/rs/moq-clock/src/main.rs +++ b/rs/moq-clock/src/main.rs @@ -68,9 +68,7 @@ async fn main() -> anyhow::Result<()> { origin.producer.publish_broadcast(&config.broadcast, broadcast.consumer); - let session = client - .connect_with_fallback(config.url, Some(origin.consumer), None) - .await?; + let session = client.with_publish(origin.consumer).connect(config.url).await?; tokio::select! { res = session.closed() => res.context("session closed"), @@ -78,7 +76,7 @@ async fn main() -> anyhow::Result<()> { } } Command::Subscribe => { - let session = client.connect_with_fallback(config.url, None, origin.producer).await?; + let session = client.with_consume(origin.producer).connect(config.url).await?; // NOTE: We could just call `session.consume_broadcast(&config.broadcast)` instead, // However that won't work with IETF MoQ and the current OriginConsumer API the moment. diff --git a/rs/moq-lite/src/client.rs b/rs/moq-lite/src/client.rs new file mode 100644 index 000000000..d38b126b6 --- /dev/null +++ b/rs/moq-lite/src/client.rs @@ -0,0 +1,104 @@ +// TODO: Uncomment when observability feature is merged +// use std::sync::Arc; + +use crate::{ + Error, OriginConsumer, OriginProducer, Session, VERSIONS, + coding::{Decode, Encode, Stream}, + ietf, lite, setup, +}; + +/// A MoQ client session builder. +#[derive(Default, Clone)] +pub struct Client { + publish: Option, + consume: Option, + // TODO: Uncomment when observability feature is merged + // stats: Option>, +} + +impl Client { + pub fn new() -> Self { + Default::default() + } + + pub fn with_publish(mut self, publish: impl Into>) -> Self { + self.publish = publish.into(); + self + } + + pub fn with_consume(mut self, consume: impl Into>) -> Self { + self.consume = consume.into(); + self + } + + // TODO: Uncomment when observability feature is merged + // pub fn with_stats(mut self, stats: impl Into>>) -> Self { + // self.stats = stats.into(); + // self + // } + + /// Perform the MoQ handshake as a client negotiating the version. + pub async fn connect(&self, session: S) -> Result { + if self.publish.is_none() && self.consume.is_none() { + tracing::warn!("not publishing or consuming anything"); + } + + let mut stream = Stream::open(&session, setup::ServerKind::Ietf14).await?; + + let mut parameters = ietf::Parameters::default(); + parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64); + parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec()); + let parameters = parameters.encode_bytes(()); + + let client = setup::Client { + // Unfortunately, we have to pick a single draft range to support. + // moq-lite can support this handshake. + kind: setup::ClientKind::Ietf14, + versions: VERSIONS.into(), + parameters, + }; + + // TODO pretty print the parameters. + tracing::trace!(?client, "sending client setup"); + stream.writer.encode(&client).await?; + + let mut server: setup::Server = stream.reader.decode().await?; + tracing::trace!(?server, "received server setup"); + + if let Ok(version) = lite::Version::try_from(server.version) { + let stream = stream.with_version(version); + lite::start( + session.clone(), + stream, + self.publish.clone(), + self.consume.clone(), + version, + ) + .await?; + } else if let Ok(version) = ietf::Version::try_from(server.version) { + // Decode the parameters to get the initial request ID. + let parameters = ietf::Parameters::decode(&mut server.parameters, version)?; + let request_id_max = + ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0)); + + let stream = stream.with_version(version); + ietf::start( + session.clone(), + stream, + request_id_max, + true, + self.publish.clone(), + self.consume.clone(), + version, + ) + .await?; + } else { + // unreachable, but just in case + return Err(Error::Version(client.versions, [server.version].into())); + } + + tracing::debug!(version = ?server.version, "connected"); + + Ok(Session::new(session)) + } +} diff --git a/rs/moq-lite/src/lib.rs b/rs/moq-lite/src/lib.rs index 9494a1284..bed0e98bd 100644 --- a/rs/moq-lite/src/lib.rs +++ b/rs/moq-lite/src/lib.rs @@ -34,9 +34,11 @@ //! - Use [FrameProducer] and [FrameConsumer] for chunked frame writes/reads without allocating entire frames (useful for relaying). //! - Use [TrackProducer::create_group] instead of [TrackProducer::append_group] to produce groups out-of-order. +mod client; mod error; mod model; mod path; +mod server; mod session; mod setup; @@ -44,7 +46,9 @@ pub mod coding; pub mod ietf; pub mod lite; +pub use client::*; pub use error::*; pub use model::*; pub use path::*; +pub use server::*; pub use session::*; diff --git a/rs/moq-lite/src/server.rs b/rs/moq-lite/src/server.rs new file mode 100644 index 000000000..6cb77c35e --- /dev/null +++ b/rs/moq-lite/src/server.rs @@ -0,0 +1,111 @@ +// TODO: Uncomment when observability feature is merged +// use std::sync::Arc; + +use crate::{ + Error, OriginConsumer, OriginProducer, Session, VERSIONS, + coding::{Decode, Encode, Stream}, + ietf, lite, setup, +}; + +/// A MoQ server session builder. +#[derive(Default, Clone)] +pub struct Server { + publish: Option, + consume: Option, + // TODO: Uncomment when observability feature is merged + // stats: Option>, +} + +impl Server { + pub fn new() -> Self { + Default::default() + } + + pub fn with_publish(mut self, publish: impl Into>) -> Self { + self.publish = publish.into(); + self + } + + pub fn with_consume(mut self, consume: impl Into>) -> Self { + self.consume = consume.into(); + self + } + + // TODO: Uncomment when observability feature is merged + // pub fn with_stats(mut self, stats: impl Into>>) -> Self { + // self.stats = stats.into(); + // self + // } + + /// Perform the MoQ handshake as a server for the given session. + pub async fn accept(&self, session: S) -> Result { + if self.publish.is_none() && self.consume.is_none() { + tracing::warn!("not publishing or consuming anything"); + } + + // Accept with an initial version; we'll switch to the negotiated version later + let mut stream = Stream::accept(&session, ()).await?; + let mut client: setup::Client = stream.reader.decode().await?; + tracing::trace!(?client, "received client setup"); + + // Choose the version to use + let version = client + .versions + .iter() + .find(|v| VERSIONS.contains(v)) + .copied() + .ok_or_else(|| Error::Version(client.versions.clone(), VERSIONS.into()))?; + + // Only encode parameters if we're using the IETF draft because it has max_request_id + let parameters = if ietf::Version::try_from(version).is_ok() && client.kind == setup::ClientKind::Ietf14 { + let mut parameters = ietf::Parameters::default(); + parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64); + parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec()); + parameters.encode_bytes(()) + } else { + lite::Parameters::default().encode_bytes(()) + }; + + let server = setup::Server { version, parameters }; + tracing::trace!(?server, "sending server setup"); + + let mut stream = stream.with_version(client.kind.reply()); + stream.writer.encode(&server).await?; + + if let Ok(version) = lite::Version::try_from(version) { + let stream = stream.with_version(version); + lite::start( + session.clone(), + stream, + self.publish.clone(), + self.consume.clone(), + version, + ) + .await?; + } else if let Ok(version) = ietf::Version::try_from(version) { + // Decode the client's parameters to get their max request ID. + let parameters = ietf::Parameters::decode(&mut client.parameters, version)?; + let request_id_max = + ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0)); + + let stream = stream.with_version(version); + ietf::start( + session.clone(), + stream, + request_id_max, + false, + self.publish.clone(), + self.consume.clone(), + version, + ) + .await?; + } else { + // unreachable, but just in case + return Err(Error::Version(client.versions, VERSIONS.into())); + } + + tracing::debug!(?version, "connected"); + + Ok(Session::new(session)) + } +} diff --git a/rs/moq-lite/src/session.rs b/rs/moq-lite/src/session.rs index 0f1f6347b..4f15e851b 100644 --- a/rs/moq-lite/src/session.rs +++ b/rs/moq-lite/src/session.rs @@ -1,19 +1,6 @@ use std::{future::Future, pin::Pin, sync::Arc}; -use crate::{ - Error, OriginConsumer, OriginProducer, - coding::{self, Decode, Encode, Stream}, - ietf, lite, setup, -}; - -/// A MoQ transport session, wrapping a WebTransport connection. -/// -/// Created via: -/// - [`Session::connect`] for clients. -/// - [`Session::accept`] for servers. -pub struct Session { - session: Arc, -} +use crate::{Error, coding, ietf, lite}; /// The versions of MoQ that are supported by this implementation. /// @@ -27,142 +14,22 @@ pub const VERSIONS: [coding::Version; 3] = [ /// The ALPN strings for supported versions. pub const ALPNS: [&str; 2] = [lite::ALPN, ietf::ALPN]; +/// A MoQ transport session, wrapping a WebTransport connection. +/// +/// Created via: +/// - [`crate::Client::connect`] for clients. +/// - [`crate::Server::accept`] for servers. +pub struct Session { + session: Arc, +} + impl Session { - fn new(session: S) -> Self { + pub(super) fn new(session: S) -> Self { Self { session: Arc::new(session), } } - /// Perform the MoQ handshake as a client, negotiating the version. - /// - /// Publishing is performed with [OriginConsumer] and subscribing with [OriginProducer]. - /// The connection remains active until the session is closed. - pub async fn connect( - session: S, - publish: impl Into>, - subscribe: impl Into>, - ) -> Result { - let mut stream = Stream::open(&session, setup::ServerKind::Ietf14).await?; - - let mut parameters = ietf::Parameters::default(); - parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64); - parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec()); - let parameters = parameters.encode_bytes(()); - - let client = setup::Client { - // Unfortunately, we have to pick a single draft range to support. - // moq-lite can support this handshake. - kind: setup::ClientKind::Ietf14, - versions: VERSIONS.into(), - parameters, - }; - - // TODO pretty print the parameters. - tracing::trace!(?client, "sending client setup"); - stream.writer.encode(&client).await?; - - let mut server: setup::Server = stream.reader.decode().await?; - tracing::trace!(?server, "received server setup"); - - if let Ok(version) = lite::Version::try_from(server.version) { - let stream = stream.with_version(version); - lite::start(session.clone(), stream, publish.into(), subscribe.into(), version).await?; - } else if let Ok(version) = ietf::Version::try_from(server.version) { - // Decode the parameters to get the initial request ID. - let parameters = ietf::Parameters::decode(&mut server.parameters, version)?; - let request_id_max = - ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0)); - - let stream = stream.with_version(version); - ietf::start( - session.clone(), - stream, - request_id_max, - true, - publish.into(), - subscribe.into(), - version, - ) - .await?; - } else { - // unreachable, but just in case - return Err(Error::Version(client.versions, [server.version].into())); - } - - tracing::debug!(version = ?server.version, "connected"); - - Ok(Self::new(session)) - } - - /// Perform the MoQ handshake as a server. - /// - /// Publishing is performed with [OriginConsumer] and subscribing with [OriginProducer]. - /// The connection remains active until the session is closed. - pub async fn accept( - session: S, - publish: impl Into>, - subscribe: impl Into>, - ) -> Result { - // Accept with an initial version; we'll switch to the negotiated version later - let mut stream = Stream::accept(&session, ()).await?; - let client: setup::Client = stream.reader.decode().await?; - tracing::trace!(?client, "received client setup"); - - // Choose the version to use - let version = client - .versions - .iter() - .find(|v| VERSIONS.contains(v)) - .copied() - .ok_or_else(|| Error::Version(client.versions.clone(), VERSIONS.into()))?; - - // Only encode parameters if we're using the IETF draft because it has max_request_id - let parameters = if ietf::Version::try_from(version).is_ok() && client.kind == setup::ClientKind::Ietf14 { - let mut parameters = ietf::Parameters::default(); - parameters.set_varint(ietf::ParameterVarInt::MaxRequestId, u32::MAX as u64); - parameters.set_bytes(ietf::ParameterBytes::Implementation, b"moq-lite-rs".to_vec()); - parameters.encode_bytes(()) - } else { - lite::Parameters::default().encode_bytes(()) - }; - - let mut server = setup::Server { version, parameters }; - tracing::trace!(?server, "sending server setup"); - - let mut stream = stream.with_version(client.kind.reply()); - stream.writer.encode(&server).await?; - - if let Ok(version) = lite::Version::try_from(version) { - let stream = stream.with_version(version); - lite::start(session.clone(), stream, publish.into(), subscribe.into(), version).await?; - } else if let Ok(version) = ietf::Version::try_from(version) { - // Decode the parameters to get the initial request ID. - let parameters = ietf::Parameters::decode(&mut server.parameters, version)?; - let request_id_max = - ietf::RequestId(parameters.get_varint(ietf::ParameterVarInt::MaxRequestId).unwrap_or(0)); - - let stream = stream.with_version(version); - ietf::start( - session.clone(), - stream, - request_id_max, - false, - publish.into(), - subscribe.into(), - version, - ) - .await?; - } else { - // unreachable, but just in case - return Err(Error::Version(client.versions, VERSIONS.into())); - } - - tracing::debug!(?version, "connected"); - - Ok(Self::new(session)) - } - /// Close the underlying transport session. pub fn close(self, err: Error) { self.session.close(err.to_code(), err.to_string().as_ref()); diff --git a/rs/moq-native/examples/chat.rs b/rs/moq-native/examples/chat.rs index 596a75c56..b73f4d7a3 100644 --- a/rs/moq-native/examples/chat.rs +++ b/rs/moq-native/examples/chat.rs @@ -27,9 +27,7 @@ async fn run_session(origin: moq_lite::OriginConsumer) -> anyhow::Result<()> { let url = url::Url::parse("https://cdn.moq.dev/anon/chat-example").unwrap(); // Establish a WebTransport/QUIC connection and MoQ handshake. - // Optional: You could do this as two separate steps, but this is more convenient. - // Optional: Use connect_with_fallback if you also want to support WebSocket too. - let session = client.connect(url, origin, None).await?; + let session = client.with_publish(origin).connect(url).await?; // Wait until the session is closed. session.closed().await.map_err(Into::into) diff --git a/rs/moq-native/src/client.rs b/rs/moq-native/src/client.rs index c858360bf..89a91c5b9 100644 --- a/rs/moq-native/src/client.rs +++ b/rs/moq-native/src/client.rs @@ -45,6 +45,15 @@ pub struct ClientTls { #[serde(default, deny_unknown_fields)] #[non_exhaustive] pub struct ClientWebSocket { + /// Whether to enable WebSocket support. + #[arg( + id = "websocket-enabled", + long = "websocket-enabled", + env = "MOQ_CLIENT_WEBSOCKET_ENABLED", + default_value = "true" + )] + pub enabled: bool, + /// Delay in milliseconds before attempting WebSocket fallback (default: 200) /// If WebSocket won the previous race for a given server, this will be 0. #[arg( @@ -62,6 +71,7 @@ pub struct ClientWebSocket { impl Default for ClientWebSocket { fn default() -> Self { Self { + enabled: true, delay: Some(time::Duration::from_millis(200)), } } @@ -110,11 +120,13 @@ impl Default for ClientConfig { /// /// Create via [`ClientConfig::init`] or [`Client::new`]. #[derive(Clone)] +#[non_exhaustive] pub struct Client { + pub moq: moq_lite::Client, pub quic: quinn::Endpoint, pub tls: rustls::ClientConfig, pub transport: Arc, - pub websocket_delay: Option, + pub websocket: ClientWebSocket, #[cfg(feature = "iroh")] pub iroh: Option, } @@ -186,82 +198,73 @@ impl Client { quinn::Endpoint::new(endpoint_config, None, socket, runtime).context("failed to create QUIC endpoint")?; Ok(Self { + moq: moq_lite::Client::new(), quic, tls, transport, - websocket_delay: config.websocket.delay, + websocket: config.websocket, #[cfg(feature = "iroh")] iroh: None, }) } #[cfg(feature = "iroh")] - pub fn with_iroh(&mut self, iroh: Option) -> &mut Self { + pub fn with_iroh(mut self, iroh: Option) -> Self { self.iroh = iroh; self } - /// Establish a WebTransport/QUIC connection followed by a MoQ handshake. - pub async fn connect( - &self, - url: Url, - publish: impl Into>, - subscribe: impl Into>, - ) -> anyhow::Result { - #[cfg(feature = "iroh")] - if crate::iroh::is_iroh_url(&url) { - let session = self.connect_iroh(url).await?; - let session = moq_lite::Session::connect(session, publish, subscribe).await?; - return Ok(session); - } + pub fn with_publish(mut self, publish: impl Into>) -> Self { + self.moq = self.moq.with_publish(publish); + self + } - let session = self.connect_quic(url).await?; - let session = moq_lite::Session::connect(session, publish, subscribe).await?; - Ok(session) + pub fn with_consume(mut self, consume: impl Into>) -> Self { + self.moq = self.moq.with_consume(consume); + self } - /// Establish a WebTransport/QUIC connection or a WebSocket connection, whichever is available first. - /// - /// Establishes a MoQ handshake on the winning transport. - pub async fn connect_with_fallback( - &self, - url: Url, - publish: impl Into>, - subscribe: impl Into>, - ) -> anyhow::Result { + // TODO: Uncomment when observability feature is merged + // pub fn with_stats(mut self, stats: impl Into>>) -> Self { + // self.moq = self.moq.with_stats(stats); + // self + // } + + /// Establish a WebTransport/QUIC connection followed by a MoQ handshake. + pub async fn connect(&self, url: Url) -> anyhow::Result { #[cfg(feature = "iroh")] if crate::iroh::is_iroh_url(&url) { let session = self.connect_iroh(url).await?; - let session = moq_lite::Session::connect(session, publish, subscribe).await?; + let session = self.moq.connect(session).await?; return Ok(session); } // Create futures for both possible protocols let quic_url = url.clone(); let quic_handle = async { - match self.connect_quic(quic_url).await { - Ok(session) => Some(session), - Err(err) => { - tracing::warn!(%err, "QUIC connection failed"); - None - } + let res = self.connect_quic(quic_url).await; + if let Err(err) = &res { + tracing::warn!(%err, "QUIC connection failed"); } + res }; let ws_handle = async { - match self.connect_websocket(url).await { - Ok(session) => Some(session), - Err(err) => { - tracing::warn!(%err, "WebSocket connection failed"); - None - } + if !self.websocket.enabled { + return None; } + + let res = self.connect_websocket(url).await; + if let Err(err) = &res { + tracing::warn!(%err, "WebSocket connection failed"); + } + Some(res) }; // Race the connection futures Ok(tokio::select! { - Some(quic) = quic_handle => moq_lite::Session::connect(quic, publish, subscribe).await?, - Some(ws) = ws_handle => moq_lite::Session::connect(ws, publish, subscribe).await?, + Ok(quic) = quic_handle => self.moq.connect(quic).await?, + Some(Ok(ws)) = ws_handle => self.moq.connect(ws).await?, // If both attempts fail, return an error else => anyhow::bail!("failed to connect to server"), }) @@ -334,6 +337,8 @@ impl Client { } async fn connect_websocket(&self, mut url: Url) -> anyhow::Result { + anyhow::ensure!(self.websocket.enabled, "WebSocket support is disabled"); + let host = url.host_str().context("missing hostname")?.to_string(); let port = url.port().unwrap_or_else(|| match url.scheme() { "https" | "wss" | "moql" | "moqt" => 443, @@ -345,7 +350,7 @@ impl Client { // Apply a small penalty to WebSocket to improve odds for QUIC to connect first, // unless we've already had to fall back to WebSockets for this server. // TODO if let chain - match self.websocket_delay { + match self.websocket.delay { Some(delay) if !WEBSOCKET_WON.lock().unwrap().contains(&key) => { tokio::time::sleep(delay).await; tracing::debug!(%url, delay_ms = %delay.as_millis(), "QUIC not yet connected, attempting WebSocket fallback"); diff --git a/rs/moq-native/src/server.rs b/rs/moq-native/src/server.rs index d45e380d0..1c7a67c5a 100644 --- a/rs/moq-native/src/server.rs +++ b/rs/moq-native/src/server.rs @@ -95,6 +95,7 @@ impl ServerConfig { /// /// Create via [`ServerConfig::init`] or [`Server::new`]. pub struct Server { + moq: moq_lite::Server, quic: quinn::Endpoint, accept: FuturesUnordered>>, certs: Arc, @@ -171,17 +172,34 @@ impl Server { quic: quic.clone(), accept: Default::default(), certs, + moq: moq_lite::Server::new(), #[cfg(feature = "iroh")] iroh: None, }) } #[cfg(feature = "iroh")] - pub fn with_iroh(&mut self, iroh: Option) -> &mut Self { + pub fn with_iroh(mut self, iroh: Option) -> Self { self.iroh = iroh; self } + pub fn with_publish(mut self, publish: impl Into>) -> Self { + self.moq = self.moq.with_publish(publish); + self + } + + pub fn with_consume(mut self, consume: impl Into>) -> Self { + self.moq = self.moq.with_consume(consume); + self + } + + // TODO: Uncomment when observability feature is merged + // pub fn with_stats(mut self, stats: impl Into>>) -> Self { + // self.moq = self.moq.with_stats(stats); + // self + // } + #[cfg(unix)] async fn reload_certs(certs: Arc, tls_config: ServerTlsConfig) { use tokio::signal::unix::{SignalKind, signal}; @@ -229,13 +247,13 @@ impl Server { tokio::select! { res = self.quic.accept() => { let conn = res?; - self.accept.push(Self::accept_session(conn).boxed()); + self.accept.push(Self::accept_session(self.moq.clone(), conn).boxed()); } res = iroh_accept_fut => { #[cfg(feature = "iroh")] { let conn = res?; - self.accept.push(Self::accept_iroh_session(conn).boxed()); + self.accept.push(Self::accept_iroh_session(self.moq.clone(), conn).boxed()); } #[cfg(not(feature = "iroh"))] let _: () = res; @@ -257,7 +275,7 @@ impl Server { } } - async fn accept_session(conn: quinn::Incoming) -> anyhow::Result { + async fn accept_session(server: moq_lite::Server, conn: quinn::Incoming) -> anyhow::Result { let mut conn = conn.accept()?; let handshake = conn @@ -285,15 +303,21 @@ impl Server { let request = web_transport_quinn::Request::accept(conn) .await .context("failed to receive WebTransport request")?; - Ok(Request::WebTransport(request)) + Ok(Request { + server: server.clone(), + kind: RequestKind::WebTransport(request), + }) } - moq_lite::lite::ALPN | moq_lite::ietf::ALPN => Ok(Request::Quic(QuicRequest::accept(conn))), + moq_lite::lite::ALPN | moq_lite::ietf::ALPN => Ok(Request { + server: server.clone(), + kind: RequestKind::Quic(QuicRequest::accept(conn)), + }), _ => anyhow::bail!("unsupported ALPN: {alpn}"), } } #[cfg(feature = "iroh")] - async fn accept_iroh_session(conn: iroh::endpoint::Incoming) -> anyhow::Result { + async fn accept_iroh_session(server: moq_lite::Server, conn: iroh::endpoint::Incoming) -> anyhow::Result { let conn = conn.accept()?.await?; let alpn = String::from_utf8(conn.alpn().to_vec()).context("failed to decode ALPN")?; tracing::Span::current().record("id", conn.stable_id()); @@ -303,11 +327,17 @@ impl Server { let request = web_transport_iroh::H3Request::accept(conn) .await .context("failed to receive WebTransport request")?; - Ok(Request::IrohWebTransport(request)) + Ok(Request { + server: server.clone(), + kind: RequestKind::IrohWebTransport(request), + }) } moq_lite::lite::ALPN | moq_lite::ietf::ALPN => { let request = IrohQuicRequest::accept(conn); - Ok(Request::IrohQuic(request)) + Ok(Request { + server: server.clone(), + kind: RequestKind::IrohQuic(request), + }) } _ => Err(anyhow::anyhow!("unsupported ALPN: {alpn}")), } @@ -328,7 +358,7 @@ impl Server { } /// An incoming connection that can be accepted or rejected. -pub enum Request { +enum RequestKind { WebTransport(web_transport_quinn::Request), Quic(QuicRequest), #[cfg(feature = "iroh")] @@ -337,43 +367,60 @@ pub enum Request { IrohQuic(IrohQuicRequest), } +pub struct Request { + server: moq_lite::Server, + kind: RequestKind, +} + impl Request { /// Reject the session, returning your favorite HTTP status code. pub async fn reject(self, status: http::StatusCode) -> anyhow::Result<()> { - match self { - Self::WebTransport(request) => request.close(status).await?, - Self::Quic(request) => request.close(status), + match self.kind { + RequestKind::WebTransport(request) => request.close(status).await?, + RequestKind::Quic(request) => request.close(status), #[cfg(feature = "iroh")] - Request::IrohWebTransport(request) => request.close(status).await?, + RequestKind::IrohWebTransport(request) => request.close(status).await?, #[cfg(feature = "iroh")] - Request::IrohQuic(request) => request.close(status), + RequestKind::IrohQuic(request) => request.close(status), } Ok(()) } + pub fn with_publish(mut self, publish: impl Into>) -> Self { + self.server = self.server.with_publish(publish); + self + } + + pub fn with_consume(mut self, consume: impl Into>) -> Self { + self.server = self.server.with_consume(consume); + self + } + + // TODO: Uncomment when observability feature is merged + // pub fn with_stats(mut self, stats: impl Into>>) -> Self { + // self.server = self.server.with_stats(stats); + // self + // } + /// Accept the session, performing rest of the MoQ handshake. - pub async fn accept( - self, - publish: impl Into>, - subscribe: impl Into>, - ) -> anyhow::Result { - let session = match self { - Request::WebTransport(request) => Session::accept(request.ok().await?, publish, subscribe).await?, - Request::Quic(request) => Session::accept(request.ok(), publish, subscribe).await?, + pub async fn accept(self) -> anyhow::Result { + let session = match self.kind { + RequestKind::WebTransport(request) => self.server.accept(request.ok().await?).await?, + RequestKind::Quic(request) => self.server.accept(request.ok()).await?, #[cfg(feature = "iroh")] - Request::IrohWebTransport(request) => Session::accept(request.ok().await?, publish, subscribe).await?, + RequestKind::IrohWebTransport(request) => self.server.accept(request.ok().await?).await?, #[cfg(feature = "iroh")] - Request::IrohQuic(request) => Session::accept(request.ok(), publish, subscribe).await?, + RequestKind::IrohQuic(request) => self.server.accept(request.ok()).await?, }; Ok(session) } /// Returns the URL provided by the client. pub fn url(&self) -> Option<&Url> { - match self { - Request::WebTransport(request) => Some(request.url()), + match &self.kind { + RequestKind::WebTransport(request) => Some(request.url()), #[cfg(feature = "iroh")] - Request::IrohWebTransport(request) => Some(request.url()), + RequestKind::IrohWebTransport(request) => Some(request.url()), _ => None, } } diff --git a/rs/moq-relay/src/cluster.rs b/rs/moq-relay/src/cluster.rs index 4637a75c8..09650902a 100644 --- a/rs/moq-relay/src/cluster.rs +++ b/rs/moq-relay/src/cluster.rs @@ -262,13 +262,12 @@ impl Cluster { async fn run_remote_once(&mut self, url: &Url) -> anyhow::Result<()> { tracing::info!(%url, "connecting to remote"); - // Connect to the remote node. - let publish = Some(self.primary.consumer.consume()); - let subscribe = Some(self.secondary.producer.clone()); - let session = self .client - .connect(url.clone(), publish, subscribe) + .clone() + .with_publish(self.primary.consumer.consume()) + .with_consume(self.secondary.producer.clone()) + .connect(url.clone()) .await .context("failed to connect to remote")?; diff --git a/rs/moq-relay/src/connection.rs b/rs/moq-relay/src/connection.rs index 12d1c4c0b..f73ce8d6d 100644 --- a/rs/moq-relay/src/connection.rs +++ b/rs/moq-relay/src/connection.rs @@ -50,7 +50,14 @@ impl Connection { // NOTE: subscribe and publish seem backwards because of how relays work. // We publish the tracks the client is allowed to subscribe to. // We subscribe to the tracks the client is allowed to publish. - let session = self.request.accept(subscribe, publish).await?; + let session = self + .request + .with_publish(subscribe) + .with_consume(publish) + // TODO: Uncomment when observability feature is merged + // .with_stats(stats) + .accept() + .await?; // Wait until the session is closed. session.closed().await.map_err(Into::into) diff --git a/rs/moq-relay/src/main.rs b/rs/moq-relay/src/main.rs index b188ce984..6ab08a3cf 100644 --- a/rs/moq-relay/src/main.rs +++ b/rs/moq-relay/src/main.rs @@ -31,17 +31,14 @@ async fn main() -> anyhow::Result<()> { let config = Config::load()?; let addr = config.server.bind.unwrap_or("[::]:443".parse().unwrap()); - let mut server = config.server.init()?; - - #[allow(unused_mut)] - let mut client = config.client.init()?; + let server = config.server.init()?; + let client = config.client.init()?; #[cfg(feature = "iroh")] - { + let (mut server, client) = { let iroh = config.iroh.bind().await?; - server.with_iroh(iroh.clone()); - client.with_iroh(iroh); - } + (server.with_iroh(iroh.clone()), client.with_iroh(iroh)) + }; let auth = config.auth.init().await?; diff --git a/rs/moq-relay/src/web.rs b/rs/moq-relay/src/web.rs index 525a79baf..620fa9bbb 100644 --- a/rs/moq-relay/src/web.rs +++ b/rs/moq-relay/src/web.rs @@ -216,7 +216,13 @@ where { // Wrap the WebSocket in a WebTransport compatibility layer. let ws = web_transport_ws::Session::new(socket, true); - let session = moq_lite::Session::accept(ws, subscribe, publish).await?; + let session = moq_lite::Server::new() + .with_publish(subscribe) + .with_consume(publish) + // TODO: Uncomment when observability feature is merged + // .with_stats(stats) + .accept(ws) + .await?; session.closed().await.map_err(Into::into) }